VAE(変分オートエンコーダ)のELBO導出と潜在空間の理解

VAE(Variational Autoencoder、変分オートエンコーダ)は、2013年にKingmaとWellingが提案した生成モデルです。論文 “Auto-Encoding Variational Bayes” で発表され、深層学習と変分推論を組み合わせることで、複雑なデータの生成モデルを効率的に学習する方法を確立しました。

VAEは単なるオートエンコーダではなく、確率的生成モデルです。潜在空間が標準正規分布に従うよう正則化されることで、連続的で滑らかな潜在空間が形成され、新しいデータの生成や補間が可能になります。

Stable Diffusionなどの現代の画像生成モデルでも、画像を効率的な潜在表現に変換するためにVAEが使用されています。

本記事の内容

  • VAEの問題設定と確率的解釈
  • ELBO(エビデンス下界)の導出
  • 再パラメータ化トリック
  • 潜在空間の性質と可視化
  • PyTorchによるスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

VAEの問題設定

生成モデルとしての定式化

データ $\bm{x}$(例:画像)の生成過程を、以下のような潜在変数モデルとして考えます。

  1. 潜在変数 $\bm{z}$ を事前分布 $p(\bm{z})$ からサンプリング
  2. $\bm{z}$ から条件付き分布 $p_\theta(\bm{x}|\bm{z})$ に従ってデータ $\bm{x}$ を生成

ここで $p(\bm{z})$ は標準正規分布 $\mathcal{N}(\bm{0}, \bm{I})$ を仮定します。

データの周辺尤度は以下のように表されます。

$$ p_\theta(\bm{x}) = \int p_\theta(\bm{x}|\bm{z}) p(\bm{z}) d\bm{z} $$

最尤推定の困難さ

理想的には対数尤度 $\log p_\theta(\bm{x})$ を最大化したいのですが、上の積分は $\bm{z}$ の空間が高次元の場合、計算が困難です。また、真の事後分布 $p_\theta(\bm{z}|\bm{x})$ も直接計算できません。

VAEは、この問題を変分推論を使って解決します。

ELBOの導出

変分推論のアイデア

計算困難な真の事後分布 $p_\theta(\bm{z}|\bm{x})$ を、パラメータ $\phi$ を持つ扱いやすい分布 $q_\phi(\bm{z}|\bm{x})$ で近似します。この $q_\phi(\bm{z}|\bm{x})$ を認識モデル(または推論ネットワーク、エンコーダ)と呼びます。

VAEでは $q_\phi(\bm{z}|\bm{x})$ を対角共分散を持つガウス分布として定義します。

$$ q_\phi(\bm{z}|\bm{x}) = \mathcal{N}(\bm{z}; \bm{\mu}_\phi(\bm{x}), \text{diag}(\bm{\sigma}_\phi^2(\bm{x}))) $$

ここで $\bm{\mu}_\phi(\bm{x})$, $\bm{\sigma}_\phi(\bm{x})$ はニューラルネットワーク(エンコーダ)で計算されます。

ELBOの導出

対数尤度 $\log p_\theta(\bm{x})$ に対して、以下の分解を行います。

$$ \begin{align} \log p_\theta(\bm{x}) &= \log p_\theta(\bm{x}) \int q_\phi(\bm{z}|\bm{x}) d\bm{z} \\ &= \int q_\phi(\bm{z}|\bm{x}) \log p_\theta(\bm{x}) d\bm{z} \\ &= \int q_\phi(\bm{z}|\bm{x}) \log \frac{p_\theta(\bm{x}, \bm{z})}{p_\theta(\bm{z}|\bm{x})} d\bm{z} \\ &= \int q_\phi(\bm{z}|\bm{x}) \log \frac{p_\theta(\bm{x}, \bm{z}) q_\phi(\bm{z}|\bm{x})}{p_\theta(\bm{z}|\bm{x}) q_\phi(\bm{z}|\bm{x})} d\bm{z} \\ &= \int q_\phi(\bm{z}|\bm{x}) \log \frac{p_\theta(\bm{x}, \bm{z})}{q_\phi(\bm{z}|\bm{x})} d\bm{z} + \int q_\phi(\bm{z}|\bm{x}) \log \frac{q_\phi(\bm{z}|\bm{x})}{p_\theta(\bm{z}|\bm{x})} d\bm{z} \\ &= \mathcal{L}(\theta, \phi; \bm{x}) + D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p_\theta(\bm{z}|\bm{x})) \end{align} $$

