GAN(敵対的生成ネットワーク)の理論とミニマックスゲームを解説

GAN(Generative Adversarial Network, 敵対的生成ネットワーク)は、2014年にIan Goodfellowらによって提案された生成モデルです。生成器(Generator)と判別器(Discriminator)の2つのネットワークを競わせる「敵対的学習」により、驚くほどリアルなデータを生成できることで注目を集めました。

GANは画像生成のみならず、テキスト生成、音声合成、画像変換、データ拡張など多岐にわたる応用を持ちます。本記事では、GANの理論的基盤であるミニマックスゲームの定式化から、最適判別器の導出、Jensen-Shannonダイバージェンスとの関係の証明まで、数式を省略せず丁寧に解説します。

本記事の内容

  • GANの基本構成(Generator vs Discriminator)
  • ミニマックス目的関数の定式化
  • 最適判別器 $D^*(\bm{x})$ の導出
  • 最適判別器の下でのGの最適化がJSDの最小化であることの証明
  • 大域的最適解の導出
  • 訓練の不安定性と対処法
  • Pythonで2次元ガウス混合分布のGAN実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

GANの基本構成

GANは以下の2つのネットワークから構成されます。

生成器(Generator) $G$: ランダムノイズ $\bm{z} \sim p_z(\bm{z})$ を入力として、データ空間のサンプル $G(\bm{z})$ を生成します。$G$ の目的は、真のデータ分布 $p_{\mathrm{data}}(\bm{x})$ に従うようなサンプルを生成することです。

判別器(Discriminator) $D$: データ $\bm{x}$ を入力として、それが真のデータである確率 $D(\bm{x}) \in [0, 1]$ を出力します。$D$ の目的は、真のデータと生成データを正しく見分けることです。

大雑把に言うと、$G$ は偽札を作る「偽造者」、$D$ はそれを見破る「鑑定士」に例えられます。偽造者は鑑定士を騙せるほど精巧な偽札を作ろうとし、鑑定士は偽札を見抜く能力を高めようとします。この競争を通じて、$G$ が生成するデータの品質が向上していきます。

ミニマックス目的関数

GANの学習は以下のミニマックスゲームとして定式化されます。

$$ \min_G \max_D V(D, G) = \mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}(\bm{x})}[\log D(\bm{x})] + \mathbb{E}_{\bm{z} \sim p_z(\bm{z})}[\log(1 – D(G(\bm{z})))] $$

この目的関数の各項を理解しましょう。

第1項: $\mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}}[\log D(\bm{x})]$

真のデータ $\bm{x}$ に対する $D$ の出力の対数の期待値です。$D$ が正しく判別するとき $D(\bm{x}) \to 1$ なので $\log D(\bm{x}) \to 0$(最大値)になります。判別器 $D$ はこの項を大きくしたいので、真のデータに対して高い確率を出力しようとします。

第2項: $\mathbb{E}_{\bm{z} \sim p_z}[\log(1 – D(G(\bm{z})))]$

生成データ $G(\bm{z})$ に対する判別の対数の期待値です。$D$ が正しく判別するとき $D(G(\bm{z})) \to 0$ なので $\log(1 – D(G(\bm{z}))) \to 0$(最大値)になります。

$D$ はこの目的関数を最大化しようとし、$G$ はこの目的関数を最小化しようとします。

最適判別器の導出

任意の固定された $G$ に対して、最適な判別器 $D^*$ を求めます。

$V(D, G)$ を積分形式で書き直すと、

$$ V(D, G) = \int_{\bm{x}} p_{\mathrm{data}}(\bm{x}) \log D(\bm{x}) \, d\bm{x} + \int_{\bm{x}} p_g(\bm{x}) \log(1 – D(\bm{x})) \, d\bm{x} $$

ここで $p_g(\bm{x})$ は生成器 $G$ が誘導するデータ空間上の分布です。被積分関数をまとめると、

$$ V(D, G) = \int_{\bm{x}} \left[ p_{\mathrm{data}}(\bm{x}) \log D(\bm{x}) + p_g(\bm{x}) \log(1 – D(\bm{x})) \right] d\bm{x} $$

各点 $\bm{x}$ において、$D(\bm{x})$ に関して最大化する問題を解きます。$a = p_{\mathrm{data}}(\bm{x})$, $b = p_g(\bm{x})$, $y = D(\bm{x})$ とおくと、

$$ f(y) = a \log y + b \log(1 – y) $$

を $y \in (0, 1)$ で最大化する問題です。$y$ で微分して0とおくと、

