VAE(変分オートエンコーダ)の理論とELBOを導出する

VAE(Variational Autoencoder, 変分オートエンコーダ)は、深層生成モデルの代表的な手法のひとつです。データの背後にある潜在的な構造を学習し、新たなデータを生成できるだけでなく、潜在空間における意味のある表現を獲得できる点が大きな特徴です。

VAEは画像生成、テキスト生成、異常検知、データ拡張など幅広い応用を持ち、後続の拡散モデルなどの基盤にもなっています。本記事では、VAEの理論的な基盤を数式で丁寧に導出し、PyTorchによる実装まで一貫して解説します。

本記事の内容

  • 生成モデルと潜在変数モデルの概要
  • 周辺尤度の計算困難性と変分推論の導入
  • ELBO(Evidence Lower Bound)の数学的導出
  • ELBOの分解(再構成誤差 + KL正則化項)
  • 再パラメータ化トリックの導出と必要性
  • PyTorchによるMNISTでのVAE実装と潜在空間の可視化

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

生成モデルの目的

生成モデルの目的は、観測データ $\bm{x}$ の確率分布 $p(\bm{x})$ をモデル化し、その分布からの新しいサンプルを生成することです。

訓練データ $\{\bm{x}^{(1)}, \bm{x}^{(2)}, \dots, \bm{x}^{(N)}\}$ が与えられたとき、データの対数尤度を最大化するパラメータ $\bm{\theta}$ を求めることが基本的な方針になります。

$$ \max_{\bm{\theta}} \sum_{i=1}^{N} \log p_{\bm{\theta}}(\bm{x}^{(i)}) $$

潜在変数モデル

多くの生成モデルでは、観測データ $\bm{x}$ の背後に潜在変数(latent variable) $\bm{z}$ が存在すると仮定します。潜在変数 $\bm{z}$ はデータの本質的な要因(手書き数字であれば、数字の種類、傾き、太さなど)を表すと考えます。

このとき、観測データの周辺尤度(marginal likelihood)は以下のように書けます。

$$ p_{\bm{\theta}}(\bm{x}) = \int p_{\bm{\theta}}(\bm{x} | \bm{z}) \, p(\bm{z}) \, d\bm{z} $$

ここで、$p(\bm{z})$ は潜在変数の事前分布であり、一般には標準正規分布 $\mathcal{N}(\bm{0}, \bm{I})$ を仮定します。$p_{\bm{\theta}}(\bm{x} | \bm{z})$ はデコーダ(生成器)が与える尤度関数です。

周辺尤度の計算困難性

上の積分を直接計算するには、すべての可能な $\bm{z}$ について $p_{\bm{\theta}}(\bm{x} | \bm{z}) \, p(\bm{z})$ を積分する必要があります。しかし、$p_{\bm{\theta}}(\bm{x} | \bm{z})$ がニューラルネットワークで表現される場合、この積分は解析的に解けません。

また、事後分布 $p_{\bm{\theta}}(\bm{z} | \bm{x})$ もベイズの定理から以下のように書けますが、

$$ p_{\bm{\theta}}(\bm{z} | \bm{x}) = \frac{p_{\bm{\theta}}(\bm{x} | \bm{z}) \, p(\bm{z})}{p_{\bm{\theta}}(\bm{x})} $$

分母に計算困難な周辺尤度 $p_{\bm{\theta}}(\bm{x})$ が現れるため、事後分布の計算もまた困難です。

変分推論の導入

計算困難な事後分布 $p_{\bm{\theta}}(\bm{z} | \bm{x})$ を、パラメータ $\bm{\phi}$ を持つ扱いやすい分布 $q_{\bm{\phi}}(\bm{z} | \bm{x})$ で近似するのが変分推論(Variational Inference)のアイデアです。

VAEでは、この近似事後分布をガウス分布で表現します。

$$ q_{\bm{\phi}}(\bm{z} | \bm{x}) = \mathcal{N}(\bm{z}; \bm{\mu}_{\bm{\phi}}(\bm{x}), \mathrm{diag}(\bm{\sigma}_{\bm{\phi}}^2(\bm{x}))) $$

