交差エントロピー(Cross-Entropy)は、自然言語処理における最も基本的な損失関数です。言語モデル、テキスト分類、機械翻訳など、ほぼすべてのNLPタスクで使用されています。
本記事では、交差エントロピーの情報理論的な背景から、NLPへの具体的な適用方法、そしてPyTorchでの実装まで解説します。
本記事の内容
- 交差エントロピーの情報理論的定義
- NLPタスクへの適用(分類、言語モデル、系列ラベリング)
- ソフトマックス関数との組み合わせ
- ラベル平滑化(Label Smoothing)
- 数値安定性と実装の注意点
- PyTorchでの実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
交差エントロピーとは
情報理論的定義
交差エントロピーは、2つの確率分布 $p$ と $q$ の間の「違い」を測る指標です。
真の分布 $p$ に対して、モデル分布 $q$ で符号化したときの平均符号長(ビット数)を表します。
$$ H(p, q) = -\sum_{x} p(x) \log q(x) $$
連続分布の場合:
$$ H(p, q) = -\int p(x) \log q(x) \, dx $$
エントロピーとの関係
真の分布 $p$ のエントロピー $H(p)$ は、最適な符号化での平均符号長です。
$$ H(p) = -\sum_{x} p(x) \log p(x) $$
交差エントロピーは常にエントロピー以上です:
$$ H(p, q) \geq H(p) $$
等号成立は $p = q$ のときのみ。
KLダイバージェンスとの関係
KLダイバージェンス(相対エントロピー)は、交差エントロピーとエントロピーの差です。
$$ D_{\text{KL}}(p \| q) = H(p, q) – H(p) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} $$
したがって:
$$ H(p, q) = H(p) + D_{\text{KL}}(p \| q) $$
重要な洞察: 真の分布 $p$ が固定されているとき(教師あり学習では正解ラベルが固定)、$H(p)$ は定数です。したがって、交差エントロピーの最小化はKLダイバージェンスの最小化と等価です。
$$ \min_q H(p, q) = \min_q D_{\text{KL}}(p \| q) $$
分類タスクへの適用
単一ラベル分類
クラス数を $K$、正解クラスを $y$(one-hotベクトル)、モデルの予測確率を $\hat{y}$ とします。
正解ラベル(one-hot): $$ y_k = \begin{cases} 1 & \text{if } k = c \text{ (正解クラス)} \\ 0 & \text{otherwise} \end{cases} $$
予測確率(softmax出力): $$ \hat{y}_k = P(k \mid x) = \frac{\exp(z_k)}{\sum_{j=1}^{K} \exp(z_j)} $$
交差エントロピー損失: $$ \mathcal{L} = -\sum_{k=1}^{K} y_k \log \hat{y}_k = -\log \hat{y}_c $$
one-hotラベルの場合、正解クラス $c$ の項のみが残ります。
ミニバッチでの計算
バッチサイズ $N$ のミニバッチに対して:
$$ \mathcal{L}_{\text{batch}} = -\frac{1}{N} \sum_{i=1}^{N} \log \hat{y}_{i,c_i} $$
ここで $c_i$ はサンプル $i$ の正解クラスです。
数式の導出
なぜ交差エントロピーが分類タスクに適しているのかを、最尤推定の観点から導出します。
モデルが各クラスの確率 $P(k \mid x; \theta)$ を出力するとき、正解クラス $c$ の対数尤度は:
$$ \log P(c \mid x; \theta) = \log \hat{y}_c $$
$N$ 個の独立なサンプルの対数尤度:
$$ \log L(\theta) = \sum_{i=1}^{N} \log P(c_i \mid x_i; \theta) = \sum_{i=1}^{N} \log \hat{y}_{i,c_i} $$
最尤推定では対数尤度を最大化します。これは負の対数尤度を最小化することと等価です:
$$ \min_\theta \left( -\sum_{i=1}^{N} \log \hat{y}_{i,c_i} \right) $$
これはまさに交差エントロピー損失の和です。
言語モデルへの適用
次トークン予測
言語モデルでは、各位置 $t$ で次のトークン $w_t$ を予測します。
入力: $w_1, w_2, \ldots, w_{t-1}$ 出力: $P(w_t \mid w_1, \ldots, w_{t-1})$
各位置での交差エントロピー:
$$ \mathcal{L}_t = -\log P(w_t \mid w_1, \ldots, w_{t-1}) $$
系列全体の損失(平均):
$$ \mathcal{L} = -\frac{1}{T} \sum_{t=1}^{T} \log P(w_t \mid w_1, \ldots, w_{t-1}) $$
パープレキシティとの関係
パープレキシティは交差エントロピーの指数関数です:
$$
\text{PPL} = \exp(\mathcal{L}) = \exp\left( -\frac{1}{T} \sum_{t=1}^{T} \log P(w_t \mid w_{ $K$ クラス分類で、モデルの生出力(logits)を $\bm{z} = (z_1, \ldots, z_K)$ とすると: $$
\hat{y}_k = \text{softmax}(\bm{z})_k = \frac{\exp(z_k)}{\sum_{j=1}^{K} \exp(z_j)}
$$ 直接計算すると、$\exp(z_k)$ がオーバーフローする可能性があります。対数を取った形式で計算することで数値安定性を確保できます。 $$
\log \hat{y}_k = z_k – \log \sum_{j=1}^{K} \exp(z_j)
$$ さらに、LogSumExpのトリックを用いて: $$
\log \sum_{j=1}^{K} \exp(z_j) = m + \log \sum_{j=1}^{K} \exp(z_j – m)
$$ ここで $m = \max_j z_j$ とすることで、指数関数の引数が0以下になり、オーバーフローを防げます。 交差エントロピー損失は、softmaxとlog、そして和を組み合わせたものです: $$
\mathcal{L} = -\log \text{softmax}(\bm{z})_c = -z_c + \log \sum_{j=1}^{K} \exp(z_j)
$$ この形式では、logとexpが打ち消し合い、数値的に安定します。 ラベル平滑化は、one-hotラベルを「ソフト」にする正則化手法です。 通常のラベル:
$$
y_k = \begin{cases} 1 & \text{if } k = c \\ 0 & \text{otherwise} \end{cases}
$$ 平滑化後のラベル:
$$
y_k^{\text{smooth}} = \begin{cases} 1 – \epsilon + \frac{\epsilon}{K} & \text{if } k = c \\ \frac{\epsilon}{K} & \text{otherwise} \end{cases}
$$ ここで $\epsilon$ は平滑化パラメータ(通常 0.1)、$K$ はクラス数です。 平滑化ラベルは、one-hotラベルと均一分布の混合と見なせます: $$
\bm{y}^{\text{smooth}} = (1 – \epsilon) \bm{y}_{\text{one-hot}} + \epsilon \bm{u}
$$ ここで $\bm{u} = (1/K, \ldots, 1/K)$ は均一分布です。 1. 過信の防止 ラベル平滑化なしでは、モデルは正解クラスの確率を1に近づけようとし、logitsが極端に大きくなります。これは汎化性能を低下させます。 2. KLダイバージェンス項の追加 ラベル平滑化付き交差エントロピーは、通常の交差エントロピーに正則化項を加えたものと解釈できます: $$
\mathcal{L}_{\text{smooth}} = (1 – \epsilon) \mathcal{L}_{\text{CE}} + \epsilon \cdot D_{\text{KL}}(\bm{u} \| \hat{\bm{y}})
$$ つまり、予測分布が均一分布から離れすぎないように正則化されます。 系列ラベリングでは、各トークンにラベルを付与します。 系列長 $T$、各位置の正解ラベルを $y_t$、予測確率を $\hat{y}_t$ とすると: $$
\mathcal{L} = -\frac{1}{T} \sum_{t=1}^{T} \log P(y_t \mid x, t)
$$ 可変長系列を扱う場合、パディングトークンを損失計算から除外する必要があります。 マスク $m_t \in \{0, 1\}$ を用いて: $$
\mathcal{L} = -\frac{\sum_{t=1}^{T} m_t \log P(y_t \mid x, t)}{\sum_{t=1}^{T} m_t}
$$ logits $\bm{z}$ に対する勾配を導出します。 $$
\frac{\partial \mathcal{L}}{\partial z_k} = \hat{y}_k – y_k
$$ これは非常にシンプルな形です。正解クラス $c$ について: $$
\frac{\partial \mathcal{L}}{\partial z_c} = \hat{y}_c – 1
$$ その他のクラス $k \neq c$ について: $$
\frac{\partial \mathcal{L}}{\partial z_k} = \hat{y}_k
$$ 損失関数を展開すると: $$
\mathcal{L} = -\log \hat{y}_c = -z_c + \log \sum_{j=1}^{K} \exp(z_j)
$$ $z_k$ で微分: $$
\frac{\partial \mathcal{L}}{\partial z_k} = -\delta_{kc} + \frac{\exp(z_k)}{\sum_j \exp(z_j)} = -\delta_{kc} + \hat{y}_k
$$ ここで $\delta_{kc}$ はクロネッカーのデルタ($k = c$ のとき1、それ以外は0)。 本記事では、交差エントロピー損失とNLPの関係について解説しました。 交差エントロピー損失は、NLPにおいて最も基本的かつ重要な損失関数です。その数学的背景を理解することで、より深いモデル設計が可能になります。 次のステップとして、以下の記事も参考にしてください。ソフトマックス関数との組み合わせ
Softmaxの定義
Log-Softmaxと数値安定性
組み合わせた損失関数
ラベル平滑化(Label Smoothing)
概要
数学的定式化
効果
系列ラベリングへの適用
固有表現抽出(NER)など
パディングの処理
PyTorchでの実装
基本的な使用方法
import torch
import torch.nn as nn
import torch.nn.functional as F
# 方法1: nn.CrossEntropyLoss (logitsを入力)
criterion = nn.CrossEntropyLoss()
logits = torch.randn(32, 10) # (batch_size, num_classes)
targets = torch.randint(0, 10, (32,)) # (batch_size,)
loss = criterion(logits, targets)
print(f"CrossEntropyLoss: {loss.item():.4f}")
# 方法2: F.cross_entropy (関数形式)
loss = F.cross_entropy(logits, targets)
print(f"F.cross_entropy: {loss.item():.4f}")
# 方法3: 手動計算(確認用)
log_probs = F.log_softmax(logits, dim=1)
loss_manual = F.nll_loss(log_probs, targets)
print(f"Manual (log_softmax + nll_loss): {loss_manual.item():.4f}")
言語モデルでの使用
def compute_lm_loss(logits, targets, ignore_index=-100):
"""
言語モデルの交差エントロピー損失を計算
Args:
logits: (batch_size, seq_len, vocab_size)
targets: (batch_size, seq_len)
ignore_index: 無視するインデックス(パディング等)
Returns:
loss: スカラー
"""
# (batch_size * seq_len, vocab_size) に reshape
logits_flat = logits.view(-1, logits.size(-1))
targets_flat = targets.view(-1)
loss = F.cross_entropy(
logits_flat,
targets_flat,
ignore_index=ignore_index,
reduction='mean'
)
return loss
# 使用例
batch_size, seq_len, vocab_size = 4, 128, 32000
logits = torch.randn(batch_size, seq_len, vocab_size)
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
# パディング位置を-100に設定
targets[:, -10:] = -100 # 最後の10トークンをパディングと仮定
loss = compute_lm_loss(logits, targets)
print(f"Language Model Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
ラベル平滑化の実装
class LabelSmoothingCrossEntropy(nn.Module):
"""ラベル平滑化付き交差エントロピー"""
def __init__(self, smoothing=0.1, reduction='mean'):
super().__init__()
self.smoothing = smoothing
self.reduction = reduction
def forward(self, logits, targets):
"""
Args:
logits: (batch_size, num_classes)
targets: (batch_size,) クラスインデックス
"""
num_classes = logits.size(-1)
log_probs = F.log_softmax(logits, dim=-1)
# 正解クラスの対数確率
nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
# 全クラスの対数確率の平均(均一分布とのKL項)
smooth_loss = -log_probs.mean(dim=-1)
# 混合
loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
# 使用例
criterion_smooth = LabelSmoothingCrossEntropy(smoothing=0.1)
logits = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))
loss_smooth = criterion_smooth(logits, targets)
loss_normal = F.cross_entropy(logits, targets)
print(f"With Label Smoothing: {loss_smooth.item():.4f}")
print(f"Without Label Smoothing: {loss_normal.item():.4f}")
数値安定性の確認
def demonstrate_numerical_stability():
"""数値安定性の重要性を示す"""
# 極端に大きなlogits
logits_extreme = torch.tensor([[1000.0, 0.0, -1000.0]])
targets = torch.tensor([0])
# 方法1: 安定な実装(PyTorchのcross_entropy)
loss_stable = F.cross_entropy(logits_extreme, targets)
print(f"Stable implementation: {loss_stable.item():.4f}")
# 方法2: 不安定な実装(手動でsoftmax + log)
try:
probs = torch.softmax(logits_extreme, dim=-1)
print(f"Probabilities: {probs}") # [1., 0., 0.] - 正しいがlog(0)問題
log_probs = torch.log(probs)
print(f"Log probabilities: {log_probs}") # -infが含まれる可能性
except Exception as e:
print(f"Unstable implementation error: {e}")
demonstrate_numerical_stability()
損失曲線の可視化
import matplotlib.pyplot as plt
import numpy as np
def visualize_cross_entropy():
"""交差エントロピー損失の特性を可視化"""
# 正解クラスの予測確率
p = np.linspace(0.01, 0.99, 100)
# 交差エントロピー = -log(p)
ce_loss = -np.log(p)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 損失 vs 正解クラス確率
axes[0].plot(p, ce_loss, 'b-', linewidth=2)
axes[0].set_xlabel('Predicted probability for correct class', fontsize=12)
axes[0].set_ylabel('Cross-Entropy Loss', fontsize=12)
axes[0].set_title('Cross-Entropy Loss vs Prediction Confidence', fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 5)
# 注目点
highlight_points = [(0.1, 'p=0.1'), (0.5, 'p=0.5'), (0.9, 'p=0.9')]
for prob, label in highlight_points:
loss = -np.log(prob)
axes[0].scatter([prob], [loss], color='red', s=100, zorder=5)
axes[0].annotate(f'{label}\nL={loss:.2f}', (prob, loss),
textcoords="offset points", xytext=(10, 10))
# 勾配の可視化
gradient = -1 / p # d/dp (-log p) = -1/p
axes[1].plot(p, np.abs(gradient), 'r-', linewidth=2)
axes[1].set_xlabel('Predicted probability for correct class', fontsize=12)
axes[1].set_ylabel('|Gradient|', fontsize=12)
axes[1].set_title('Gradient Magnitude of Cross-Entropy', fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 20)
plt.tight_layout()
plt.savefig('cross_entropy_visualization.png', dpi=150, bbox_inches='tight')
plt.show()
visualize_cross_entropy()
交差エントロピーの勾配
Softmax + Cross-Entropyの勾配
導出
まとめ