【拡散モデル】U-Netアーキテクチャの理論と実装

U-Netは、もともと医用画像のセグメンテーションのために2015年に提案されたニューラルネットワークアーキテクチャです。その後、画像生成タスク、特に拡散モデルにおいてノイズ予測ネットワークとして採用され、DDPM、Stable Diffusion、DALL-Eなど現代の画像生成モデルの中核を担っています。

拡散モデルにおけるU-Netは、オリジナルのU-Netを拡張し、タイムステップ条件付けSelf-AttentionCross-Attentionなどを組み込んでいます。本記事では、U-Netの基本構造から拡散モデル向けの拡張、そしてPyTorchでの実装までを解説します。

本記事の内容

  • U-Netの基本構造(エンコーダ・デコーダ・スキップ接続)
  • 拡散モデル向けU-Netの拡張
  • タイムステップ埋め込み
  • Self-AttentionとCross-Attention
  • PyTorchによるスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

U-Netの基本構造

オリジナルU-Net

U-Netは、その名の通りU字型の構造を持つネットワークです。

入力画像
   ↓
[エンコーダ(収縮パス)]
   ↓ ダウンサンプル
   ↓ ダウンサンプル
   ↓ ダウンサンプル
[ボトルネック]
   ↓ アップサンプル ← スキップ接続
   ↓ アップサンプル ← スキップ接続
   ↓ アップサンプル ← スキップ接続
[デコーダ(拡張パス)]
   ↓
出力

エンコーダ(収縮パス)

エンコーダは、入力画像の解像度を段階的に下げながら、チャネル数を増やして特徴を抽出します。

各ステージは以下の構造を持ちます。 – 2つの3×3畳み込み + ReLU – 2×2 Max Pooling(ダウンサンプル)

解像度: $H \times W \to H/2 \times W/2 \to H/4 \times W/4 \to \cdots$

チャネル: $64 \to 128 \to 256 \to 512 \to 1024$

デコーダ(拡張パス)

デコーダは、解像度を徐々に上げながら、詳細な出力を生成します。

各ステージは以下の構造を持ちます。 – 2×2転置畳み込み(アップサンプル) – エンコーダからのスキップ接続(特徴マップを結合) – 2つの3×3畳み込み + ReLU

スキップ接続

U-Netの最も重要な特徴はスキップ接続です。エンコーダの各解像度レベルの出力を、対応するデコーダの入力に直接結合します。

スキップ接続の利点: 1. 細部の保持: ダウンサンプルで失われがちな空間的な詳細を保持 2. 勾配伝播: 深いネットワークでも勾配が効率的に伝播 3. マルチスケール特徴: 低レベル(エッジ等)と高レベル(意味)の特徴を組み合わせ

拡散モデル向けU-Netの拡張

拡散モデルでは、オリジナルのU-Netに以下の拡張が加えられます。

1. タイムステップ条件付け

拡散過程のどのステップにいるか($t$)をネットワークに伝える必要があります。

2. 条件付け(テキスト等)

テキストプロンプトなどの条件をネットワークに注入します。

3. Attention機構

Self-AttentionとCross-Attentionを組み込み、長距離依存関係と条件の反映を可能にします。

4. ResBlock

単純な畳み込みブロックの代わりに、残差接続を持つResBlockを使用します。

拡散U-Netの全体構造

[入力: ノイズ付き潜在表現 z_t]
        ↓
[タイムステップ埋め込み t_emb]
        ↓
┌───────────────────────────────────────┐
│ ダウンブロック 1: ResBlock + Attention │
│   解像度: 64x64 → 32x32              │
└───────────────────────────────────────┘
        ↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ダウンブロック 2: ResBlock + Attention │
│   解像度: 32x32 → 16x16              │
└───────────────────────────────────────┘
        ↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ダウンブロック 3: ResBlock + Attention │
│   解像度: 16x16 → 8x8                │
└───────────────────────────────────────┘
        ↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ミドルブロック: ResBlock + Attention   │
│   解像度: 8x8(維持)                 │
└───────────────────────────────────────┘
        ↓
┌───────────────────────────────────────┐
│ アップブロック 1: ResBlock + Attention │ ← スキップ接続
│   解像度: 8x8 → 16x16                │
└───────────────────────────────────────┘
        ↓
┌───────────────────────────────────────┐
│ アップブロック 2: ResBlock + Attention │ ← スキップ接続
│   解像度: 16x16 → 32x32              │
└───────────────────────────────────────┘
        ↓
