Pytorch之经典神经网络Generative Model(二) —— VAE (MNIST)

2023-11-09

      变分编码器(Variational AutoEncoder)是自动编码器的升级版本, 其结构跟自动编码器是类似的, 也由编码器和解码器构成。

      回忆一下, 自动编码器有个问题, 就是并不能任意生成图片, 因为我们没有办法自己去构造隐藏向量, 需要通过一张图片输入编码我们才知道得到的隐含向量是什么, 这时我们就可以通过变分自动编码器来解决这个问题。

      其实原理特别简单, 只需要在编码过程给它增加一些限制, 迫使其生成的隐含向量能够粗略的遵循一个标准正态分布, 这就是其与一般的自动编码器最大的不同。这样我们生成一张新图片就很简单了, 我们只需要给它一个标准正态分布的随机隐含向量, 这样通过解码器就能够生成我们想要的图片, 而不需要给它一张原始图片先编码。

      一般来讲, 我们通过 encoder 得到的隐含向量并不是一个标准的正态分布, 为了衡量两种分布的相似程度, 我们使用 KL divergence, 这是用来衡量两种分布相似程度的统计量,它越小,表示两种概率分布越接近。

       在实际情况中,需要在模型的准确率和encoder得到的隐含向量服从标准正态分布之间做一个权衡,所谓模型的准确率就是指解码器生成的图片与原始图片的相似程度。可以让神经网络自己做这个决定,只需要将两者都做一个loss,然后求和作为总的loss,这样网络就能够自己选择如何做才能使这个总的loss下降。

      为了避免计算 KL divergence 中的积分, 我们使用重参数的技巧, 不是每次产生一个隐含向量, 而是生成两个向量, 一个表示均值, 一个表示标准差, 这里我们默认编码之后的隐含向量服从一个正态分布的之后, 就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布, 最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布, 也就是希望均值为 0, 方差为 1

      所以标准的变分自动编码器VAE如下

import os
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
from visdom import Visdom

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean 均值
        self.fc22 = nn.Linear(400, 20) # var  标准差

        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        x = self.fc1(x)
        h1 = F.relu(x)
        mean = self.fc21(h1)
        var = self.fc22(h1)
        return mean, var

    #重参数化
    def reparametrize(self, mean, logvar):
        std = logvar.mul(0.5).exp_()
        normal = torch.FloatTensor(std.size()).normal_() #生成标准正态分布
        if torch.cuda.is_available():
            normal = torch.tensor(normal.cuda())
        else:
            normal = torch.tensor(normal)
        return normal.mul(std).add_(mean)  #标准正态分布乘上标准差再加上均值
        #这里返回的结果就是我们encoder得到的编码,也就是我们decoder要decode的编码

    def decode(self, z):
        z = self.fc3(z)
        z = F.relu(z)
        z = self.fc4(z)
        z = torch.tanh(z)
        return z

    def forward(self, x):
        mean, logvar = self.encode(x) # 编码
        z = self.reparametrize(mean, logvar) # 重新参数化成正态分布
        return self.decode(z), mean, logvar # 解码, 同时输出均值方差

def loss_function(recon_image, image, mean, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    reconstruction_function = nn.MSELoss(reduction='sum')
    MSE = reconstruction_function(recon_image, image)

    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mean.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD


def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # 标准化
])

train_set = MNIST(
                root='dataset/', 
                transform=img_transforms
)
train_data = DataLoader(
                dataset=train_set, 
                batch_size=128, 
                shuffle=True
)


net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
viz = Visdom()
viz.line([0.], [0.], win='loss', opts=dict(title='loss'))



for epoch in range(100):
    for image, _ in train_data:
        image = image.view(image.shape[0], -1)
        image = torch.tensor(image)
        if torch.cuda.is_available():
            image = image.cuda()
        recon_image, mean, logvar = net(image)
        loss = loss_function(recon_image, image, mean, logvar) / image.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    print('epoch: {}, Loss: {:.4f}'.format(epoch, loss.item()))
    save = to_img(recon_image.cpu().data)
    if not os.path.exists('./vae_img'):
        os.mkdir('./vae_img')
    save_image(save, './vae_img/image_{}.png'.format(epoch))

    viz.line([loss.item()], [epoch], win='loss', update='append')

