Stable Diffusionの仕組み — 拡散モデルの理論と全体像

Stable Diffusionは、2022年にStability AI、CompVis、Runwayが共同で発表したテキストから画像を生成する拡散モデルです。論文 “High-Resolution Image Synthesis with Latent Diffusion Models”(Rombach et al.)で提案されたLatent Diffusion Model(LDM)をベースとしており、オープンソースで公開されたことで爆発的に普及しました。

従来の拡散モデルがピクセル空間で動作するのに対し、Stable Diffusionは潜在空間(latent space)で拡散過程を行うことで、計算効率を大幅に向上させながら高品質な画像生成を実現しています。

本記事の内容

  • 拡散モデルの基本原理
  • Stable Diffusionの全体アーキテクチャ
  • 潜在空間での拡散(VAEの役割)
  • U-Netによるノイズ予測
  • テキスト条件付け(CLIPテキストエンコーダ)
  • サンプリングアルゴリズム

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

拡散モデルの基本原理

拡散過程(Forward Process)

拡散モデルの核心的なアイデアは、「画像にノイズを徐々に加えていき、最終的に純粋なガウスノイズにする」という過程を逆転させることです。

元の画像 $\bm{x}_0$ に対して、$T$ ステップにわたってノイズを加えていく拡散過程(forward process)を定義します。

$$ q(\bm{x}_t | \bm{x}_{t-1}) = \mathcal{N}(\bm{x}_t; \sqrt{1 – \beta_t} \bm{x}_{t-1}, \beta_t \bm{I}) $$

ここで $\beta_t$ は各ステップでのノイズスケジュールで、$0 < \beta_1 < \beta_2 < \cdots < \beta_T < 1$ と増加するように設定されます。

この過程を繰り返すと、$\bm{x}_T$ はほぼ純粋なガウスノイズ $\mathcal{N}(\bm{0}, \bm{I})$ になります。

重要な性質:任意のステップへの直接サンプリング

$\alpha_t = 1 – \beta_t$、$\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s$ と定義すると、$\bm{x}_0$ から任意のステップ $\bm{x}_t$ を一度に計算できます。

$$ \bm{x}_t = \sqrt{\bar{\alpha}_t} \bm{x}_0 + \sqrt{1 – \bar{\alpha}_t} \bm{\epsilon}, \quad \bm{\epsilon} \sim \mathcal{N}(\bm{0}, \bm{I}) $$

これは学習時に非常に重要で、逐次的にノイズを加える必要がなく、ランダムなステップ $t$ を選んで直接学習できます。

逆拡散過程(Reverse Process)

画像生成は、ノイズから元の画像を復元する逆拡散過程(reverse process)で行います。

$$ p_\theta(\bm{x}_{t-1} | \bm{x}_t) = \mathcal{N}(\bm{x}_{t-1}; \bm{\mu}_\theta(\bm{x}_t, t), \sigma_t^2 \bm{I}) $$

真の逆過程 $q(\bm{x}_{t-1} | \bm{x}_t, \bm{x}_0)$ は計算可能で、

$$ q(\bm{x}_{t-1} | \bm{x}_t, \bm{x}_0) = \mathcal{N}(\bm{x}_{t-1}; \tilde{\bm{\mu}}_t(\bm{x}_t, \bm{x}_0), \tilde{\beta}_t \bm{I}) $$

ここで、

