前言
之前总结了一篇VAE的,这次来个它的离散版本。
VAE(Variational Autoencoder)简单记录
论文: Neural Discrete Representation Learning
代码: https://gitee.com/mirrors_ritheshkumar95/pytorch-vqvae
v2: Generating Diverse High-Fidelity Images with VQ-VAE-2
原理
代码
这里选取 生成模型之VQ-VAE 的代码。之所以粘贴过来是因为我想写一些笔记啥的,只用于学习用途…
class VectorQuantizer(nn.Module):
"""
VQ-VAE layer: Input any tensor to be quantized.
Args:
embedding_dim (int): the dimensionality of the tensors in the
quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int): the number of vectors in the quantized space.
commitment_cost (float): scalar which controls the weighting of the loss terms (see
equation 4 in the paper - this variable is Beta).
"""
def __init__(self, embedding_dim, num_embeddings, commitment_cost):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# initialize embeddings
self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
def forward(self, x):
# [B, C, H, W] -> [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
# [B, H, W, C] -> [BHW, C]
flat_x = x.reshape(-1, self.embedding_dim)
encoding_indices = self.get_code_indices(flat_x)
quantized = self.quantize(encoding_indices)
quantized = quantized.view_as(x) # [B, H, W, C]
if not self.training:
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized
# embedding loss: move the embeddings towards the encoder's output
q_latent_loss = F.mse_loss(quantized, x.detach())
# commitment loss
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = x + (quantized - x).detach()
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized, loss
def get_code_indices(self, flat_x):
# compute L2 distance
distances = (
torch.sum(flat_x ** 2, dim=1, keepdim=True) +
torch.sum(self.embeddings.weight ** 2, dim=1) -
2. * torch.matmul(flat_x, self.embeddings.weight.t())
) # [N, M]
encoding_indices = torch.argmin(distances, dim=1) # [N,]
return encoding_indices
def quantize(self, encoding_indices):
"""Returns embedding tensor for a batch of indices."""
return self.embeddings(encoding_indices)
class Encoder(nn.Module):
"""Encoder of VQ-VAE"""
def __init__(self, in_dim=3, latent_dim=16):
super().__init__()
self.in_dim = in_dim
self.latent_dim = latent_dim
self.convs = nn.Sequential(
nn.Conv2d(in_dim, 32, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, latent_dim, 1),
)
def forward(self, x):
return self.convs(x)
class Decoder(nn.Module):
"""Decoder of VQ-VAE"""
def __init__(self, out_dim=1, latent_dim=16):
super().__init__()
self.out_dim = out_dim
self.latent_dim = latent_dim
self.convs = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, out_dim, 3, padding=1),
)
def forward(self, x):
return self.convs(x)
class VQVAE(nn.Module):
"""VQ-VAE"""
def __init__(self, in_dim, embedding_dim, num_embeddings, data_variance,
commitment_cost=0.25):
super().__init__()
self.in_dim = in_dim
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.data_variance = data_variance
self.encoder = Encoder(in_dim, embedding_dim)
self.vq_layer = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost)
self.decoder = Decoder(in_dim, embedding_dim)
def forward(self, x):
z = self.encoder(x)
if not self.training:
e = self.vq_layer(z)
x_recon = self.decoder(e)
return e, x_recon
e, e_q_loss = self.vq_layer(z)
x_recon = self.decoder(e)
recon_loss = F.mse_loss(x_recon, x) / self.data_variance
return e_q_loss + recon_loss