运行100个eopch之后,可以看出来结果比自动编码器清晰一点,本质上VAE就是在encoder的结果添加了高斯噪声,通过训练要使得decoder对噪声有一定的鲁棒性,这样的话我们生成一张图片就没有必须用一张图片先做编码了,可以想象,我们只需要利用训练好的encoder对一张图片编码得到其分布后,符合这个分布的隐含向量理论上都可以通过decoder得到类似这张图片的图片。

KL越小,噪声越大(可以这麽理解,我们强行让z的分布符合正态分布,其和N(0,1)越接近,KL越小,相当于我们添加的噪声越大),所以直觉上来想loss合并后的训练过程:

  • 当 decoder 还没有训练好时(重构误差远大于 KL loss),就会适当降低噪声(KL loss 增加),使得拟合起来容易一些(重构误差开始下降);
  • 反之,如果 decoder 训练得还不错时(重构误差小于 KL loss),这时候噪声就会增加(KL loss 减少),使得拟合更加困难了(重构误差又开始增加),这时候 decoder 就要想办法提高它的生成能力了。

      变分自动编码器虽然比一般的自动编码器效果要好, 而且也限制了其输出的编码(code) 的概率分布, 但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss, 这个方式并不好。

在之后生成对抗网络中, 我们会讲一讲这种方式计算 loss 的局限性, 然后会介绍一种新的训练办法, 就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

变分自编码器VAE:原来是这么一回事 | 附开源代码 - 知乎

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch之经典神经网络Generative Model(二) —— VAE (MNIST) 的相关文章