ここで $\bm{\mu}_{\bm{\phi}}(\bm{x})$ と $\bm{\sigma}_{\bm{\phi}}^2(\bm{x})$ はニューラルネットワーク(エンコーダ)が出力する平均ベクトルと分散ベクトルです。

ELBO(Evidence Lower Bound)の導出

ここからが本記事の核心です。対数周辺尤度 $\log p_{\bm{\theta}}(\bm{x})$ の下界であるELBOを導出します。

導出法1: KLダイバージェンスの非負性から

真の事後分布 $p_{\bm{\theta}}(\bm{z} | \bm{x})$ と近似事後分布 $q_{\bm{\phi}}(\bm{z} | \bm{x})$ のKLダイバージェンスは以下で定義されます。

$$ D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p_{\bm{\theta}}(\bm{z} | \bm{x})) = \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \log \frac{q_{\bm{\phi}}(\bm{z} | \bm{x})}{p_{\bm{\theta}}(\bm{z} | \bm{x})} \, d\bm{z} $$

ここでベイズの定理 $p_{\bm{\theta}}(\bm{z} | \bm{x}) = \frac{p_{\bm{\theta}}(\bm{x} | \bm{z}) p(\bm{z})}{p_{\bm{\theta}}(\bm{x})}$ を代入します。

$$ \begin{align} D_{\mathrm{KL}}(q \| p) &= \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \log \frac{q_{\bm{\phi}}(\bm{z} | \bm{x})}{\frac{p_{\bm{\theta}}(\bm{x} | \bm{z}) p(\bm{z})}{p_{\bm{\theta}}(\bm{x})}} \, d\bm{z} \\ &= \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \log \frac{q_{\bm{\phi}}(\bm{z} | \bm{x}) \, p_{\bm{\theta}}(\bm{x})}{p_{\bm{\theta}}(\bm{x} | \bm{z}) \, p(\bm{z})} \, d\bm{z} \\ &= \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \left[ \log \frac{q_{\bm{\phi}}(\bm{z} | \bm{x})}{p(\bm{z})} – \log p_{\bm{\theta}}(\bm{x} | \bm{z}) + \log p_{\bm{\theta}}(\bm{x}) \right] d\bm{z} \end{align} $$

3行目では、対数の積を和に分解しました。$\log p_{\bm{\theta}}(\bm{x})$ は $\bm{z}$ に依存しないので、積分の外に出せます。

$$ \begin{align} D_{\mathrm{KL}}(q \| p) &= \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \log \frac{q_{\bm{\phi}}(\bm{z} | \bm{x})}{p(\bm{z})} \, d\bm{z} – \int q_{\bm{\phi}}(\bm{z} | \bm{x}) \log p_{\bm{\theta}}(\bm{x} | \bm{z}) \, d\bm{z} + \log p_{\bm{\theta}}(\bm{x}) \\ &= D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z})) – \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] + \log p_{\bm{\theta}}(\bm{x}) \end{align} $$

これを $\log p_{\bm{\theta}}(\bm{x})$ について整理します。

$$ \log p_{\bm{\theta}}(\bm{x}) = \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] – D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z})) + D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p_{\bm{\theta}}(\bm{z} | \bm{x})) $$

KLダイバージェンスは常に非負($D_{\mathrm{KL}} \geq 0$)であるため、右辺の第3項を落とすことで下界が得られます。

$$ \log p_{\bm{\theta}}(\bm{x}) \geq \underbrace{\mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] – D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z}))}_{\text{ELBO}(\bm{\theta}, \bm{\phi}; \bm{x})} $$

これがELBO(Evidence Lower Bound)です。

導出法2: イェンセンの不等式から

別の導出方法も見ておきましょう。対数周辺尤度に $q_{\bm{\phi}}(\bm{z} | \bm{x})$ を導入します。

