重参数化技巧:高斯分布采样

2023-10-26

1、高斯分布采样

我们现在得到了有样本X得到的分布X ~ N( μ \mu μ, σ \sigma σ^2),通过采样我们得到确定的隐变量向量,从而作为解码器的输入。采样这个操作本身是不可导的,但是我们可以通过重参数化技巧,将简单分布的采样结果变换到特定分布中,如此一来则可以对变换过程进行求导。具体而言,我们从标准高斯分布中采样,并将其变换到X ~ N( μ \mu μ, σ \sigma σ^2),过程如下:

ϵ \epsilon ϵ ~ N ( 0 , I ) N(0, I) N(0,I)
Z = μ + σ × ϵ Z=\mu +\sigma × \epsilon Z=μ+σ×ϵ

也就是说,从 N( μ \mu μ, σ \sigma σ^2) 采样 Z Z Z ,等同于从 ϵ \epsilon ϵ ~ N ( 0 , I ) N(0, I) N(0,I)中采样高斯噪声 ϵ \epsilon ϵ,再将其按 Z = μ + σ × ϵ Z=\mu +\sigma × \epsilon Z=μ+σ×ϵ 变换。

import torch

def reparametrize(mean,lg_var): # 采样器方法:对方差(lg_var)进行还原,并从高斯分布中采样,将采样数值映射到编码器输出的数据分布中。
        std = lg_var.exp().sqrt()
        # torch.FloatTensor(std.size())的作用是,生成一个与std形状一样的张量。然后,调用该张量的normal_()方法,系统会对该张量中的每个元素在标准高斯空间(均值为0、方差为1)中进行采样。
        eps = torch.FloatTensor(std.size()).normal_() # 随机张量方法normal_(),完成高斯空间的采样过程。
        return eps.mul(std).add_(mean)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

重参数化技巧:高斯分布采样 的相关文章

随机推荐