SimCLRの仕組みとNT-Xent損失をPythonで実装する

SimCLR(Simple Framework for Contrastive Learning of Visual Representations)は、2020年にGoogleが発表した自己教師あり学習のフレームワークです。シンプルながら強力な性能を発揮し、対照学習の研究に大きな影響を与えました。

本記事では、SimCLRの理論的背景から実装まで詳しく解説します。

本記事の内容

  • SimCLRのアーキテクチャと各コンポーネントの役割
  • NT-Xent損失の数学的定義
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

SimCLRとは

SimCLRは、対照学習のシンプルなフレームワークです。主な特徴は:

  1. 強力なデータ拡張: 複数の変換を組み合わせた強いデータ拡張
  2. Projection Head: 非線形のProjection Headが表現学習に重要
  3. 大きなバッチサイズ: 多くのNegativeサンプルを利用
  4. NT-Xent損失: 正規化された温度付きクロスエントロピー損失

SimCLRのアーキテクチャ

全体構造

SimCLRは以下の4つのコンポーネントで構成されます:

  1. データ拡張モジュール $\mathcal{T}$: 入力画像に2つの異なる変換を適用
  2. ベースエンコーダ $f(\cdot)$: 画像から特徴ベクトルを抽出(例:ResNet)
  3. Projection Head $g(\cdot)$: 特徴ベクトルを対照学習用の空間に射影
  4. 対照損失関数: NT-Xent損失

処理の流れは以下の通りです:

$$ x \xrightarrow{t \sim \mathcal{T}} \tilde{x}_i \xrightarrow{f(\cdot)} \bm{h}_i \xrightarrow{g(\cdot)} \bm{z}_i $$

データ拡張

SimCLRでは、以下のデータ拡張を組み合わせます:

変換 説明
Random Crop & Resize ランダムな領域を切り出してリサイズ
Color Distortion 色相、彩度、明度、コントラストの変更
Gaussian Blur ガウシアンぼかし
Random Horizontal Flip 水平方向の反転

論文では、Random Crop と Color Distortion の組み合わせが特に重要であることが示されています。

Projection Head

Projection Head $g(\cdot)$ は、ベースエンコーダの出力 $\bm{h}$ を対照学習用の表現 $\bm{z}$ に変換します:

$$ \bm{z} = g(\bm{h}) = W^{(2)} \sigma(W^{(1)} \bm{h}) $$

ここで $\sigma$ は ReLU 活性化関数です。

重要な発見として、Projection Headを使った空間 $\bm{z}$ で対照学習を行い、下流タスクにはProjection Head前の表現 $\bm{h}$ を使用すると性能が向上します。

これは、対照学習の損失関数によって $\bm{z}$ からデータ拡張に関する情報が失われる一方、$\bm{h}$ にはより一般的な情報が保持されるためと考えられています。

NT-Xent損失

数学的定義

NT-Xent(Normalized Temperature-scaled Cross Entropy)損失は、InfoNCE損失のSimCLR版です。

ミニバッチサイズを $N$ とすると、データ拡張により $2N$ 個のサンプルが得られます。サンプル $i$ に対する損失は:

