CLIPの対照学習を数式から理解してPythonで実装する

CLIP(Contrastive Language-Image Pre-training)は、2021年にOpenAIが発表したマルチモーダル学習モデルです。論文 “Learning Transferable Visual Models From Natural Language Supervision”(Radford et al.)で提案され、4億組の画像-テキストペアから対照学習を行うことで、事前に見たことのないカテゴリの画像も分類できるという驚異的なゼロショット能力を獲得しました。

CLIPは、DALL-E、Stable Diffusion、Midjourney、GPT-4Vなど、現代の画像生成・マルチモーダルAIの基盤技術となっています。本記事では、CLIPの理論的背景から対照学習の数式、ゼロショット分類の仕組み、そしてPyTorchでの実装までを解説します。

本記事の内容

  • CLIPの設計思想と全体アーキテクチャ
  • 対照学習とInfoNCE損失の数式
  • 画像エンコーダとテキストエンコーダ
  • ゼロショット分類の仕組み
  • PyTorchによるスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

CLIPのアイデア:自然言語を監督信号として使う

従来のアプローチの限界

従来の画像分類モデル(ResNet、ViTなど)は、ImageNetのような手動でラベル付けされたデータセットで学習されます。このアプローチには以下の限界があります。

  1. スケーラビリティ: 人手でのラベル付けはコストが高く、カテゴリ数の拡張が困難
  2. 固定カテゴリ: 学習時に定義されたカテゴリしか分類できない
  3. ドメイン特化: 特定のタスク・ドメインに過度に最適化される

自然言語による監督

CLIPは、インターネット上に存在する画像とそれに付随するテキスト(キャプション、alt text等)のペアを学習データとして使用します。これにより、

  1. 大規模データの活用: 人手でのラベル付けなしに4億組のデータを収集
  2. オープンな語彙: 固定カテゴリではなく、任意のテキスト記述で分類可能
  3. 汎化性能: 多様なドメインの概念を学習

という利点が得られます。

CLIPの全体アーキテクチャ

CLIPは、画像エンコーダテキストエンコーダの2つのエンコーダから構成されます。

[画像] → [Image Encoder] → [画像埋め込み] ─┐
                                           ├→ [対照学習]
[テキスト] → [Text Encoder] → [テキスト埋め込み] ─┘

両エンコーダは、画像とテキストを同じ次元の埋め込み空間に写像します。対照学習により、対応する画像-テキストペアの埋め込みが近づき、対応しないペアは離れるように学習されます。

画像エンコーダ

画像エンコーダには以下の選択肢があります。

  1. ResNet: 修正版ResNet(ResNet-50, ResNet-101など)
  2. Vision Transformer (ViT): ViT-B/32, ViT-B/16, ViT-L/14など

原論文では、最も性能が高かったのはViT-L/14(パッチサイズ14のViT-Large)で、入力解像度336pxの場合に最高性能を達成しました。

画像エンコーダの出力は、CLSトークンの出力(またはResNetの最終プーリング出力)を線形変換して得られる $d$ 次元のベクトルです。

$$ \bm{z}_{\text{img}} = \text{Proj}_I(\text{ImageEncoder}(\bm{x}_{\text{img}})) $$

テキストエンコーダ

テキストエンコーダはTransformerを使用します。入力テキストをトークナイズし、Transformer Encoderで処理した後、EOT(End of Text)トークンの出力を使用します。

$$ \bm{z}_{\text{text}} = \text{Proj}_T(\text{TextEncoder}(\bm{x}_{\text{text}})) $$

テキストエンコーダは12層、幅512、8ヘッドのTransformerで、語彙サイズは49,152です。コンテキスト長は77トークンに制限されています。

対照学習とInfoNCE損失

対照学習の直感

対照学習の目標は、「正しいペア(画像とそのキャプション)」を「間違ったペア(画像と無関係なキャプション)」から区別することです。

バッチサイズ $N$ のミニバッチを考えます。$N$ 個の画像と $N$ 個のテキストがあり、$i$ 番目の画像と $i$ 番目のテキストが正しいペアです。つまり、$N$ 個の正例と $N \times N – N = N(N-1)$ 個の負例が存在します。

コサイン類似度

画像埋め込み $\bm{z}_i^I$ とテキスト埋め込み $\bm{z}_j^T$ のコサイン類似度を計算します。

$$ s_{ij} = \frac{\langle \bm{z}_i^I, \bm{z}_j^T \rangle}{\|\bm{z}_i^I\| \cdot \|\bm{z}_j^T\|} $$

ここで $\langle \cdot, \cdot \rangle$ は内積、$\| \cdot \|$ はL2ノルムです。

実装上は、埋め込みをL2正規化した後に内積を計算します。

