SRGAN 图像超分辨率重建(Keras)

2023-11-02


前言

SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络。训练完网络后。只用生成器来重建低分辨率图像。网络结构主要使用生成器(Generator)和判别器(Discriminator)。训练过程不太稳定。一般用于卫星图像,遥感图像的图像重建,人脸图像超分重建。
这里我们使用的高分辨率的数据集 (DIV2K)
数据集下载链接:链接:https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg 提取码:luly
github代码地址:https://github.com/jiantenggei/srgan
重制版代码仓库:https://github.com/jiantenggei/Srgan_

一、SRGAN

1.训练步骤

SRGAN 网络的训练思路如下图所示:
在这里插入图片描述

训练步骤如下:
(1) 将低分辨率输入到生成网络,生成高分辨率图像。
(2) 将高分辨率图像输入的判别网络判别真假,与0和1进行对比
(3) 将原始高分辨率图像和生成的高分辨率图像分别用VGG19 的前9层提取特征,将提取的特征计算loss。
(4). 将loss返回给生成器继续训练。
这就是SRGAN 的训练流程了。
接下来我们一一去实现上述步骤。

2.生成器

生成器网络结构如下图所示:
在这里插入图片描述
生成器主要有两部分构成,第一部分是residual block 残差块(图中红色方块),第二部分是上采样部分(图中蓝色方块)用来上采样特征图。
残差块:包含一个两个3x3的卷积 BN,PReLu
上采样:使用UpSampling2D,这里可能与原模型不同实现
生成器代码如下所示:

# 生成器中的残差块
def res_block_gen(x, kernal_size, filters, strides):
    
    gen = x
    
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
    # Using Parametric ReLU
    x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
        
    x = add([gen, x])
    
    return x

#上采样样块
def up_sampling_block(x, kernal_size, filters, strides):
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = UpSampling2D(size = 2)(x)
    x = LeakyReLU(alpha = 0.2)(x)
    
    return x
#--------------------------------------
# 亚像素卷积上采样块
# 生成器 还是用的 UpSampling2D
# 如果有需要可以自己更改
# -------------------------------------
def SubpixelConv2D(input_shape, scale=4):
    def subpixel_shape(input_shape):
        dims = [input_shape[0],input_shape[1] * scale,input_shape[2] * scale,int(input_shape[3] / (scale ** 2))]
        output_shape = tuple(dims)
        return output_shape
    
    def subpixel(x):
        return tf.compat.v1.depth_to_space(x, scale)
        
    return Lambda(subpixel, output_shape=subpixel_shape)
    
def Generator(input_shape=[128,128,3]):
    gen_input = Input(input_shape)
    x = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
    x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
	    
    gen_x = x
        
    # 16 个残差快
    for index in range(16):
            x = res_block_gen(x, 3, 64, 1)
	    
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
    x = add([gen_x, x])
	    
	#两个上采样 -> 放大四倍
    for index in range(2):
        x = up_sampling_block(x, 3, 256, 1)
	    
    x = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(x)
    x = Activation('tanh')(x)
	   
    generator_x = Model(inputs = gen_input, outputs = x)
        
    return generator_x

3.判别器

判别器主要用于判断生成图片的真假。与0和1比较,1代表真图片,0代表假图片。这里的0和1 是与判别器输出大小想用的向量,而不是单纯的0,1,判别器网络结果如下所示:
在这里插入图片描述

判别网络由一个个包含卷积、BN、和LeakyRelu 激活函数的块组成,最后输出1或0, 实际上就相当于是一个二分类的分类网络,代码如下所示:

#判别器中的卷积块
def discriminator_block(x, filters, kernel_size, strides):
    
    x = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
    x = LeakyReLU(alpha = 0.2)(x)
    
    return x
    
def Discriminator(image_shape=[512,512,3]):
        
        dis_input = Input(image_shape)
        
        x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(dis_input)
        x = LeakyReLU(alpha = 0.2)(x)
        
        x = discriminator_block(x, 64, 3, 2)
        x = discriminator_block(x, 128, 3, 1)
        x = discriminator_block(x, 128, 3, 2)
        x = discriminator_block(x, 256, 3, 1)
        x = discriminator_block(x, 256, 3, 2)
        x = discriminator_block(x, 512, 3, 1)
        x = discriminator_block(x, 512, 3, 2)
        
        #x = Flatten()(x) # 这里采用Flatten 太浪费现存了 改为 全局池化
        x = GlobalAveragePooling2D()(x)
        x = Dense(1024)(x)
        x = LeakyReLU(alpha = 0.2)(x)
       
        x = Dense(1)(x)
        x = Activation('sigmoid')(x) 
        
        discriminator_x = Model(inputs = dis_input, outputs = x)
        
        return discriminator_x

