SimCLR(Simple Framework for Contrastive Learning of Visual Representations)は、2020年にGoogleが発表した自己教師あり学習のフレームワークです。シンプルながら強力な性能を発揮し、対照学習の研究に大きな影響を与えました。
本記事では、SimCLRの理論的背景から実装まで詳しく解説します。
本記事の内容
- SimCLRのアーキテクチャと各コンポーネントの役割
- NT-Xent損失の数学的定義
- PyTorchでの実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
SimCLRとは
SimCLRは、対照学習のシンプルなフレームワークです。主な特徴は:
- 強力なデータ拡張: 複数の変換を組み合わせた強いデータ拡張
- Projection Head: 非線形のProjection Headが表現学習に重要
- 大きなバッチサイズ: 多くのNegativeサンプルを利用
- NT-Xent損失: 正規化された温度付きクロスエントロピー損失
SimCLRのアーキテクチャ
全体構造
SimCLRは以下の4つのコンポーネントで構成されます:
- データ拡張モジュール $\mathcal{T}$: 入力画像に2つの異なる変換を適用
- ベースエンコーダ $f(\cdot)$: 画像から特徴ベクトルを抽出(例:ResNet)
- Projection Head $g(\cdot)$: 特徴ベクトルを対照学習用の空間に射影
- 対照損失関数: NT-Xent損失
処理の流れは以下の通りです:
$$ x \xrightarrow{t \sim \mathcal{T}} \tilde{x}_i \xrightarrow{f(\cdot)} \bm{h}_i \xrightarrow{g(\cdot)} \bm{z}_i $$
データ拡張
SimCLRでは、以下のデータ拡張を組み合わせます:
| 変換 | 説明 |
|---|---|
| Random Crop & Resize | ランダムな領域を切り出してリサイズ |
| Color Distortion | 色相、彩度、明度、コントラストの変更 |
| Gaussian Blur | ガウシアンぼかし |
| Random Horizontal Flip | 水平方向の反転 |
論文では、Random Crop と Color Distortion の組み合わせが特に重要であることが示されています。
Projection Head
Projection Head $g(\cdot)$ は、ベースエンコーダの出力 $\bm{h}$ を対照学習用の表現 $\bm{z}$ に変換します:
$$ \bm{z} = g(\bm{h}) = W^{(2)} \sigma(W^{(1)} \bm{h}) $$
ここで $\sigma$ は ReLU 活性化関数です。
重要な発見として、Projection Headを使った空間 $\bm{z}$ で対照学習を行い、下流タスクにはProjection Head前の表現 $\bm{h}$ を使用すると性能が向上します。
これは、対照学習の損失関数によって $\bm{z}$ からデータ拡張に関する情報が失われる一方、$\bm{h}$ にはより一般的な情報が保持されるためと考えられています。
NT-Xent損失
数学的定義
NT-Xent(Normalized Temperature-scaled Cross Entropy)損失は、InfoNCE損失のSimCLR版です。
ミニバッチサイズを $N$ とすると、データ拡張により $2N$ 個のサンプルが得られます。サンプル $i$ に対する損失は:
$$ \ell_{i,j} = -\log \frac{\exp(\text{sim}(\bm{z}_i, \bm{z}_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(\bm{z}_i, \bm{z}_k) / \tau)} $$
ここで、$(i, j)$ はPositive Pair(同じ画像から生成された2つのビュー)です。
類似度関数はコサイン類似度を使用します:
$$ \text{sim}(\bm{z}_i, \bm{z}_j) = \frac{\bm{z}_i^\top \bm{z}_j}{\|\bm{z}_i\| \|\bm{z}_j\|} $$
全体の損失
$2N$ 個のサンプルすべてについて損失を計算し、平均を取ります:
$$ \mathcal{L} = \frac{1}{2N} \sum_{k=1}^{N} \left[ \ell_{2k-1, 2k} + \ell_{2k, 2k-1} \right] $$
損失関数の性質
NT-Xent損失の特徴:
- 正規化: コサイン類似度による正規化で、特徴ベクトルの大きさに依存しない
- 温度スケーリング: $\tau$ により分布の鋭さを調整
- 対称性: Positive Pairの両方向から損失を計算
バッチサイズの影響
SimCLRでは大きなバッチサイズが重要です。バッチサイズ $N$ のとき、各サンプルに対して $2(N-1)$ 個のNegativeサンプルが存在します。
バッチサイズが大きいほど: – より多様なNegativeサンプルを利用できる – より難しい(類似度の高い)Negativeサンプルに遭遇する確率が上がる
論文では、バッチサイズ256から8192まで性能が向上し続けることが報告されています。
PyTorchでの実装
データ拡張の実装
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
class SimCLRTransform:
"""SimCLR用のデータ拡張"""
def __init__(self, size=224, s=1.0):
"""
Args:
size: 出力画像サイズ
s: Color Distortionの強度
"""
# Color Distortion
color_jitter = transforms.ColorJitter(
brightness=0.8 * s,
contrast=0.8 * s,
saturation=0.8 * s,
hue=0.2 * s
)
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=int(0.1 * size) | 1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def __call__(self, x):
"""同じ画像に2つの異なる変換を適用"""
return self.transform(x), self.transform(x)
NT-Xent損失の実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class NTXentLoss(nn.Module):
"""NT-Xent損失の実装"""
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature
def forward(self, z1, z2):
"""
Args:
z1: 1つ目のビューの特徴ベクトル (batch_size, feature_dim)
z2: 2つ目のビューの特徴ベクトル (batch_size, feature_dim)
Returns:
loss: NT-Xent損失
"""
batch_size = z1.shape[0]
device = z1.device
# L2正規化
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 全サンプルを結合 (2 * batch_size, feature_dim)
z = torch.cat([z1, z2], dim=0)
# 類似度行列 (2 * batch_size, 2 * batch_size)
sim_matrix = torch.matmul(z, z.T) / self.temperature
# 自分自身との類似度をマスク
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))
# Positive Pairのインデックス
# z1[i]のPositive Pairはz2[i](インデックス: batch_size + i)
# z2[i]のPositive Pairはz1[i](インデックス: i)
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
], dim=0).to(device)
# クロスエントロピー損失
loss = F.cross_entropy(sim_matrix, labels)
return loss
# 動作確認
torch.manual_seed(42)
batch_size = 256
feature_dim = 128
z1 = torch.randn(batch_size, feature_dim)
z2 = torch.randn(batch_size, feature_dim)
criterion = NTXentLoss(temperature=0.5)
loss = criterion(z1, z2)
print(f"NT-Xent Loss: {loss.item():.4f}")
SimCLRモデルの実装
import torch
import torch.nn as nn
import torchvision.models as models
class SimCLR(nn.Module):
"""SimCLRモデル"""
def __init__(self, base_encoder='resnet18', projection_dim=128):
super().__init__()
# ベースエンコーダ
if base_encoder == 'resnet18':
self.encoder = models.resnet18(pretrained=False)
hidden_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity()
elif base_encoder == 'resnet50':
self.encoder = models.resnet50(pretrained=False)
hidden_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity()
else:
raise ValueError(f"Unknown encoder: {base_encoder}")
# Projection Head (MLP)
self.projection_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x):
"""
Args:
x: 入力画像 (batch_size, 3, H, W)
Returns:
h: ベースエンコーダの出力(下流タスク用)
z: Projection Headの出力(対照学習用)
"""
h = self.encoder(x)
z = self.projection_head(h)
return h, z
def encode(self, x):
"""下流タスク用の特徴抽出"""
return self.encoder(x)
学習ループの実装
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
# シンプルな実験用の合成データ
def create_synthetic_dataset(n_samples=1000, n_classes=10, img_size=32):
"""合成データセットの作成(デモ用)"""
np.random.seed(42)
images = []
labels = []
for class_idx in range(n_classes):
# 各クラスで異なる色のパターン
color = np.random.rand(3)
for _ in range(n_samples // n_classes):
img = np.zeros((3, img_size, img_size))
for c in range(3):
img[c] = color[c] + np.random.randn(img_size, img_size) * 0.1
img = np.clip(img, 0, 1)
images.append(img)
labels.append(class_idx)
images = np.array(images, dtype=np.float32)
labels = np.array(labels)
return images, labels
# 簡易版のSimCLR学習(デモ用)
class SimpleSimCLR(nn.Module):
"""簡易版SimCLR(小規模データ用)"""
def __init__(self, input_channels=3, img_size=32, hidden_dim=256, projection_dim=64):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_channels, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
)
self.projection_head = nn.Sequential(
nn.Linear(128, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projection_head(h)
return h, z
def simple_augmentation(x):
"""簡易データ拡張"""
# ノイズ追加
noise = torch.randn_like(x) * 0.1
x_aug = x + noise
# ランダムにチャンネルの順序を変更(色変換の簡易版)
if torch.rand(1) > 0.5:
perm = torch.randperm(3)
x_aug = x_aug[:, perm, :, :]
return torch.clamp(x_aug, 0, 1)
# 学習の実行
def train_simclr(model, data, n_epochs=50, batch_size=128, temperature=0.5, lr=0.001):
"""SimCLRの学習"""
criterion = NTXentLoss(temperature=temperature)
optimizer = optim.Adam(model.parameters(), lr=lr)
dataset = TensorDataset(torch.tensor(data))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
losses = []
for epoch in range(n_epochs):
epoch_loss = 0
n_batches = 0
for (batch,) in dataloader:
# 2つのビューを作成
x1 = simple_augmentation(batch)
x2 = simple_augmentation(batch)
# 順伝播
_, z1 = model(x1)
_, z2 = model(x2)
# 損失計算
loss = criterion(z1, z2)
# 逆伝播
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
avg_loss = epoch_loss / n_batches
losses.append(avg_loss)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {avg_loss:.4f}")
return losses
# 実験
images, labels = create_synthetic_dataset(n_samples=1000, n_classes=10)
model = SimpleSimCLR()
losses = train_simclr(model, images, n_epochs=50, batch_size=64)
# 学習曲線のプロット
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('NT-Xent Loss')
plt.title('SimCLR Training Loss')
plt.grid(True, alpha=0.3)
# 学習後の特徴空間の可視化
model.eval()
with torch.no_grad():
h, _ = model(torch.tensor(images))
features = h.numpy()
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)
plt.subplot(1, 2, 2)
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Learned Feature Space')
plt.colorbar(scatter)
plt.tight_layout()
plt.show()
SimCLRの実験結果(論文より)
論文で報告された主な結果:
| 設定 | ImageNet Top-1 (%) |
|---|---|
| SimCLR (ResNet-50, 1x) | 69.3 |
| SimCLR (ResNet-50, 4x) | 76.5 |
| 教師あり (ResNet-50) | 76.5 |
自己教師あり学習でありながら、教師あり学習と同等の性能を達成しています。
各コンポーネントの寄与
| 要素 | 性能への寄与 |
|---|---|
| データ拡張(Crop + Color) | +10% 以上 |
| Projection Head(非線形) | +3% |
| 大きなバッチサイズ | +5% |
まとめ
本記事では、SimCLR(Simple Framework for Contrastive Learning)について解説しました。
- SimCLRは、データ拡張、ベースエンコーダ、Projection Head、NT-Xent損失の4つのコンポーネントで構成される
- 強力なデータ拡張(特にRandom CropとColor Distortion)が重要
- Projection Headを使った空間で学習し、下流タスクにはその前の表現を使用
- 大きなバッチサイズにより多くのNegativeサンプルを活用
次のステップとして、以下の記事も参考にしてください。