$$ \ell_{i,j} = -\log \frac{\exp(\text{sim}(\bm{z}_i, \bm{z}_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(\bm{z}_i, \bm{z}_k) / \tau)} $$

ここで、$(i, j)$ はPositive Pair(同じ画像から生成された2つのビュー)です。

類似度関数はコサイン類似度を使用します:

$$ \text{sim}(\bm{z}_i, \bm{z}_j) = \frac{\bm{z}_i^\top \bm{z}_j}{\|\bm{z}_i\| \|\bm{z}_j\|} $$

全体の損失

$2N$ 個のサンプルすべてについて損失を計算し、平均を取ります:

$$ \mathcal{L} = \frac{1}{2N} \sum_{k=1}^{N} \left[ \ell_{2k-1, 2k} + \ell_{2k, 2k-1} \right] $$

損失関数の性質

NT-Xent損失の特徴:

  1. 正規化: コサイン類似度による正規化で、特徴ベクトルの大きさに依存しない
  2. 温度スケーリング: $\tau$ により分布の鋭さを調整
  3. 対称性: Positive Pairの両方向から損失を計算

バッチサイズの影響

SimCLRでは大きなバッチサイズが重要です。バッチサイズ $N$ のとき、各サンプルに対して $2(N-1)$ 個のNegativeサンプルが存在します。

バッチサイズが大きいほど: – より多様なNegativeサンプルを利用できる – より難しい(類似度の高い)Negativeサンプルに遭遇する確率が上がる

論文では、バッチサイズ256から8192まで性能が向上し続けることが報告されています。

PyTorchでの実装

データ拡張の実装

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

class SimCLRTransform:
    """SimCLR用のデータ拡張"""

    def __init__(self, size=224, s=1.0):
        """
        Args:
            size: 出力画像サイズ
            s: Color Distortionの強度
        """
        # Color Distortion
        color_jitter = transforms.ColorJitter(
            brightness=0.8 * s,
            contrast=0.8 * s,
            saturation=0.8 * s,
            hue=0.2 * s
        )

        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=int(0.1 * size) | 1),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __call__(self, x):
        """同じ画像に2つの異なる変換を適用"""
        return self.transform(x), self.transform(x)

NT-Xent損失の実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class NTXentLoss(nn.Module):
    """NT-Xent損失の実装"""

    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        """
        Args:
            z1: 1つ目のビューの特徴ベクトル (batch_size, feature_dim)
            z2: 2つ目のビューの特徴ベクトル (batch_size, feature_dim)
        Returns:
            loss: NT-Xent損失
        """
        batch_size = z1.shape[0]
        device = z1.device

        # L2正規化
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        # 全サンプルを結合 (2 * batch_size, feature_dim)
        z = torch.cat([z1, z2], dim=0)

        # 類似度行列 (2 * batch_size, 2 * batch_size)
        sim_matrix = torch.matmul(z, z.T) / self.temperature

        # 自分自身との類似度をマスク
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
        sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))

        # Positive Pairのインデックス
        # z1[i]のPositive Pairはz2[i](インデックス: batch_size + i)
        # z2[i]のPositive Pairはz1[i](インデックス: i)
        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(batch_size)
        ], dim=0).to(device)

        # クロスエントロピー損失
        loss = F.cross_entropy(sim_matrix, labels)

        return loss

# 動作確認
torch.manual_seed(42)
batch_size = 256
feature_dim = 128

z1 = torch.randn(batch_size, feature_dim)
z2 = torch.randn(batch_size, feature_dim)

criterion = NTXentLoss(temperature=0.5)
loss = criterion(z1, z2)
print(f"NT-Xent Loss: {loss.item():.4f}")

SimCLRモデルの実装

import torch
import torch.nn as nn
import torchvision.models as models

class SimCLR(nn.Module):
    """SimCLRモデル"""

    def __init__(self, base_encoder='resnet18', projection_dim=128):
        super().__init__()

        # ベースエンコーダ
        if base_encoder == 'resnet18':
            self.encoder = models.resnet18(pretrained=False)
            hidden_dim = self.encoder.fc.in_features
            self.encoder.fc = nn.Identity()
        elif base_encoder == 'resnet50':
            self.encoder = models.resnet50(pretrained=False)
            hidden_dim = self.encoder.fc.in_features
            self.encoder.fc = nn.Identity()
        else:
            raise ValueError(f"Unknown encoder: {base_encoder}")

        # Projection Head (MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )

    def forward(self, x):
        """
        Args:
            x: 入力画像 (batch_size, 3, H, W)
        Returns:
            h: ベースエンコーダの出力(下流タスク用)
            z: Projection Headの出力(対照学習用)
        """
        h = self.encoder(x)
        z = self.projection_head(h)
        return h, z

    def encode(self, x):
        """下流タスク用の特徴抽出"""
        return self.encoder(x)

学習ループの実装

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

