VQ-VAE

2023-10-31

前言

之前总结了一篇VAE的,这次来个它的离散版本。
VAE(Variational Autoencoder)简单记录

论文: Neural Discrete Representation Learning

代码: https://gitee.com/mirrors_ritheshkumar95/pytorch-vqvae

v2: Generating Diverse High-Fidelity Images with VQ-VAE-2



原理



代码

这里选取 生成模型之VQ-VAE 的代码。之所以粘贴过来是因为我想写一些笔记啥的,只用于学习用途…


class VectorQuantizer(nn.Module):
    """
    VQ-VAE layer: Input any tensor to be quantized. 
    Args:
        embedding_dim (int): the dimensionality of the tensors in the
          quantized space. Inputs to the modules must be in this format as well.
        num_embeddings (int): the number of vectors in the quantized space.
        commitment_cost (float): scalar which controls the weighting of the loss terms (see
          equation 4 in the paper - this variable is Beta).
    """
    def __init__(self, embedding_dim, num_embeddings, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        
        # initialize embeddings
        self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        
    def forward(self, x):
        # [B, C, H, W] -> [B, H, W, C]
        x = x.permute(0, 2, 3, 1).contiguous()
        # [B, H, W, C] -> [BHW, C]
        flat_x = x.reshape(-1, self.embedding_dim)
        
        encoding_indices = self.get_code_indices(flat_x)
        quantized = self.quantize(encoding_indices)
        quantized = quantized.view_as(x) # [B, H, W, C]
        
        if not self.training:
            quantized = quantized.permute(0, 3, 1, 2).contiguous()
            return quantized
        
        # embedding loss: move the embeddings towards the encoder's output
        q_latent_loss = F.mse_loss(quantized, x.detach())
        # commitment loss
        e_latent_loss = F.mse_loss(x, quantized.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = x + (quantized - x).detach()
        
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        return quantized, loss
    
    def get_code_indices(self, flat_x):
        # compute L2 distance
        distances = (
            torch.sum(flat_x ** 2, dim=1, keepdim=True) +
            torch.sum(self.embeddings.weight ** 2, dim=1) -
            2. * torch.matmul(flat_x, self.embeddings.weight.t())
        ) # [N, M]
        encoding_indices = torch.argmin(distances, dim=1) # [N,]
        return encoding_indices
    
    def quantize(self, encoding_indices):
        """Returns embedding tensor for a batch of indices."""
        return self.embeddings(encoding_indices) 

class Encoder(nn.Module):
    """Encoder of VQ-VAE"""
    
    def __init__(self, in_dim=3, latent_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.latent_dim = latent_dim
        
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, 32, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, latent_dim, 1),
        )
        
    def forward(self, x):
        return self.convs(x)
class Decoder(nn.Module):
    """Decoder of VQ-VAE"""
    
    def __init__(self, out_dim=1, latent_dim=16):
        super().__init__()
        self.out_dim = out_dim
        self.latent_dim = latent_dim
        
        self.convs = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, out_dim, 3, padding=1),
        )
        
    def forward(self, x):
        return self.convs(x)

class VQVAE(nn.Module):
    """VQ-VAE"""
    
    def __init__(self, in_dim, embedding_dim, num_embeddings, data_variance, 
                 commitment_cost=0.25):
        super().__init__()
        self.in_dim = in_dim
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.data_variance = data_variance
        
        self.encoder = Encoder(in_dim, embedding_dim)
        self.vq_layer = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost)
        self.decoder = Decoder(in_dim, embedding_dim)
        
    def forward(self, x):
        z = self.encoder(x)
        if not self.training:
            e = self.vq_layer(z)
            x_recon = self.decoder(e)
            return e, x_recon
        
        e, e_q_loss = self.vq_layer(z)
        x_recon = self.decoder(e)
        
        recon_loss = F.mse_loss(x_recon, x) / self.data_variance
        
        return e_q_loss + recon_loss
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

VQ-VAE 的相关文章