$$ \begin{align} \log p_{\bm{\theta}}(\bm{x}) &= \log \int p_{\bm{\theta}}(\bm{x}, \bm{z}) \, d\bm{z} \\ &= \log \int \frac{p_{\bm{\theta}}(\bm{x}, \bm{z})}{q_{\bm{\phi}}(\bm{z} | \bm{x})} q_{\bm{\phi}}(\bm{z} | \bm{x}) \, d\bm{z} \\ &= \log \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})} \left[ \frac{p_{\bm{\theta}}(\bm{x}, \bm{z})}{q_{\bm{\phi}}(\bm{z} | \bm{x})} \right] \end{align} $$

$\log$ は凹関数なので、イェンセンの不等式 $\log \mathbb{E}[X] \geq \mathbb{E}[\log X]$ を適用すると、

$$ \begin{align} \log p_{\bm{\theta}}(\bm{x}) &\geq \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})} \left[ \log \frac{p_{\bm{\theta}}(\bm{x}, \bm{z})}{q_{\bm{\phi}}(\bm{z} | \bm{x})} \right] \\ &= \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})} \left[ \log p_{\bm{\theta}}(\bm{x} | \bm{z}) + \log p(\bm{z}) – \log q_{\bm{\phi}}(\bm{z} | \bm{x}) \right] \\ &= \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] – D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z})) \end{align} $$

最後の行では、$\mathbb{E}_{q}[\log p(\bm{z}) – \log q(\bm{z} | \bm{x})] = -D_{\mathrm{KL}}(q \| p)$ の関係を用いました。どちらの方法でも同じELBOが得られることが確認できます。

ELBOの解釈: 再構成誤差とKL正則化

ELBOは2つの項から構成されています。

$$ \mathrm{ELBO} = \underbrace{\mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})]}_{\text{再構成誤差}} – \underbrace{D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z}))}_{\text{KL正則化項}} $$

第1項: 再構成誤差(Reconstruction Error)

エンコーダ $q_{\bm{\phi}}(\bm{z} | \bm{x})$ でサンプルした $\bm{z}$ を使って、デコーダ $p_{\bm{\theta}}(\bm{x} | \bm{z})$ が元の $\bm{x}$ を復元できるかを測る項です。この項を大きくする(負の対数尤度を小さくする)ことで、再構成の精度が向上します。

デコーダがベルヌーイ分布の場合はバイナリクロスエントロピー、正規分布の場合は二乗誤差に帰着します。

第2項: KL正則化項

近似事後分布 $q_{\bm{\phi}}(\bm{z} | \bm{x})$ が事前分布 $p(\bm{z}) = \mathcal{N}(\bm{0}, \bm{I})$ に近くなるよう正則化する項です。これにより潜在空間が滑らかに構造化され、意味のある補間が可能になります。

ガウス分布のKLダイバージェンスの解析解

$q_{\bm{\phi}}(\bm{z} | \bm{x}) = \mathcal{N}(\bm{\mu}, \mathrm{diag}(\bm{\sigma}^2))$ と $p(\bm{z}) = \mathcal{N}(\bm{0}, \bm{I})$ の場合、KLダイバージェンスは解析的に計算できます。$d$ を潜在空間の次元として、

$$ D_{\mathrm{KL}}(q \| p) = -\frac{1}{2} \sum_{j=1}^{d} \left(1 + \log \sigma_j^2 – \mu_j^2 – \sigma_j^2 \right) $$

この式は、各次元 $j$ について独立にKLダイバージェンスの寄与を計算できることを示しています。導出は以下のとおりです。

1次元の場合を考えます。$q = \mathcal{N}(\mu, \sigma^2)$, $p = \mathcal{N}(0, 1)$ として、

$$ \begin{align} D_{\mathrm{KL}}(q \| p) &= \int q(z) \log \frac{q(z)}{p(z)} dz \\ &= \int q(z) \left[ \log q(z) – \log p(z) \right] dz \\ &= \mathbb{E}_q[\log q(z)] – \mathbb{E}_q[\log p(z)] \end{align} $$