$$ \begin{align} \frac{df}{dy} &= \frac{a}{y} – \frac{b}{1-y} = 0 \\ \frac{a}{y} &= \frac{b}{1-y} \\ a(1-y) &= by \\ a – ay &= by \\ a &= (a+b)y \\ y &= \frac{a}{a+b} \end{align} $$

2次導関数を確認すると、

$$ \frac{d^2f}{dy^2} = -\frac{a}{y^2} – \frac{b}{(1-y)^2} < 0 $$

確かに極大値です。したがって、最適判別器は

$$ \boxed{D^*(\bm{x}) = \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})}} $$

直感的には、ある点 $\bm{x}$ が真のデータ分布に由来する割合を返しています。$p_{\mathrm{data}}(\bm{x}) = p_g(\bm{x})$ のとき $D^*(\bm{x}) = \frac{1}{2}$ となり、判別器は真偽を区別できなくなります。

Jensen-Shannonダイバージェンスとの関係

最適判別器 $D^*$ を $V(D, G)$ に代入して、$G$ の目的関数を導出します。

Jensen-Shannonダイバージェンスの定義

まず、2つの分布 $P$ と $Q$ のJensen-Shannonダイバージェンス(JSD)を定義します。

$$ \mathrm{JSD}(P \| Q) = \frac{1}{2} D_{\mathrm{KL}}\left(P \left\| \frac{P+Q}{2}\right.\right) + \frac{1}{2} D_{\mathrm{KL}}\left(Q \left\| \frac{P+Q}{2}\right.\right) $$

JSDはKLダイバージェンスと異なり、対称で常に有限値を取ります。値の範囲は $0 \leq \mathrm{JSD} \leq \log 2$ です。

証明

$D^*(\bm{x}) = \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})}$ を $V(D^*, G)$ に代入します。

$$ \begin{align} V(D^*, G) &= \int p_{\mathrm{data}}(\bm{x}) \log \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})} \, d\bm{x} + \int p_g(\bm{x}) \log \frac{p_g(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})} \, d\bm{x} \end{align} $$

ここで $M(\bm{x}) = \frac{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})}{2}$ とおくと、

$$ \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})} = \frac{p_{\mathrm{data}}(\bm{x})}{2M(\bm{x})} $$

したがって、

$$ \begin{align} V(D^*, G) &= \int p_{\mathrm{data}}(\bm{x}) \log \frac{p_{\mathrm{data}}(\bm{x})}{2M(\bm{x})} \, d\bm{x} + \int p_g(\bm{x}) \log \frac{p_g(\bm{x})}{2M(\bm{x})} \, d\bm{x} \end{align} $$

対数を分解します。

$$ \begin{align} V(D^*, G) &= \int p_{\mathrm{data}}(\bm{x}) \left[ \log \frac{p_{\mathrm{data}}(\bm{x})}{M(\bm{x})} – \log 2 \right] d\bm{x} + \int p_g(\bm{x}) \left[ \log \frac{p_g(\bm{x})}{M(\bm{x})} – \log 2 \right] d\bm{x} \\ &= D_{\mathrm{KL}}(p_{\mathrm{data}} \| M) – \log 2 + D_{\mathrm{KL}}(p_g \| M) – \log 2 \\ &= D_{\mathrm{KL}}(p_{\mathrm{data}} \| M) + D_{\mathrm{KL}}(p_g \| M) – 2\log 2 \end{align} $$

JSDの定義と比較すると、

$$ \begin{align} V(D^*, G) &= 2 \cdot \mathrm{JSD}(p_{\mathrm{data}} \| p_g) – 2\log 2 \end{align} $$

これは非常に重要な結果です。最適判別器の下で $G$ がミニマックスゲームの目的関数を最小化することは、$p_{\mathrm{data}}$ と $p_g$ のJensen-Shannonダイバージェンスを最小化することと等価なのです。

大域的最適解の導出

$\mathrm{JSD}(p_{\mathrm{data}} \| p_g) \geq 0$ であり、等号は $p_{\mathrm{data}} = p_g$ のときかつそのときに限り成立します。したがって、

$$ V(D^*, G) \geq -2\log 2 $$

等号成立条件は $p_g = p_{\mathrm{data}}$ です。このとき、

$$ D^*(\bm{x}) = \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_{\mathrm{data}}(\bm{x})} = \frac{1}{2} $$

つまり、生成器が完全に真のデータ分布を学習した大域的最適解では、

$$ \boxed{V(D^*, G^*) = -\log 4} $$

であり、判別器はすべてのデータに対して $\frac{1}{2}$ を出力します。これは真偽を全く区別できない状態に対応しています。