┌───────────────────────────────────────┐
│ アップブロック 3: ResBlock + Attention │ ← スキップ接続
│   解像度: 32x32 → 64x64              │
└───────────────────────────────────────┘
        ↓
[出力: 予測ノイズ ε]

タイムステップ埋め込み

Sinusoidal Position Embedding

Transformerの位置エンコーディングと同様に、タイムステップ $t$ を連続的なベクトルに変換します。

$$ \text{PE}(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) $$

$$ \text{PE}(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) $$

ここで $d$ は埋め込み次元、$i$ は次元のインデックスです。

MLPによる変換

Sinusoidal埋め込みの後、MLPで非線形変換を行います。

$$ \bm{t}_{\text{emb}} = \text{SiLU}(\text{Linear}(\text{SiLU}(\text{Linear}(\text{PE}(t))))) $$

この埋め込みは各ResBlockに注入されます。

ResBlock

拡散モデルのResBlockは、タイムステップ条件付けを含む以下の構造を持ちます。

$$ \begin{align} \bm{h} &= \text{Conv}(\text{SiLU}(\text{GroupNorm}(\bm{x}))) \\ \bm{h} &= \bm{h} + \text{Linear}(\bm{t}_{\text{emb}}) \\ \bm{h} &= \text{Conv}(\text{SiLU}(\text{GroupNorm}(\bm{h}))) \\ \text{out} &= \bm{h} + \text{shortcut}(\bm{x}) \end{align} $$

ポイント: – GroupNorm: Batch Normalizationの代わりに使用。バッチサイズに依存しない – SiLU (Swish): $\text{SiLU}(x) = x \cdot \sigma(x)$ 滑らかな活性化関数 – タイムステップの加算: タイムステップ埋め込みを空間次元に放送して加算

Attention機構

Self-Attention

特徴マップ内の長距離依存関係を捉えます。

特徴マップ: (B, C, H, W)
    ↓ reshape
(B, H*W, C)
    ↓ Self-Attention
(B, H*W, C)
    ↓ reshape
(B, C, H, W)

空間的に離れた位置同士の関連を直接計算できるため、大域的な構造の生成に有効です。

Cross-Attention

テキスト条件など外部の情報を注入します。

$$ \text{CrossAttn}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d}}\right)\bm{V} $$

  • Query: 画像特徴から生成 $\bm{Q} = \bm{W}^Q \bm{z}$
  • Key, Value: テキスト埋め込みから生成 $\bm{K} = \bm{W}^K \bm{c}$, $\bm{V} = \bm{W}^V \bm{c}$

これにより、「”cat”という単語に対応する画像領域に猫を生成する」といった条件付けが可能になります。

PyTorchによる実装

タイムステップ埋め込み

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SinusoidalPositionEmbedding(nn.Module):
    """Sinusoidal位置エンコーディング"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        Args:
            t: (batch_size,) タイムステップ
        Returns:
            (batch_size, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb


class TimeEmbedding(nn.Module):
    """タイムステップ埋め込み(Sinusoidal + MLP)"""
    def __init__(self, dim, time_emb_dim):
        super().__init__()
        self.sinusoidal = SinusoidalPositionEmbedding(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

    def forward(self, t):
        emb = self.sinusoidal(t)
        return self.mlp(emb)

ResBlock

class ResBlock(nn.Module):
    """残差ブロック(タイムステップ条件付き)"""
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.0):
        super().__init__()

        # 第1畳み込み
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        # タイムステップ射影
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels),
        )

        # 第2畳み込み
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        # ショートカット接続
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t_emb):
        """
        Args:
            x: (B, C, H, W) 入力特徴
            t_emb: (B, time_emb_dim) タイムステップ埋め込み
        """
        h = self.conv1(F.silu(self.norm1(x)))

        # タイムステップを加算
        h = h + self.time_mlp(t_emb)[:, :, None, None]

        h = self.conv2(self.dropout(F.silu(self.norm2(h))))

        return h + self.shortcut(x)

Attentionブロック

class AttentionBlock(nn.Module):
    """Self-Attention + Cross-Attention ブロック"""
    def __init__(self, channels, num_heads=8, context_dim=None):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.head_dim = channels // num_heads

        # Self-Attention
        self.norm1 = nn.GroupNorm(32, channels)
        self.self_attn = nn.MultiheadAttention(
            channels, num_heads, batch_first=True
        )

        # Cross-Attention(条件がある場合)
        self.has_cross_attn = context_dim is not None
        if self.has_cross_attn:
            self.norm2 = nn.LayerNorm(channels)
            self.cross_attn_q = nn.Linear(channels, channels)
            self.cross_attn_k = nn.Linear(context_dim, channels)
            self.cross_attn_v = nn.Linear(context_dim, channels)
            self.cross_attn_out = nn.Linear(channels, channels)

        # Feed-Forward
        self.norm3 = nn.LayerNorm(channels)
        self.ff = nn.Sequential(
            nn.Linear(channels, channels * 4),
            nn.GELU(),
            nn.Linear(channels * 4, channels),
        )

    def forward(self, x, context=None):
        """
        Args:
            x: (B, C, H, W)
            context: (B, seq_len, context_dim) テキスト埋め込み等
        """
        B, C, H, W = x.shape

        # (B, C, H, W) -> (B, H*W, C)
        x_flat = x.view(B, C, -1).transpose(1, 2)

        # Self-Attention
        x_norm = self.norm1(x).view(B, C, -1).transpose(1, 2)
        attn_out, _ = self.self_attn(x_norm, x_norm, x_norm)
        x_flat = x_flat + attn_out

        # Cross-Attention
        if self.has_cross_attn and context is not None:
            x_norm = self.norm2(x_flat)
            q = self.cross_attn_q(x_norm)
            k = self.cross_attn_k(context)
            v = self.cross_attn_v(context)

            # Scaled dot-product attention
            scale = self.head_dim ** -0.5
            attn_weights = torch.softmax(
                torch.bmm(q, k.transpose(-2, -1)) * scale, dim=-1
            )
            cross_out = torch.bmm(attn_weights, v)
            cross_out = self.cross_attn_out(cross_out)
            x_flat = x_flat + cross_out

        # Feed-Forward
        x_flat = x_flat + self.ff(self.norm3(x_flat))

        # (B, H*W, C) -> (B, C, H, W)
        return x_flat.transpose(1, 2).view(B, C, H, W)

ダウンサンプル・アップサンプル

class Downsample(nn.Module):
    """ダウンサンプル(解像度を半分に)"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)