$$ \tilde{\bm{\mu}}_t(\bm{x}_t, \bm{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 – \bar{\alpha}_t} \bm{x}_0 + \frac{\sqrt{\alpha_t}(1 – \bar{\alpha}_{t-1})}{1 – \bar{\alpha}_t} \bm{x}_t $$

ノイズ予測としての定式化

$\bm{x}_0$ を直接予測する代わりに、加えられたノイズ $\bm{\epsilon}$ を予測するようにニューラルネットワークを学習します。これがDDPM(Denoising Diffusion Probabilistic Models)のアイデアです。

$\bm{x}_t = \sqrt{\bar{\alpha}_t} \bm{x}_0 + \sqrt{1 – \bar{\alpha}_t} \bm{\epsilon}$ から $\bm{x}_0$ を表すと、

$$ \bm{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\bm{x}_t – \sqrt{1 – \bar{\alpha}_t} \bm{\epsilon}) $$

したがって、ノイズ予測モデル $\bm{\epsilon}_\theta(\bm{x}_t, t)$ を使って $\bm{x}_0$ を推定できます。

学習目標

最終的な学習目標は、予測ノイズと実際のノイズのMSE損失を最小化することです。

$$ \mathcal{L}_{\text{simple}} = \mathbb{E}_{t, \bm{x}_0, \bm{\epsilon}} \left[ \| \bm{\epsilon} – \bm{\epsilon}_\theta(\bm{x}_t, t) \|^2 \right] $$

ここで $t \sim \text{Uniform}\{1, \ldots, T\}$、$\bm{\epsilon} \sim \mathcal{N}(\bm{0}, \bm{I})$ です。

Stable Diffusionの全体アーキテクチャ

Stable Diffusionは、拡散モデルを潜在空間で実行することで効率化を図っています。

[テキストプロンプト]
    ↓
[CLIPテキストエンコーダ] → [テキスト埋め込み]
                                    ↓
[ランダムノイズ] → [U-Net(ノイズ予測)] ← [タイムステップ埋め込み]
    ↓            ↓(繰り返し)
[潜在表現 z_T] → [潜在表現 z_0]
                    ↓
                [VAEデコーダ]
                    ↓
                [生成画像]

主要コンポーネントは以下の3つです。

  1. VAE(Variational Autoencoder): 画像と潜在空間の変換
  2. U-Net: 潜在空間でのノイズ予測
  3. テキストエンコーダ(CLIP): テキスト条件の埋め込み

潜在空間での拡散の利点

ピクセル空間で $512 \times 512 \times 3 = 786,432$ 次元だったものが、潜在空間では $64 \times 64 \times 4 = 16,384$ 次元に圧縮されます(約48倍の圧縮)。

これにより、

  • 計算コストの大幅削減: U-Netの計算量が削減
  • メモリ使用量の削減: より大きなバッチサイズで学習可能
  • より本質的な特徴での拡散: ピクセルレベルの詳細ではなく、意味的な特徴を扱う

VAE:画像と潜在空間の変換

エンコーダ

VAEのエンコーダは画像を潜在表現に変換します。

$$ \bm{z} = \mathcal{E}(\bm{x}) $$

入力画像 $\bm{x} \in \mathbb{R}^{H \times W \times 3}$ を潜在表現 $\bm{z} \in \mathbb{R}^{h \times w \times c}$ に変換します。Stable Diffusionでは $H = W = 512$、$h = w = 64$、$c = 4$ が標準的です。

デコーダ

デコーダは潜在表現を画像に復元します。

$$ \tilde{\bm{x}} = \mathcal{D}(\bm{z}) $$

KL正則化

VAEは通常のオートエンコーダと異なり、潜在変数が標準正規分布に従うようKL正則化されています。これにより、潜在空間がスムーズで連続的になり、拡散過程が安定します。

$$ \mathcal{L}_{\text{VAE}} = \| \bm{x} – \mathcal{D}(\mathcal{E}(\bm{x})) \|^2 + \lambda \cdot D_{\text{KL}}(q(\bm{z}|\bm{x}) \| p(\bm{z})) $$

Stable DiffusionのVAEは画像の再構成に特化して事前学習されており、拡散過程の学習時は凍結(固定)されます。

U-Net:潜在空間でのノイズ予測

U-Netは、ノイズが加えられた潜在表現 $\bm{z}_t$ からノイズ $\bm{\epsilon}$ を予測するネットワークです。

U-Netの構造

[入力 z_t] ─→ [ダウンサンプル] ─→ [ミドル] ─→ [アップサンプル] ─→ [出力 ε]
                ↓                              ↑
                └─────── スキップ接続 ─────────┘

U-Netは以下の特徴を持ちます。

  1. エンコーダ・デコーダ構造: 解像度を下げながら特徴を抽出し、上げながら詳細を復元
  2. スキップ接続: エンコーダの各層の出力をデコーダの対応する層に接続し、細部の情報を保持
  3. Residual Block: 残差接続による安定した勾配伝播
  4. Self-Attention層: 空間的な長距離依存関係の捕捉
  5. Cross-Attention層: テキスト条件の注入

ResBlock + Attention Block

U-Netの各層は、ResBlockとAttention Blockの組み合わせで構成されます。

# 擬似コード
def layer(x, t_emb, context):
    # ResBlock: 特徴変換 + タイムステップ条件
    x = res_block(x, t_emb)

    # Self-Attention: 空間的な依存関係
    x = self_attention(x)

    # Cross-Attention: テキスト条件
    x = cross_attention(x, context)

    return x

Cross-Attentionによるテキスト条件付け

Cross-Attentionでは、画像特徴がQuery、テキスト埋め込みがKey・Valueとなります。

$$ \text{CrossAttn}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d}}\right)\bm{V} $$