ここで、

$$ \mathcal{L}(\theta, \phi; \bm{x}) = \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log \frac{p_\theta(\bm{x}, \bm{z})}{q_\phi(\bm{z}|\bm{x})}\right] $$

ELBO(Evidence Lower Bound、エビデンス下界)と呼ばれます。

ELBOが下界である理由

KLダイバージェンスは常に非負なので、

$$ D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p_\theta(\bm{z}|\bm{x})) \geq 0 $$

したがって、

$$ \log p_\theta(\bm{x}) \geq \mathcal{L}(\theta, \phi; \bm{x}) $$

ELBOは対数尤度の下界になっており、ELBOを最大化することで対数尤度を間接的に最大化できます。

ELBOの2つの解釈

ELBOを変形すると、以下の2つの項に分解できます。

$$ \begin{align} \mathcal{L}(\theta, \phi; \bm{x}) &= \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log p_\theta(\bm{x}, \bm{z}) – \log q_\phi(\bm{z}|\bm{x})\right] \\ &= \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log p_\theta(\bm{x}|\bm{z})\right] + \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log \frac{p(\bm{z})}{q_\phi(\bm{z}|\bm{x})}\right] \\ &= \underbrace{\mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log p_\theta(\bm{x}|\bm{z})\right]}_{\text{再構成項}} – \underbrace{D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p(\bm{z}))}_{\text{正則化項}} \end{align} $$

再構成項:$\bm{z}$ から $\bm{x}$ を再構成できる度合い。デコーダの性能を表す。

正則化項(KL項):エンコーダの出力分布 $q_\phi(\bm{z}|\bm{x})$ を事前分布 $p(\bm{z}) = \mathcal{N}(\bm{0}, \bm{I})$ に近づける。潜在空間の正則化。

VAEの損失関数

ELBOの符号を反転した損失関数を最小化します。

$$ \mathcal{L}_{\text{VAE}} = -\mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log p_\theta(\bm{x}|\bm{z})\right] + D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p(\bm{z})) $$

$p_\theta(\bm{x}|\bm{z})$ をガウス分布と仮定すると、再構成項はMSE損失(またはBernoulli分布を仮定するとバイナリクロスエントロピー)に対応します。

KL項の解析的計算

$q_\phi(\bm{z}|\bm{x}) = \mathcal{N}(\bm{\mu}, \text{diag}(\bm{\sigma}^2))$ と $p(\bm{z}) = \mathcal{N}(\bm{0}, \bm{I})$ のKLダイバージェンスは解析的に計算できます。

潜在次元を $d$ として、

$$ D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p(\bm{z})) = \frac{1}{2} \sum_{j=1}^{d} \left( \mu_j^2 + \sigma_j^2 – \log \sigma_j^2 – 1 \right) $$

導出

2つの多変量ガウス分布間のKLダイバージェンスの公式から導出できます。

$$ D_{\text{KL}}(\mathcal{N}(\bm{\mu}_1, \bm{\Sigma}_1) \| \mathcal{N}(\bm{\mu}_2, \bm{\Sigma}_2)) = \frac{1}{2}\left[ \text{tr}(\bm{\Sigma}_2^{-1}\bm{\Sigma}_1) + (\bm{\mu}_2 – \bm{\mu}_1)^\top \bm{\Sigma}_2^{-1} (\bm{\mu}_2 – \bm{\mu}_1) – d + \log \frac{|\bm{\Sigma}_2|}{|\bm{\Sigma}_1|} \right] $$

