条件付きGAN(cGAN)の理論と実装を解説

通常のGANはランダムノイズから画像を生成しますが、「どのようなデータを生成するか」を制御することはできません。例えば、MNISTで学習したGANに「数字の7を生成してほしい」と指定することはできないのです。

条件付きGAN(Conditional GAN, cGAN)は、2014年にMirzaとOsinderoによって提案された手法で、生成器と判別器に条件情報 $\bm{y}$(クラスラベル、テキスト、画像など)を追加入力することで、この問題を解決します。cGANはPix2Pix、ACGAN、InfoGANなど多くの発展手法の基盤となっています。

本記事の内容

  • cGANの動機と定式化
  • 条件付きミニマックス目的関数の導出
  • 条件情報の入力方法(連結・埋め込み・射影)
  • ACGANとの比較
  • Pix2Pixの概要とPatchGAN判別器
  • PyTorchによるMNISTでの条件付き数字生成

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

cGANの動機

通常のGANの目的関数は以下でした。

$$ \min_G \max_D V(D, G) = \mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}}[\log D(\bm{x})] + \mathbb{E}_{\bm{z} \sim p_z}[\log(1 – D(G(\bm{z})))] $$

このとき、生成器 $G$ への入力はノイズ $\bm{z}$ のみです。$G$ がどのクラスの画像を生成するかは制御できず、$\bm{z}$ のランダム性に依存します。

cGANでは、条件情報 $\bm{y}$ を生成器と判別器の両方に追加入力します。$\bm{y}$ はクラスラベル(one-hotベクトル)、テキストの埋め込みベクトル、別の画像など、任意の補助情報を表します。

条件付きミニマックス目的関数

cGANの目的関数は以下のように定式化されます。

$$ \min_G \max_D V(D, G) = \mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}}[\log D(\bm{x} | \bm{y})] + \mathbb{E}_{\bm{z} \sim p_z}[\log(1 – D(G(\bm{z} | \bm{y}) | \bm{y}))] $$

ここで、

  • $G(\bm{z} | \bm{y})$: 条件 $\bm{y}$ の下でノイズ $\bm{z}$ から生成されたデータ
  • $D(\bm{x} | \bm{y})$: 条件 $\bm{y}$ の下でデータ $\bm{x}$ が真である確率

より正確に書くと、$G$ と $D$ はそれぞれ $(\bm{z}, \bm{y})$ と $(\bm{x}, \bm{y})$ を入力として受け取ります。

$$ \boxed{\min_G \max_D V(D, G) = \mathbb{E}_{\bm{x}, \bm{y} \sim p_{\mathrm{data}}}[\log D(\bm{x}, \bm{y})] + \mathbb{E}_{\bm{z} \sim p_z, \bm{y}}[\log(1 – D(G(\bm{z}, \bm{y}), \bm{y}))]} $$

最適判別器

通常のGANと同様の議論により、条件付き分布の下での最適判別器が導出できます。

各 $\bm{y}$ に対して固定すると、通常のGANの最適判別器の議論がそのまま適用でき、

$$ D^*(\bm{x}, \bm{y}) = \frac{p_{\mathrm{data}}(\bm{x} | \bm{y})}{p_{\mathrm{data}}(\bm{x} | \bm{y}) + p_g(\bm{x} | \bm{y})} $$

同様に、最適判別器の下での$G$の最適化は、条件付き分布間のJensen-Shannonダイバージェンスの最小化に対応します。

$$ V(D^*, G) = 2 \cdot \mathbb{E}_{\bm{y}}[\mathrm{JSD}(p_{\mathrm{data}}(\bm{x} | \bm{y}) \| p_g(\bm{x} | \bm{y}))] – 2\log 2 $$

大域的最適解は $p_g(\bm{x} | \bm{y}) = p_{\mathrm{data}}(\bm{x} | \bm{y})$ であり、すべての条件 $\bm{y}$ に対して生成分布が真の条件付き分布と一致する状態です。