第1項($q$ のエントロピーの符号反転):

$$ \begin{align} \mathbb{E}_q[\log q(z)] &= \mathbb{E}_q\left[ -\frac{1}{2}\log(2\pi\sigma^2) – \frac{(z-\mu)^2}{2\sigma^2} \right] \\ &= -\frac{1}{2}\log(2\pi\sigma^2) – \frac{1}{2} \quad (\because \mathbb{E}_q[(z-\mu)^2] = \sigma^2) \end{align} $$

第2項:

$$ \begin{align} \mathbb{E}_q[\log p(z)] &= \mathbb{E}_q\left[ -\frac{1}{2}\log(2\pi) – \frac{z^2}{2} \right] \\ &= -\frac{1}{2}\log(2\pi) – \frac{1}{2}\mathbb{E}_q[z^2] \\ &= -\frac{1}{2}\log(2\pi) – \frac{1}{2}(\mu^2 + \sigma^2) \quad (\because \mathbb{E}_q[z^2] = \mu^2 + \sigma^2) \end{align} $$

差を取ると、

$$ \begin{align} D_{\mathrm{KL}} &= \left( -\frac{1}{2}\log(2\pi\sigma^2) – \frac{1}{2} \right) – \left( -\frac{1}{2}\log(2\pi) – \frac{\mu^2 + \sigma^2}{2} \right) \\ &= -\frac{1}{2}\log\sigma^2 – \frac{1}{2} + \frac{\mu^2 + \sigma^2}{2} \\ &= \frac{1}{2}\left( \mu^2 + \sigma^2 – \log\sigma^2 – 1 \right) \end{align} $$

$d$ 次元で各次元が独立なので、総和を取れば先の式が得られます。

再パラメータ化トリック

VAEの学習では、エンコーダのパラメータ $\bm{\phi}$ とデコーダのパラメータ $\bm{\theta}$ を同時に勾配降下法で最適化します。しかし、ELBOの第1項には $q_{\bm{\phi}}(\bm{z} | \bm{x})$ からのサンプリングが含まれています。

$$ \mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] \approx \frac{1}{L}\sum_{l=1}^{L} \log p_{\bm{\theta}}(\bm{x} | \bm{z}^{(l)}), \quad \bm{z}^{(l)} \sim q_{\bm{\phi}}(\bm{z} | \bm{x}) $$

サンプリング操作は確率的であり、$\bm{\phi}$ に関する勾配を直接計算することができません。サンプリングという離散的な操作が計算グラフを「切断」してしまうためです。

再パラメータ化のアイデア

この問題を解決するのが再パラメータ化トリック(reparameterization trick)です。

$\bm{z} \sim \mathcal{N}(\bm{\mu}, \mathrm{diag}(\bm{\sigma}^2))$ を直接サンプリングする代わりに、ノイズ $\bm{\varepsilon} \sim \mathcal{N}(\bm{0}, \bm{I})$ をサンプリングし、以下の決定論的な変換を適用します。

$$ \bm{z} = \bm{\mu}_{\bm{\phi}}(\bm{x}) + \bm{\sigma}_{\bm{\phi}}(\bm{x}) \odot \bm{\varepsilon}, \quad \bm{\varepsilon} \sim \mathcal{N}(\bm{0}, \bm{I}) $$

ここで $\odot$ は要素ごとの積です。

この変換が正しいことは、確率変数の変換の理論から確認できます。$\varepsilon_j \sim \mathcal{N}(0, 1)$ のとき、$z_j = \mu_j + \sigma_j \varepsilon_j$ は以下の分布に従います。

$$ \begin{align} \mathbb{E}[z_j] &= \mu_j + \sigma_j \mathbb{E}[\varepsilon_j] = \mu_j \\ \mathrm{Var}[z_j] &= \sigma_j^2 \mathrm{Var}[\varepsilon_j] = \sigma_j^2 \end{align} $$

したがって $z_j \sim \mathcal{N}(\mu_j, \sigma_j^2)$ となり、元のサンプリングと等価です。