ここで、

  • $\bm{Q} = \bm{W}^Q \bm{z}$(画像特徴から生成)
  • $\bm{K} = \bm{W}^K \bm{c}$(テキスト埋め込みから生成)
  • $\bm{V} = \bm{W}^V \bm{c}$(テキスト埋め込みから生成)

$\bm{c}$ はCLIPテキストエンコーダの出力です。

タイムステップ埋め込み

拡散ステップ $t$ はSinusoidal Position Embeddingで埋め込まれ、MLPで変換された後、各ResBlockに加算されます。

$$ \bm{t}_{\text{emb}} = \text{MLP}(\text{SinusoidalEmb}(t)) $$

テキストエンコーダ:CLIP

Stable DiffusionはCLIPのテキストエンコーダを使用してテキストプロンプトを埋め込みに変換します。

$$ \bm{c} = \text{CLIPTextEncoder}(\text{text}) $$

CLIPテキストエンコーダの出力は $(77, 768)$ の形状を持ち、77はコンテキスト長、768は埋め込み次元です。この系列全体がCross-AttentionのKey・Valueとして使用されます。

Stable Diffusion 1.x と 2.x の違い

  • SD 1.x: CLIP ViT-L/14のテキストエンコーダ(768次元)
  • SD 2.x: OpenCLIPのViT-H/14のテキストエンコーダ(1024次元)

サンプリングアルゴリズム

DDPMサンプリング

最も基本的なサンプリング方法です。$T$(通常1000)ステップかけて徐々にノイズを除去します。

$$ \bm{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(\bm{x}_t – \frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t}}\bm{\epsilon}_\theta(\bm{x}_t, t)\right) + \sigma_t \bm{z} $$

ここで $\bm{z} \sim \mathcal{N}(\bm{0}, \bm{I})$($t > 1$ の場合)です。

DDIMサンプリング

DDIM(Denoising Diffusion Implicit Models)は、決定論的なサンプリングを可能にし、ステップ数を大幅に削減できます。

$$ \bm{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\left(\frac{\bm{x}_t – \sqrt{1 – \bar{\alpha}_t} \bm{\epsilon}_\theta(\bm{x}_t, t)}{\sqrt{\bar{\alpha}_t}}\right)}_{\text{予測された } \bm{x}_0} + \sqrt{1 – \bar{\alpha}_{t-1} – \sigma_t^2} \cdot \bm{\epsilon}_\theta(\bm{x}_t, t) + \sigma_t \bm{z} $$

$\sigma_t = 0$ とすると完全に決定論的になり、同じ初期ノイズから常に同じ画像が生成されます。

Euler / Euler Ancestral