条件情報の入力方法

条件情報 $\bm{y}$ をネットワークに入力する方法にはいくつかのアプローチがあります。

方法1: 連結(Concatenation)

最もシンプルな方法です。条件情報を入力と直接連結します。

生成器の場合: ノイズ $\bm{z} \in \mathbb{R}^{n_z}$ と条件ベクトル $\bm{y} \in \mathbb{R}^{n_y}$ を連結して $[\bm{z}; \bm{y}] \in \mathbb{R}^{n_z + n_y}$ を入力にします。

$$ G(\bm{z}, \bm{y}) = G([\bm{z}; \bm{y}]) $$

判別器の場合: 画像特徴 $\bm{x}$ と条件ベクトル $\bm{y}$ を連結します。画像入力の場合は、$\bm{y}$ をone-hotベクトルから空間的に拡張(各チャンネルに同じ値を持つ特徴マップに変換)して、画像テンソルにチャンネル方向で連結します。

$$ D(\bm{x}, \bm{y}) = D([\bm{x}; \bm{y}_{\text{spatial}}]) $$

ここで $\bm{y}_{\text{spatial}} \in \mathbb{R}^{n_y \times H \times W}$ は各チャンネルが対応するone-hotの値で埋められたテンソルです。

方法2: 埋め込み(Embedding)

クラスラベル $y \in \{0, 1, \dots, C-1\}$ を学習可能な埋め込みベクトル $\bm{e}_y \in \mathbb{R}^{d}$ に変換してから連結します。

$$ \bm{e}_y = \bm{W}_{\text{emb}}[y, :], \quad \bm{W}_{\text{emb}} \in \mathbb{R}^{C \times d} $$

one-hotベクトルよりも低次元で情報を表現でき、クラス間の関係性も学習できるメリットがあります。

方法3: 射影(Projection)

Miyatoら(2018)が提案した射影判別器(Projection Discriminator)では、判別器の中間特徴ベクトル $\bm{h}$ と条件埋め込み $\bm{e}_y$ の内積を取ります。

$$ D(\bm{x}, y) = \sigma\left(\bm{w}^T \bm{h} + \bm{e}_y^T \bm{h}\right) $$

ここで $\bm{w}$ は学習可能な重みベクトル、$\sigma$ はシグモイド関数です。第1項はデータの真偽を判定し、第2項はデータとラベルの整合性を評価します。

この方法は連結よりも理論的根拠が明確で、条件付き分布のモデル化が自然に行えるため、高品質な生成結果が得られることが知られています。

ACGAN(Auxiliary Classifier GAN)との比較

ACGAN(Odenaら, 2017)はcGANの別のアプローチで、判別器に補助分類器(Auxiliary Classifier)を追加します。

ACGANの判別器は2つの出力を持ちます。

  1. 真偽判定: $D_{\text{src}}(\bm{x})$ — データが真か偽かの確率
  2. クラス分類: $D_{\text{cls}}(\bm{x})$ — データのクラスの確率分布

ACGANの目的関数は以下のようになります。

$$ \begin{align} \mathcal{L}_S &= \mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}}[\log D_{\text{src}}(\bm{x})] + \mathbb{E}_{\bm{z}}[\log(1 – D_{\text{src}}(G(\bm{z}, y)))] \\ \mathcal{L}_C &= \mathbb{E}_{\bm{x} \sim p_{\mathrm{data}}}[\log P(C = y | \bm{x})] + \mathbb{E}_{\bm{z}}[\log P(C = y | G(\bm{z}, y))] \end{align} $$

$D$ は $\mathcal{L}_S + \mathcal{L}_C$ を最大化し、$G$ は $-\mathcal{L}_S + \mathcal{L}_C$ を最大化します。

cGANとACGANの主な違いをまとめます。