# シンプルな実験用の合成データ
def create_synthetic_dataset(n_samples=1000, n_classes=10, img_size=32):
    """合成データセットの作成(デモ用)"""
    np.random.seed(42)

    images = []
    labels = []

    for class_idx in range(n_classes):
        # 各クラスで異なる色のパターン
        color = np.random.rand(3)
        for _ in range(n_samples // n_classes):
            img = np.zeros((3, img_size, img_size))
            for c in range(3):
                img[c] = color[c] + np.random.randn(img_size, img_size) * 0.1
            img = np.clip(img, 0, 1)
            images.append(img)
            labels.append(class_idx)

    images = np.array(images, dtype=np.float32)
    labels = np.array(labels)

    return images, labels

# 簡易版のSimCLR学習(デモ用)
class SimpleSimCLR(nn.Module):
    """簡易版SimCLR(小規模データ用)"""

    def __init__(self, input_channels=3, img_size=32, hidden_dim=256, projection_dim=64):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )

        self.projection_head = nn.Sequential(
            nn.Linear(128, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return h, z

def simple_augmentation(x):
    """簡易データ拡張"""
    # ノイズ追加
    noise = torch.randn_like(x) * 0.1
    x_aug = x + noise

    # ランダムにチャンネルの順序を変更(色変換の簡易版)
    if torch.rand(1) > 0.5:
        perm = torch.randperm(3)
        x_aug = x_aug[:, perm, :, :]

    return torch.clamp(x_aug, 0, 1)

# 学習の実行
def train_simclr(model, data, n_epochs=50, batch_size=128, temperature=0.5, lr=0.001):
    """SimCLRの学習"""
    criterion = NTXentLoss(temperature=temperature)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    dataset = TensorDataset(torch.tensor(data))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    losses = []

    for epoch in range(n_epochs):
        epoch_loss = 0
        n_batches = 0

        for (batch,) in dataloader:
            # 2つのビューを作成
            x1 = simple_augmentation(batch)
            x2 = simple_augmentation(batch)

            # 順伝播
            _, z1 = model(x1)
            _, z2 = model(x2)

            # 損失計算
            loss = criterion(z1, z2)

            # 逆伝播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            n_batches += 1

        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {avg_loss:.4f}")

    return losses

# 実験
images, labels = create_synthetic_dataset(n_samples=1000, n_classes=10)

model = SimpleSimCLR()
losses = train_simclr(model, images, n_epochs=50, batch_size=64)

# 学習曲線のプロット
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('NT-Xent Loss')
plt.title('SimCLR Training Loss')
plt.grid(True, alpha=0.3)

# 学習後の特徴空間の可視化
model.eval()
with torch.no_grad():
    h, _ = model(torch.tensor(images))
    features = h.numpy()

from sklearn.decomposition import PCA
pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)

plt.subplot(1, 2, 2)
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Learned Feature Space')
plt.colorbar(scatter)

plt.tight_layout()
plt.show()

SimCLRの実験結果(論文より)

論文で報告された主な結果:

設定 ImageNet Top-1 (%)
SimCLR (ResNet-50, 1x) 69.3
SimCLR (ResNet-50, 4x) 76.5
教師あり (ResNet-50) 76.5

自己教師あり学習でありながら、教師あり学習と同等の性能を達成しています。

各コンポーネントの寄与

要素 性能への寄与
データ拡張(Crop + Color) +10% 以上
Projection Head(非線形) +3%
大きなバッチサイズ +5%

まとめ

本記事では、SimCLR(Simple Framework for Contrastive Learning)について解説しました。

  • SimCLRは、データ拡張、ベースエンコーダ、Projection Head、NT-Xent損失の4つのコンポーネントで構成される
  • 強力なデータ拡張(特にRandom CropとColor Distortion)が重要
  • Projection Headを使った空間で学習し、下流タスクにはその前の表現を使用
  • 大きなバッチサイズにより多くのNegativeサンプルを活用

次のステップとして、以下の記事も参考にしてください。