Euler法はODEソルバーに基づくサンプリング方法です。

$$ \bm{x}_{t-\Delta t} = \bm{x}_t + \Delta t \cdot \bm{v}_\theta(\bm{x}_t, t) $$

ここで $\bm{v}_\theta$ は速度場の予測です。

PyTorchによる簡易実装

拡散スケジューラ

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class DiffusionScheduler:
    """DDPMノイズスケジューラ"""
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps

        # 線形スケジュール
        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

    def add_noise(self, x0, noise, t):
        """x0にノイズを加えてxtを生成"""
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise

    def step(self, model_output, t, x_t):
        """1ステップのデノイジング"""
        alpha = self.alphas[t]
        alpha_cumprod = self.alphas_cumprod[t]
        alpha_cumprod_prev = self.alphas_cumprod[t - 1] if t > 0 else torch.tensor(1.0)
        beta = self.betas[t]

        # 予測されたx0
        pred_x0 = (x_t - torch.sqrt(1 - alpha_cumprod) * model_output) / torch.sqrt(alpha_cumprod)

        # 平均の計算
        coef1 = torch.sqrt(alpha_cumprod_prev) * beta / (1 - alpha_cumprod)
        coef2 = torch.sqrt(alpha) * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)
        mean = coef1 * pred_x0 + coef2 * x_t

        # 分散
        variance = beta * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)

        if t > 0:
            noise = torch.randn_like(x_t)
            return mean + torch.sqrt(variance) * noise
        else:
            return mean

簡易U-Net

