U-Netは、もともと医用画像のセグメンテーションのために2015年に提案されたニューラルネットワークアーキテクチャです。その後、画像生成タスク、特に拡散モデルにおいてノイズ予測ネットワークとして採用され、DDPM、Stable Diffusion、DALL-Eなど現代の画像生成モデルの中核を担っています。
拡散モデルにおけるU-Netは、オリジナルのU-Netを拡張し、タイムステップ条件付け、Self-Attention、Cross-Attentionなどを組み込んでいます。本記事では、U-Netの基本構造から拡散モデル向けの拡張、そしてPyTorchでの実装までを解説します。
本記事の内容
- U-Netの基本構造(エンコーダ・デコーダ・スキップ接続)
- 拡散モデル向けU-Netの拡張
- タイムステップ埋め込み
- Self-AttentionとCross-Attention
- PyTorchによるスクラッチ実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
U-Netの基本構造
オリジナルU-Net
U-Netは、その名の通りU字型の構造を持つネットワークです。
入力画像
↓
[エンコーダ(収縮パス)]
↓ ダウンサンプル
↓ ダウンサンプル
↓ ダウンサンプル
[ボトルネック]
↓ アップサンプル ← スキップ接続
↓ アップサンプル ← スキップ接続
↓ アップサンプル ← スキップ接続
[デコーダ(拡張パス)]
↓
出力
エンコーダ(収縮パス)
エンコーダは、入力画像の解像度を段階的に下げながら、チャネル数を増やして特徴を抽出します。
各ステージは以下の構造を持ちます。 – 2つの3×3畳み込み + ReLU – 2×2 Max Pooling(ダウンサンプル)
解像度: $H \times W \to H/2 \times W/2 \to H/4 \times W/4 \to \cdots$
チャネル: $64 \to 128 \to 256 \to 512 \to 1024$
デコーダ(拡張パス)
デコーダは、解像度を徐々に上げながら、詳細な出力を生成します。
各ステージは以下の構造を持ちます。 – 2×2転置畳み込み(アップサンプル) – エンコーダからのスキップ接続(特徴マップを結合) – 2つの3×3畳み込み + ReLU
スキップ接続
U-Netの最も重要な特徴はスキップ接続です。エンコーダの各解像度レベルの出力を、対応するデコーダの入力に直接結合します。
スキップ接続の利点: 1. 細部の保持: ダウンサンプルで失われがちな空間的な詳細を保持 2. 勾配伝播: 深いネットワークでも勾配が効率的に伝播 3. マルチスケール特徴: 低レベル(エッジ等)と高レベル(意味)の特徴を組み合わせ
拡散モデル向けU-Netの拡張
拡散モデルでは、オリジナルのU-Netに以下の拡張が加えられます。
1. タイムステップ条件付け
拡散過程のどのステップにいるか($t$)をネットワークに伝える必要があります。
2. 条件付け(テキスト等)
テキストプロンプトなどの条件をネットワークに注入します。
3. Attention機構
Self-AttentionとCross-Attentionを組み込み、長距離依存関係と条件の反映を可能にします。
4. ResBlock
単純な畳み込みブロックの代わりに、残差接続を持つResBlockを使用します。
拡散U-Netの全体構造
[入力: ノイズ付き潜在表現 z_t]
↓
[タイムステップ埋め込み t_emb]
↓
┌───────────────────────────────────────┐
│ ダウンブロック 1: ResBlock + Attention │
│ 解像度: 64x64 → 32x32 │
└───────────────────────────────────────┘
↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ダウンブロック 2: ResBlock + Attention │
│ 解像度: 32x32 → 16x16 │
└───────────────────────────────────────┘
↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ダウンブロック 3: ResBlock + Attention │
│ 解像度: 16x16 → 8x8 │
└───────────────────────────────────────┘
↓ ─────────────────→ スキップ接続
┌───────────────────────────────────────┐
│ ミドルブロック: ResBlock + Attention │
│ 解像度: 8x8(維持) │
└───────────────────────────────────────┘
↓
┌───────────────────────────────────────┐
│ アップブロック 1: ResBlock + Attention │ ← スキップ接続
│ 解像度: 8x8 → 16x16 │
└───────────────────────────────────────┘
↓
┌───────────────────────────────────────┐
│ アップブロック 2: ResBlock + Attention │ ← スキップ接続
│ 解像度: 16x16 → 32x32 │
└───────────────────────────────────────┘
↓
┌───────────────────────────────────────┐
│ アップブロック 3: ResBlock + Attention │ ← スキップ接続
│ 解像度: 32x32 → 64x64 │
└───────────────────────────────────────┘
↓
[出力: 予測ノイズ ε]
タイムステップ埋め込み
Sinusoidal Position Embedding
Transformerの位置エンコーディングと同様に、タイムステップ $t$ を連続的なベクトルに変換します。
$$ \text{PE}(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) $$
$$ \text{PE}(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) $$
ここで $d$ は埋め込み次元、$i$ は次元のインデックスです。
MLPによる変換
Sinusoidal埋め込みの後、MLPで非線形変換を行います。
$$ \bm{t}_{\text{emb}} = \text{SiLU}(\text{Linear}(\text{SiLU}(\text{Linear}(\text{PE}(t))))) $$
この埋め込みは各ResBlockに注入されます。
ResBlock
拡散モデルのResBlockは、タイムステップ条件付けを含む以下の構造を持ちます。
$$ \begin{align} \bm{h} &= \text{Conv}(\text{SiLU}(\text{GroupNorm}(\bm{x}))) \\ \bm{h} &= \bm{h} + \text{Linear}(\bm{t}_{\text{emb}}) \\ \bm{h} &= \text{Conv}(\text{SiLU}(\text{GroupNorm}(\bm{h}))) \\ \text{out} &= \bm{h} + \text{shortcut}(\bm{x}) \end{align} $$
ポイント: – GroupNorm: Batch Normalizationの代わりに使用。バッチサイズに依存しない – SiLU (Swish): $\text{SiLU}(x) = x \cdot \sigma(x)$ 滑らかな活性化関数 – タイムステップの加算: タイムステップ埋め込みを空間次元に放送して加算
Attention機構
Self-Attention
特徴マップ内の長距離依存関係を捉えます。
特徴マップ: (B, C, H, W)
↓ reshape
(B, H*W, C)
↓ Self-Attention
(B, H*W, C)
↓ reshape
(B, C, H, W)
空間的に離れた位置同士の関連を直接計算できるため、大域的な構造の生成に有効です。
Cross-Attention
テキスト条件など外部の情報を注入します。
$$ \text{CrossAttn}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d}}\right)\bm{V} $$
- Query: 画像特徴から生成 $\bm{Q} = \bm{W}^Q \bm{z}$
- Key, Value: テキスト埋め込みから生成 $\bm{K} = \bm{W}^K \bm{c}$, $\bm{V} = \bm{W}^V \bm{c}$
これにより、「”cat”という単語に対応する画像領域に猫を生成する」といった条件付けが可能になります。
PyTorchによる実装
タイムステップ埋め込み
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SinusoidalPositionEmbedding(nn.Module):
"""Sinusoidal位置エンコーディング"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
"""
Args:
t: (batch_size,) タイムステップ
Returns:
(batch_size, dim)
"""
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class TimeEmbedding(nn.Module):
"""タイムステップ埋め込み(Sinusoidal + MLP)"""
def __init__(self, dim, time_emb_dim):
super().__init__()
self.sinusoidal = SinusoidalPositionEmbedding(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
)
def forward(self, t):
emb = self.sinusoidal(t)
return self.mlp(emb)
ResBlock
class ResBlock(nn.Module):
"""残差ブロック(タイムステップ条件付き)"""
def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.0):
super().__init__()
# 第1畳み込み
self.norm1 = nn.GroupNorm(32, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
# タイムステップ射影
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels),
)
# 第2畳み込み
self.norm2 = nn.GroupNorm(32, out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# ショートカット接続
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
else:
self.shortcut = nn.Identity()
def forward(self, x, t_emb):
"""
Args:
x: (B, C, H, W) 入力特徴
t_emb: (B, time_emb_dim) タイムステップ埋め込み
"""
h = self.conv1(F.silu(self.norm1(x)))
# タイムステップを加算
h = h + self.time_mlp(t_emb)[:, :, None, None]
h = self.conv2(self.dropout(F.silu(self.norm2(h))))
return h + self.shortcut(x)
Attentionブロック
class AttentionBlock(nn.Module):
"""Self-Attention + Cross-Attention ブロック"""
def __init__(self, channels, num_heads=8, context_dim=None):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
# Self-Attention
self.norm1 = nn.GroupNorm(32, channels)
self.self_attn = nn.MultiheadAttention(
channels, num_heads, batch_first=True
)
# Cross-Attention(条件がある場合)
self.has_cross_attn = context_dim is not None
if self.has_cross_attn:
self.norm2 = nn.LayerNorm(channels)
self.cross_attn_q = nn.Linear(channels, channels)
self.cross_attn_k = nn.Linear(context_dim, channels)
self.cross_attn_v = nn.Linear(context_dim, channels)
self.cross_attn_out = nn.Linear(channels, channels)
# Feed-Forward
self.norm3 = nn.LayerNorm(channels)
self.ff = nn.Sequential(
nn.Linear(channels, channels * 4),
nn.GELU(),
nn.Linear(channels * 4, channels),
)
def forward(self, x, context=None):
"""
Args:
x: (B, C, H, W)
context: (B, seq_len, context_dim) テキスト埋め込み等
"""
B, C, H, W = x.shape
# (B, C, H, W) -> (B, H*W, C)
x_flat = x.view(B, C, -1).transpose(1, 2)
# Self-Attention
x_norm = self.norm1(x).view(B, C, -1).transpose(1, 2)
attn_out, _ = self.self_attn(x_norm, x_norm, x_norm)
x_flat = x_flat + attn_out
# Cross-Attention
if self.has_cross_attn and context is not None:
x_norm = self.norm2(x_flat)
q = self.cross_attn_q(x_norm)
k = self.cross_attn_k(context)
v = self.cross_attn_v(context)
# Scaled dot-product attention
scale = self.head_dim ** -0.5
attn_weights = torch.softmax(
torch.bmm(q, k.transpose(-2, -1)) * scale, dim=-1
)
cross_out = torch.bmm(attn_weights, v)
cross_out = self.cross_attn_out(cross_out)
x_flat = x_flat + cross_out
# Feed-Forward
x_flat = x_flat + self.ff(self.norm3(x_flat))
# (B, H*W, C) -> (B, C, H, W)
return x_flat.transpose(1, 2).view(B, C, H, W)
ダウンサンプル・アップサンプル
class Downsample(nn.Module):
"""ダウンサンプル(解像度を半分に)"""
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
"""アップサンプル(解像度を2倍に)"""
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
U-Net全体
class UNet(nn.Module):
"""拡散モデル用U-Net"""
def __init__(
self,
in_channels=4,
out_channels=4,
base_channels=128,
channel_mults=(1, 2, 4, 4),
num_res_blocks=2,
attention_resolutions=(16, 8),
num_heads=8,
context_dim=768,
dropout=0.0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
time_emb_dim = base_channels * 4
# タイムステップ埋め込み
self.time_embed = TimeEmbedding(base_channels, time_emb_dim)
# 入力畳み込み
self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
# ダウンサンプルブロック
self.down_blocks = nn.ModuleList()
self.down_samples = nn.ModuleList()
channels = base_channels
current_res = 64 # 入力解像度を仮定
for i, mult in enumerate(channel_mults):
out_ch = base_channels * mult
# ResBlocks
for _ in range(num_res_blocks):
block = ResBlock(channels, out_ch, time_emb_dim, dropout)
self.down_blocks.append(block)
channels = out_ch
# Attention(特定の解像度で)
if current_res in attention_resolutions:
self.down_blocks.append(
AttentionBlock(channels, num_heads, context_dim)
)
# ダウンサンプル(最後のレベルを除く)
if i < len(channel_mults) - 1:
self.down_samples.append(Downsample(channels))
current_res //= 2
# ミドルブロック
self.mid_block1 = ResBlock(channels, channels, time_emb_dim, dropout)
self.mid_attn = AttentionBlock(channels, num_heads, context_dim)
self.mid_block2 = ResBlock(channels, channels, time_emb_dim, dropout)
# アップサンプルブロック
self.up_blocks = nn.ModuleList()
self.up_samples = nn.ModuleList()
for i, mult in enumerate(reversed(channel_mults)):
out_ch = base_channels * mult
# ResBlocks(スキップ接続を受け取るので入力チャネルが2倍)
for j in range(num_res_blocks + 1):
in_ch = channels + (base_channels * mult if j == 0 else out_ch)
block = ResBlock(in_ch, out_ch, time_emb_dim, dropout)
self.up_blocks.append(block)
channels = out_ch
# Attention
if current_res in attention_resolutions:
self.up_blocks.append(
AttentionBlock(channels, num_heads, context_dim)
)
# アップサンプル(最後のレベルを除く)
if i < len(channel_mults) - 1:
self.up_samples.append(Upsample(channels))
current_res *= 2
# 出力畳み込み
self.output_norm = nn.GroupNorm(32, channels)
self.output_conv = nn.Conv2d(channels, out_channels, 3, padding=1)
# スキップ接続用のインデックスを記録
self.num_res_blocks = num_res_blocks
self.channel_mults = channel_mults
def forward(self, x, t, context=None):
"""
Args:
x: (B, in_channels, H, W) ノイズ付き入力
t: (B,) タイムステップ
context: (B, seq_len, context_dim) テキスト埋め込み等
Returns:
(B, out_channels, H, W) 予測ノイズ
"""
# タイムステップ埋め込み
t_emb = self.time_embed(t)
# 入力畳み込み
h = self.input_conv(x)
# スキップ接続を保存
skips = [h]
# ダウンサンプル
block_idx = 0
sample_idx = 0
for i, mult in enumerate(self.channel_mults):
for _ in range(self.num_res_blocks):
h = self.down_blocks[block_idx](h, t_emb)
block_idx += 1
skips.append(h)
# Attention
if block_idx < len(self.down_blocks) and isinstance(
self.down_blocks[block_idx], AttentionBlock
):
h = self.down_blocks[block_idx](h, context)
block_idx += 1
if i < len(self.channel_mults) - 1:
h = self.down_samples[sample_idx](h)
sample_idx += 1
skips.append(h)
# ミドル
h = self.mid_block1(h, t_emb)
h = self.mid_attn(h, context)
h = self.mid_block2(h, t_emb)
# アップサンプル
block_idx = 0
sample_idx = 0
for i, mult in enumerate(reversed(self.channel_mults)):
for j in range(self.num_res_blocks + 1):
skip = skips.pop()
h = torch.cat([h, skip], dim=1)
h = self.up_blocks[block_idx](h, t_emb)
block_idx += 1
# Attention
if block_idx < len(self.up_blocks) and isinstance(
self.up_blocks[block_idx], AttentionBlock
):
h = self.up_blocks[block_idx](h, context)
block_idx += 1
if i < len(self.channel_mults) - 1:
h = self.up_samples[sample_idx](h)
sample_idx += 1
# 出力
h = F.silu(self.output_norm(h))
return self.output_conv(h)
動作確認
# モデル作成
model = UNet(
in_channels=4,
out_channels=4,
base_channels=64,
channel_mults=(1, 2, 4),
num_res_blocks=2,
attention_resolutions=(16, 8),
num_heads=4,
context_dim=256,
)
# ダミー入力
batch_size = 2
x = torch.randn(batch_size, 4, 64, 64) # 潜在表現
t = torch.randint(0, 1000, (batch_size,)) # タイムステップ
context = torch.randn(batch_size, 77, 256) # テキスト埋め込み
# 順伝播
output = model(x, t, context)
print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
# パラメータ数
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}")
U-Netの設計上の工夫
GroupNormの使用
Batch Normalizationの代わりにGroupNormalizationを使用する理由: – バッチサイズに依存しない(小バッチでも安定) – 生成タスクでは統計量がバッチ内で大きく変動しうる
SiLU活性化関数
SiLU(Swish)は滑らかな活性化関数で、ReLUより勾配の流れが良好です。
$$ \text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} $$
解像度に応じたAttention
高解像度でのSelf-Attentionは計算コストが高い($O(n^2)$)ため、低解像度のレベルでのみAttentionを適用します。
Stable Diffusionでは、$64 \times 64$ の潜在表現に対して、$32 \times 32$、$16 \times 16$、$8 \times 8$ の解像度でAttentionを適用します。
まとめ
本記事では、拡散モデルで使用されるU-Netの仕組みを解説しました。
- U字構造: エンコーダで特徴を圧縮し、デコーダで復元。スキップ接続で細部を保持
- タイムステップ埋め込み: Sinusoidal埋め込み + MLPで拡散ステップを条件付け
- ResBlock: 残差接続 + タイムステップ加算で安定した学習
- Self-Attention: 空間的な長距離依存関係を捉える
- Cross-Attention: テキストなどの外部条件を注入
- GroupNorm + SiLU: 生成タスクに適した正規化と活性化
U-Netは拡散モデルの心臓部であり、その設計はStable Diffusion、DALL-E、Imagenなど多くの最先端モデルで採用されています。
次のステップとして、以下の記事も参考にしてください。