随机推荐

  • The type org.springframework.dao.DataAccessException cannot be resolved. It is indirectly referenced

    今天使用Spring Cloud Mybatis Plus3 x 搭建微服务项目时 提示如下错误信息 The type org springframework dao DataAccessException cannot be resolv
  • vue-cli3搭建多入口应用项目搭建以及webpack配置

    我们平时开发 vue项目的时候 就有一种感觉就是 vue就像是专门为单页应用而诞生的 因为人家的官方文档中也说了 其实不是的 因为vue在工程化开发的时候依赖于 webpack 而webpack是将所有的资源整合到一块后形成一个html文件
  • Python基础(三)_函数和代码复用

    三 函数和代码复用 一 函数的基本使用 1 函数的定义 函数是一段具有特定功能的 可重用的语句组 用函数名来表示并通过函数名进行功能调用 函数也可以看作是一段具有名字的子程序 可以在需要的地方调用执行 不需要在每个执行的地方重复编写这些语句
  • HJ32 密码截取(java详解)(动态规划)

    hello world 你好 世界 想要了解这题的动态规划 提议先了解这题的 中心扩散法 解题思路 最长回文子串的中心扩散法 遍历每个字符作为中间位 进行左右比较 算法流程 从右到左 对每个字符进行遍历处理 并且每个字符要处理两次 因为回文
  • 二手打印机如何挑选?

    打印机作为生产力工具 最重要的是 稳定性 可靠性 以及使用成本 常用的打印机分为三种 分别是 激光打印机 喷墨打印机 针式打印机 不管你是去网店还是实体店铺购买打印机 首先你要了解自己的需求 打印机作为商品 没有好与不好 只有适不适合你 一
  • Python编程基础题(20-宇宙无敌加法器)

    Description Input 输入首先在第一行给出一个 N 位的进制表 0 lt N 20 以回车结束 随后两行 每行给出一个不超过 N 位的非负的 PAT 数 Output 在一行中输出两个 PAT 数之和 Sample Input
  • 点位运动

    梯形速度规划是最简单的速度规划方法 加速度是常数 规划过程中只需要控制速度和位移与时间的关系 如图所示 整个过程分为 加速段 匀速段 减速段 每一个轴在规划静止时都可以设置为点位运动 在点位运动模式下 各轴可以独立设置目标位置 目标速度 加
  • linux/windows下查看目标文件.a/.lib的函数符号名称

    1 linux下 objdump t 查看对象文件所有的符号列表 例如 objdump t libtest o 2 nm列出目标文件 o 的符号清单 例如 nm s filename a filename o a out 3 列出所有定义的
  • jq中快速返回祖先元素

    div class one div class two div class three div class focus 我是这个div div div div div
  • 解决页面favicon.ico文件不存在提示404问题

    所谓favicon 即Favorites Icon的缩写 顾名思义 便是其可以让浏览器的收藏夹中除显示相应的标题外 还以图标的方式区别不同的网站 当然 这不是Favicon的全部 根据浏览器的不同 Favicon显示也有所区别 在大多数主流
  • 逗号和分号

    上面的程序
  • 将python代码打包成可执行文件

    文章目录 打包工具 使用 pyinstaller 安装pyinstaller库 打包 Python是一种高级编程语言 它具有易学易用 跨平台等优点 因此在开发中得到了广泛的应用 然而 Python代码需要在Python解释器中运行 这对于一
  • UML类图几种关系的总结

    UML类图几种关系的总结 转载链接 http blog csdn net sunboy 2050 article details 9211457 UML类图 描述对象和类之间相互关系的方式包括 依赖 Dependency 关联 Associ
  • mac生成树形结构

    第一步 安装tree brew install tree 第二步 在要展示树结构的文件里面打开终端 运行命令 tree d 只显示文件夹 tree L n 显示项目的层级 n表示层级数 比如想要显示项目三层结构 可以用tree l 3 tr
  • firefox安装selenium插件

    1 目前新版类似Firefox58不兼容 打开 https addons mozilla org en US firefox addon selenium ide 网址 显示add to firefox为灰色 下载Firefox48即可 2
  • R:RStudio和RStudio Server

    RStudio是R语言开发中的利器 是最好用的R语言IDE集成环境 RStudio Server更是利器中的神器 不仅提供了Web的功能 可以安装到远程服务器上 通过Web进行访问 还支持多用户的协作开发 RStudio 是一个强大的 免费
  • IDEA——手把手教你mybatis的使用(新手教程)

    说到Mybatis 很多人不知道这是用来干什么的 简单来说就是用来优化JDBC的使用 我们可以理解为一个这样的流程 数据库 gt JDBC gt MyBatis gt Java 今天来教一下简单的mybatis使用方法 对于初学者很友好 目
  • C++基础(11)类模板

    1 类模板 类模板和函数模板的定义和使用类似 我们已经进行了介绍 有时 有两个或多个类 其功能是相同的 仅仅是数据类型不同 类模板用于实现类所需数据的类型参数化 include
  • Java并发工具之CyclicBarrier

    一 简介 摘自 Java并发编程的艺术 一书中 CyclicBarrier的字面意思是可循环使用 Cyclic 的屏障 Barrier 它要做的事情是 让一组线程到达一个屏障 也可以叫同步点 时被阻塞 直到最后一个线程到达屏障时 屏障才会开
  • VQ-VAE

    前言 之前总结了一篇VAE的 这次来个它的离散版本 VAE Variational Autoencoder 简单记录 论文 Neural Discrete Representation Learning 代码 https gitee com