网络主要分为生成器和判别器,训练时相互对抗,以达到一个很好的平衡为目的。

二、其他准备

1.数据读取

训练时,输入的高分辨率图像一般为很大的图片。需要将其随机裁剪为预设的大小。再将裁剪的图像,下采样作为低分辨率图像。代码过长,不全部贴出来了。

class SRganDataset(keras.utils.Sequence):
    def __init__(self, train_lines, lr_shape, hr_shape, batch_size):
        super(SRganDataset, self).__init__()

        self.train_lines    = train_lines
        self.train_batches  = len(train_lines)

        self.lr_shape       = lr_shape
        self.hr_shape       = hr_shape

        self.batch_size     = batch_size

    def __len__(self):
        return math.ceil(self.train_batches / float(self.batch_size))

    def __getitem__(self, index):
        if index == 0:
            self.on_epoch_begin()

        images_l = []
        images_h = []
        for i in range(index * self.batch_size, (index + 1) * self.batch_size):  
            i = i % self.train_batches

            image_origin = Image.open(self.train_lines[i].split()[0])
            if self.rand()<.5:
                img_h = self.get_random_data(image_origin, self.hr_shape)
            else:
                img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])
            img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)

            img_h = preprocess_input(np.array(img_h, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5])
            img_l = preprocess_input(np.array(img_l, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5])

            images_h.append(img_h)
            images_l.append(img_l)
        return np.array(images_l), np.array(images_h)

    def on_epoch_begin(self):
        shuffle(self.train_lines)

    def rand(self, a=0, b=1):
        return np.random.rand()*(b-a) + a

2.VGG19提取特征

VGG19提取生成高分辨率的图像特征与真实高分辨率图像特征进行比较。计算LOSS。

class VGG_LOSS(object):

    def __init__(self, image_shape):
        
        self.image_shape = image_shape

    # 用VGG19 计算 高清图和生成的高清图之间的差别
    def vgg_loss(self, y_true, y_pred):
    
        vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=self.image_shape)
        vgg19.trainable = False
        # Make trainable as False
        for l in vgg19.layers:
            l.trainable = False
        model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)
        model.trainable = False
    
        return K.mean(K.square(model(y_true) - model(y_pred)))

4.训练过程代码

def train(epochs, batch_size, model_save_dir):

    train_annotation_path = 'dataset.txt'
    #下采样倍数
    downscale_factor = 4

    #输入图片形状
    hr_shape = (384,384,3)
    #加载数据集
    with open(train_annotation_path, encoding='utf-8') as f:
         train_lines = f.readlines()
    #计算 生成图片 和 原高清图 之间的loss
    loss = VGG_LOSS(hr_shape) 
    #打乱 
    random.shuffle(train_lines)
    batch_count = int(len(train_lines)/ batch_size)
    lr_shape = (hr_shape[0]//downscale_factor, hr_shape[1]//downscale_factor, hr_shape[2])
    
    generator = Generator(lr_shape)
    discriminator = Discriminator(hr_shape)

    optimizer =tf.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)
    gen                 = SRganDataset(train_lines, lr_shape[:2], hr_shape[:2], batch_size)
    gan = get_gan(discriminator, lr_shape, generator, optimizer,loss.vgg_loss)
    loss_file = open(model_save_dir + 'losses.txt' , 'w+')
    loss_file.close()
    
    for epoch in range(0,epochs):
        print ('-'*15, 'Epoch %d' % epoch, '-'*15)
        with tqdm(total=batch_count,desc=f'Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:
            for iteration, batch in enumerate(gen):
                if iteration >= batch_count:
                    break
                imgs_lr, imgs_hr        = batch
                #生成器生成图片
                gen_img = generator.predict(imgs_lr)

                real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
                fake_data_Y = np.random.random_sample(batch_size)*0.2
                
                discriminator.trainable = True
                
                d_loss_real = discriminator.train_on_batch(imgs_hr, real_data_Y)
                d_loss_fake = discriminator.train_on_batch(gen_img, fake_data_Y)
                discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            

                gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
                discriminator.trainable = False
                gan_loss = gan.train_on_batch(imgs_lr, [imgs_hr,gan_Y])
                pbar.set_postfix(**{'G_loss'        : gan_loss[0] , 
                                    'D_loss'        : discriminator_loss,
                                    'PSNR'          : gan_loss[4]
                                    },)
                pbar.update(1)  
            print("discriminator_loss : %f" % discriminator_loss)
            print("gan_loss :", gan_loss)
            gan_loss = str(gan_loss)
            
            loss_file = open(model_save_dir + 'losses.txt' , 'a')
            loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(epoch, gan_loss, discriminator_loss) )
            loss_file.close()

            
            show_result(epoch,generator,imgs_lr,imgs_hr)
            
            generator.save(model_save_dir + 'gen_model%d.h5' % epoch)
            discriminator.save(model_save_dir + 'dis_model%d.h5' % epoch)

