通常のGAN(Vanilla GAN)はJensen-Shannonダイバージェンス(JSD)を最小化することで真のデータ分布を学習しますが、実際の訓練では勾配消失やモード崩壊といった深刻な問題が発生します。2017年にArjovskyらが提案したWasserstein GAN(WGAN)は、JSDの代わりにWasserstein距離(Earth Mover’s Distance)を用いることで、これらの問題を理論的に解決するアプローチです。
WGANは「なぜGANの訓練は不安定なのか」という根本的な問いに対する理論的に明快な回答を提示し、後続のGAN研究に大きな影響を与えました。本記事では、通常GANの問題点の分析からWasserstein距離の定義、Kantorovich-Rubinstein双対定理、そしてWGAN-GPの勾配ペナルティまで、数式を省略せずに丁寧に導出します。
本記事の内容
- 通常GANの問題点の理論的分析
- Wasserstein距離(Earth Mover’s Distance)の定義
- Kantorovich-Rubinstein双対定理の解説
- Lipschitz制約の必要性
- 重みクリッピングの問題点
- WGAN-GP(勾配ペナルティ)の導出
- PyTorchでのWGAN-GP実装と通常GANとの学習安定性比較
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
通常GANの問題点
通常GANの理論的な問題を理解するために、JSDの性質を分析します。
台が重ならない分布のJSD
2つの分布 $P$ と $Q$ の台(support、確率密度が正の値を取る領域)が重ならない場合を考えます。これは高次元のデータ空間では珍しいことではなく、むしろ一般的な状況です。画像のような高次元データの分布は、低次元の多様体(manifold)上に集中しており、2つの低次元多様体が高次元空間で重なる確率は測度0です。
台が重ならない($\mathrm{supp}(P) \cap \mathrm{supp}(Q) = \emptyset$)とき、各ダイバージェンスは以下のようになります。
KLダイバージェンスの場合:
$P$ の台に含まれる $\bm{x}$ で $Q(\bm{x}) = 0$ となる点が存在すると、
$$ D_{\mathrm{KL}}(P \| Q) = \int p(\bm{x}) \log \frac{p(\bm{x})}{q(\bm{x})} d\bm{x} = +\infty $$
KLダイバージェンスは無限大に発散します。
JSDの場合:
$$ M = \frac{P + Q}{2} $$
とおくと、$P$ の台上では $M = \frac{P}{2}$($Q = 0$ なので)、$Q$ の台上では $M = \frac{Q}{2}$ です。
$$ \begin{align} D_{\mathrm{KL}}(P \| M) &= \int p(\bm{x}) \log \frac{p(\bm{x})}{p(\bm{x})/2} d\bm{x} = \int p(\bm{x}) \log 2 \, d\bm{x} = \log 2 \\ D_{\mathrm{KL}}(Q \| M) &= \int q(\bm{x}) \log \frac{q(\bm{x})}{q(\bm{x})/2} d\bm{x} = \int q(\bm{x}) \log 2 \, d\bm{x} = \log 2 \end{align} $$
したがって、
$$ \mathrm{JSD}(P \| Q) = \frac{1}{2}(\log 2 + \log 2) = \log 2 $$
つまり、台が重ならない2つの分布のJSDは、分布間の距離がどれだけ近くても常に $\log 2$ という定数値を取ります。
勾配消失の問題
これがGANの訓練にどう影響するかを考えましょう。
学習の初期段階では、生成分布 $p_g$ と真のデータ分布 $p_{\mathrm{data}}$ の台はほとんど重なりません。このとき、JSDは定数 $\log 2$ であり、$G$ のパラメータに関する勾配は
$$ \nabla_{\bm{\theta}} \mathrm{JSD}(p_{\mathrm{data}} \| p_g) = \bm{0} $$
となります。生成器は学習のための勾配信号を受け取れず、パラメータが更新されません。これが勾配消失問題です。
最適判別器は、台が重ならない場合に $D^*(\bm{x}) = 1$(真のデータ上)または $D^*(\bm{x}) = 0$(生成データ上)と完璧に分類します。判別器が完璧になると生成器の勾配が消失するため、実用上は判別器を「強くしすぎない」ように訓練ステップ数を調整する必要があり、これがGAN訓練のデリケートさの原因です。
Wasserstein距離(Earth Mover’s Distance)
Wasserstein距離は、これらの問題を解決する距離尺度です。直感的には、分布 $P$ を分布 $Q$ に変形するために必要な「仕事量」を測ります。
定義
2つの分布 $P_r$ と $P_g$ の間の 1-Wasserstein距離(Earth Mover’s Distance, EMD)は以下で定義されます。
$$ W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(\bm{x}, \bm{y}) \sim \gamma}[\|\bm{x} – \bm{y}\|] $$
ここで、$\Pi(P_r, P_g)$ は $P_r$ と $P_g$ を周辺分布に持つすべての同時分布(輸送計画)の集合です。つまり、$\gamma(\bm{x}, \bm{y})$ は以下を満たす同時分布です。
$$ \int \gamma(\bm{x}, \bm{y}) \, d\bm{y} = P_r(\bm{x}), \quad \int \gamma(\bm{x}, \bm{y}) \, d\bm{x} = P_g(\bm{y}) $$
$\gamma(\bm{x}, \bm{y})$ は「$\bm{x}$ にある質量をどれだけ $\bm{y}$ に移動させるか」を表す輸送計画であり、$\|\bm{x} – \bm{y}\|$ は移動コスト(距離)です。Wasserstein距離は、すべての可能な輸送計画の中で最小コストのものを選びます。
具体例: 1次元の場合
2つのデルタ分布 $P_r = \delta_0$(質量が0にある)と $P_g = \delta_\theta$(質量が $\theta$ にある)を考えます。
各距離尺度を計算すると、
$$ \begin{align} D_{\mathrm{KL}}(P_r \| P_g) &= \begin{cases} 0 & (\theta = 0) \\ +\infty & (\theta \neq 0) \end{cases} \\ \mathrm{JSD}(P_r \| P_g) &= \begin{cases} 0 & (\theta = 0) \\ \log 2 & (\theta \neq 0) \end{cases} \\ W(P_r, P_g) &= |\theta| \end{align} $$
KLダイバージェンスは $\theta \neq 0$ で無限大、JSDは不連続な定数値ですが、Wasserstein距離は $\theta$ について連続かつ微分可能です。$\theta$ に関する勾配は
$$ \frac{\partial W}{\partial \theta} = \mathrm{sgn}(\theta) $$
で、$\theta = 0$ の近傍を除いて常に非零の勾配を持ちます。
これが、Wasserstein距離を用いるとGANの訓練が安定する根本的な理由です。台が重ならない場合でも意味のある勾配信号を生成器に伝えることができます。
Wasserstein距離の性質
Wasserstein距離は以下の重要な性質を持ちます。
- 非負性: $W(P_r, P_g) \geq 0$、等号は $P_r = P_g$ のときかつそのとき
- 対称性: $W(P_r, P_g) = W(P_g, P_r)$
- 三角不等式: $W(P_r, P_g) \leq W(P_r, P_h) + W(P_h, P_g)$
- 連続性: パラメトリックな分布族において、パラメータに関して連続(KLやJSDにはない性質)
Kantorovich-Rubinstein双対定理
Wasserstein距離の定義には最適輸送計画を求める必要がありますが、これは一般に計算困難です。Kantorovich-Rubinstein双対定理は、この問題を別の最適化問題に変換します。
$$ \boxed{W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \left\{ \mathbb{E}_{\bm{x} \sim P_r}[f(\bm{x})] – \mathbb{E}_{\bm{x} \sim P_g}[f(\bm{x})] \right\}} $$
ここで上限は、1-Lipschitz連続な関数 $f$ 全体にわたって取ります。
Lipschitz連続の定義
関数 $f$ が $K$-Lipschitz連続であるとは、すべての $\bm{x}_1, \bm{x}_2$ に対して
$$ |f(\bm{x}_1) – f(\bm{x}_2)| \leq K \|\bm{x}_1 – \bm{x}_2\| $$
が成り立つことを言います。$K = 1$ のとき、1-Lipschitz連続です。
直感的には、Lipschitz制約は関数の「傾き」に上限を設けます。関数の値がどこでも急激に変化しないことを保証します。微分可能な関数の場合、1-Lipschitz連続は $\|\nabla f(\bm{x})\| \leq 1$ と等価です。
双対定理の直感的理解
元の輸送問題(主問題)は「質量を最小コストで移動する計画を見つける」問題でした。双対問題は「分布間の差を最も大きく検出できる(ただし変化が緩やかな)関数を見つける」問題です。
Lipschitz制約がないと、$f$ を任意に急峻にできるため、上限は無限大に発散します。1-Lipschitz制約は「関数の傾きが1以下」という制約であり、この制約の下で2つの分布の期待値の差を最大化する問題が、元のWasserstein距離と一致するのです。
WGANの定式化
Kantorovich-Rubinstein双対定理を用いて、WGANの目的関数を以下のように定式化します。
$$ \min_G \max_{f: \|f\|_L \leq 1} \left\{ \mathbb{E}_{\bm{x} \sim P_r}[f(\bm{x})] – \mathbb{E}_{\bm{z} \sim p_z}[f(G(\bm{z}))] \right\} $$
ここで $f$ は通常のGANの判別器 $D$ に相当しますが、Sigmoid活性化を持たず、実数値を出力するためCritic(批評家)と呼ばれます。
重要な違い:
| 項目 | 通常GAN | WGAN |
|---|---|---|
| $D$/Criticの出力 | 確率 $[0, 1]$ | 実数 $(-\infty, +\infty)$ |
| 最終活性化 | Sigmoid | なし |
| 損失関数 | BCELoss | 期待値の差 |
| 制約 | なし | Lipschitz制約 |
| 最適化する距離 | JSD | Wasserstein距離 |
Criticの損失は以下になります。
$$ \mathcal{L}_{\text{critic}} = -\mathbb{E}_{\bm{x} \sim P_r}[f(\bm{x})] + \mathbb{E}_{\bm{z} \sim p_z}[f(G(\bm{z}))] $$
生成器の損失は以下です。
$$ \mathcal{L}_G = -\mathbb{E}_{\bm{z} \sim p_z}[f(G(\bm{z}))] $$
重みクリッピングとその問題点
オリジナルのWGAN論文では、Lipschitz制約を実現するために重みクリッピングが提案されました。Criticのすべての重み $w$ を $[-c, c]$ の範囲にクリップします。
$$ w \leftarrow \mathrm{clip}(w, -c, c) $$
重みが有界であれば、ネットワーク全体のLipschitz定数も有界になります。これはシンプルで実装が容易ですが、以下の深刻な問題があります。
問題1: 容量の制約
$c$ が小さいと、Criticの表現力が大幅に制限されます。Lipschitz定数が実際に必要な値よりもはるかに小さくなり、Wasserstein距離を正確に推定できなくなります。
問題2: 勾配爆発/消失
$c$ が大きいとLipschitz制約が実質的に機能せず、$c$ が小さすぎると勾配が消失します。適切な $c$ の値を選ぶことは困難です。
問題3: 重みの二極化
訓練が進むと、多くの重みが $c$ または $-c$ に張り付く現象が観察されます。これは関数空間の探索が非常に限定的であることを意味します。
WGAN-GP(勾配ペナルティ)
Gulrajaniら(2017)は、重みクリッピングの問題を解決する勾配ペナルティ(Gradient Penalty, GP)を提案しました。これがWGAN-GPです。
勾配ペナルティの導出
1-Lipschitz連続な関数 $f$ は、微分可能な場合に $\|\nabla_{\bm{x}} f(\bm{x})\| \leq 1$ を満たします。最適なCritic $f^*$ では、$P_r$ と $P_g$ を結ぶ最適輸送経路上で
$$ \|\nabla_{\bm{x}} f^*(\bm{x})\| = 1 $$
が成り立つことが知られています(直感的には、最適なCriticは許容される最大の傾きで分布間の差を測る)。
この性質を利用し、勾配ノルムが1であることをペナルティ項として目的関数に追加します。
$$ \boxed{\mathcal{L}_{\text{critic}} = \underbrace{-\mathbb{E}_{\bm{x} \sim P_r}[f(\bm{x})] + \mathbb{E}_{\bm{z} \sim p_z}[f(G(\bm{z}))]}_{\text{Wasserstein距離の推定}} + \underbrace{\lambda \, \mathbb{E}_{\hat{\bm{x}} \sim P_{\hat{x}}}[(\|\nabla_{\hat{\bm{x}}} f(\hat{\bm{x}})\|_2 – 1)^2]}_{\text{勾配ペナルティ}}} $$
ペナルティ点 $\hat{\bm{x}}$ のサンプリング
勾配ペナルティは理想的にはすべての点で計算すべきですが、計算コストの問題から、$P_r$ と $P_g$ のサンプルを結ぶ直線上の点でサンプリングします。
$$ \hat{\bm{x}} = t \, \bm{x}_{\text{real}} + (1 – t) \, \bm{x}_{\text{fake}}, \quad t \sim \mathrm{Uniform}(0, 1) $$
ここで $\bm{x}_{\text{real}} \sim P_r$、$\bm{x}_{\text{fake}} = G(\bm{z})$、$\bm{z} \sim p_z$ です。
この選択の根拠は、最適輸送理論において最適なCriticの勾配ノルムが1となる点は、主に最適輸送経路上に存在するためです。$P_r$ と $P_g$ のサンプルを結ぶ直線は、最適輸送経路の近似として機能します。
ペナルティの形式の意味
勾配ペナルティの形式 $(\|\nabla f\|_2 – 1)^2$ について説明します。
- $\|\nabla f\|_2 > 1$: Lipschitz制約に違反しているので、ペナルティで抑制される
- $\|\nabla f\|_2 < 1$: Lipschitz制約は満たしているが、Criticが最適でないことを示す(最適Criticでは $= 1$)
- $\|\nabla f\|_2 = 1$: ペナルティが0。最適な状態
$\lambda$ はペナルティの強さを制御するハイパーパラメータで、論文では $\lambda = 10$ が推奨されています。
WGAN-GPの特徴
WGAN-GPでは、重みクリッピングとは異なり、以下の点が変更されます。
- バッチ正規化を使わない: バッチ正規化はバッチ内のサンプル間に依存関係を導入し、各サンプル独立に勾配ペナルティを計算する前提に反するためです。代わりに層正規化(Layer Normalization)を使います。
- Criticをより多く更新: 論文ではCriticを5回更新するごとにGeneratorを1回更新する設定が推奨されています。Wasserstein距離の推定が正確であるほど、Generatorの勾配も正確になるためです。
Pythonでの実装
PyTorchでWGAN-GPを実装し、通常のGANとの学習安定性を比較します。2次元のガウス混合分布をターゲットとして使用します。
共通のデータ生成
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
def sample_target(n_samples):
"""ターゲット分布: 8つのガウス分布の混合"""
n_per_mode = n_samples // 8
samples = []
for i in range(8):
angle = 2 * np.pi * i / 8
center = np.array([2 * np.cos(angle), 2 * np.sin(angle)])
s = np.random.randn(n_per_mode, 2) * 0.05 + center
samples.append(s)
samples = np.concatenate(samples, axis=0)
idx = np.random.permutation(len(samples))[:n_samples]
return torch.tensor(samples[idx], dtype=torch.float32)
WGAN-GPのモデル定義
class WGANCritic(nn.Module):
"""WGAN-GPのCritic(判別器に相当、Sigmoidなし)"""
def __init__(self, hidden_dim=256):
super(WGANCritic, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, hidden_dim),
nn.LayerNorm(hidden_dim), # バッチ正規化の代わりに層正規化
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1) # Sigmoidなし: 実数値を出力
)
def forward(self, x):
return self.net(x)
class WGANGenerator(nn.Module):
"""WGAN-GPのGenerator"""
def __init__(self, noise_dim=2, hidden_dim=256):
super(WGANGenerator, self).__init__()
self.net = nn.Sequential(
nn.Linear(noise_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2)
)
def forward(self, z):
return self.net(z)
勾配ペナルティの実装
def compute_gradient_penalty(critic, real_data, fake_data, lambda_gp=10.0):
"""
WGAN-GPの勾配ペナルティを計算
1. 真のデータと偽のデータの間の補間点を生成
2. 補間点でのCriticの勾配を計算
3. 勾配ノルムが1からどれだけ離れているかをペナルティとして返す
"""
batch_size = real_data.size(0)
# 補間点のサンプリング: x_hat = t * x_real + (1-t) * x_fake
t = torch.rand(batch_size, 1)
interpolated = (t * real_data + (1 - t) * fake_data).requires_grad_(True)
# 補間点でのCriticの出力
d_interpolated = critic(interpolated)
# 補間点での勾配を計算
gradients = autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated),
create_graph=True,
retain_graph=True
)[0]
# 勾配のL2ノルム
gradient_norm = gradients.norm(2, dim=1)
# 勾配ペナルティ: λ * E[(||∇f(x̂)||₂ - 1)²]
gradient_penalty = lambda_gp * ((gradient_norm - 1.0) ** 2).mean()
return gradient_penalty
通常GANの実装(比較用)
class VanillaGenerator(nn.Module):
"""通常GANのGenerator"""
def __init__(self, noise_dim=2, hidden_dim=256):
super(VanillaGenerator, self).__init__()
self.net = nn.Sequential(
nn.Linear(noise_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2)
)
def forward(self, z):
return self.net(z)
class VanillaDiscriminator(nn.Module):
"""通常GANのDiscriminator"""
def __init__(self, hidden_dim=256):
super(VanillaDiscriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
訓練と比較
# === WGAN-GP の訓練 ===
noise_dim = 2
batch_size = 512
n_epochs = 15000
n_critic = 5 # Criticの更新回数/Generator1回
lambda_gp = 10.0
wgan_G = WGANGenerator(noise_dim=noise_dim)
wgan_C = WGANCritic()
opt_wgan_G = optim.Adam(wgan_G.parameters(), lr=1e-4, betas=(0.0, 0.9))
opt_wgan_C = optim.Adam(wgan_C.parameters(), lr=1e-4, betas=(0.0, 0.9))
wgan_w_distances = []
wgan_snapshots = []
for epoch in range(1, n_epochs + 1):
# Criticをn_critic回更新
for _ in range(n_critic):
real_data = sample_target(batch_size)
z = torch.randn(batch_size, noise_dim)
fake_data = wgan_G(z).detach()
# Wasserstein距離の推定(符号反転で損失に)
w_real = wgan_C(real_data).mean()
w_fake = wgan_C(fake_data).mean()
# 勾配ペナルティ
gp = compute_gradient_penalty(wgan_C, real_data, fake_data, lambda_gp)
# Critic損失 = -E[f(x_real)] + E[f(x_fake)] + λ * GP
loss_C = -w_real + w_fake + gp
opt_wgan_C.zero_grad()
loss_C.backward()
opt_wgan_C.step()
# Generator更新
z = torch.randn(batch_size, noise_dim)
fake_data = wgan_G(z)
loss_G = -wgan_C(fake_data).mean()
opt_wgan_G.zero_grad()
loss_G.backward()
opt_wgan_G.step()
# Wasserstein距離の推定値を記録
w_dist = (w_real - w_fake).item()
wgan_w_distances.append(w_dist)
if epoch in [1, 1000, 5000, 10000, 15000]:
with torch.no_grad():
z_vis = torch.randn(2000, noise_dim)
gen_vis = wgan_G(z_vis).numpy()
wgan_snapshots.append((epoch, gen_vis.copy()))
if epoch % 3000 == 0:
print(f'WGAN-GP Epoch {epoch}, W-distance: {w_dist:.4f}')
# === 通常GAN の訓練 ===
van_G = VanillaGenerator(noise_dim=noise_dim)
van_D = VanillaDiscriminator()
opt_van_G = optim.Adam(van_G.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_van_D = optim.Adam(van_D.parameters(), lr=1e-4, betas=(0.5, 0.999))
bce = nn.BCELoss()
van_g_losses = []
van_d_losses = []
van_snapshots = []
for epoch in range(1, n_epochs + 1):
# Discriminator更新
real_data = sample_target(batch_size)
z = torch.randn(batch_size, noise_dim)
fake_data = van_G(z).detach()
d_real = van_D(real_data)
d_fake = van_D(fake_data)
loss_D = bce(d_real, torch.ones_like(d_real) * 0.9) + \
bce(d_fake, torch.zeros_like(d_fake))
opt_van_D.zero_grad()
loss_D.backward()
opt_van_D.step()
# Generator更新
z = torch.randn(batch_size, noise_dim)
fake_data = van_G(z)
d_fake = van_D(fake_data)
loss_G = bce(d_fake, torch.ones_like(d_fake))
opt_van_G.zero_grad()
loss_G.backward()
opt_van_G.step()
van_g_losses.append(loss_G.item())
van_d_losses.append(loss_D.item())
if epoch in [1, 1000, 5000, 10000, 15000]:
with torch.no_grad():
z_vis = torch.randn(2000, noise_dim)
gen_vis = van_G(z_vis).numpy()
van_snapshots.append((epoch, gen_vis.copy()))
if epoch % 3000 == 0:
print(f'Vanilla GAN Epoch {epoch}, D Loss: {loss_D.item():.4f}, G Loss: {loss_G.item():.4f}')
学習過程の比較可視化
real_samples = sample_target(2000).numpy()
fig, axes = plt.subplots(3, len(wgan_snapshots) + 1,
figsize=(4 * (len(wgan_snapshots) + 1), 12))
# 行1: ターゲット分布
axes[0, 0].scatter(real_samples[:, 0], real_samples[:, 1], s=3, alpha=0.5, c='blue')
axes[0, 0].set_title('Target', fontsize=11)
axes[0, 0].set_xlim(-3.5, 3.5); axes[0, 0].set_ylim(-3.5, 3.5)
axes[0, 0].set_aspect('equal'); axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_ylabel('Target', fontsize=12)
for i in range(1, len(wgan_snapshots) + 1):
axes[0, i].scatter(real_samples[:, 0], real_samples[:, 1], s=3, alpha=0.5, c='blue')
axes[0, i].set_xlim(-3.5, 3.5); axes[0, i].set_ylim(-3.5, 3.5)
axes[0, i].set_aspect('equal'); axes[0, i].grid(True, alpha=0.3)
axes[0, i].set_title(f'Epoch {wgan_snapshots[i-1][0]}', fontsize=11)
# 行2: WGAN-GP
axes[1, 0].text(0.5, 0.5, 'WGAN-GP', ha='center', va='center',
fontsize=14, transform=axes[1, 0].transAxes)
axes[1, 0].axis('off')
for i, (epoch, data) in enumerate(wgan_snapshots):
axes[1, i + 1].scatter(data[:, 0], data[:, 1], s=3, alpha=0.5, c='green')
axes[1, i + 1].set_xlim(-3.5, 3.5); axes[1, i + 1].set_ylim(-3.5, 3.5)
axes[1, i + 1].set_aspect('equal'); axes[1, i + 1].grid(True, alpha=0.3)
# 行3: 通常GAN
axes[2, 0].text(0.5, 0.5, 'Vanilla\nGAN', ha='center', va='center',
fontsize=14, transform=axes[2, 0].transAxes)
axes[2, 0].axis('off')
for i, (epoch, data) in enumerate(van_snapshots):
axes[2, i + 1].scatter(data[:, 0], data[:, 1], s=3, alpha=0.5, c='red')
axes[2, i + 1].set_xlim(-3.5, 3.5); axes[2, i + 1].set_ylim(-3.5, 3.5)
axes[2, i + 1].set_aspect('equal'); axes[2, i + 1].grid(True, alpha=0.3)
plt.suptitle('WGAN-GP vs Vanilla GAN: Training Progress', fontsize=14, y=1.01)
plt.tight_layout()
plt.show()
Wasserstein距離の推移
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# WGAN-GP: Wasserstein距離の推移
window = 200
w_smooth = np.convolve(wgan_w_distances, np.ones(window)/window, mode='valid')
ax1.plot(w_smooth, color='green', linewidth=1.5)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Estimated Wasserstein Distance', fontsize=12)
ax1.set_title('WGAN-GP: Wasserstein Distance', fontsize=13)
ax1.grid(True, alpha=0.3)
# 通常GAN: D/G損失の推移
g_smooth = np.convolve(van_g_losses, np.ones(window)/window, mode='valid')
d_smooth = np.convolve(van_d_losses, np.ones(window)/window, mode='valid')
ax2.plot(g_smooth, label='Generator', alpha=0.8)
ax2.plot(d_smooth, label='Discriminator', alpha=0.8)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.set_title('Vanilla GAN: Training Losses', fontsize=13)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
WGAN-GPの大きな利点は、推定されるWasserstein距離が生成品質の指標として使えることです。通常GANの判別器の損失は生成品質と相関しないことが多いですが、Wasserstein距離は単調に減少し、生成品質の改善と対応します。この性質により、学習の進行度を定量的に監視でき、早期停止の判断も容易になります。
勾配ノルムの確認
# 勾配ペナルティが正しく機能しているか確認
real_data = sample_target(1000)
z = torch.randn(1000, noise_dim)
with torch.no_grad():
fake_data = wgan_G(z)
# 補間点での勾配ノルムを計算
t = torch.rand(1000, 1)
interpolated = (t * real_data + (1 - t) * fake_data).requires_grad_(True)
d_interp = wgan_C(interpolated)
gradients = autograd.grad(
outputs=d_interp, inputs=interpolated,
grad_outputs=torch.ones_like(d_interp),
create_graph=True
)[0]
grad_norms = gradients.norm(2, dim=1).detach().numpy()
plt.figure(figsize=(8, 5))
plt.hist(grad_norms, bins=50, density=True, alpha=0.7, color='green', edgecolor='black')
plt.axvline(x=1.0, color='red', linestyle='--', linewidth=2, label='Target norm = 1')
plt.xlabel('Gradient Norm $\\|\\nabla_x f(x)\\|_2$', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.title('Distribution of Gradient Norms (WGAN-GP)', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f'Mean gradient norm: {grad_norms.mean():.4f}')
print(f'Std gradient norm: {grad_norms.std():.4f}')
勾配ノルムの分布が1付近に集中していれば、勾配ペナルティが正しく機能し、Lipschitz制約が満たされていることを確認できます。
まとめ
本記事では、Wasserstein GAN(WGAN)の理論を通常GANの問題点の分析から勾配ペナルティの導出まで丁寧に解説し、PyTorchによる実装と比較実験を行いました。
- 通常GANのJSDは台が重ならない分布に対して定数値($\log 2$)を返すため、勾配が消失する
- Wasserstein距離(EMD)は台が重ならない場合でも連続的に変化し、意味のある勾配を提供する
- Kantorovich-Rubinstein双対定理により、$W(P_r, P_g) = \sup_{\|f\|_L \leq 1}\{\mathbb{E}[f(\bm{x})] – \mathbb{E}[f(G(\bm{z}))]\}$ と書ける
- Lipschitz制約は重みクリッピングでは不十分であり、勾配ペナルティ $\lambda \mathbb{E}[(\|\nabla f(\hat{\bm{x}})\| – 1)^2]$ が効果的
- WGAN-GPでは推定Wasserstein距離が生成品質の指標として使え、訓練の安定性も大幅に向上する
次のステップとして、以下の記事も参考にしてください。