随机推荐

  • RX8025T RTC读写与秒中断

    目录 一 精度 二 读写时序 三 写RTC对其内部ms计数的影响 四 在FPGA中用GPS校正RTC 五 ms维护 六 IIC防锁死计数清零 七 日期计算星期公式 一 精度 二 读写时序 接口为IIC 读写时序如下图 注意 1 写操作 写从
  • PHP常见问题总结

    1 为什么会出现这种情况 端口什么的都设置正确了 解决方法 请将本机的IIS服务关闭 开启Apache服务 IIS服务的关闭方法可参见 https jingyan baidu com article 0f5fb099e0d7216d8334
  • 理解JPEG文件头的格式

    1 JPEG 1 why jpeg jpeg作为图片传输格式使用最为普遍 压缩比最高 每天我们都会产出和传输大量的jpeg格式数据 手机拍出来的格式默认是jpeg 朋友圈各种分享 磁盘上积累了大量的jpeg 因此本人一直对jpeg头部数据非
  • CLIP:创建图像分类器

    介绍 假设需要对人们是否戴眼镜进行分类 但是没有数据或资源训练自定义模型 在本教程中 你将学习如何使用预训练的CLIP模型创建自定义分类器 无需任何训练 这种方法称为零快照图像分类 它使得能够对在原始CLIP模型训练期间未明确观察到的的类进
  • 并发基础知识(二)[进程间通信·信号]

    1 信号 信号是进程间通信的一种方式 这种方式没有数据传输 只是在内核中传递一个信号 整数 信号的表示是一个整数 不同的信号值 代表不同的含义 当然用户可以自定义信号 那么自定义的信号的含义和值由程序员来定和解释 Term Terminat
  • DVWA-15.Open HTTP Redirect

    OWASP将其定义为 当 Web 应用程序接受不受信任的输入时 可能会导致 Web 应用程序将请求重定向到不受信任输入中包含的 URL 则可能会出现未经验证的重定向和转发 通过修改恶意站点的不受信任的 URL 输入 攻击者可以成功发起网络钓
  • OpenGL ES基本流程总结

    作为一个学习总结 绘制了OpenGL ES中完成一次渲染所需要的一些基本步骤 离屏渲染 此处是以离屏渲染为例 离屏渲染是不直接上屏的 而是渲染到缓冲区中 那么这块缓冲区就需要我们手动创建 也就是上图所示的Framebuffer 其中需要三个
  • 车载以太网入门

    车载以太网入门 以太网的首要优势之一在于支持多种网络介质 因此可以在汽车领域进行使用 同时由于物理介质与协议无关 因此可以在汽车领域可以做相应的调整与拓展 形成一整套车载以太网协议 该协议将会在未来不断发展并长期使用 车载以太网总体架构 正
  • spring事务传播机制使用及原理

    事务 事务是逻辑上的一组操作 要么都执行 要么都不执行 事务的四大特性 原子性 构成事务的所有操作 要么都执行完成 要么全部不执行 不可能出现部分成功部分失 败的情况 一致性 在事务执行前后 数据库的一致性约束没有被破坏 隔离性 数据库中的
  • Games104 引擎工具链笔记

    一 GUI体系 1 Immediate Mode 比如UnityUGUI 优点 直接快速 缺点 逻辑比重大 2 Retained Mode 把要绘制的指令存到一个buffer中 统一绘制 优点 把游戏逻辑和UI渲染分开 扩展性强 例子 Un
  • StackOverflow 这么大,它的架构是怎么样的?

    伯乐在线补充 Nick Craver 是 StackOverflow 的软件工程师 网站可靠性工程师 这是 解密 Stack Overflow 架构 系列的第一篇 本系列会有非常多的内容 欢迎阅读并保持关注 为了便于理解本文涉及到的东西到底
  • QT5:VS创建的QT项目头文件标红和控件对象无法调用

    最近使用VS QT编写代码 除了界面和调试比较舒服以外 感觉的很不习惯 小问题不断 问题1 解决方法 1 系统环境变量中添加bin路径 2 属性配置中添加包含目录和库目录 问题2 上面这个问题 会导致ui调用不了添加的控件对象 搞到怀疑人生
  • SourceTree使用教程(七)--合并某次提交

    概述 在Git的实际使用场景中 未必都是很规矩的拉一个分支 开发一个功能 等功能测试完成后 合并到主分支 有很多的场景都是很多人在同一个开发分支上开发 然后按照上线的实际需要 依次去上传自己的功能模块 这个功能模块的提交记录很可能是交叉提交
  • vasp-自旋轨道耦合(SOC)计算步骤

    在VASP中执行自旋轨道耦合 SOC 计算 具体的计算步骤如下 结构优化获取CONTCAR文件 自洽计算 collinear normal VASP calculations 获取CHGCAR文件 能带结构计算 在此步骤中 将KPOINTS
  • 第二十六节:class和焦点的操作管理

    1 关于class的操作 IE9以下的getElementsByClassName 方法兼容问题 p Hello World p p class a 增加样式 World p ul ul function getClass classA i
  • 想要成为网络hacker黑客?先来学习这十方面的知识

    黑客 一词来源于 hacker 在英语中它实际是个中性词 本身并没有褒贬之分 指的是精通编程 计算机 网络的人 另外专门有一个词 cracker 指那些利用计算机技术侵入他人系统从事非法活动的人 但在国内这两个词都被翻译为 黑客 导致在大部
  • Unity制作Live2D(一)模型导入

    目录 序言 前期准备 导入模型 序言 在许多游戏当中 Live2D展现出来了优秀的游戏体验 通过Live2D效果 让平面的游戏人物看起来更加生动 玩家也会感受到更多乐趣 前期准备 前往Live2DCubism官网下载Unity需要的SDKC
  • cenos6.4 mongodb shell模式 常用指令

    如果还没有安装mongodb DB服务端 用户可以参考该篇文章尝试安装mongodb http blog csdn net zhouzhiwengang article details 51441638 我们的实验环境为 操作系统 cent
  • Flutter中Provider的一般用法(一)

    在flutter中Provider是比较常用的Widget Provider通常用来管理value的生命周期 通过Create和Dispose 它们是成对出现的 可以在Create进行value的初始化操作 在dispose进行value的
  • Pytorch之经典神经网络Generative Model(二) —— VAE (MNIST)

    变分编码器 Variational AutoEncoder 是自动编码器的升级版本 其结构跟自动编码器是类似的 也由编码器和解码器构成 回忆一下 自动编码器有个问题 就是并不能任意生成图片 因为我们没有办法自己去构造隐藏向量 需要通过一张图