$\bm{\mu}_1 = \bm{\mu}$, $\bm{\Sigma}_1 = \text{diag}(\bm{\sigma}^2)$, $\bm{\mu}_2 = \bm{0}$, $\bm{\Sigma}_2 = \bm{I}$ を代入すると、

$$ \begin{align} D_{\text{KL}} &= \frac{1}{2}\left[ \text{tr}(\text{diag}(\bm{\sigma}^2)) + \bm{\mu}^\top \bm{\mu} – d + \log \frac{1}{\prod_j \sigma_j^2} \right] \\ &= \frac{1}{2}\left[ \sum_j \sigma_j^2 + \sum_j \mu_j^2 – d – \sum_j \log \sigma_j^2 \right] \\ &= \frac{1}{2} \sum_{j=1}^{d} \left( \mu_j^2 + \sigma_j^2 – \log \sigma_j^2 – 1 \right) \end{align} $$

再パラメータ化トリック

問題:サンプリングの微分

ELBOの再構成項は、

$$ \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[\log p_\theta(\bm{x}|\bm{z})\right] $$

という期待値を含んでいます。これを最適化するには、$q_\phi(\bm{z}|\bm{x})$ からサンプリングした $\bm{z}$ を使ってこの期待値を近似し、$\phi$ について勾配を計算する必要があります。

しかし、「$q_\phi(\bm{z}|\bm{x})$ からのサンプリング」という操作は微分可能ではありません。$\bm{z} \sim \mathcal{N}(\bm{\mu}_\phi, \bm{\sigma}_\phi^2)$ というサンプリング操作を通じて $\phi$ への勾配を伝播できません。

解決策:再パラメータ化

サンプリング操作を、パラメータに依存しないノイズ $\bm{\epsilon} \sim \mathcal{N}(\bm{0}, \bm{I})$ を使って表現し直します。

$$ \bm{z} = \bm{\mu}_\phi(\bm{x}) + \bm{\sigma}_\phi(\bm{x}) \odot \bm{\epsilon}, \quad \bm{\epsilon} \sim \mathcal{N}(\bm{0}, \bm{I}) $$

ここで $\odot$ は要素ごとの積です。

この変換により、$\bm{z}$ は $\bm{\mu}_\phi$ と $\bm{\sigma}_\phi$ の決定論的な関数(+ ノイズ)となり、勾配を $\phi$ に伝播できるようになります。

対数分散の使用

数値安定性のため、エンコーダは $\bm{\sigma}$ ではなく $\log \bm{\sigma}^2$ を出力することが多いです。

$$ \bm{z} = \bm{\mu} + \exp\left(\frac{\log \bm{\sigma}^2}{2}\right) \odot \bm{\epsilon} = \bm{\mu} + \bm{\sigma} \odot \bm{\epsilon} $$

潜在空間の性質

連続性

KL正則化により、エンコーダの出力分布が標準正規分布に近づくことで、潜在空間は連続的になります。つまり、潜在空間の任意の点 $\bm{z}$ をデコードすると、意味のある出力が得られます。

通常のオートエンコーダでは、学習データの潜在表現の周りにしか意味のある構造がなく、それ以外の点をデコードすると崩壊した出力になりがちです。

補間

潜在空間が連続的であるため、2つのデータ点 $\bm{x}_1$, $\bm{x}_2$ の潜在表現 $\bm{z}_1$, $\bm{z}_2$ を線形補間することで、滑らかな遷移を生成できます。

$$ \bm{z}_{\alpha} = (1 – \alpha) \bm{z}_1 + \alpha \bm{z}_2, \quad \alpha \in [0, 1] $$

生成

事前分布 $p(\bm{z}) = \mathcal{N}(\bm{0}, \bm{I})$ からサンプリングした $\bm{z}$ をデコードすることで、新しいデータを生成できます。

PyTorchによる実装

エンコーダ

