好像还挺好玩的GAN5——Keras搭建COGAN耦合生成式对抗网络

2023-11-04

好像还挺好玩的GAN5——Keras搭建COGAN耦合生成式对抗网络

学习前言

发现一个挺有意思的结构,可以通过同一个输入,生成不同的内容。
在这里插入图片描述

什么是COGAN

COGAN是一种耦合生成式对抗网络,其内部具有一定的耦合,可以对同一个输入有不同的输出。

其具体实现方式就是:
1、建立两个生成模型,两个判别模型。
2、两个生成模型的特征提取部分有一定的重合,在最后生成图片的部分分开,以生成不同类型的图片。
3、两个判别模型的特征提取部分有一定的重合,在最后判别真伪的部分分开,以判别不同类型的图片。

神经网络构建

1、Generator

生成模型的输入是一个N维度的符合正太分布的随机数,输出是一个28,28,1的Mnist手写体。
一共存在两个生成模型,两个生成模型的特征提取部分有一定的重合,在最后生成图片的部分分开,以生成不同类型的图片。

即:权值部分有一定的共享。

def build_generators(self):
    # 共享权值部分
    noise = Input(shape=(self.latent_dim,))
    x = Dense(32 * 7 * 7, activation="relu", input_dim=self.latent_dim)(noise)
    x = Reshape((7, 7, 32))(x)
    
    x = Conv2D(64, kernel_size=3, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)

    x = UpSampling2D()(x)
    x = Conv2D(128, kernel_size=3, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)

    x = UpSampling2D()(x)
    x = Conv2D(128, kernel_size=3, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    feature_repr = Activation("relu")(x)
    model = Model(noise,feature_repr)

    noise = Input(shape=(self.latent_dim,))
    feature_repr = model(noise)
    # 生成模型1
    g1 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
    g1 = BatchNormalization(momentum=0.8)(g1)
    g1 = Activation("relu")(g1)

    g1 = Conv2D(64, kernel_size=3, padding="same")(g1)
    g1 = BatchNormalization(momentum=0.8)(g1)
    g1 = Activation("relu")(g1)

    g1 = Conv2D(64, kernel_size=1, padding="same")(g1)
    g1 = BatchNormalization(momentum=0.8)(g1)
    g1 = Activation("relu")(g1)

    g1 = Conv2D(self.channels, kernel_size=1, padding="same")(g1)
    img1 = Activation("tanh")(g1)

    # 生成模型2
    g2 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
    g2 = BatchNormalization(momentum=0.8)(g2)
    g2 = Activation("relu")(g2)
    
    g2 = Conv2D(64, kernel_size=3, padding="same")(g2)
    g2 = BatchNormalization(momentum=0.8)(g2)
    g2 = Activation("relu")(g2)

    g2 = Conv2D(64, kernel_size=1, padding="same")(g2)
    g2 = BatchNormalization(momentum=0.8)(g2)
    g2 = Activation("relu")(g2)

    g2 = Conv2D(self.channels, kernel_size=1, padding="same")(g2)
    img2 = Activation("tanh")(g2)

    return Model(noise, img1), Model(noise, img2)

2、Discriminator

判别模型的输入一个28,28,1维的图片,输出一个是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。

一共存在两个判别模型,两个生成模型的特征提取部分有一定的重合,在最后判别真伪的部分分开,以判别不同类型的图片。

def build_discriminators(self):

    # 共享权值部分
    img = Input(shape=self.img_shape)
    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)

    x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)

    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = GlobalAveragePooling2D()(x)
    feature_repr = Activation("relu")(x)

    model = Model(img,feature_repr)

    img1 = Input(shape=self.img_shape)
    img2 = Input(shape=self.img_shape)
    img1_embedding = model(img1)
    img2_embedding = model(img2)
    # 生成评价模型1
    validity1 = Dense(1, activation='sigmoid')(img1_embedding)
    # 生成评价模型2
    validity2 = Dense(1, activation='sigmoid')(img2_embedding)

    return Model(img1, validity1), Model(img2, validity2)

训练思路

COGAN的训练思路分为如下几个步骤:
1、创建两个风格不同的数据集。
2、随机生成batch_size个N维向量,利用两个不同的生成模型生成图片。
3、利用两个判别模型分别对两个不同的生成模型的生成图片进行判别、对两个风格不同的数据集进行随机选取并进行判别。
4、根据两个判别模型的结果与1对比,对两个生成模型进行训练。