特徴 cGAN ACGAN
条件情報の入力 $D$ に $\bm{y}$ を入力 $D$ が $\bm{y}$ を予測
$D$ の出力 真偽確率のみ 真偽確率 + クラス確率
クラス数が多い場合 比較的安定 品質が低下しやすい
理論的基盤 条件付きJSD最小化 追加の分類損失

Pix2Pixの概要

Pix2Pix(Isolaら, 2017)は、cGANのフレームワークを画像から画像への変換(Image-to-Image Translation)に応用した手法です。条件 $\bm{y}$ として入力画像を使い、対応する出力画像を生成します。

例えば、 – セグメンテーションマップ → 写真 – エッジ画像 → 写真 – 白黒画像 → カラー画像 – 昼の写真 → 夜の写真

Pix2Pixの目的関数

Pix2Pixでは、cGANの目的関数にL1損失を追加します。

$$ \mathcal{L}_{\text{Pix2Pix}} = \mathcal{L}_{\text{cGAN}}(G, D) + \lambda \mathcal{L}_{L1}(G) $$

ここで、

$$ \mathcal{L}_{L1}(G) = \mathbb{E}_{\bm{x}, \bm{y}}[\|\bm{x} – G(\bm{z}, \bm{y})\|_1] $$

L1損失を加える理由は、cGAN損失だけでは生成画像がシャープだがアーティファクトが多くなりがちで、L1損失だけでは画像がぼやけがちであるためです。両者を組み合わせることで、シャープかつ構造的に正しい画像が生成できます。

$\lambda$ は2つの損失のバランスを制御するハイパーパラメータで、論文では $\lambda = 100$ が使われています。

PatchGAN判別器

Pix2Pixのもうひとつの重要な貢献はPatchGAN判別器です。通常の判別器は画像全体に対して1つのスカラーを出力しますが、PatchGANは画像を $N \times N$ のパッチに分割し、各パッチごとに真偽を判定します。

PatchGANの出力は $N \times N$ の行列 $\bm{D}(\bm{x}) \in \mathbb{R}^{N \times N}$ であり、各要素 $D_{i,j}(\bm{x})$ は受容野(receptive field)に対応する局所パッチの真偽確率を表します。

$$ \mathcal{L}_D = -\frac{1}{N^2} \sum_{i,j} \left[ \log D_{i,j}(\bm{x}_{\text{real}}) + \log(1 – D_{i,j}(\bm{x}_{\text{fake}})) \right] $$

PatchGANのメリットは以下のとおりです。

  • パラメータ数が少ない(全結合層がない)
  • 任意サイズの画像に適用可能
  • 高周波のテクスチャ情報をよく捉える(低周波はL1損失が担当)

典型的なパッチサイズは $70 \times 70$ ピクセルの受容野を持つ設計です。

Pythonでの実装

PyTorchを用いて、MNISTデータセットでクラス指定の数字生成を行う条件付きGANを実装します。

モデル定義

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ハイパーパラメータ
n_classes = 10      # MNISTのクラス数
latent_dim = 100    # ノイズの次元
embed_dim = 50      # クラス埋め込みの次元
img_shape = (1, 28, 28)
img_dim = 1 * 28 * 28


class ConditionalGenerator(nn.Module):
    """条件付き生成器: (z, y) -> 画像"""
    def __init__(self):
        super(ConditionalGenerator, self).__init__()
        # クラスラベルの埋め込み
        self.label_embedding = nn.Embedding(n_classes, embed_dim)

        # 入力: ノイズ(100) + 埋め込み(50) = 150次元
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # ラベルを埋め込みベクトルに変換
        label_emb = self.label_embedding(labels)
        # ノイズと埋め込みを連結: [z; e_y]
        gen_input = torch.cat([z, label_emb], dim=1)
        img = self.model(gen_input)
        return img.view(img.size(0), *img_shape)