import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    """VAEのエンコーダ"""
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        """
        Args:
            x: (batch_size, input_dim)
        Returns:
            mu: (batch_size, latent_dim)
            logvar: (batch_size, latent_dim)
        """
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

デコーダ

class Decoder(nn.Module):
    """VAEのデコーダ"""
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        """
        Args:
            z: (batch_size, latent_dim)
        Returns:
            x_recon: (batch_size, output_dim)
        """
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        x_recon = torch.sigmoid(self.fc_out(h))  # 画像の場合は[0,1]にスケール
        return x_recon

VAE全体

class VAE(nn.Module):
    """変分オートエンコーダ"""
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        """再パラメータ化トリック"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        """
        Args:
            x: (batch_size, input_dim)
        Returns:
            x_recon: (batch_size, input_dim)
            mu: (batch_size, latent_dim)
            logvar: (batch_size, latent_dim)
        """
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    def sample(self, num_samples, device):
        """事前分布からサンプリングして画像を生成"""
        z = torch.randn(num_samples, self.encoder.fc_mu.out_features, device=device)
        return self.decoder(z)

損失関数

def vae_loss(x, x_recon, mu, logvar, beta=1.0):
    """
    VAEの損失関数(ELBO)

    Args:
        x: 入力データ
        x_recon: 再構成データ
        mu: エンコーダの平均出力
        logvar: エンコーダの対数分散出力
        beta: KL項の重み(beta-VAEの場合に調整)
    Returns:
        total_loss: 総損失
        recon_loss: 再構成損失
        kl_loss: KL損失
    """
    # 再構成損失(バイナリクロスエントロピー)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')

    # KL損失(解析的計算)
    # D_KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

学習ループ

def train_epoch(model, dataloader, optimizer, device):
    """1エポックの学習"""
    model.train()
    total_loss = 0
    total_recon = 0
    total_kl = 0

    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.view(-1, 784).to(device)

        optimizer.zero_grad()
        x_recon, mu, logvar = model(data)
        loss, recon, kl = vae_loss(data, x_recon, mu, logvar)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon.item()
        total_kl += kl.item()

    n = len(dataloader.dataset)
    return total_loss / n, total_recon / n, total_kl / n

潜在空間の可視化

import matplotlib.pyplot as plt
import numpy as np


def visualize_latent_space(model, dataloader, device):
    """2D潜在空間を可視化(latent_dim=2の場合)"""
    model.eval()
    z_list = []
    label_list = []

    with torch.no_grad():
        for data, labels in dataloader:
            data = data.view(-1, 784).to(device)
            mu, _ = model.encoder(data)
            z_list.append(mu.cpu().numpy())
            label_list.append(labels.numpy())

    z = np.concatenate(z_list, axis=0)
    labels = np.concatenate(label_list, axis=0)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap='tab10', alpha=0.5, s=2)
    plt.colorbar(scatter, label='Digit')
    plt.xlabel('$z_1$')
    plt.ylabel('$z_2$')
    plt.title('VAE Latent Space')
    plt.grid(True, alpha=0.3)
    plt.show()


def visualize_reconstruction(model, dataloader, device, num_images=10):
    """入力と再構成を並べて表示"""
    model.eval()
    data, _ = next(iter(dataloader))
    data = data[:num_images].view(-1, 784).to(device)

    with torch.no_grad():
        x_recon, _, _ = model(data)

    fig, axes = plt.subplots(2, num_images, figsize=(num_images * 1.5, 3))

    for i in range(num_images):
        # 入力画像
        axes[0, i].imshow(data[i].cpu().view(28, 28), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Input')

        # 再構成画像
        axes[1, i].imshow(x_recon[i].cpu().view(28, 28), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Recon')

    plt.tight_layout()
    plt.show()


def visualize_generated(model, device, num_images=10):
    """生成画像を表示"""
    model.eval()

    with torch.no_grad():
        samples = model.sample(num_images, device)

    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 1.5, 1.5))

    for i in range(num_images):
        axes[i].imshow(samples[i].cpu().view(28, 28), cmap='gray')
        axes[i].axis('off')

    plt.suptitle('Generated Samples')
    plt.tight_layout()
    plt.show()

潜在空間のグリッド探索

def visualize_latent_grid(model, device, grid_size=20, z_range=(-3, 3)):
    """2D潜在空間をグリッド状に探索して画像を生成"""
    model.eval()

    z1 = np.linspace(z_range[0], z_range[1], grid_size)
    z2 = np.linspace(z_range[0], z_range[1], grid_size)

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))

    for i, z1_val in enumerate(z1):
        for j, z2_val in enumerate(z2):
            z = torch.tensor([[z1_val, z2_val]], dtype=torch.float32, device=device)
            with torch.no_grad():
                x = model.decoder(z)
            axes[grid_size - 1 - j, i].imshow(x[0].cpu().view(28, 28), cmap='gray')
            axes[grid_size - 1 - j, i].axis('off')

    plt.suptitle('Latent Space Grid')
    plt.tight_layout()
    plt.show()

畳み込みVAE

画像データに対しては、畳み込み層を使ったVAEがより効果的です。

class ConvEncoder(nn.Module):
    """畳み込みエンコーダ"""
    def __init__(self, in_channels=1, latent_dim=20):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.fc = nn.Linear(128 * 4 * 4, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = h.view(h.size(0), -1)
        h = F.relu(self.fc(h))
        return self.fc_mu(h), self.fc_logvar(h)


class ConvDecoder(nn.Module):
    """畳み込みデコーダ"""
    def __init__(self, latent_dim=20, out_channels=1):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(32, out_channels, 3, stride=2, padding=1, output_padding=1)

    def forward(self, z):
        h = F.relu(self.fc(z))
        h = h.view(h.size(0), 128, 4, 4)
        h = F.relu(self.deconv1(h))
        h = F.relu(self.deconv2(h))
        return torch.sigmoid(self.deconv3(h))

beta-VAE

beta-VAEは、KL項に重み $\beta > 1$ をかけることで、より独立した潜在変数を学習する手法です。

$$ \mathcal{L}_{\beta\text{-VAE}} = \mathbb{E}_{q_\phi(\bm{z}|\bm{x})}\left[-\log p_\theta(\bm{x}|\bm{z})\right] + \beta \cdot D_{\text{KL}}(q_\phi(\bm{z}|\bm{x}) \| p(\bm{z})) $$

$\beta > 1$ とすることで、潜在変数間の独立性(disentanglement)が促進されます。これにより、各潜在次元が解釈可能な意味を持つようになることが期待されます。

Stable Diffusionでの使用

Stable Diffusionでは、VAEは以下の役割を果たします。

  1. エンコード: 画像 $\bm{x} \in \mathbb{R}^{512 \times 512 \times 3}$ を潜在表現 $\bm{z} \in \mathbb{R}^{64 \times 64 \times 4}$ に圧縮
  2. デコード: 生成された潜在表現 $\bm{z}$ を画像 $\bm{x}$ に復元

拡散過程はこの圧縮された潜在空間で行われるため、計算効率が大幅に向上します。

まとめ

本記事では、VAE(変分オートエンコーダ)の理論を解説しました。

  • 確率的生成モデル: VAEはデータの生成過程を潜在変数モデルとして定式化
  • ELBO: 計算困難な対数尤度の代わりに、その下界(ELBO)を最大化
  • 損失関数: 再構成項(デコーダの性能)とKL項(潜在空間の正則化)の和
  • 再パラメータ化トリック: サンプリング操作を微分可能な形に変換
  • 潜在空間: KL正則化により、連続的で滑らかな潜在空間が形成

VAEは、画像生成、表現学習、異常検知など多くの応用を持ち、Stable Diffusionなどの現代的な生成モデルの重要なコンポーネントとなっています。

次のステップとして、以下の記事も参考にしてください。