WGAN算法从理论层面分析了GAN训练不稳定的原因,并提出了有效的解决方法。那么是什么原因导致了GAN训练如此不稳定呢?WGAN提出是因为JS散度在不重叠的分布
p
p
p和
q
q
q上的梯度曲面是恒定为0的。如下图所示。当分布p和q不重叠时,JS散度的梯度值始终为0,从而导致此时GAN的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。
图1. JS散度出现梯度弥散现象
接下来我们将详细阐述JS散度的缺陷以及怎么解决此缺陷。
1. JS散度的缺陷
为了避免过多的理论推导,我们这里通过一个简单的分布实例来解释JS散度的缺陷。 考虑完全不重叠(
θ
≠
0
θ≠0
θ=0)的两个分布
p
p
p和
q
q
q,其中
p
p
p为:
∀
(
x
,
y
)
∈
p
,
x
=
0
,
y
∼
U
(
0
,
1
)
∀(x,y)∈p,x=0,y\sim\text{U}(0,1)
∀(x,y)∈p,x=0,y∼U(0,1) 分布
q
q
q为:
∀
(
x
,
y
)
∈
q
,
x
=
θ
,
y
∼
U
(
0
,
1
)
∀(x,y)∈q,x=θ,y\sim\text{U}(0,1)
∀(x,y)∈q,x=θ,y∼U(0,1) 其中
θ
∈
R
θ∈R
θ∈R,当
θ
=
0
θ=0
θ=0时,分布
p
p
p和
q
q
q重叠,两者相等;当
θ
≠
0
θ≠0
θ=0时,分布
p
p
p和
q
q
q不重叠。
图2. 分布$p$和$q$示意图
我们来分析上述分布
p
p
p和
q
q
q之间的JS散度随
θ
θ
θ的变化情况。根据KL散度与JS散度的定义,计算
θ
=
0
θ=0
θ=0时的JS散度
D
J
S
(
p
∣
∣
q
)
D_{JS} (p||q)
DJS(p∣∣q):
D
K
L
(
p
∣
∣
q
)
=
∑
x
=
0
,
y
∼
U
(
0
,
1
)
1
⋅
log
1
0
=
+
∞
D_{KL} (p||q)=∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{0}=+∞
DKL(p∣∣q)=x=0,y∼U(0,1)∑1⋅log01=+∞
D
K
L
(
q
∣
∣
p
)
=
∑
x
=
θ
,
y
∼
U
(
0
,
1
)
1
⋅
log
1
0
=
+
∞
D_{KL} (q||p)=∑_{x=θ,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{0}=+∞
DKL(q∣∣p)=x=θ,y∼U(0,1)∑1⋅log01=+∞
D
J
S
(
p
∣
∣
q
)
=
1
2
(
∑
x
=
0
,
y
∼
U
(
0
,
1
)
1
⋅
log
1
1
/
2
+
∑
x
=
0
,
y
∼
U
(
0
,
1
)
1
⋅
log
1
1
/
2
)
=
log
2
D_{JS} (p||q)=\frac{1}{2} \bigg(∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}+∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}\bigg)=\text{log}2
DJS(p∣∣q)=21(x=0,y∼U(0,1)∑1⋅log1/21+x=0,y∼U(0,1)∑1⋅log1/21)=log2 当
θ
=
0
θ=0
θ=0时,两个分布完全重叠,此时的JS散度和KL散度都取得最小值,即0:
D
K
L
(
p
∣
∣
q
)
=
D
K
L
(
q
∣
∣
p
)
=
D
J
S
(
p
∣
∣
q
)
=
0
D_{KL} (p||q)=D_{KL} (q||p)=D_{JS} (p||q)=0
DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0 从上面的推导,我们可以得到
D
J
S
(
p
∣
∣
q
)
D_{JS} (p||q)
DJS(p∣∣q)随
θ
θ
θ的变化趋势:
D
J
S
(
p
∣
∣
q
)
=
{
log
2
θ
≠
0
0
θ
=
0
D_{JS} (p||q) = \begin{cases} \text{log}2 &\text{} θ≠0 \\ 0 &\text{} θ=0 \end{cases}
DJS(p∣∣q)={log20θ=0θ=0 也就是说,当两个分布完全不重叠时,无论发布之间的距离远近,JS散度为恒定值
log
2
\text{log}2
log2,此时JS散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS散度采会平滑变动,产生有效梯度信息;当完全重合后,JS散度取得最小值0.如下图所示,红色的曲线分割两个正态分布,由于两个分布没有重叠,生成样本位置处的梯度值始终为0,无法更新生成网络的参数,从而出现网络训练困难的现象。
图3. JS散度出现梯度弥散现象
因此,JS散度在分布
p
p
p和
q
q
q不重叠时是无法平滑地衡量分布之间的距离,从而导致此位置上无法产生有效梯度信息,出现GAN训练不稳定的情况。要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布
p
p
p和
q
q
q不重叠时,也能平滑反映分布之间的真实距离变化。
2. EM距离
WGAN论文发现了JS散度导致GAN训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein距离,也叫推土机距离(Earth-Mover Distance,简称EM距离),它表示了从一个分布变换到另一个分布的最小代价,定义为:
W
(
p
,
q
)
=
inf
γ
∼
∏
(
p
,
q
)
E
(
x
,
y
)
∼
γ
[
∥
x
−
y
∥
]
W(p,q)=\underset{γ\sim∏(p,q)}{\text{inf}}\mathbb E_{(x,y)\simγ} [\|x-y\|]
W(p,q)=γ∼∏(p,q)infE(x,y)∼γ[∥x−y∥] 其中
∏
(
p
,
q
)
∏(p,q)
∏(p,q)是分布
p
p
p和
q
q
q组合起来的所有可能的联合分布的集合,对于每个可能的联合分布
γ
∼
∏
(
p
,
q
)
γ\sim∏(p,q)
γ∼∏(p,q),计算距离
∥
x
−
y
∥
\|x-y\|
∥x−y∥的期望
E
(
x
,
y
)
∼
γ
[
∥
x
−
y
∥
]
\mathbb E_{(x,y)\simγ} [\|x-y\|]
E(x,y)∼γ[∥x−y∥],其中
(
x
,
y
)
(x,y)
(x,y)采样自联合分布
γ
γ
γ。不同的联合分布
γ
γ
γ由不同的期望
E
(
x
,
y
)
∼
γ
[
∥
x
−
y
∥
]
\mathbb E_{(x,y)\simγ} [\|x-y\|]
E(x,y)∼γ[∥x−y∥],这些期望中的下确界即定义为分布
p
p
p和
q
q
q的Wasserstein距离。其中
inf
{
⋅
}
\text{inf}\{\cdot\}
inf{⋅}表示集合的下确界,例如
{
x
∣
1
<
x
<
3
,
x
∈
R
}
\{x|1<x<3,x∈R\}
{x∣1<x<3,x∈R}的下确界为1。
继续考虑图2中的例子,我们直接给出分布
p
p
p和
q
q
q之间的EM距离的表达式:
W
(
p
,
q
)
=
∣
θ
∣
W(p,q)=|θ|
W(p,q)=∣θ∣ 绘制出JS散度和EM距离的曲线,如下图所示,可以看到,JS散度在
θ
=
0
θ=0
θ=0处不连续,其他位置导数均为0,而EM距离总能够产生有效的导数信息,因此EM距离相对于JS散度更适合直到GAN网络的训练。
图4. JS散度和EM距离随$θ$变换曲线
3. WGAN-GP
考虑到几乎不可能遍历所有的联合分布
γ
γ
γ去计算距离
∥
x
−
y
∥
\|x-y\|
∥x−y∥的期望
E
(
x
,
y
)
∼
γ
[
∥
x
−
y
∥
]
\mathbb E_{(x,y)\simγ} [\|x-y\|]
E(x,y)∼γ[∥x−y∥],因此直接计算生成网络分布
p
g
p_g
pg与真实数据数据分布
p
r
p_r
pr的距离
W
(
p
r
,
p
g
)
W(p_r,p_g )
W(pr,pg)距离是不现实的,WGAN作者基于Kantorchovich-Rubin对偶性将直接求
W
(
p
r
,
p
g
)
W(p_r,p_g )
W(pr,pg)转换为求:
W
(
p
r
,
p
g
)
=
1
K
sup
∥
f
∥
L
≤
K
E
x
∼
p
r
[
f
(
x
)
]
−
E
x
∼
p
g
[
f
(
x
)
]
W(p_r,p_g )=\frac{1}{K} \underset{\|f\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [f(x)]-\mathbb E_{x\sim p_g} [f(x)]
W(pr,pg)=K1∥f∥L≤KsupEx∼pr[f(x)]−Ex∼pg[f(x)] 其中
sup
{
⋅
}
\text{sup}\{\cdot\}
sup{⋅}表示集合的上确界,
∥
f
∥
L
≤
K
\|f\|_L≤K
∥f∥L≤K表示函数
f
:
R
→
R
f:R→R
f:R→R满足K阶-Lipschitz连续性,即满足
∣
f
(
x
1
)
−
f
(
x
2
)
∣
≤
K
⋅
∣
x
1
−
x
2
∣
|f(x_1 )-f(x_2)|≤K\cdot|x_1-x_2 |
∣f(x1)−f(x2)∣≤K⋅∣x1−x2∣ 于是,我们使用判别网络
D
θ
(
x
)
D_θ (\boldsymbol x)
Dθ(x)参数化
f
(
x
)
f(\boldsymbol x)
f(x)函数,在
D
θ
D_θ
Dθ满足1阶-Lipschitz约束条件下,即
K
=
1
K=1
K=1,此时:
W
(
p
r
,
p
g
)
=
1
K
sup
∥
D
θ
∥
L
≤
K
E
x
∼
p
r
[
D
θ
(
x
)
]
−
E
x
∼
p
g
[
D
θ
(
x
)
]
W(p_r,p_g )=\frac{1}{K} \underset{\|D_θ\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]
W(pr,pg)=K1∥Dθ∥L≤KsupEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)] 因此求解
W
(
p
r
,
p
g
)
W(p_r,p_g )
W(pr,pg)的问题可以转化为:
max
θ
E
x
∼
p
r
[
D
θ
(
x
)
]
−
E
x
∼
p
g
[
D
θ
(
x
)
]
\underset{θ}{\text{max}}\ \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]
θmaxEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)] 这就是判别器D的优化目标。判别网络函数D_θ (x)需要满足1阶-Lipschitz约束:
∇
x
^
D
(
x
^
)
≤
1
∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})≤1
∇x^D(x^)≤1 在WGAN-GP论文中,作者提出采用增加梯度惩罚项(Gradient Penalty)方法来迫使判别网络满足1阶-Lipschitz函数约束,同时作者发现将梯度值约束在1周围时工程效果更好,因此梯度惩罚项定义为:
G
P
≜
E
x
^
∼
P
x
^
[
(
∥
∇
x
^
D
(
x
^
)
∥
2
−
1
)
2
]
GP≜\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]
GP≜Ex^∼Px^[(∥∇x^D(x^)∥2−1)2] 因此WGAN的判别器D的训练目标为:
max
θ
L
(
G
,
D
)
=
E
x
r
∼
p
r
[
D
(
x
r
)
]
−
E
x
f
∼
p
g
[
D
(
x
f
)
]
⏟
E
M
距
离
−
λ
E
x
^
∼
P
x
^
[
(
∥
∇
x
^
D
(
x
^
)
∥
2
−
1
)
2
]
⏟
G
P
惩
罚
项
\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}-\underbrace{λ\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]}_{GP惩罚项}
θmaxL(G,D)=EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)]−GP惩罚项λEx^∼Px^[(∥∇x^D(x^)∥2−1)2] 其中
x
^
\hat{\boldsymbol x}
x^来自于
x
r
\boldsymbol x_r
xr与
x
f
\boldsymbol x_f
xf的线性差值:
x
^
=
t
x
r
+
(
1
−
t
)
x
f
,
t
∈
[
0
,
1
]
\hat{\boldsymbol x}=t\boldsymbol x_r+(1-t) \boldsymbol x_f,t∈[0,1]
x^=txr+(1−t)xf,t∈[0,1] 判别器D的优化目标是最小化上述的误差
L
(
G
,
D
)
\mathcal L(G,D)
L(G,D),即迫使生成器G的分布
p
g
p_g
pg与真实分布
p
r
p_r
pr之间的EM距离
E
x
r
∼
p
r
[
D
(
x
r
)
]
−
E
x
f
∼
p
g
[
D
(
x
f
)
]
\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]
Exr∼pr[D(xr)]−Exf∼pg[D(xf)]项尽可能大,
∥
∇
x
^
D
(
x
^
)
∥
2
\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2
∥∇x^D(x^)∥2逼近于1。
WGAN的生成器G的训练目标为:
max
θ
L
(
G
,
D
)
=
E
x
r
∼
p
r
[
D
(
x
r
)
]
−
E
x
f
∼
p
g
[
D
(
x
f
)
]
⏟
E
M
距
离
\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}
θmaxL(G,D)=EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)] 即使得生成器的分布
p
g
p_g
pg与真实分布
p
r
p_r
pr之间的EM距离越小越好。考虑到
E
x
r
∼
p
r
[
D
(
x
r
)
]
\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]
Exr∼pr[D(xr)]一项与生成器无关,因此生成器的训练目标简写为:
max
θ
L
(
G
,
D
)
=
−
E
x
f
∼
p
g
[
D
(
x
f
)
]
=
−
E
z
∼
p
z
(
⋅
)
[
D
(
G
(
z
)
)
]
\begin{aligned}\underset{θ}{\text{max}} \mathcal L(G,D)&=-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]\\ &=-E_{\boldsymbol z\sim p_\boldsymbol z (\cdot)} [D(G(\boldsymbol z))]\end{aligned}
θmaxL(G,D)=−Exf∼pg[D(xf)]=−Ez∼pz(⋅)[D(G(z))] 从现实来看,判别网络D的输出不需要添加Sigmoid激活函数,这是因为原始版本的判别器的功能是作为二分类网络,添加Sigmoid函数获得类别的概率;而WGAN中判别器作为EM距离的度量网络,其目标是衡量生成网络的分布
p
g
p_g
pg和真实分布
p
r
p_r
pr之间的EM距离,属于实数空间,因此不需要添加Sigmoid激活函数。在误差函数计算时,WGAN也没有
log
\text{log}
log函数存在。在训练WGAN时,WGAN作者推荐使用RMSProp或SGD等不带动量的优化器。