class ConditionalDiscriminator(nn.Module):
    """条件付き判別器: (画像, y) -> 真偽"""
    def __init__(self):
        super(ConditionalDiscriminator, self).__init__()
        # クラスラベルの埋め込み
        self.label_embedding = nn.Embedding(n_classes, embed_dim)

        # 入力: 画像(784) + 埋め込み(50) = 834次元
        self.model = nn.Sequential(
            nn.Linear(img_dim + embed_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # 画像を平坦化
        img_flat = img.view(img.size(0), -1)
        # ラベルを埋め込みベクトルに変換
        label_emb = self.label_embedding(labels)
        # 画像と埋め込みを連結: [x; e_y]
        d_input = torch.cat([img_flat, label_emb], dim=1)
        validity = self.model(d_input)
        return validity

訓練ループ

# データの準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True,
                         download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

# モデルの初期化
G = ConditionalGenerator().to(device)
D = ConditionalDiscriminator().to(device)
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# 訓練
n_epochs = 50
G_losses, D_losses = [], []

for epoch in range(n_epochs):
    epoch_g_loss, epoch_d_loss = 0, 0
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size_curr = imgs.size(0)
        imgs = imgs.to(device)
        labels = labels.to(device)

        # ラベル(ラベルスムージング適用)
        real_label = torch.full((batch_size_curr, 1), 0.9, device=device)
        fake_label = torch.full((batch_size_curr, 1), 0.0, device=device)

        # === 判別器の更新 ===
        optimizer_D.zero_grad()

        # 真のデータの判別
        validity_real = D(imgs, labels)
        loss_real = criterion(validity_real, real_label)

        # 偽のデータの生成と判別
        z = torch.randn(batch_size_curr, latent_dim, device=device)
        gen_labels = torch.randint(0, n_classes, (batch_size_curr,), device=device)
        gen_imgs = G(z, gen_labels)
        validity_fake = D(gen_imgs.detach(), gen_labels)
        loss_fake = criterion(validity_fake, fake_label)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # === 生成器の更新 ===
        optimizer_G.zero_grad()

        z = torch.randn(batch_size_curr, latent_dim, device=device)
        gen_labels = torch.randint(0, n_classes, (batch_size_curr,), device=device)
        gen_imgs = G(z, gen_labels)
        validity = D(gen_imgs, gen_labels)
        loss_G = criterion(validity, torch.ones(batch_size_curr, 1, device=device))

        loss_G.backward()
        optimizer_G.step()

        epoch_g_loss += loss_G.item()
        epoch_d_loss += loss_D.item()

    n_batches = len(dataloader)
    G_losses.append(epoch_g_loss / n_batches)
    D_losses.append(epoch_d_loss / n_batches)

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{n_epochs}] '
              f'D Loss: {D_losses[-1]:.4f} G Loss: {G_losses[-1]:.4f}')

学習曲線の可視化