$$ \tilde{\bm{z}}_i^I = \frac{\bm{z}_i^I}{\|\bm{z}_i^I\|}, \quad s_{ij} = \langle \tilde{\bm{z}}_i^I, \tilde{\bm{z}}_j^T \rangle $$

温度パラメータ

類似度に学習可能な温度パラメータ $\tau$ を適用します。

$$ \text{logits}_{ij} = s_{ij} / \tau $$

温度が低いほど類似度の差が強調され、高いほど平滑化されます。CLIPでは $\tau$ を学習パラメータとし、対数スケール $\log \tau$ で最適化します(初期値は $\tau = 0.07$、つまり $\log \tau \approx -2.66$)。

InfoNCE損失

バッチ内の全ペアに対してクロスエントロピー損失を計算します。

画像→テキストの損失(各画像について、正しいテキストを $N$ 個のテキストから識別):

$$ \mathcal{L}_{\text{I2T}} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(s_{ii} / \tau)}{\sum_{j=1}^{N} \exp(s_{ij} / \tau)} $$

テキスト→画像の損失(各テキストについて、正しい画像を $N$ 個の画像から識別):

$$ \mathcal{L}_{\text{T2I}} = -\frac{1}{N} \sum_{j=1}^{N} \log \frac{\exp(s_{jj} / \tau)}{\sum_{i=1}^{N} \exp(s_{ij} / \tau)} $$

最終的な損失は両者の平均です。

$$ \mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{\text{I2T}} + \mathcal{L}_{\text{T2I}}) $$

これはInfoNCE損失(Noise Contrastive Estimation)とも呼ばれ、相互情報量の下界を最大化することと等価です。

損失関数の直感的理解

InfoNCE損失を行列形式で見ると理解しやすくなります。類似度行列 $\bm{S} \in \mathbb{R}^{N \times N}$ を考えると、

$$ \bm{S} = \begin{pmatrix} s_{11} & s_{12} & \cdots & s_{1N} \\ s_{21} & s_{22} & \cdots & s_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ s_{N1} & s_{N2} & \cdots & s_{NN} \end{pmatrix} $$

対角成分 $s_{ii}$ が正例(マッチするペア)、非対角成分 $s_{ij}$ ($i \neq j$) が負例です。

$\mathcal{L}_{\text{I2T}}$ は各行に対するクロスエントロピー(正解ラベルは対角成分)、$\mathcal{L}_{\text{T2I}}$ は各列に対するクロスエントロピーです。

ゼロショット分類

CLIPの最も強力な特性は、事前に学習していないカテゴリの画像も分類できることです。

分類の手順

  1. 分類したいカテゴリ(例: “cat”, “dog”, “bird”)をテンプレートに埋め込んでテキストを生成 – 例: “a photo of a cat”, “a photo of a dog”, “a photo of a bird”
  2. 各テキストをテキストエンコーダで埋め込みに変換
  3. 分類対象の画像を画像エンコーダで埋め込みに変換
  4. 画像埋め込みと各テキスト埋め込みのコサイン類似度を計算
  5. 最も類似度の高いテキストに対応するカテゴリを予測

数式での表現

カテゴリ集合 $\mathcal{C} = \{c_1, c_2, \dots, c_K\}$ に対して、

$$ \hat{c} = \arg\max_{c \in \mathcal{C}} \frac{\langle \bm{z}_{\text{img}}, \bm{z}_c \rangle}{\|\bm{z}_{\text{img}}\| \cdot \|\bm{z}_c\|} $$

ここで $\bm{z}_c = \text{TextEncoder}(\text{prompt}(c))$ です。

プロンプトエンジニアリング

テキストのテンプレート(プロンプト)の選び方が性能に大きく影響します。

  • 単純: “cat”, “dog”
  • テンプレート: “a photo of a {class}”
  • 詳細: “a photo of a {class}, a type of pet”

原論文では、80種類のテンプレートを用意し、それらの埋め込みを平均(アンサンブル)することで性能を向上させています。

PyTorchによる実装

CLIPの簡易実装