训练时,在目录result 目录下会出现这样的图片。
在这里插入图片描述
lr_images : 低分辨率图
Fake_Hr_Images:生成的高分辨率图像
True_Hr_Images:远高分图像

5. 预测过程

预测部分代码:



from pickle import NONE
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from nets.nets import Generator
before_image = Image.open(r"0.jpg")

before_image = before_image.convert("RGB")
gen_model = Generator([None,None,3])
gen_model.load_weights('loss\gen_model99.h5')
# gen_model.summary()
new_img = Image.new('RGB', before_image.size, (128, 128, 128))
new_img.paste(before_image)
# plt.imshow(new_img)
# plt.show()

new_image = np.array(new_img)/127.5 - 1
# 三维变4维  因为神经网络的输入是四维的
new_image = np.expand_dims(new_image, axis=0)  # [batch_size,w,h,c]
fake = (gen_model.predict(new_image)*0.5 + 0.5)*255
#将np array 形式的图片转换为unit8  把数据转换为图
fake = Image.fromarray(np.uint8(fake[0]))

fake.save("out.png")
titles = ['Generated', 'Original']
plt.subplot(1, 2, 1)
plt.imshow(before_image)
plt.subplot(1, 2, 2)
plt.imshow(fake)
plt.show()

重建效果:
在这里插入图片描述

参考链接

https://github.com/bubbliiiing/srgan-keras
https://github.com/deepak112/Keras-SRGAN
https://github.com/JustinhoCHN/SRGAN_Wasserstein
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

SRGAN 图像超分辨率重建(Keras) 的相关文章