重要なのは、この変換が $\bm{\mu}$ と $\bm{\sigma}$ に関して微分可能であることです。$\bm{\varepsilon}$ はパラメータに依存しないノイズ源なので、計算グラフを通じて $\bm{\phi}$ への勾配が正しく伝搬します。

$$ \frac{\partial \bm{z}}{\partial \bm{\mu}} = \bm{I}, \quad \frac{\partial \bm{z}}{\partial \bm{\sigma}} = \mathrm{diag}(\bm{\varepsilon}) $$

VAEの全体構成

VAEの全体構成を整理しましょう。

  1. エンコーダ(推論ネットワーク): $q_{\bm{\phi}}(\bm{z} | \bm{x})$ を出力。入力 $\bm{x}$ から $\bm{\mu}$ と $\log \bm{\sigma}^2$ を出力する
  2. 再パラメータ化: $\bm{z} = \bm{\mu} + \bm{\sigma} \odot \bm{\varepsilon}$ でサンプリング
  3. デコーダ(生成ネットワーク): $p_{\bm{\theta}}(\bm{x} | \bm{z})$ を出力。潜在変数 $\bm{z}$ から $\bm{x}$ を再構成する

実装上は $\log \bm{\sigma}^2$ を出力する設計が一般的です。これは $\sigma^2$ が常に正であることを自然に保証でき、数値的にも安定するためです。

損失関数は ELBO の符号を反転させたものになります。

$$ \mathcal{L}(\bm{\theta}, \bm{\phi}; \bm{x}) = -\mathbb{E}_{q_{\bm{\phi}}(\bm{z} | \bm{x})}[\log p_{\bm{\theta}}(\bm{x} | \bm{z})] + D_{\mathrm{KL}}(q_{\bm{\phi}}(\bm{z} | \bm{x}) \| p(\bm{z})) $$

Pythonでの実装

ここからはPyTorchを用いてMNISTデータセットでVAEを実装します。

モデル定義

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class VAE(nn.Module):
    """変分オートエンコーダ"""
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
        super(VAE, self).__init__()
        # エンコーダ: 入力 -> 隠れ層 -> (平均, 対数分散)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)       # 平均 μ
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)    # 対数分散 log(σ^2)

        # デコーダ: 潜在変数 -> 隠れ層 -> 再構成
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        """エンコーダ: x -> (μ, log σ^2)"""
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """再パラメータ化トリック: z = μ + σ ⊙ ε"""
        std = torch.exp(0.5 * logvar)  # σ = exp(0.5 * log σ^2)
        eps = torch.randn_like(std)     # ε ~ N(0, I)
        return mu + std * eps

    def decode(self, z):
        """デコーダ: z -> x の再構成"""
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

損失関数の定義

def vae_loss(recon_x, x, mu, logvar):
    """
    VAE損失関数 = 再構成誤差 + KL正則化項

    再構成誤差: バイナリクロスエントロピー(MNISTはピクセル値[0,1]なので)
    KL項: -0.5 * Σ(1 + log σ^2 - μ^2 - σ^2)
    """
    # 再構成誤差(バイナリクロスエントロピー、要素ごとに計算して総和)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # KL正則化項の解析解
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

訓練ループ

# ハイパーパラメータ
batch_size = 128
epochs = 30
latent_dim = 2  # 2次元にして可視化しやすくする
learning_rate = 1e-3

# データの準備(MNIST)
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True,
                               download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# モデルとオプティマイザの初期化
model = VAE(input_dim=784, hidden_dim=400, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 訓練
train_losses = []
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)  # 28x28 -> 784に平坦化
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()

    avg_loss = epoch_loss / len(train_loader.dataset)
    train_losses.append(avg_loss)
    if epoch % 5 == 0:
        print(f'Epoch {epoch}/{epochs}, Loss: {avg_loss:.4f}')

学習曲線の可視化

