対照学習(Contrastive Learning)は、ラベルなしデータから有用な特徴表現を学習する自己教師あり学習の代表的手法です。近年の大規模言語モデルや画像認識モデルの事前学習において中核的な役割を果たしています。
本記事では、対照学習の基本原理から数学的な定式化、Pythonでの実装までを解説します。
本記事の内容
- 対照学習の直感的な理解
- InfoNCE損失の数学的導出
- Pythonでのシンプルな実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
対照学習とは
対照学習は、「似ているサンプル同士は近く、異なるサンプル同士は遠くなるような特徴空間を学習する」という考え方に基づいています。
基本的なアイデア
教師なし学習では、ラベルがないためどのサンプルが似ているかを直接知ることができません。対照学習では、データ拡張(Data Augmentation)を活用してこの問題を解決します。
- 同じサンプルに異なるデータ拡張を適用した2つのビュー(Positive Pair)
- 異なるサンプル同士のペア(Negative Pair)
を作成し、Positive Pairは近く、Negative Pairは遠くなるように学習します。
直感的な理解
画像分類を例にすると、同じ猫の画像に対して「回転」「色変換」「クロップ」などの変換を適用しても、それらはすべて「猫」という本質的な情報を保持しています。対照学習は、このような変換に対して不変な表現を学習することで、クラスラベルなしでも意味のある特徴を獲得できます。
対照学習の数学的定式化
問題設定
$N$ 個のサンプルからなるミニバッチ $\{x_1, x_2, \ldots, x_N\}$ を考えます。各サンプル $x_i$ に対して、2つのデータ拡張 $t, t’$ を適用し、2つのビュー $\tilde{x}_i = t(x_i)$, $\tilde{x}’_i = t'(x_i)$ を生成します。
エンコーダ $f_\theta$ を用いて、各ビューの特徴ベクトルを得ます:
$$ \bm{z}_i = f_\theta(\tilde{x}_i), \quad \bm{z}’_i = f_\theta(\tilde{x}’_i) $$
類似度の定義
2つの特徴ベクトル間の類似度として、コサイン類似度を使用します:
$$ \text{sim}(\bm{z}_i, \bm{z}_j) = \frac{\bm{z}_i^\top \bm{z}_j}{\|\bm{z}_i\| \|\bm{z}_j\|} $$
特徴ベクトルを正規化して $\|\bm{z}\| = 1$ とすると、これは単純な内積になります:
$$ \text{sim}(\bm{z}_i, \bm{z}_j) = \bm{z}_i^\top \bm{z}_j $$
InfoNCE損失の導出
損失関数の定義
InfoNCE(Information Noise Contrastive Estimation)損失は、対照学習で最も広く使われる損失関数です。
サンプル $i$ に対するInfoNCE損失は以下のように定義されます:
$$ \mathcal{L}_i = -\log \frac{\exp(\text{sim}(\bm{z}_i, \bm{z}’_i) / \tau)}{\sum_{k=1}^{N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(\bm{z}_i, \bm{z}’_k) / \tau) + \exp(\text{sim}(\bm{z}_i, \bm{z}’_i) / \tau)} $$
ここで、$\tau > 0$ は温度パラメータ、$\mathbb{1}_{[k \neq i]}$ は $k \neq i$ のとき1、そうでないとき0を取る指示関数です。
損失関数の解釈
この損失関数をより詳しく見てみましょう。
$$ \mathcal{L}_i = -\log \frac{\exp(s_{ii’} / \tau)}{\exp(s_{ii’} / \tau) + \sum_{k \neq i} \exp(s_{ik’} / \tau)} $$
ここで $s_{ij’} = \text{sim}(\bm{z}_i, \bm{z}’_j)$ としました。
これはソフトマックス関数の形をしており、Positive Pair $(\bm{z}_i, \bm{z}’_i)$ が選ばれる確率を最大化することに対応します:
$$ p(i’ | i) = \frac{\exp(s_{ii’} / \tau)}{\sum_{k=1}^{N} \exp(s_{ik’} / \tau)} $$
温度パラメータの役割
温度パラメータ $\tau$ は分布の鋭さを制御します:
- $\tau \to 0$:ハードな選択(最も類似度の高いペアのみを考慮)
- $\tau \to \infty$:一様分布に近づく
一般的に $\tau = 0.07 \sim 0.5$ 程度の値が使用されます。
全体の損失関数
ミニバッチ全体の損失は、各サンプルの損失の平均として計算されます:
$$ \mathcal{L} = \frac{1}{2N} \sum_{i=1}^{N} \left[ \mathcal{L}_i + \mathcal{L}’_i \right] $$
ここで $\mathcal{L}’_i$ は $\bm{z}’_i$ をアンカーとした場合の損失です。
勾配の導出
InfoNCE損失の勾配を導出しましょう。簡単のため、$\bm{z}_i, \bm{z}’_j$ が正規化されているとします。
損失関数を以下のように書き直します:
$$ \mathcal{L}_i = -\frac{s_{ii’}}{\tau} + \log \sum_{k=1}^{N} \exp\left(\frac{s_{ik’}}{\tau}\right) $$
特徴ベクトル $\bm{z}_i$ に関する勾配は:
$$ \begin{align} \frac{\partial \mathcal{L}_i}{\partial \bm{z}_i} &= -\frac{1}{\tau} \bm{z}’_i + \frac{1}{\tau} \sum_{k=1}^{N} p(k’ | i) \bm{z}’_k \\ &= \frac{1}{\tau} \left( \sum_{k=1}^{N} p(k’ | i) \bm{z}’_k – \bm{z}’_i \right) \end{align} $$
この勾配は、現在の表現をPositive Pair $\bm{z}’_i$ に近づけ、Negative Pair $\bm{z}’_k$ $(k \neq i)$ から遠ざける方向を示しています。
Pythonでの実装
InfoNCE損失の実装
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class InfoNCELoss(nn.Module):
"""InfoNCE損失関数の実装"""
def __init__(self, temperature=0.07):
super().__init__()
self.temperature = temperature
def forward(self, z1, z2):
"""
Args:
z1: 1つ目のビューの特徴ベクトル (batch_size, feature_dim)
z2: 2つ目のビューの特徴ベクトル (batch_size, feature_dim)
Returns:
loss: InfoNCE損失
"""
# L2正規化
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
batch_size = z1.shape[0]
# 類似度行列の計算
# sim[i, j] = z1[i] と z2[j] の類似度
sim_matrix = torch.matmul(z1, z2.T) / self.temperature
# 対角成分がPositive Pairの類似度
labels = torch.arange(batch_size, device=z1.device)
# クロスエントロピー損失(両方向)
loss_12 = F.cross_entropy(sim_matrix, labels)
loss_21 = F.cross_entropy(sim_matrix.T, labels)
loss = (loss_12 + loss_21) / 2
return loss
# 動作確認
torch.manual_seed(42)
batch_size = 256
feature_dim = 128
# ランダムな特徴ベクトルを生成
z1 = torch.randn(batch_size, feature_dim)
z2 = torch.randn(batch_size, feature_dim)
criterion = InfoNCELoss(temperature=0.07)
loss = criterion(z1, z2)
print(f"InfoNCE Loss: {loss.item():.4f}")
温度パラメータの影響の可視化
import numpy as np
import matplotlib.pyplot as plt
def compute_softmax_probs(similarities, temperature):
"""温度付きソフトマックスの計算"""
scaled = similarities / temperature
exp_scaled = np.exp(scaled - np.max(scaled)) # 数値安定性のため
return exp_scaled / np.sum(exp_scaled)
# 類似度の例
similarities = np.array([0.9, 0.3, 0.2, 0.1, 0.05]) # 最初がPositive
temperatures = [0.01, 0.07, 0.1, 0.5, 1.0]
plt.figure(figsize=(10, 6))
for tau in temperatures:
probs = compute_softmax_probs(similarities, tau)
plt.plot(range(len(similarities)), probs, 'o-', label=f'τ = {tau}')
plt.xlabel('Sample Index')
plt.ylabel('Probability')
plt.title('Effect of Temperature on Softmax Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xticks(range(len(similarities)), ['Positive', 'Neg 1', 'Neg 2', 'Neg 3', 'Neg 4'])
plt.tight_layout()
plt.show()
シンプルな対照学習の学習ループ
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
class SimpleEncoder(nn.Module):
"""シンプルなエンコーダネットワーク"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
def generate_synthetic_data(n_samples, n_clusters=5, dim=10):
"""クラスタ構造を持つ合成データの生成"""
np.random.seed(42)
data = []
labels = []
for i in range(n_clusters):
center = np.random.randn(dim) * 5
cluster_data = center + np.random.randn(n_samples // n_clusters, dim) * 0.5
data.append(cluster_data)
labels.extend([i] * (n_samples // n_clusters))
data = np.vstack(data)
labels = np.array(labels)
return data.astype(np.float32), labels
def add_noise(x, noise_level=0.1):
"""データ拡張としてノイズを追加"""
return x + np.random.randn(*x.shape).astype(np.float32) * noise_level
# データ生成
data, true_labels = generate_synthetic_data(500, n_clusters=5, dim=10)
# モデルとオプティマイザの設定
encoder = SimpleEncoder(input_dim=10, hidden_dim=64, output_dim=32)
criterion = InfoNCELoss(temperature=0.1)
optimizer = optim.Adam(encoder.parameters(), lr=0.001)
# 学習ループ
n_epochs = 100
batch_size = 64
losses = []
for epoch in range(n_epochs):
# ミニバッチのサンプリング
indices = np.random.choice(len(data), batch_size, replace=False)
x = data[indices]
# データ拡張で2つのビューを作成
x1 = torch.tensor(add_noise(x, 0.1))
x2 = torch.tensor(add_noise(x, 0.1))
# 特徴抽出
z1 = encoder(x1)
z2 = encoder(x2)
# 損失計算と更新
loss = criterion(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (epoch + 1) % 20 == 0:
print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}")
# 学習曲線のプロット
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('InfoNCE Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
# 学習後の特徴空間の可視化(PCAで2次元に)
from sklearn.decomposition import PCA
encoder.eval()
with torch.no_grad():
features = encoder(torch.tensor(data)).numpy()
pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)
plt.subplot(1, 2, 2)
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=true_labels, cmap='tab10', alpha=0.6)
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Learned Feature Space (PCA)')
plt.colorbar(scatter)
plt.tight_layout()
plt.show()
対照学習の発展
対照学習は多くの発展形があります:
- SimCLR: データ拡張とProjection Headの重要性を示した手法
- MoCo(Momentum Contrast): メモリバンクを用いて大量のNegativeサンプルを効率的に利用
- BYOL, SimSiam: Negativeサンプルなしで学習する手法
- CLIP: 画像とテキストの対照学習
まとめ
本記事では、対照学習(Contrastive Learning)の基礎について解説しました。
- 対照学習は、Positive Pairを近く、Negative Pairを遠くする特徴空間を学習する
- InfoNCE損失は、ソフトマックスベースの損失関数で、Positive Pairが選ばれる確率を最大化する
- 温度パラメータは分布の鋭さを制御する重要なハイパーパラメータ
次のステップとして、以下の記事も参考にしてください。