1、高斯分布采样
我们现在得到了有样本X得到的分布X ~ N(
μ
\mu
μ,
σ
\sigma
σ^2),通过采样我们得到确定的隐变量向量,从而作为解码器的输入。采样这个操作本身是不可导的,但是我们可以通过重参数化技巧,将简单分布的采样结果变换到特定分布中,如此一来则可以对变换过程进行求导。具体而言,我们从标准高斯分布中采样,并将其变换到X ~ N(
μ
\mu
μ,
σ
\sigma
σ^2),过程如下:
ϵ
\epsilon
ϵ ~
N
(
0
,
I
)
N(0, I)
N(0,I)
Z
=
μ
+
σ
×
ϵ
Z=\mu +\sigma × \epsilon
Z=μ+σ×ϵ
也就是说,从 N(
μ
\mu
μ,
σ
\sigma
σ^2) 采样
Z
Z
Z ,等同于从
ϵ
\epsilon
ϵ ~
N
(
0
,
I
)
N(0, I)
N(0,I)中采样高斯噪声
ϵ
\epsilon
ϵ,再将其按
Z
=
μ
+
σ
×
ϵ
Z=\mu +\sigma × \epsilon
Z=μ+σ×ϵ 变换。
import torch
def reparametrize(mean,lg_var):
std = lg_var.exp().sqrt()
eps = torch.FloatTensor(std.size()).normal_()
return eps.mul(std).add_(mean)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)