plt.figure(figsize=(10, 5))
plt.plot(G_losses, label='Generator', linewidth=2)
plt.plot(D_losses, label='Discriminator', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Conditional GAN Training Loss', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

指定数字の生成

cGANの最大の特徴である、クラスを指定した画像生成を行います。

def generate_specified_digits(generator, n_per_class=10):
    """
    0〜9の各数字を指定して生成
    """
    generator.eval()
    fig, axes = plt.subplots(n_classes, n_per_class,
                              figsize=(n_per_class * 1.2, n_classes * 1.2))

    with torch.no_grad():
        for digit in range(n_classes):
            z = torch.randn(n_per_class, latent_dim, device=device)
            labels = torch.full((n_per_class,), digit,
                               dtype=torch.long, device=device)
            gen_imgs = generator(z, labels).cpu()

            for j in range(n_per_class):
                img = gen_imgs[j].squeeze().numpy() * 0.5 + 0.5
                axes[digit, j].imshow(img, cmap='gray')
                axes[digit, j].axis('off')

            # 行ラベルとして数字を表示
            axes[digit, 0].set_ylabel(str(digit), fontsize=14,
                                       rotation=0, labelpad=20)

    plt.suptitle('Conditional GAN: Specified Digit Generation', fontsize=14)
    plt.tight_layout()
    plt.show()

generate_specified_digits(G)

同一条件での多様性の確認

同じクラスラベルで異なるノイズ $\bm{z}$ を入力したとき、多様な画像が生成されることを確認します。

def show_diversity(generator, digit=3, n_samples=20):
    """
    同じ数字で多様なサンプルを生成
    """
    generator.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, latent_dim, device=device)
        labels = torch.full((n_samples,), digit,
                           dtype=torch.long, device=device)
        gen_imgs = generator(z, labels).cpu()

    fig, axes = plt.subplots(2, n_samples // 2,
                              figsize=(n_samples // 2 * 1.5, 3))
    for i in range(n_samples):
        row, col = i // (n_samples // 2), i % (n_samples // 2)
        img = gen_imgs[i].squeeze().numpy() * 0.5 + 0.5
        axes[row, col].imshow(img, cmap='gray')
        axes[row, col].axis('off')

    plt.suptitle(f'Diversity of Generated Digit "{digit}"', fontsize=14)
    plt.tight_layout()
    plt.show()

show_diversity(G, digit=3, n_samples=20)
show_diversity(G, digit=7, n_samples=20)

潜在空間の条件付き補間

同じクラスの条件の下で、2つのノイズベクトル間を補間し、画像がどう変化するか確認します。

def conditional_interpolation(generator, digit, n_steps=10):
    """条件付きの潜在空間補間"""
    generator.eval()
    z1 = torch.randn(1, latent_dim, device=device)
    z2 = torch.randn(1, latent_dim, device=device)
    labels = torch.full((1,), digit, dtype=torch.long, device=device)

    fig, axes = plt.subplots(1, n_steps, figsize=(n_steps * 1.5, 1.5))

    with torch.no_grad():
        for i, alpha in enumerate(np.linspace(0, 1, n_steps)):
            z = z1 * (1 - alpha) + z2 * alpha
            gen_img = generator(z, labels).cpu()
            img = gen_img.squeeze().numpy() * 0.5 + 0.5
            axes[i].imshow(img, cmap='gray')
            axes[i].axis('off')
            axes[i].set_title(f'{alpha:.1f}', fontsize=9)

    plt.suptitle(f'Latent Interpolation (digit={digit})', fontsize=13)
    plt.tight_layout()
    plt.show()

# 各数字で補間を表示
for digit in [0, 3, 7]:
    conditional_interpolation(G, digit)

この結果から、同じクラスの条件下では、ノイズの補間によって書体やスタイルが滑らかに変化することが確認できます。cGANは「何を生成するか」をクラスラベルで制御し、「どのように生成するか」をノイズ $\bm{z}$ で制御するという役割分担を実現しています。

まとめ

本記事では、条件付きGAN(cGAN)の理論を目的関数の定式化から解説し、PyTorchによる指定数字生成の実装を行いました。

  • cGANは通常のGANに条件情報 $\bm{y}$ を追加することで、生成するデータのクラスや属性を制御できる
  • 目的関数 $V(D,G) = \mathbb{E}[\log D(\bm{x}, \bm{y})] + \mathbb{E}[\log(1-D(G(\bm{z}, \bm{y}), \bm{y}))]$ は、条件付き分布間のJSD最小化に対応する
  • 条件情報の入力方法には連結埋め込み射影があり、射影判別器が理論的にも品質的にも優れている
  • Pix2PixはcGANを画像変換に応用し、L1損失の追加とPatchGAN判別器の導入で高品質な変換を実現した
  • 実装では、クラスラベルの埋め込みをノイズ/画像と連結する方法で、指定した数字を多様に生成できることを確認した

次のステップとして、以下の記事も参考にしてください。