实现全部代码

from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, GlobalAveragePooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import numpy as np

class COGAN():
    def __init__(self):
        # 输入shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        # 分十类
        self.num_classes = 10
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)
        # 生成两个判别器
        self.d1, self.d2 = self.build_discriminators()
        self.d1.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d2.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
        # 建立生成器
        self.g1, self.g2 = self.build_generators()

        z = Input(shape=(self.latent_dim,))
        img1 = self.g1(z)
        img2 = self.g2(z)

        self.d1.trainable = False
        self.d2.trainable = False

        valid1 = self.d1(img1)
        valid2 = self.d2(img2)

        self.combined = Model(z, [valid1, valid2])
        self.combined.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
                                    optimizer=optimizer)

    def build_generators(self):
        # 共享权值部分
        noise = Input(shape=(self.latent_dim,))
        x = Dense(32 * 7 * 7, activation="relu", input_dim=self.latent_dim)(noise)
        x = Reshape((7, 7, 32))(x)
        
        x = Conv2D(64, kernel_size=3, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = Activation("relu")(x)

        x = UpSampling2D()(x)
        x = Conv2D(128, kernel_size=3, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = Activation("relu")(x)

        x = UpSampling2D()(x)
        x = Conv2D(128, kernel_size=3, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        feature_repr = Activation("relu")(x)
        model = Model(noise,feature_repr)

        noise = Input(shape=(self.latent_dim,))
        feature_repr = model(noise)
        # 生成模型1
        g1 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
        g1 = BatchNormalization(momentum=0.8)(g1)
        g1 = Activation("relu")(g1)

        g1 = Conv2D(64, kernel_size=3, padding="same")(g1)
        g1 = BatchNormalization(momentum=0.8)(g1)
        g1 = Activation("relu")(g1)

        g1 = Conv2D(64, kernel_size=1, padding="same")(g1)
        g1 = BatchNormalization(momentum=0.8)(g1)
        g1 = Activation("relu")(g1)

        g1 = Conv2D(self.channels, kernel_size=1, padding="same")(g1)
        img1 = Activation("tanh")(g1)

        # 生成模型2
        g2 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)
        g2 = BatchNormalization(momentum=0.8)(g2)
        g2 = Activation("relu")(g2)
        
        g2 = Conv2D(64, kernel_size=3, padding="same")(g2)
        g2 = BatchNormalization(momentum=0.8)(g2)
        g2 = Activation("relu")(g2)

        g2 = Conv2D(64, kernel_size=1, padding="same")(g2)
        g2 = BatchNormalization(momentum=0.8)(g2)
        g2 = Activation("relu")(g2)

        g2 = Conv2D(self.channels, kernel_size=1, padding="same")(g2)
        img2 = Activation("tanh")(g2)

        return Model(noise, img1), Model(noise, img2)


    def build_discriminators(self):

        # 共享权值部分
        img = Input(shape=self.img_shape)
        x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)
        x = BatchNormalization(momentum=0.8)(x)
        x = Activation("relu")(x)

        x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = Activation("relu")(x)

        x = Conv2D(64, kernel_size=3, strides=2, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = GlobalAveragePooling2D()(x)
        feature_repr = Activation("relu")(x)

        model = Model(img,feature_repr)

        img1 = Input(shape=self.img_shape)
        img2 = Input(shape=self.img_shape)
        img1_embedding = model(img1)
        img2_embedding = model(img2)
        # 生成评价模型1
        validity1 = Dense(1, activation='sigmoid')(img1_embedding)
        # 生成评价模型2
        validity2 = Dense(1, activation='sigmoid')(img2_embedding)

        return Model(img1, validity1), Model(img2, validity2)

    def train(self, epochs, batch_size=128, sample_interval=50):
        (X_train, _), (_, _) = mnist.load_data()

        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        X1 = X_train[:int(X_train.shape[0]/2)]
        X2 = X_train[int(X_train.shape[0]/2):]
        X2 = scipy.ndimage.interpolation.rotate(X2, 90, axes=(1, 2))

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------- #
            #  训练评价者
            # ---------------------- #

            idx = np.random.randint(0, X1.shape[0], batch_size)
            imgs1 = X1[idx]
            imgs2 = X2[idx]

            noise = np.random.normal(0, 1, (batch_size, 100))

            gen_imgs1 = self.g1.predict(noise)
            gen_imgs2 = self.g2.predict(noise)

            d1_loss_real = self.d1.train_on_batch(imgs1, valid)
            d2_loss_real = self.d2.train_on_batch(imgs2, valid)
            d1_loss_fake = self.d1.train_on_batch(gen_imgs1, fake)
            d2_loss_fake = self.d2.train_on_batch(gen_imgs2, fake)
            d1_loss = 0.5 * np.add(d1_loss_real, d1_loss_fake)
            d2_loss = 0.5 * np.add(d2_loss_real, d2_loss_fake)


            # ------------------ #
            #  训练生成模型
            # ------------------ #

            g_loss = self.combined.train_on_batch(noise, [valid, valid])

            print ("%d [D1 loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" \
                % (epoch, d1_loss[0], 100*d1_loss[1], d2_loss[0], 100*d2_loss[1], g_loss[0]))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 4, 4
        noise = np.random.normal(0, 1, (r * int(c/2), 100))
        gen_imgs1 = self.g1.predict(noise)
        gen_imgs2 = self.g2.predict(noise)
        gen_imgs = np.concatenate([gen_imgs1, gen_imgs2])

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = COGAN()
    gan.train(epochs=30000, batch_size=256, sample_interval=200)

实现效果为:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

好像还挺好玩的GAN5——Keras搭建COGAN耦合生成式对抗网络 的相关文章

  • QT的UI界面效果预览快捷键

    QtCreator的界面预览 Shift Alt R 运行快捷键 Ctrl R 只构建快捷键 Ctrl B 怕忘了 自己记录一下
  • 西门子S7-200PLC温度程序 最近看到同行发的各种帮做程序

    西门子S7 200PLC温度程序 最近看到同行发的各种帮做程序 留言里总有初学者求温度 PID程序 所以我发出来给需要的人学习 适合初学者的模拟量温度处理程序 包含一份CAD图纸 一份PDF图纸 一套PLC程序 id 56733527955
  • 实训一 思科交换机基础配置

    交换机的命令行操作模式 用户模式 特权模式 全局配置模式 端口模式 1 命令行操作模式的进入 用户模式switch gt 输入enable 进入特权模式switch 输入configure terminal 进入全局配置模式switch c

随机推荐

  • 双节点文件服务器,在 Windows Server 2008 中配置双节点文件服务器故障转移群集的循序渐进指南.doc...

    在 Windows Server 2008 中配置双节点文件服务器故障转移群集的循序渐进指南 在 Windows Server 2008 中配置双节点文件服务器故障转移群集的循序渐进指南 更新时间 2007年4月 应用到 Windows S
  • win7系统下利用VS Code断点调试C/C++源码

    1 系统配置 win7 64系统 2 VS Code版本 1 70 2 System setup 3 安装包下载官网 Visual Studio Code July 2022 由于计算机是win7版的 故选择能支持win7系统版本的VS C
  • Rancher和K8s关系

    产品介绍 K8s Kubernetes 为企业提供了一种一致的方式来管理任何计算基础架构 百度百科 Rancher则是用于管理位于任何位置的Kubernetes集群的完整平台 如果用户是自己手动部署K8s集群 流程还是比较复杂的 需要掌握一
  • Idea中打包jar包(mavan项目)

    一 操作环境 IntelliJ IDEA 2020 3 3 Win10 Mavan项目 二 操作 1 确保打包方式为jar 打开pom xml文件 找到
  • MAC搜索不到蓝牙设备解决方案

    简单的解决方案就是 在打开的窗口中找到 com apple Bluetooth plist 删掉 如果还看到 com apple Bluetooth plist lockfile 也一并删 如下图 然后重启 也可以参考 http bbs p
  • [Unity][安卓]Unity和Android Studio 3.0 交互通讯(1)Android Studio 3.0 设置

    安卓 Android Studio 3 0 JDK安卓环境配置 2017 10 http blog csdn net bulademian article details 78387052 安卓 Android Studio 3 0 安装包
  • 计算机中丢失dasfj_v1.2.dll,S7DasBrMenu.dll(修复丢失S7DasBrMenu.dll文件)V1.0 正式版

    S7DasBrMenu dll 修复丢失S7DasBrMenu dll文件 是针对S7DasBrMenu dll文件的一款很好用的修复工具 使用电脑时遇到S7DasBrMenu dll文件丢失怎么办 没关系 小编带来的这款S7DasBrMe
  • 深入理解设计原则之KISS/YAGNI/DRY原则【软件架构设计】

    系列文章目录 C 高性能优化编程系列 软件架构设计系列 深入理解设计模式系列 高级C 并发线程编程 如果敌人使你生气 说明你没有胜过他的信心 If the enemy makes you angry that means you have
  • js 判断字符串中是否包含某个字符串

    可通过str includes 和str indexOf 1 includes 语法 arr includes searchElement arr includes searchElement fromIndex 参数说明 参数 描述 se
  • 数据结构-线性结构之线性表

    什么是线性表 线性表 Linear List 由同类型数据元素构成的有序序列的线性结构 1 表中元素个数称为线性表的长度 2 线性表没有元素时 称为空表 3 表起始位置称为表头 表的结束位置称为表尾 线性表的抽象数据类型描述 类型名称 线性
  • 使用maven 打成可提供给第三方使用的jar包

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 1 集成环境idea 2 目标第三方可直接引入使用 步骤如下 0 组件项目打包一定要把build 元素注释掉 否则别人无法引入 jar包 pom 例子如下
  • Django笔记总结

    1 web框架的本质 web通信流程 web我们这里指的就是通过浏览器去访问服务端 请求页面或者数据的通信方式 属于B S架构 就是我们常见的网站 浏览器与服务端的通信流程 浏览器客户端发送一个请求信息 数据 发送到我们服务端 服务端接受这
  • 使用强随机数

    伪随机数易被攻击者破解而找到其数序生成规律 伪随机数不能用于安全敏感应用 常见安全敏感应用 SessionID的生成 挑战算法中的随机数生成 验证码的随机数生成 生成重要随机文件的文件名 生成密钥相关的随机数等 对于安全敏感应用 应该使用强
  • H3C交换机堆叠(IRF)

    目录 1 IRF简介 1 1 实验环境 1 2 添加交换机 1 3 添加连接线 1 4 启动设备 1 5 修改设备名称 1 6 关闭IRF物理端口 1 7 设置IRF域编号 1 8 设置member成员编号 1 9 配置IRF端口并与物理端
  • 使用ESP8266和Blynk应用程序的远程房间恒温器

    该项目可通过您的手机通过家庭WiFi或移动网络的任何地方控制您的家庭供暖单元 基本上 它可以用远程控制器代替普通的可编程房间恒温器 作为遥控器 它使用安装了Blynk App并配置为可以满足所有需求的智能手机 智能手机和供暖系统之间的通信由
  • Android 9(P)应用进程创建流程大揭秘

    Android 9 P 应用进程创建流程大揭秘 Android 9 P 系统启动及进程创建源码分析目录 Android 9 P 之init进程启动源码分析指南之一 Android 9 P 之init进程启动源码分析指南之二 Android
  • 【面试题】-java分布式及微服务面试题汇总

    目录 1 CAP理论 2 BASE理论 3 接口的幂等性问题 4 消息中间件如何解决消息丢失问题 5 什么是分布式事务 分布式事务的类型有哪些 6 分布式事务的解决方案有哪些 7 Dubbo的服务请求失败怎么处理 10 Soa和微服务架构有
  • 二叉树中序线索化与遍历(c)

    对下图的二叉树进行 1 创建一个带线索域的二叉树 数据类型如下 typedef struct 当 tag 为0时 代表child 正常指向下一个节点 当tag为1时 child为线索 其中 lchild 为指向中序遍历前驱 rchild 为
  • 中级职称的计算机考试题库,中级职称计算机考试模拟题库及答案

    资料仅供参考 中级职称计算机考试模拟题库及答案 1 计算机中数据的表示形式是 C A 八进制 B 十进制 C 二进制 D 十六进制 2 硬盘工作时应特别注意避免 B A 噪声 B 震动 C 潮湿 D 日光 3 下面列出的四种存储器中 易失性
  • 好像还挺好玩的GAN5——Keras搭建COGAN耦合生成式对抗网络

    好像还挺好玩的GAN5 Keras搭建COGAN耦合生成式对抗网络 学习前言 什么是COGAN 神经网络构建 1 Generator 2 Discriminator 训练思路 实现全部代码 学习前言 发现一个挺有意思的结构 可以通过同一个输