import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    """画像をパッチに分割して埋め込み"""
    def __init__(self, img_size=224, patch_size=32, in_channels=3, embed_dim=512):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class VisionEncoder(nn.Module):
    """Vision Transformer ベースの画像エンコーダ"""
    def __init__(
        self, img_size=224, patch_size=32, in_channels=3,
        embed_dim=512, depth=6, num_heads=8
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.0, activation='gelu', batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.ln_post = nn.LayerNorm(embed_dim)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        batch_size = x.shape[0]

        # パッチ埋め込み
        x = self.patch_embed(x)

        # CLSトークンを追加
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # 位置埋め込み
        x = x + self.pos_embed

        # Transformer
        x = self.transformer(x)

        # CLSトークンの出力
        x = self.ln_post(x[:, 0])

        return x


class TextEncoder(nn.Module):
    """Transformer ベースのテキストエンコーダ"""
    def __init__(
        self, vocab_size=49408, context_length=77,
        embed_dim=512, depth=6, num_heads=8
    ):
        super().__init__()
        self.context_length = context_length

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Parameter(
            torch.zeros(context_length, embed_dim)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.0, activation='gelu', batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.ln_final = nn.LayerNorm(embed_dim)

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

    def forward(self, text, eot_indices):
        """
        Args:
            text: (batch_size, context_length) トークンID
            eot_indices: (batch_size,) EOTトークンの位置
        """
        # トークン埋め込み + 位置埋め込み
        x = self.token_embedding(text) + self.positional_embedding

        # 因果マスク(オプション、CLIPでは使用)
        mask = self._generate_causal_mask(text.size(1)).to(text.device)

        # Transformer
        x = self.transformer(x, mask=mask)

        # Layer Norm
        x = self.ln_final(x)

        # EOTトークンの出力を取得
        batch_indices = torch.arange(x.size(0), device=x.device)
        x = x[batch_indices, eot_indices]

        return x

    def _generate_causal_mask(self, size):
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask


class CLIP(nn.Module):
    """CLIP モデル"""
    def __init__(
        self,
        img_size=224,
        patch_size=32,
        in_channels=3,
        embed_dim=512,
        vision_depth=6,
        vision_heads=8,
        vocab_size=49408,
        context_length=77,
        text_depth=6,
        text_heads=8,
    ):
        super().__init__()

        self.visual = VisionEncoder(
            img_size, patch_size, in_channels,
            embed_dim, vision_depth, vision_heads
        )
        self.text = TextEncoder(
            vocab_size, context_length,
            embed_dim, text_depth, text_heads
        )

        # 射影層
        self.visual_projection = nn.Linear(embed_dim, embed_dim, bias=False)
        self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False)

        # 学習可能な温度パラメータ(対数スケール)
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))

    def encode_image(self, image):
        """画像を埋め込みに変換"""
        x = self.visual(image)
        x = self.visual_projection(x)
        return x

    def encode_text(self, text, eot_indices):
        """テキストを埋め込みに変換"""
        x = self.text(text, eot_indices)
        x = self.text_projection(x)
        return x

    def forward(self, image, text, eot_indices):
        """
        Args:
            image: (batch_size, channels, height, width)
            text: (batch_size, context_length)
            eot_indices: (batch_size,)
        Returns:
            logits_per_image: (batch_size, batch_size)
            logits_per_text: (batch_size, batch_size)
        """
        # エンコード
        image_features = self.encode_image(image)
        text_features = self.encode_text(text, eot_indices)

        # L2正規化
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        # コサイン類似度を計算(温度でスケーリング)
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text


def clip_loss(logits_per_image, logits_per_text):
    """CLIP の対照学習損失(InfoNCE)"""
    batch_size = logits_per_image.size(0)
    labels = torch.arange(batch_size, device=logits_per_image.device)

    # 画像→テキストの損失
    loss_i2t = F.cross_entropy(logits_per_image, labels)
    # テキスト→画像の損失
    loss_t2i = F.cross_entropy(logits_per_text, labels)

    return (loss_i2t + loss_t2i) / 2

動作確認

# モデル作成
model = CLIP(
    img_size=224,
    patch_size=32,
    embed_dim=256,
    vision_depth=4,
    vision_heads=4,
    context_length=32,
    text_depth=4,
    text_heads=4,
)

# ダミー入力
batch_size = 4
images = torch.randn(batch_size, 3, 224, 224)
texts = torch.randint(0, 1000, (batch_size, 32))
eot_indices = torch.tensor([10, 15, 8, 12])  # EOTトークンの位置

# 順伝播
logits_per_image, logits_per_text = model(images, texts, eot_indices)

print(f"画像入力形状: {images.shape}")
print(f"テキスト入力形状: {texts.shape}")
print(f"logits_per_image形状: {logits_per_image.shape}")
print(f"logits_per_text形状: {logits_per_text.shape}")

# 損失計算
loss = clip_loss(logits_per_image, logits_per_text)
print(f"CLIP損失: {loss.item():.4f}")

# パラメータ数
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}")

ゼロショット分類の実装

