ControlNetの仕組み — Zero Convolutionと条件付き画像生成

ControlNetは、2023年にZhangらが発表した、拡散モデルに空間的な条件付けを追加する手法です。論文 “Adding Conditional Control to Text-to-Image Diffusion Models” で提案され、エッジ検出結果、深度マップ、ポーズ推定結果などの追加条件を入力として、画像生成をより細かく制御できるようになりました。

テキストプロンプトだけでは難しい「この構図で」「このポーズで」「この輪郭で」といった空間的な制約を指定できることから、ControlNetはStable Diffusionのエコシステムにおいて非常に人気の高い拡張機能となっています。

本記事の内容

  • ControlNetの設計思想
  • Zero Convolutionによる学習の安定化
  • 各種条件(Canny、Pose、Depthなど)
  • PyTorchによる実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

ControlNetのアイデア

テキスト条件付けの限界

Stable DiffusionはCLIPテキストエンコーダを通じてテキスト条件付けを行いますが、テキストだけでは表現しにくい制約があります。

  • 「画面左上に人物を配置」のような空間的な指示
  • 「特定のポーズ」の詳細な指定
  • 「この写真と同じ構図で」という参照画像の活用

空間的条件の導入

ControlNetは、画像形式の追加条件を入力として受け取ります。

  • Canny Edge: エッジ検出結果(輪郭線)
  • OpenPose: 人体のポーズ推定結果(関節位置)
  • Depth: 深度推定結果
  • Normal Map: 法線マップ
  • Segmentation: セマンティックセグメンテーション
  • Scribble: ユーザーのラフスケッチ

これらの条件は、元の画像と同じ解像度を持ち、「どの位置に何を生成すべきか」という空間的な情報を提供します。

ControlNetのアーキテクチャ

基本構造

ControlNetは、事前学習済みのStable Diffusion U-Netのエンコーダ部分のコピーを作成し、追加条件を処理します。

[ノイズ付き潜在表現 z_t]    [追加条件(エッジ等)]
         ↓                        ↓
   [U-Net Encoder]  ←───── [ControlNet Encoder]
         ↓                        ↓
   [U-Net Middle]   ←───── [ControlNet Middle]
         ↓
   [U-Net Decoder]
         ↓
   [予測ノイズ ε]

ControlNetのエンコーダは、元のU-Netエンコーダと同じ構造を持ち、学習可能なコピーとして初期化されます。追加条件を処理した結果は、元のU-Netの対応する層に加算されます。

なぜコピーを使うのか

この設計には以下の利点があります。

  1. 事前学習済みの知識を保持: 元のU-Netは凍結(freeze)されるため、Stable Diffusionの生成能力は維持される
  2. 学習の効率化: ControlNetは追加条件の処理のみを学習すれば良い
  3. モジュール性: 異なる条件に対して別々のControlNetを学習し、組み合わせることが可能

Zero Convolution

ControlNetの重要な技術的工夫として、Zero Convolutionがあります。

学習開始時、ControlNetの出力がU-Netに影響を与えないよう、出力層の重みとバイアスをゼロで初期化した1×1畳み込み層を使用します。

$$ \bm{y} = \text{ZeroConv}(\bm{x}) = \bm{W} \cdot \bm{x} + \bm{b}, \quad \bm{W} = \bm{0}, \bm{b} = \bm{0} $$

学習初期: – ControlNetの出力はゼロなので、元のU-Netはそのまま動作 – Stable Diffusionの事前学習された生成能力が保持される

学習が進むと: – Zero Convの重みが徐々に更新される – ControlNetの条件が徐々に生成に反映される

この設計により、学習の初期段階で生成品質が崩壊することを防ぎます。

詳細なアーキテクチャ

条件入力の処理

追加条件(例:Cannyエッジ画像)は、まず小さなネットワークで潜在空間のサイズに変換されます。