随机推荐

  • Winsock状态说明及错误代码

    Winsock状态参数说明 常数 值 描述 sckClosed 0 缺省值 关闭 SckOpen 1 打开 SckListening 2 侦听 sckConnectionPending 3 连接挂起 sckResolvingHost 4 识
  • “定点打击”——XPath 使用细则(Just For Selenium WebDriver)

    该系列文章系个人读书笔记及总结性内容 任何组织和个人不得转载进行商业活动 Selenium WebDriver中有关元素定位的学习 需要XPath的支持 特此梳理 前言 XPath教程 XPath是一门在XML文档中查找信息的语言 XPat
  • 基于线性预测的语音编码原理解析

    早期的音频系统都是基于声音的模拟信号实现的 在声音的录制 编辑和播放过程中很容易引入各种噪声 从而导致信号的失真 随着信息技术的发展 数字信号处理技术在越来越多领域得到了应用 数字信号更是具备了易于存储和远距离传输 没有累积失真 抗干扰能力
  • SpringSecurity——OAuth2框架鉴权实现源码分析

    SpringSecurity OAuth2框架鉴权实现源码分析 一 ManagedFilter迭代过滤器链 1 4 springSecurityFilterChain 1 4 7 OAuth2AuthenticationProcessing
  • 【Redis】Redis在Windows下的使用(hiredis+Qt5.7.0+mingw5.3.0)

    1 下载hiredis https github com redis hiredis 得到hiredis master zip 解压后 得到hiredis master目录 可以看到CMakeLists txt 2 下载CMake http
  • Oracle12C 用户创建、修改、授权、删除、登录等操作

    1 以系统用户命令行登录 sqlplus sys sys as sysdba 2 确认选择CDB select name cdb from v database col pdb name for a30 select pdb id pdb
  • matlab绘图(三)绘制三维图像

    目录 一 绘制三维曲线 二 绘制三维曲面 1 meshgrid函数 2 mesh和surf函数 一 绘制三维曲线 1 最基本的绘制三维曲线的函数 plot3 plot3 x1 y1 z1 选项 1 x2 y2 z2 选项 2 xn yn z
  • 样本方差的分母为何为n-1而不是n之无偏估计

    设样本均值为 样本方差为 总体均值为 总体方差为 那么样本方差有如下公式 很多人可能都会有疑问 为什么要除以n 1 而不是n 但是翻阅资料 发现很多都是交代到 如果除以n 对样本方差的估计不是无偏估计 比总体方差要小 要想是无偏估计就要调小
  • 撸一撸Spring Cloud Ribbon的原理

    说起负载均衡一般都会想到服务端的负载均衡 常用产品包括LBS硬件或云服务 Nginx等 都是耳熟能详的产品 而Spring Cloud提供了让服务调用端具备负载均衡能力的Ribbon 通过和Eureka的紧密结合 不用在服务集群内再架设负载
  • html alert字体颜色,js里alert里的字体颜色怎么设置:字体颜色方法;fontcolor(color)...

    我的总结 alert应该是没办法改变的 只有自己写个弹出窗口才可以改变字体颜色 我的总结 alert应该是没办法改变的 只有自己写个弹出窗口才可以改变字体颜色 alert 投票总数不大于 不知道怎么改变字体所以查了下 找到下面的信息 好东西
  • 超强干货,Pytest自动化测试框架fixture固件使用,0-1精通实战

    前言 如果有以下场景 用例 1 需要先登录 用例 2 不需要登录 用例 3 需要先登录 很显然无法用 setup 和 teardown 来实现了 fixture 可以让我们自定义测试用例的前置条件 fixture 的优势 命名方式灵活 不局
  • std::thread的常用参数传递总结

    实参的生命周期 给std thread传递参数的时候要注意 参数是引用或者指针的情况下 要注意生命周期的问题 看代码 include
  • c++职工管理系统

    要求 代码 management system cpp main函数 include
  • hp服务器系统管理软件,惠普raid管理软件

    有HP服务器的RAID查看管理软件吗 1 创建分区 fdisk dev hda 创建两个要用做实验的新分区 比如hda5 hda6 2 cp usr share doc raidtools 1 0 0 2 raid0 conf sample
  • ASP.NET CORE 6.0实现SignalR

    当使用ASP NET Core Web API与SignalR结合 可以实现实时通信功能 以下是一些详细步骤 步骤一 创建ASP NET Core Web API项目 1 打开Visual Studio 选择创建新项目 2 选择 ASP N
  • C语言-一维数组、二维数组 (一篇文章带你彻底读懂!!!)

    数组简介 数组是相同数据类型的元素的集合 数组中的各元素的存储是有先后顺序的 它们在内存中按照这个先后顺序连续存放在一起 数组元素用整个数组的名字和它自己在数组中的顺序位置来表示 数组的特性 查询快 增删慢 为什么查询快 因为数组存在下标这
  • python 数据分析与挖局书籍

    之前一直有朋友叫我列一个数据科学的书单 说实话这件事情我是犹豫了很久的 有两个原因 其一是因为自己读书太少才疏学浅 其二我觉得基于我个人观点认为 好 的书其实可能对于很多人是不一定合适的 不过 明天正好是世界读书日 所以这里从一个 在读的统
  • 【LeetCode算法系列题解】第51~55题

    CONTENTS LeetCode 51 N 皇后 困难 LeetCode 52 N 皇后 II 困难 LeetCode 53 最大子序和 中等 LeetCode 54 螺旋矩阵 中等 LeetCode 55 跳跃游戏 中等 LeetCod
  • vue不同组件中监听Localstorage变化并实时更新---ing

    vue 不同组件中监听Localstorage变化并实时更新 一 解决问题 二 实现实例 1 控制组件 触发后向浏览器保存数据 2 main js全局引入 3 在所需要监听localstorage值的组件中写入监听事件 三 参考 一 解决问
  • SRGAN 图像超分辨率重建(Keras)

    文章目录 前言 一 SRGAN 1 训练步骤 2 生成器 3 判别器 二 其他准备 1 数据读取 2 VGG19提取特征 4 训练过程代码 5 预测过程 参考链接 前言 SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络 训练完网络