def zero_shot_classify(model, image, class_names, templates=None):
    """
    ゼロショット分類を実行

    Args:
        model: CLIPモデル
        image: (1, C, H, W) 分類対象の画像
        class_names: クラス名のリスト
        templates: テンプレートのリスト(省略時はシンプルなテンプレート)
    Returns:
        predicted_class: 予測されたクラス名
        probabilities: 各クラスの確率
    """
    if templates is None:
        templates = ["a photo of a {}"]

    model.eval()
    with torch.no_grad():
        # 画像の埋め込み
        image_features = model.encode_image(image)
        image_features = F.normalize(image_features, dim=-1)

        # 各クラスのテキスト埋め込み(テンプレートの平均)
        class_embeddings = []
        for class_name in class_names:
            texts_for_class = [t.format(class_name) for t in templates]
            # 簡易的なトークン化(実際はBPEトークナイザを使用)
            # ここではダミーのトークンIDを使用
            text_tokens = torch.randint(0, 1000, (len(texts_for_class), 32))
            eot_indices = torch.tensor([10] * len(texts_for_class))

            text_features = model.encode_text(text_tokens, eot_indices)
            text_features = F.normalize(text_features, dim=-1)
            text_features = text_features.mean(dim=0)  # テンプレートの平均
            class_embeddings.append(text_features)

        class_embeddings = torch.stack(class_embeddings)  # (num_classes, embed_dim)
        class_embeddings = F.normalize(class_embeddings, dim=-1)

        # コサイン類似度
        similarities = image_features @ class_embeddings.t()  # (1, num_classes)
        probabilities = F.softmax(similarities * model.logit_scale.exp(), dim=-1)

    predicted_idx = probabilities.argmax(dim=-1).item()
    predicted_class = class_names[predicted_idx]

    return predicted_class, probabilities[0]


# 使用例
class_names = ["cat", "dog", "bird", "car", "airplane"]
dummy_image = torch.randn(1, 3, 224, 224)

predicted, probs = zero_shot_classify(model, dummy_image, class_names)
print(f"予測クラス: {predicted}")
print(f"確率分布: {probs}")

CLIPの類似度可視化

画像-テキスト間の類似度行列を可視化してみましょう。

import matplotlib.pyplot as plt
import numpy as np


def visualize_similarity_matrix(model, images, texts, eot_indices, image_labels, text_labels):
    """類似度行列を可視化"""
    model.eval()
    with torch.no_grad():
        logits_per_image, _ = model(images, texts, eot_indices)
        similarities = logits_per_image.softmax(dim=-1).cpu().numpy()

    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(similarities, cmap='Blues')

    ax.set_xticks(range(len(text_labels)))
    ax.set_yticks(range(len(image_labels)))
    ax.set_xticklabels(text_labels, rotation=45, ha='right')
    ax.set_yticklabels(image_labels)
    ax.set_xlabel('Text')
    ax.set_ylabel('Image')
    ax.set_title('Image-Text Similarity (CLIP)')

    # 値を表示
    for i in range(len(image_labels)):
        for j in range(len(text_labels)):
            ax.text(j, i, f'{similarities[i, j]:.2f}',
                    ha='center', va='center', fontsize=10)

    plt.colorbar(im, ax=ax, shrink=0.8)
    plt.tight_layout()
    plt.show()


# 可視化例
batch_size = 4
images = torch.randn(batch_size, 3, 224, 224)
texts = torch.randint(0, 1000, (batch_size, 32))
eot_indices = torch.tensor([10, 10, 10, 10])

image_labels = ["Image 1", "Image 2", "Image 3", "Image 4"]
text_labels = ["Text 1", "Text 2", "Text 3", "Text 4"]

visualize_similarity_matrix(model, images, texts, eot_indices, image_labels, text_labels)

CLIPの応用と発展

画像生成への応用

CLIPは画像生成モデルの重要なコンポーネントとして使われています。

  • DALL-E 2: CLIPの画像埋め込みをpriorモデルで生成し、それをデコードして画像生成
  • Stable Diffusion: CLIPのテキストエンコーダを条件付け用のテキストエンコーダとして使用

マルチモーダル理解

CLIPの成功は、マルチモーダルLLM(GPT-4V、LLaVAなど)の発展にも貢献しています。画像とテキストを統一的な表現空間で扱うという考え方は、現代のマルチモーダルAIの基盤となっています。

限界

  • 細粒度の理解: 物体のカウント、空間関係、属性の理解が苦手
  • テキスト読解: 画像内のテキストの読み取りが困難
  • 構成的理解: “a red cube on a blue sphere” と “a blue cube on a red sphere” の区別が苦手

まとめ

本記事では、CLIPの仕組みを解説しました。

  • 自然言語監督: 人手ラベルなしに画像-テキストペアから学習
  • 対照学習: InfoNCE損失により、正しいペアの類似度を高め、間違ったペアの類似度を下げる
  • デュアルエンコーダ: 画像エンコーダとテキストエンコーダが同じ埋め込み空間に写像
  • 温度パラメータ: 学習可能な温度で類似度のシャープネスを制御
  • ゼロショット分類: テキストプロンプトを使って未知のカテゴリも分類可能

CLIPは、マルチモーダルAIの基盤技術として、DALL-E、Stable Diffusion、GPT-4Vなど多くの後続モデルに影響を与えています。

次のステップとして、以下の記事も参考にしてください。