訓練の不安定性と対処法

GANの理論は美しいですが、実際の訓練には以下のような問題が知られています。

モード崩壊(Mode Collapse)

生成器がデータ分布の一部のモードのみを生成し、多様性を欠く問題です。例えば、MNISTの10種類の数字のうち特定の数字しか生成しなくなるといった現象が起こります。

勾配消失

判別器が完璧になると、$D(G(\bm{z})) \to 0$ となり、$G$ の損失 $\log(1 – D(G(\bm{z}))) \to 0$ の勾配が消失します。学習の初期段階で生成器の出力が明らかに偽物であるとき、この問題が顕著になります。

対策として、$G$ の損失を $\log(1 – D(G(\bm{z})))$ の最小化ではなく $-\log D(G(\bm{z}))$ の最小化に置き換える非飽和(non-saturating)損失がよく使われます。この2つは同じ均衡点を持ちますが、学習初期の勾配が大きく異なります。

学習のコツ

手法 説明
交互更新 $D$ を $k$ ステップ更新した後 $G$ を1ステップ更新する($k=1$ が一般的)
ラベルスムージング 真のデータのラベルを1.0ではなく0.9程度にする
バッチ正規化 訓練を安定させる(ただしDの最終層には使わない)
適切な学習率 $D$ と $G$ で異なる学習率を使う場合もある
スペクトル正規化 $D$ の重みのスペクトルノルムを制約する

Pythonでの実装

ここでは2次元のガウス混合分布をGANで学習する実装を示します。低次元のデータを使うことで、学習過程を2次元平面上で可視化できます。

ターゲット分布の定義

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

# 再現性のためシードを固定
torch.manual_seed(42)
np.random.seed(42)

def sample_target_distribution(n_samples):
    """
    ターゲット分布: 8つのガウス分布を円形に配置した混合分布
    各ガウス分布は標準偏差0.05で、半径2の円上に均等に配置
    """
    n_per_mode = n_samples // 8
    samples = []
    for i in range(8):
        angle = 2 * np.pi * i / 8
        center = np.array([2 * np.cos(angle), 2 * np.sin(angle)])
        s = np.random.randn(n_per_mode, 2) * 0.05 + center
        samples.append(s)
    samples = np.concatenate(samples, axis=0)
    # 端数調整
    if len(samples) < n_samples:
        extra = n_samples - len(samples)
        samples = np.concatenate([samples, samples[:extra]], axis=0)
    return torch.tensor(samples, dtype=torch.float32)

GeneratorとDiscriminatorの定義

class Generator(nn.Module):
    """生成器: ノイズ z -> 2次元データ"""
    def __init__(self, noise_dim=2, hidden_dim=128):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # 出力: 2次元座標
        )

    def forward(self, z):
        return self.net(z)


class Discriminator(nn.Module):
    """判別器: 2次元データ -> 真偽確率"""
    def __init__(self, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

訓練ループ

# ハイパーパラメータ
noise_dim = 2
batch_size = 512
n_epochs = 10000
lr_g = 1e-4
lr_d = 1e-4
k_disc = 1  # 判別器の更新回数/生成器1回

# モデル初期化
G = Generator(noise_dim=noise_dim)
D = Discriminator()
opt_G = optim.Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))

criterion = nn.BCELoss()

# 学習過程の記録
g_losses = []
d_losses = []
snapshots = []  # 学習過程のスナップショット

for epoch in range(1, n_epochs + 1):
    # === 判別器の更新 ===
    for _ in range(k_disc):
        # 真のデータ
        real_data = sample_target_distribution(batch_size)
        real_labels = torch.ones(batch_size, 1) * 0.9  # ラベルスムージング

        # 偽のデータ
        z = torch.randn(batch_size, noise_dim)
        fake_data = G(z).detach()  # Gの勾配は不要
        fake_labels = torch.zeros(batch_size, 1)

        # 判別器の損失
        d_real = D(real_data)
        d_fake = D(fake_data)
        loss_d = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)

        opt_D.zero_grad()
        loss_d.backward()
        opt_D.step()

    # === 生成器の更新 ===
    z = torch.randn(batch_size, noise_dim)
    fake_data = G(z)
    d_fake = D(fake_data)
    # 非飽和損失: -log D(G(z)) を最小化
    loss_g = criterion(d_fake, torch.ones(batch_size, 1))

    opt_G.zero_grad()
    loss_g.backward()
    opt_G.step()

    g_losses.append(loss_g.item())
    d_losses.append(loss_d.item())

    # スナップショット保存(可視化用)
    if epoch in [1, 500, 2000, 5000, 10000]:
        with torch.no_grad():
            z_vis = torch.randn(2000, noise_dim)
            gen_vis = G(z_vis).numpy()
            snapshots.append((epoch, gen_vis.copy()))

    if epoch % 2000 == 0:
        print(f'Epoch {epoch}, D Loss: {loss_d.item():.4f}, G Loss: {loss_g.item():.4f}')