class TimeEmbedding(nn.Module):
    """タイムステップ埋め込み"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, t):
        # Sinusoidal embedding
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.mlp(emb)


class ResBlock(nn.Module):
    """残差ブロック"""
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_dim, out_channels)
        self.norm1 = nn.GroupNorm(8, in_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t_emb):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        h = h + self.time_mlp(t_emb)[:, :, None, None]
        h = F.silu(self.norm2(h))
        h = self.conv2(h)
        return h + self.shortcut(x)


class CrossAttention(nn.Module):
    """Cross-Attention層"""
    def __init__(self, query_dim, context_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = query_dim // num_heads

        self.to_q = nn.Linear(query_dim, query_dim, bias=False)
        self.to_k = nn.Linear(context_dim, query_dim, bias=False)
        self.to_v = nn.Linear(context_dim, query_dim, bias=False)
        self.to_out = nn.Linear(query_dim, query_dim)

    def forward(self, x, context):
        b, c, h, w = x.shape
        x_flat = x.view(b, c, -1).transpose(1, 2)  # (b, h*w, c)

        q = self.to_q(x_flat)
        k = self.to_k(context)
        v = self.to_v(context)

        # マルチヘッド
        q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(attn, dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).contiguous().view(b, -1, c)
        out = self.to_out(out)
        return out.transpose(1, 2).view(b, c, h, w)


class SimpleUNet(nn.Module):
    """簡易U-Net(Stable Diffusion風)"""
    def __init__(self, in_channels=4, out_channels=4, base_dim=128,
                 time_dim=256, context_dim=768):
        super().__init__()

        self.time_embed = TimeEmbedding(time_dim)

        # エンコーダ
        self.down1 = ResBlock(in_channels, base_dim, time_dim)
        self.down2 = ResBlock(base_dim, base_dim * 2, time_dim)
        self.pool = nn.MaxPool2d(2)

        # ミドル
        self.mid = ResBlock(base_dim * 2, base_dim * 2, time_dim)
        self.cross_attn = CrossAttention(base_dim * 2, context_dim)

        # デコーダ
        self.up1 = ResBlock(base_dim * 4, base_dim, time_dim)
        self.up2 = ResBlock(base_dim * 2, base_dim, time_dim)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        self.out = nn.Conv2d(base_dim, out_channels, 1)

    def forward(self, x, t, context):
        """
        Args:
            x: (batch, 4, h, w) ノイズが加えられた潜在表現
            t: (batch,) タイムステップ
            context: (batch, seq_len, context_dim) テキスト埋め込み
        """
        t_emb = self.time_embed(t.float())

        # エンコーダ
        h1 = self.down1(x, t_emb)
        h2 = self.down2(self.pool(h1), t_emb)

        # ミドル + Cross-Attention
        h = self.mid(self.pool(h2), t_emb)
        h = h + self.cross_attn(h, context)

        # デコーダ
        h = self.upsample(h)
        h = self.up1(torch.cat([h, h2], dim=1), t_emb)
        h = self.upsample(h)
        h = self.up2(torch.cat([h, h1], dim=1), t_emb)

        return self.out(h)

学習ループ

def train_step(model, vae_encoder, text_encoder, optimizer, images, texts, scheduler):
    """1ステップの学習"""
    optimizer.zero_grad()

    batch_size = images.shape[0]
    device = images.device

    # VAEで潜在表現に変換
    with torch.no_grad():
        latents = vae_encoder(images)

    # テキストエンコード
    with torch.no_grad():
        text_embeddings = text_encoder(texts)

    # ランダムなタイムステップ
    t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)

    # ノイズを生成して加える
    noise = torch.randn_like(latents)
    noisy_latents = scheduler.add_noise(latents, noise, t)

    # ノイズを予測
    noise_pred = model(noisy_latents, t, text_embeddings)

    # MSE損失
    loss = F.mse_loss(noise_pred, noise)

    loss.backward()
    optimizer.step()

    return loss.item()

サンプリング

@torch.no_grad()
def sample(model, vae_decoder, text_encoder, scheduler, text, num_steps=50,
           latent_shape=(1, 4, 64, 64)):
    """テキストから画像を生成"""
    device = next(model.parameters()).device

    # テキストエンコード
    text_embeddings = text_encoder(text)

    # ランダムノイズから開始
    latents = torch.randn(latent_shape, device=device)

    # ステップを間引く(DDIM風)
    timesteps = torch.linspace(scheduler.num_timesteps - 1, 0, num_steps, dtype=torch.long)

    for t in timesteps:
        t_batch = t.expand(latents.shape[0]).to(device)

        # ノイズ予測
        noise_pred = model(latents, t_batch, text_embeddings)

        # 1ステップのデノイジング
        latents = scheduler.step(noise_pred, int(t.item()), latents)

    # VAEでデコード
    images = vae_decoder(latents)

    return images

Classifier-Free Guidanceとの関係

実際のStable Diffusionでは、生成品質を向上させるためにClassifier-Free Guidance(CFG)が使用されます。これについては別記事で詳しく解説します。

CFGの基本的なアイデアは、条件付き生成と無条件生成の差分を増幅することです。

$$ \tilde{\bm{\epsilon}}_\theta = \bm{\epsilon}_\theta(\bm{z}_t, \varnothing) + w \cdot (\bm{\epsilon}_\theta(\bm{z}_t, \bm{c}) – \bm{\epsilon}_\theta(\bm{z}_t, \varnothing)) $$

ここで $w$ はガイダンススケールで、大きいほどテキストに忠実な生成になります。

まとめ

本記事では、Stable Diffusionの仕組みを解説しました。

  • 潜在空間での拡散: VAEで画像を圧縮し、潜在空間で拡散過程を実行することで効率化
  • 拡散過程: ノイズを徐々に加え、逆過程でノイズを除去して画像を生成
  • ノイズ予測: U-Netが各ステップで加えられたノイズを予測
  • テキスト条件付け: CLIPテキストエンコーダの出力をCross-Attentionで注入
  • サンプリング: DDPM、DDIM、Eulerなど様々なサンプリング方法が存在

Stable Diffusionは、拡散モデルの計算効率問題を解決し、高品質な画像生成を一般に利用可能にした画期的なモデルです。

次のステップとして、以下の記事も参考にしてください。