plt.figure(figsize=(8, 5))
plt.plot(range(1, epochs + 1), train_losses, 'b-', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Average Loss', fontsize=12)
plt.title('VAE Training Loss on MNIST', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

潜在空間の可視化

潜在変数を2次元に設定しているので、テストデータをエンコードして潜在空間にプロットし、数字ごとの分布を確認します。

# テストデータの潜在表現を取得
model.eval()
z_list = []
label_list = []
with torch.no_grad():
    for data, labels in test_loader:
        data = data.view(-1, 784).to(device)
        mu, logvar = model.encode(data)
        z_list.append(mu.cpu().numpy())
        label_list.append(labels.numpy())

z_all = np.concatenate(z_list, axis=0)
labels_all = np.concatenate(label_list, axis=0)

# 潜在空間の散布図(数字ごとに色分け)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(z_all[:, 0], z_all[:, 1],
                      c=labels_all, cmap='tab10', s=2, alpha=0.6)
plt.colorbar(scatter, label='Digit')
plt.xlabel('$z_1$', fontsize=14)
plt.ylabel('$z_2$', fontsize=14)
plt.title('VAE Latent Space Visualization (MNIST)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

このプロットでは、同じ数字が潜在空間上で近い領域にクラスタを形成することが確認できます。また、隣接するクラスタ間では滑らかな遷移が見られ、KL正則化項の効果が視覚的にわかります。

画像の再構成と生成

# 入力画像と再構成画像の比較
model.eval()
with torch.no_grad():
    data, _ = next(iter(test_loader))
    data = data.view(-1, 784).to(device)
    recon, _, _ = model(data)
    data = data.cpu().numpy()
    recon = recon.cpu().numpy()

fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(10):
    # 元画像
    axes[0, i].imshow(data[i].reshape(28, 28), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Input', fontsize=12)
    # 再構成画像
    axes[1, i].imshow(recon[i].reshape(28, 28), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontsize=12)
plt.suptitle('VAE Reconstruction Results', fontsize=14)
plt.tight_layout()
plt.show()

潜在空間からの画像生成

潜在空間上の格子点からデコードし、生成される画像のマニフォールド(多様体)を可視化します。

# 潜在空間の格子から画像を生成
n = 20  # 格子の分割数
digit_size = 28

# 標準正規分布の分位点を使って格子点を生成
from scipy.stats import norm
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

canvas = np.zeros((digit_size * n, digit_size * n))

model.eval()
with torch.no_grad():
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z = torch.tensor([[xi, yi]], dtype=torch.float32).to(device)
            decoded = model.decode(z).cpu().numpy()
            canvas[i * digit_size:(i + 1) * digit_size,
                   j * digit_size:(j + 1) * digit_size] = decoded.reshape(digit_size, digit_size)

plt.figure(figsize=(10, 10))
plt.imshow(canvas, cmap='gray')
plt.xlabel('$z_1$', fontsize=14)
plt.ylabel('$z_2$', fontsize=14)
plt.title('VAE Generated Digits (Latent Space Manifold)', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()

この可視化では、潜在空間を連続的に移動すると、生成される数字が滑らかに変化していく様子が確認できます。これはVAEが連続的で意味のある潜在表現を獲得していることの証拠です。

まとめ

本記事では、VAE(変分オートエンコーダ)の理論をELBOの導出から丁寧に解説し、PyTorchによる実装まで行いました。

  • 生成モデルにおける周辺尤度 $p(\bm{x}) = \int p(\bm{x} | \bm{z}) p(\bm{z}) d\bm{z}$ の計算困難性を、変分推論を用いて解決する
  • ELBO(Evidence Lower Bound)は再構成誤差KL正則化項に分解され、対数周辺尤度の下界を与える
  • 再パラメータ化トリック $\bm{z} = \bm{\mu} + \bm{\sigma} \odot \bm{\varepsilon}$ により、サンプリングを含む計算グラフでも勾配が伝搬する
  • 2次元の潜在空間を可視化すると、同じ数字が近くに集まり、空間が滑らかに構造化されていることが確認できた

次のステップとして、以下の記事も参考にしてください。