学習過程の可視化

# 真のデータ分布
real_samples = sample_target_distribution(2000).numpy()

fig, axes = plt.subplots(1, len(snapshots) + 1, figsize=(4 * (len(snapshots) + 1), 4))

# 真のデータ
axes[0].scatter(real_samples[:, 0], real_samples[:, 1], s=3, alpha=0.5, c='blue')
axes[0].set_title('Target Distribution', fontsize=11)
axes[0].set_xlim(-3.5, 3.5)
axes[0].set_ylim(-3.5, 3.5)
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

# 各スナップショット
for i, (epoch, gen_data) in enumerate(snapshots):
    axes[i + 1].scatter(gen_data[:, 0], gen_data[:, 1], s=3, alpha=0.5, c='red')
    axes[i + 1].set_title(f'Epoch {epoch}', fontsize=11)
    axes[i + 1].set_xlim(-3.5, 3.5)
    axes[i + 1].set_ylim(-3.5, 3.5)
    axes[i + 1].set_aspect('equal')
    axes[i + 1].grid(True, alpha=0.3)

plt.suptitle('GAN Training Progress on 2D Gaussian Mixture', fontsize=14)
plt.tight_layout()
plt.show()

損失曲線の可視化

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# 損失曲線(移動平均で平滑化)
window = 100
g_smooth = np.convolve(g_losses, np.ones(window)/window, mode='valid')
d_smooth = np.convolve(d_losses, np.ones(window)/window, mode='valid')

ax1.plot(g_smooth, label='Generator Loss', alpha=0.8)
ax1.plot(d_smooth, label='Discriminator Loss', alpha=0.8)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss (Moving Average)', fontsize=13)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# 判別器の判別面の可視化
xx, yy = np.meshgrid(np.linspace(-3.5, 3.5, 200), np.linspace(-3.5, 3.5, 200))
grid_points = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
with torch.no_grad():
    d_values = D(grid_points).numpy().reshape(xx.shape)

contour = ax2.contourf(xx, yy, d_values, levels=50, cmap='RdYlBu_r')
plt.colorbar(contour, ax=ax2, label='D(x)')
ax2.scatter(real_samples[:, 0], real_samples[:, 1], s=2, alpha=0.3, c='blue', label='Real')
with torch.no_grad():
    final_gen = G(torch.randn(2000, noise_dim)).numpy()
ax2.scatter(final_gen[:, 0], final_gen[:, 1], s=2, alpha=0.3, c='red', label='Generated')
ax2.set_title('Discriminator Decision Boundary', fontsize=13)
ax2.legend(fontsize=10)
ax2.set_xlim(-3.5, 3.5)
ax2.set_ylim(-3.5, 3.5)
ax2.set_aspect('equal')

plt.tight_layout()
plt.show()

この実装により、GANが徐々に8つのモードを学習していく様子と、判別器の判別面を視覚的に確認できます。理想的には学習が進むと判別器の出力が全体的に $0.5$ に近づき、生成分布がターゲット分布と一致していきます。

まとめ

本記事では、GAN(敵対的生成ネットワーク)の理論をミニマックスゲームの定式化から丁寧に解説しました。

  • GANは生成器 $G$ と判別器 $D$ のミニマックスゲーム $\min_G \max_D V(D, G)$ として定式化される
  • 最適判別器は $D^*(\bm{x}) = \frac{p_{\mathrm{data}}(\bm{x})}{p_{\mathrm{data}}(\bm{x}) + p_g(\bm{x})}$ で与えられる
  • 最適判別器の下では、$G$ の最適化はJensen-Shannonダイバージェンスの最小化に等価であり、$V(D^*, G) = 2 \cdot \mathrm{JSD}(p_{\mathrm{data}} \| p_g) – 2\log 2$ となる
  • 大域的最適解は $p_g = p_{\mathrm{data}}$ であり、そのとき $V = -\log 4$、$D^* = \frac{1}{2}$ となる
  • 実際の訓練ではモード崩壊や勾配消失が問題になり、非飽和損失やラベルスムージングなどの対処法が必要である

次のステップとして、以下の記事も参考にしてください。