# 条件エンコーダ(擬似コード)
condition_encoder = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(16, 16, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 解像度1/2
    nn.SiLU(),
    nn.Conv2d(32, 32, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(32, 96, 3, stride=2, padding=1),  # 解像度1/4
    nn.SiLU(),
    nn.Conv2d(96, 96, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(96, 256, 3, stride=2, padding=1),  # 解像度1/8 = 潜在空間と同サイズ
    nn.SiLU(),
    nn.Conv2d(256, 320, 3, padding=1),  # U-Netの入力チャネルに合わせる
)

ControlNetブロックの構造

各解像度レベルで、ControlNetのブロックは以下のように接続されます。

[U-Net入力層からの特徴 + 条件特徴]
         ↓
[ControlNet ResBlock]
         ↓
[ControlNet Attention(該当解像度の場合)]
         ↓
[Zero Conv]
         ↓
[U-Net の対応する層に加算]

複数スケールでの接続

ControlNetは、U-Netの複数の解像度レベルに接続されます。

  • エンコーダの各ダウンサンプル層の出力に接続
  • ミドルブロックの出力に接続

これにより、高レベル(意味的)と低レベル(詳細)の両方の特徴に条件を注入できます。

各種条件タイプ

Canny Edge

Cannyエッジ検出器で抽出された輪郭線を使用します。

import cv2

def get_canny_condition(image, low_threshold=100, high_threshold=200):
    """Cannyエッジ検出で条件画像を生成"""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    return edges

用途: 線画からの生成、イラストのスタイル変換など

OpenPose

人体のポーズ推定結果(関節位置と骨格)を使用します。

各関節を円で、関節間の接続を線で描画した画像が条件として使われます。

用途: 特定のポーズでの人物生成

Depth

深度推定モデル(MiDaSなど)で推定された深度マップを使用します。

グレースケールまたはカラーマップで近距離・遠距離を表現します。

用途: 3D構造を保持した生成、シーンの奥行き制御

Segmentation

セマンティックセグメンテーション結果(各ピクセルのクラス)を使用します。

各クラスを異なる色で塗り分けた画像が条件となります。

用途: シーンのレイアウト制御、特定オブジェクトの配置

Scribble / Sketch

ユーザーが描いたラフなスケッチを使用します。

Cannyよりも抽象的な入力でも機能するよう学習されています。

用途: 手描きスケッチからの画像生成

PyTorchによる実装

Zero Convolution

import torch
import torch.nn as nn


class ZeroConv(nn.Module):
    """ゼロ初期化された1x1畳み込み"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        # ゼロ初期化
        nn.init.zeros_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

条件エンコーダ

class ConditionEncoder(nn.Module):
    """追加条件を潜在空間サイズにエンコード"""
    def __init__(self, in_channels=3, out_channels=320):
        super().__init__()

        self.encoder = nn.Sequential(
            # 512x512 -> 256x256
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.SiLU(),

            # 256x256 -> 128x128
            nn.Conv2d(32, 32, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.SiLU(),

            # 128x128 -> 64x64
            nn.Conv2d(64, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.SiLU(),

            # 最終層
            nn.Conv2d(128, out_channels, 3, padding=1),
        )

    def forward(self, x):
        return self.encoder(x)

ControlNetブロック

class ControlNetBlock(nn.Module):
    """ControlNetの1ブロック"""
    def __init__(self, channels, time_emb_dim, num_heads=8, context_dim=768, use_attention=True):
        super().__init__()

        # ResBlock(U-Netと同じ構造)
        self.resblock = ResBlock(channels, channels, time_emb_dim)

        # Attention(オプション)
        self.use_attention = use_attention
        if use_attention:
            self.attention = AttentionBlock(channels, num_heads, context_dim)

        # Zero Convolution
        self.zero_conv = ZeroConv(channels, channels)

    def forward(self, x, t_emb, context=None):
        h = self.resblock(x, t_emb)

        if self.use_attention and context is not None:
            h = self.attention(h, context)

        return self.zero_conv(h)

ControlNet全体

class ControlNet(nn.Module):
    """ControlNetモジュール"""
    def __init__(
        self,
        condition_channels=3,
        base_channels=320,
        channel_mults=(1, 2, 4, 4),
        num_res_blocks=2,
        attention_resolutions=(4, 2, 1),
        num_heads=8,
        time_emb_dim=1280,
        context_dim=768,
    ):
        super().__init__()

        # 条件エンコーダ
        self.condition_encoder = ConditionEncoder(condition_channels, base_channels)

        # 入力畳み込みのコピー
        self.input_conv = nn.Conv2d(4, base_channels, 3, padding=1)
        self.input_zero_conv = ZeroConv(base_channels, base_channels)

        # ダウンブロック
        self.down_blocks = nn.ModuleList()
        self.down_zero_convs = nn.ModuleList()

        channels = base_channels
        current_res = 64  # 仮定

        for level, mult in enumerate(channel_mults):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks):
                use_attn = (current_res in attention_resolutions)
                self.down_blocks.append(
                    ControlNetBlock(channels, time_emb_dim, num_heads, context_dim, use_attn)
                )
                self.down_zero_convs.append(ZeroConv(channels, channels))
                channels = out_channels

            # ダウンサンプル(最後以外)
            if level < len(channel_mults) - 1:
                self.down_blocks.append(
                    nn.Conv2d(channels, channels, 3, stride=2, padding=1)
                )
                self.down_zero_convs.append(ZeroConv(channels, channels))
                current_res //= 2

        # ミドルブロック
        self.mid_block = ControlNetBlock(
            channels, time_emb_dim, num_heads, context_dim, True
        )
        self.mid_zero_conv = ZeroConv(channels, channels)

    def forward(self, x, t_emb, context, condition):
        """
        Args:
            x: (B, 4, H, W) ノイズ付き潜在表現
            t_emb: (B, time_emb_dim) タイムステップ埋め込み
            context: (B, seq_len, context_dim) テキスト埋め込み
            condition: (B, 3, H*8, W*8) 追加条件画像
        Returns:
            outputs: 各解像度レベルの出力リスト
        """
        outputs = []

        # 条件をエンコード
        cond_emb = self.condition_encoder(condition)

        # 入力処理(潜在表現 + 条件)
        h = self.input_conv(x) + cond_emb
        outputs.append(self.input_zero_conv(h))

        # ダウンブロック
        for block, zero_conv in zip(self.down_blocks, self.down_zero_convs):
            if isinstance(block, ControlNetBlock):
                h = h + block(h, t_emb, context)
            else:
                h = block(h)  # ダウンサンプル
            outputs.append(zero_conv(h))

        # ミドルブロック
        h = h + self.mid_block(h, t_emb, context)
        outputs.append(self.mid_zero_conv(h))

        return outputs

ControlNet付きU-Net

class UNetWithControlNet(nn.Module):
    """ControlNetを統合したU-Net"""
    def __init__(self, unet, controlnet):
        super().__init__()
        self.unet = unet
        self.controlnet = controlnet

        # U-Netの重みを凍結
        for param in self.unet.parameters():
            param.requires_grad = False

    def forward(self, x, t, context, condition, controlnet_scale=1.0):
        """
        Args:
            x: ノイズ付き潜在表現
            t: タイムステップ
            context: テキスト埋め込み
            condition: ControlNet条件
            controlnet_scale: ControlNetの出力に乗じるスケール
        """
        # タイムステップ埋め込み
        t_emb = self.unet.time_embed(t)

        # ControlNetの出力を取得
        controlnet_outputs = self.controlnet(x, t_emb, context, condition)

        # ControlNetの出力をスケーリング
        controlnet_outputs = [out * controlnet_scale for out in controlnet_outputs]

        # U-Netの順伝播にControlNet出力を注入
        # (実際の実装では、U-Netの各層に加算する)
        output = self.unet_forward_with_control(x, t_emb, context, controlnet_outputs)

        return output

    def unet_forward_with_control(self, x, t_emb, context, control_outputs):
        """ControlNet出力を統合したU-Net順伝播(簡略化版)"""
        # 実際には各層でcontrol_outputs[i]を加算
        # ここでは概念的な実装を示す
        h = self.unet.input_conv(x) + control_outputs[0]

        skips = [h]
        control_idx = 1

        # ダウンサンプル
        for block in self.unet.down_blocks:
            h = block(h, t_emb, context)
            if control_idx < len(control_outputs):
                h = h + control_outputs[control_idx]
                control_idx += 1
            skips.append(h)

        # ミドル
        h = self.unet.mid_block(h, t_emb, context)
        if control_idx < len(control_outputs):
            h = h + control_outputs[control_idx]

        # アップサンプル(通常のU-Net処理)
        for block in self.unet.up_blocks:
            skip = skips.pop()
            h = torch.cat([h, skip], dim=1)
            h = block(h, t_emb, context)

        return self.unet.output_conv(h)

ControlNetの使用例

推論パイプライン

def generate_with_controlnet(
    unet,
    controlnet,
    vae,
    text_encoder,
    scheduler,
    prompt,
    condition_image,
    guidance_scale=7.5,
    controlnet_scale=1.0,
    num_steps=50,
):
    """ControlNetを使った画像生成"""
    # テキストエンコード
    text_emb = text_encoder(prompt)
    null_emb = text_encoder("")

    # 条件画像の前処理
    condition = preprocess_condition(condition_image)

    # 潜在表現の初期化
    latents = torch.randn(1, 4, 64, 64)

    scheduler.set_timesteps(num_steps)

    for t in scheduler.timesteps:
        # CFG用にバッチを2倍
        latent_input = torch.cat([latents, latents])
        t_input = t.expand(2)
        context = torch.cat([null_emb, text_emb])

        # ControlNetの出力
        t_emb = unet.time_embed(t_input)
        control_outputs = controlnet(
            latent_input, t_emb, context,
            torch.cat([condition, condition])  # バッチ分複製
        )
        control_outputs = [out * controlnet_scale for out in control_outputs]

        # U-Netの予測
        noise_pred = unet_forward_with_control(
            latent_input, t_emb, context, control_outputs
        )

        # CFG
        noise_uncond, noise_cond = noise_pred.chunk(2)
        noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)

        # デノイジング
        latents = scheduler.step(noise_pred, t, latents)

    # VAEデコード
    image = vae.decode(latents)

    return image

複数のControlNetの組み合わせ

複数のControlNetを同時に使用することも可能です。例えば、Canny(輪郭)+ Depth(奥行き)を組み合わせることで、より詳細な制御ができます。

def apply_multiple_controlnets(controlnets, conditions, controlnet_scales):
    """複数のControlNetの出力を合成"""
    combined_outputs = None

    for controlnet, condition, scale in zip(controlnets, conditions, controlnet_scales):
        outputs = controlnet(x, t_emb, context, condition)
        outputs = [out * scale for out in outputs]

        if combined_outputs is None:
            combined_outputs = outputs
        else:
            combined_outputs = [c + o for c, o in zip(combined_outputs, outputs)]

    return combined_outputs

まとめ

本記事では、ControlNetの仕組みを解説しました。

  • 空間的条件付け: エッジ、ポーズ、深度などの画像形式の条件で生成を制御
  • アーキテクチャ: U-Netエンコーダのコピーを学習し、出力を元のU-Netに加算
  • Zero Convolution: 学習初期の安定性を確保するためのゼロ初期化
  • 事前学習の活用: 元のU-Netは凍結し、事前学習済みの生成能力を保持
  • モジュール性: 異なる条件に対して別々のControlNetを学習・組み合わせ可能

ControlNetは、テキストプロンプトだけでは表現困難な空間的な制約を指定できることから、画像生成のワークフローにおいて非常に重要なツールとなっています。

次のステップとして、以下の記事も参考にしてください。