class Upsample(nn.Module):
    """アップサンプル(解像度を2倍に)"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

U-Net全体

class UNet(nn.Module):
    """拡散モデル用U-Net"""
    def __init__(
        self,
        in_channels=4,
        out_channels=4,
        base_channels=128,
        channel_mults=(1, 2, 4, 4),
        num_res_blocks=2,
        attention_resolutions=(16, 8),
        num_heads=8,
        context_dim=768,
        dropout=0.0,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        time_emb_dim = base_channels * 4

        # タイムステップ埋め込み
        self.time_embed = TimeEmbedding(base_channels, time_emb_dim)

        # 入力畳み込み
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # ダウンサンプルブロック
        self.down_blocks = nn.ModuleList()
        self.down_samples = nn.ModuleList()

        channels = base_channels
        current_res = 64  # 入力解像度を仮定

        for i, mult in enumerate(channel_mults):
            out_ch = base_channels * mult

            # ResBlocks
            for _ in range(num_res_blocks):
                block = ResBlock(channels, out_ch, time_emb_dim, dropout)
                self.down_blocks.append(block)
                channels = out_ch

                # Attention(特定の解像度で)
                if current_res in attention_resolutions:
                    self.down_blocks.append(
                        AttentionBlock(channels, num_heads, context_dim)
                    )

            # ダウンサンプル(最後のレベルを除く)
            if i < len(channel_mults) - 1:
                self.down_samples.append(Downsample(channels))
                current_res //= 2

        # ミドルブロック
        self.mid_block1 = ResBlock(channels, channels, time_emb_dim, dropout)
        self.mid_attn = AttentionBlock(channels, num_heads, context_dim)
        self.mid_block2 = ResBlock(channels, channels, time_emb_dim, dropout)

        # アップサンプルブロック
        self.up_blocks = nn.ModuleList()
        self.up_samples = nn.ModuleList()

        for i, mult in enumerate(reversed(channel_mults)):
            out_ch = base_channels * mult

            # ResBlocks(スキップ接続を受け取るので入力チャネルが2倍)
            for j in range(num_res_blocks + 1):
                in_ch = channels + (base_channels * mult if j == 0 else out_ch)
                block = ResBlock(in_ch, out_ch, time_emb_dim, dropout)
                self.up_blocks.append(block)
                channels = out_ch

                # Attention
                if current_res in attention_resolutions:
                    self.up_blocks.append(
                        AttentionBlock(channels, num_heads, context_dim)
                    )

            # アップサンプル(最後のレベルを除く)
            if i < len(channel_mults) - 1:
                self.up_samples.append(Upsample(channels))
                current_res *= 2

        # 出力畳み込み
        self.output_norm = nn.GroupNorm(32, channels)
        self.output_conv = nn.Conv2d(channels, out_channels, 3, padding=1)

        # スキップ接続用のインデックスを記録
        self.num_res_blocks = num_res_blocks
        self.channel_mults = channel_mults

    def forward(self, x, t, context=None):
        """
        Args:
            x: (B, in_channels, H, W) ノイズ付き入力
            t: (B,) タイムステップ
            context: (B, seq_len, context_dim) テキスト埋め込み等
        Returns:
            (B, out_channels, H, W) 予測ノイズ
        """
        # タイムステップ埋め込み
        t_emb = self.time_embed(t)

        # 入力畳み込み
        h = self.input_conv(x)

        # スキップ接続を保存
        skips = [h]

        # ダウンサンプル
        block_idx = 0
        sample_idx = 0
        for i, mult in enumerate(self.channel_mults):
            for _ in range(self.num_res_blocks):
                h = self.down_blocks[block_idx](h, t_emb)
                block_idx += 1
                skips.append(h)

                # Attention
                if block_idx < len(self.down_blocks) and isinstance(
                    self.down_blocks[block_idx], AttentionBlock
                ):
                    h = self.down_blocks[block_idx](h, context)
                    block_idx += 1

            if i < len(self.channel_mults) - 1:
                h = self.down_samples[sample_idx](h)
                sample_idx += 1
                skips.append(h)

        # ミドル
        h = self.mid_block1(h, t_emb)
        h = self.mid_attn(h, context)
        h = self.mid_block2(h, t_emb)

        # アップサンプル
        block_idx = 0
        sample_idx = 0
        for i, mult in enumerate(reversed(self.channel_mults)):
            for j in range(self.num_res_blocks + 1):
                skip = skips.pop()
                h = torch.cat([h, skip], dim=1)
                h = self.up_blocks[block_idx](h, t_emb)
                block_idx += 1

                # Attention
                if block_idx < len(self.up_blocks) and isinstance(
                    self.up_blocks[block_idx], AttentionBlock
                ):
                    h = self.up_blocks[block_idx](h, context)
                    block_idx += 1

            if i < len(self.channel_mults) - 1:
                h = self.up_samples[sample_idx](h)
                sample_idx += 1

        # 出力
        h = F.silu(self.output_norm(h))
        return self.output_conv(h)

動作確認

# モデル作成
model = UNet(
    in_channels=4,
    out_channels=4,
    base_channels=64,
    channel_mults=(1, 2, 4),
    num_res_blocks=2,
    attention_resolutions=(16, 8),
    num_heads=4,
    context_dim=256,
)

# ダミー入力
batch_size = 2
x = torch.randn(batch_size, 4, 64, 64)  # 潜在表現
t = torch.randint(0, 1000, (batch_size,))  # タイムステップ
context = torch.randn(batch_size, 77, 256)  # テキスト埋め込み

# 順伝播
output = model(x, t, context)

print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")

# パラメータ数
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}")

U-Netの設計上の工夫

GroupNormの使用

Batch Normalizationの代わりにGroupNormalizationを使用する理由: – バッチサイズに依存しない(小バッチでも安定) – 生成タスクでは統計量がバッチ内で大きく変動しうる

SiLU活性化関数

SiLU(Swish)は滑らかな活性化関数で、ReLUより勾配の流れが良好です。

$$ \text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} $$

解像度に応じたAttention

高解像度でのSelf-Attentionは計算コストが高い($O(n^2)$)ため、低解像度のレベルでのみAttentionを適用します。

Stable Diffusionでは、$64 \times 64$ の潜在表現に対して、$32 \times 32$、$16 \times 16$、$8 \times 8$ の解像度でAttentionを適用します。

まとめ

本記事では、拡散モデルで使用されるU-Netの仕組みを解説しました。

  • U字構造: エンコーダで特徴を圧縮し、デコーダで復元。スキップ接続で細部を保持
  • タイムステップ埋め込み: Sinusoidal埋め込み + MLPで拡散ステップを条件付け
  • ResBlock: 残差接続 + タイムステップ加算で安定した学習
  • Self-Attention: 空間的な長距離依存関係を捉える
  • Cross-Attention: テキストなどの外部条件を注入
  • GroupNorm + SiLU: 生成タスクに適した正規化と活性化

U-Netは拡散モデルの心臓部であり、その設計はStable Diffusion、DALL-E、Imagenなど多くの最先端モデルで採用されています。

次のステップとして、以下の記事も参考にしてください。