交差エントロピー損失の数学的導出とNLPへの応用

交差エントロピー(Cross-Entropy)は、自然言語処理における最も基本的な損失関数です。言語モデル、テキスト分類、機械翻訳など、ほぼすべてのNLPタスクで使用されています。

本記事では、交差エントロピーの情報理論的な背景から、NLPへの具体的な適用方法、そしてPyTorchでの実装まで解説します。

本記事の内容

  • 交差エントロピーの情報理論的定義
  • NLPタスクへの適用(分類、言語モデル、系列ラベリング)
  • ソフトマックス関数との組み合わせ
  • ラベル平滑化(Label Smoothing)
  • 数値安定性と実装の注意点
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

交差エントロピーとは

情報理論的定義

交差エントロピーは、2つの確率分布 $p$ と $q$ の間の「違い」を測る指標です。

真の分布 $p$ に対して、モデル分布 $q$ で符号化したときの平均符号長(ビット数)を表します。

$$ H(p, q) = -\sum_{x} p(x) \log q(x) $$

連続分布の場合:

$$ H(p, q) = -\int p(x) \log q(x) \, dx $$

エントロピーとの関係

真の分布 $p$ のエントロピー $H(p)$ は、最適な符号化での平均符号長です。

$$ H(p) = -\sum_{x} p(x) \log p(x) $$

交差エントロピーは常にエントロピー以上です:

$$ H(p, q) \geq H(p) $$

等号成立は $p = q$ のときのみ。

KLダイバージェンスとの関係

KLダイバージェンス(相対エントロピー)は、交差エントロピーとエントロピーの差です。

$$ D_{\text{KL}}(p \| q) = H(p, q) – H(p) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} $$

したがって:

$$ H(p, q) = H(p) + D_{\text{KL}}(p \| q) $$

重要な洞察: 真の分布 $p$ が固定されているとき(教師あり学習では正解ラベルが固定)、$H(p)$ は定数です。したがって、交差エントロピーの最小化はKLダイバージェンスの最小化と等価です。

$$ \min_q H(p, q) = \min_q D_{\text{KL}}(p \| q) $$

分類タスクへの適用

単一ラベル分類

クラス数を $K$、正解クラスを $y$(one-hotベクトル)、モデルの予測確率を $\hat{y}$ とします。

正解ラベル(one-hot): $$ y_k = \begin{cases} 1 & \text{if } k = c \text{ (正解クラス)} \\ 0 & \text{otherwise} \end{cases} $$

予測確率(softmax出力): $$ \hat{y}_k = P(k \mid x) = \frac{\exp(z_k)}{\sum_{j=1}^{K} \exp(z_j)} $$

交差エントロピー損失: $$ \mathcal{L} = -\sum_{k=1}^{K} y_k \log \hat{y}_k = -\log \hat{y}_c $$

one-hotラベルの場合、正解クラス $c$ の項のみが残ります。

ミニバッチでの計算

バッチサイズ $N$ のミニバッチに対して:

$$ \mathcal{L}_{\text{batch}} = -\frac{1}{N} \sum_{i=1}^{N} \log \hat{y}_{i,c_i} $$

ここで $c_i$ はサンプル $i$ の正解クラスです。

数式の導出

なぜ交差エントロピーが分類タスクに適しているのかを、最尤推定の観点から導出します。

モデルが各クラスの確率 $P(k \mid x; \theta)$ を出力するとき、正解クラス $c$ の対数尤度は:

$$ \log P(c \mid x; \theta) = \log \hat{y}_c $$

$N$ 個の独立なサンプルの対数尤度:

$$ \log L(\theta) = \sum_{i=1}^{N} \log P(c_i \mid x_i; \theta) = \sum_{i=1}^{N} \log \hat{y}_{i,c_i} $$

最尤推定では対数尤度を最大化します。これは負の対数尤度を最小化することと等価です:

$$ \min_\theta \left( -\sum_{i=1}^{N} \log \hat{y}_{i,c_i} \right) $$

これはまさに交差エントロピー損失の和です。

言語モデルへの適用

次トークン予測

言語モデルでは、各位置 $t$ で次のトークン $w_t$ を予測します。

入力: $w_1, w_2, \ldots, w_{t-1}$ 出力: $P(w_t \mid w_1, \ldots, w_{t-1})$

各位置での交差エントロピー:

$$ \mathcal{L}_t = -\log P(w_t \mid w_1, \ldots, w_{t-1}) $$

系列全体の損失(平均):

$$ \mathcal{L} = -\frac{1}{T} \sum_{t=1}^{T} \log P(w_t \mid w_1, \ldots, w_{t-1}) $$

パープレキシティとの関係

パープレキシティは交差エントロピーの指数関数です:

$$ \text{PPL} = \exp(\mathcal{L}) = \exp\left( -\frac{1}{T} \sum_{t=1}^{T} \log P(w_t \mid w_{

ソフトマックス関数との組み合わせ

Softmaxの定義

$K$ クラス分類で、モデルの生出力(logits)を $\bm{z} = (z_1, \ldots, z_K)$ とすると:

$$ \hat{y}_k = \text{softmax}(\bm{z})_k = \frac{\exp(z_k)}{\sum_{j=1}^{K} \exp(z_j)} $$

Log-Softmaxと数値安定性

直接計算すると、$\exp(z_k)$ がオーバーフローする可能性があります。対数を取った形式で計算することで数値安定性を確保できます。

$$ \log \hat{y}_k = z_k – \log \sum_{j=1}^{K} \exp(z_j) $$

さらに、LogSumExpのトリックを用いて:

$$ \log \sum_{j=1}^{K} \exp(z_j) = m + \log \sum_{j=1}^{K} \exp(z_j – m) $$

ここで $m = \max_j z_j$ とすることで、指数関数の引数が0以下になり、オーバーフローを防げます。

組み合わせた損失関数

交差エントロピー損失は、softmaxとlog、そして和を組み合わせたものです:

$$ \mathcal{L} = -\log \text{softmax}(\bm{z})_c = -z_c + \log \sum_{j=1}^{K} \exp(z_j) $$

この形式では、logとexpが打ち消し合い、数値的に安定します。

ラベル平滑化(Label Smoothing)

概要

ラベル平滑化は、one-hotラベルを「ソフト」にする正則化手法です。

通常のラベル: $$ y_k = \begin{cases} 1 & \text{if } k = c \\ 0 & \text{otherwise} \end{cases} $$

平滑化後のラベル: $$ y_k^{\text{smooth}} = \begin{cases} 1 – \epsilon + \frac{\epsilon}{K} & \text{if } k = c \\ \frac{\epsilon}{K} & \text{otherwise} \end{cases} $$

ここで $\epsilon$ は平滑化パラメータ(通常 0.1)、$K$ はクラス数です。

数学的定式化

平滑化ラベルは、one-hotラベルと均一分布の混合と見なせます:

$$ \bm{y}^{\text{smooth}} = (1 – \epsilon) \bm{y}_{\text{one-hot}} + \epsilon \bm{u} $$

ここで $\bm{u} = (1/K, \ldots, 1/K)$ は均一分布です。

効果

1. 過信の防止

ラベル平滑化なしでは、モデルは正解クラスの確率を1に近づけようとし、logitsが極端に大きくなります。これは汎化性能を低下させます。

2. KLダイバージェンス項の追加

ラベル平滑化付き交差エントロピーは、通常の交差エントロピーに正則化項を加えたものと解釈できます:

$$ \mathcal{L}_{\text{smooth}} = (1 – \epsilon) \mathcal{L}_{\text{CE}} + \epsilon \cdot D_{\text{KL}}(\bm{u} \| \hat{\bm{y}}) $$

つまり、予測分布が均一分布から離れすぎないように正則化されます。

系列ラベリングへの適用

固有表現抽出(NER)など

系列ラベリングでは、各トークンにラベルを付与します。

系列長 $T$、各位置の正解ラベルを $y_t$、予測確率を $\hat{y}_t$ とすると:

$$ \mathcal{L} = -\frac{1}{T} \sum_{t=1}^{T} \log P(y_t \mid x, t) $$

パディングの処理

可変長系列を扱う場合、パディングトークンを損失計算から除外する必要があります。

マスク $m_t \in \{0, 1\}$ を用いて:

$$ \mathcal{L} = -\frac{\sum_{t=1}^{T} m_t \log P(y_t \mid x, t)}{\sum_{t=1}^{T} m_t} $$

PyTorchでの実装

基本的な使用方法

import torch
import torch.nn as nn
import torch.nn.functional as F


# 方法1: nn.CrossEntropyLoss (logitsを入力)
criterion = nn.CrossEntropyLoss()
logits = torch.randn(32, 10)  # (batch_size, num_classes)
targets = torch.randint(0, 10, (32,))  # (batch_size,)
loss = criterion(logits, targets)
print(f"CrossEntropyLoss: {loss.item():.4f}")

# 方法2: F.cross_entropy (関数形式)
loss = F.cross_entropy(logits, targets)
print(f"F.cross_entropy: {loss.item():.4f}")

# 方法3: 手動計算(確認用)
log_probs = F.log_softmax(logits, dim=1)
loss_manual = F.nll_loss(log_probs, targets)
print(f"Manual (log_softmax + nll_loss): {loss_manual.item():.4f}")

言語モデルでの使用

def compute_lm_loss(logits, targets, ignore_index=-100):
    """
    言語モデルの交差エントロピー損失を計算

    Args:
        logits: (batch_size, seq_len, vocab_size)
        targets: (batch_size, seq_len)
        ignore_index: 無視するインデックス(パディング等)
    Returns:
        loss: スカラー
    """
    # (batch_size * seq_len, vocab_size) に reshape
    logits_flat = logits.view(-1, logits.size(-1))
    targets_flat = targets.view(-1)

    loss = F.cross_entropy(
        logits_flat,
        targets_flat,
        ignore_index=ignore_index,
        reduction='mean'
    )
    return loss


# 使用例
batch_size, seq_len, vocab_size = 4, 128, 32000
logits = torch.randn(batch_size, seq_len, vocab_size)
targets = torch.randint(0, vocab_size, (batch_size, seq_len))

# パディング位置を-100に設定
targets[:, -10:] = -100  # 最後の10トークンをパディングと仮定

loss = compute_lm_loss(logits, targets)
print(f"Language Model Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")

ラベル平滑化の実装

class LabelSmoothingCrossEntropy(nn.Module):
    """ラベル平滑化付き交差エントロピー"""

    def __init__(self, smoothing=0.1, reduction='mean'):
        super().__init__()
        self.smoothing = smoothing
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        Args:
            logits: (batch_size, num_classes)
            targets: (batch_size,) クラスインデックス
        """
        num_classes = logits.size(-1)
        log_probs = F.log_softmax(logits, dim=-1)

        # 正解クラスの対数確率
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)

        # 全クラスの対数確率の平均(均一分布とのKL項)
        smooth_loss = -log_probs.mean(dim=-1)

        # 混合
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


# 使用例
criterion_smooth = LabelSmoothingCrossEntropy(smoothing=0.1)
logits = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))

loss_smooth = criterion_smooth(logits, targets)
loss_normal = F.cross_entropy(logits, targets)

print(f"With Label Smoothing: {loss_smooth.item():.4f}")
print(f"Without Label Smoothing: {loss_normal.item():.4f}")

数値安定性の確認

def demonstrate_numerical_stability():
    """数値安定性の重要性を示す"""
    # 極端に大きなlogits
    logits_extreme = torch.tensor([[1000.0, 0.0, -1000.0]])
    targets = torch.tensor([0])

    # 方法1: 安定な実装(PyTorchのcross_entropy)
    loss_stable = F.cross_entropy(logits_extreme, targets)
    print(f"Stable implementation: {loss_stable.item():.4f}")

    # 方法2: 不安定な実装(手動でsoftmax + log)
    try:
        probs = torch.softmax(logits_extreme, dim=-1)
        print(f"Probabilities: {probs}")  # [1., 0., 0.] - 正しいがlog(0)問題
        log_probs = torch.log(probs)
        print(f"Log probabilities: {log_probs}")  # -infが含まれる可能性
    except Exception as e:
        print(f"Unstable implementation error: {e}")


demonstrate_numerical_stability()

損失曲線の可視化

import matplotlib.pyplot as plt
import numpy as np


def visualize_cross_entropy():
    """交差エントロピー損失の特性を可視化"""
    # 正解クラスの予測確率
    p = np.linspace(0.01, 0.99, 100)

    # 交差エントロピー = -log(p)
    ce_loss = -np.log(p)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 損失 vs 正解クラス確率
    axes[0].plot(p, ce_loss, 'b-', linewidth=2)
    axes[0].set_xlabel('Predicted probability for correct class', fontsize=12)
    axes[0].set_ylabel('Cross-Entropy Loss', fontsize=12)
    axes[0].set_title('Cross-Entropy Loss vs Prediction Confidence', fontsize=14)
    axes[0].grid(True, alpha=0.3)
    axes[0].set_xlim(0, 1)
    axes[0].set_ylim(0, 5)

    # 注目点
    highlight_points = [(0.1, 'p=0.1'), (0.5, 'p=0.5'), (0.9, 'p=0.9')]
    for prob, label in highlight_points:
        loss = -np.log(prob)
        axes[0].scatter([prob], [loss], color='red', s=100, zorder=5)
        axes[0].annotate(f'{label}\nL={loss:.2f}', (prob, loss),
                        textcoords="offset points", xytext=(10, 10))

    # 勾配の可視化
    gradient = -1 / p  # d/dp (-log p) = -1/p

    axes[1].plot(p, np.abs(gradient), 'r-', linewidth=2)
    axes[1].set_xlabel('Predicted probability for correct class', fontsize=12)
    axes[1].set_ylabel('|Gradient|', fontsize=12)
    axes[1].set_title('Gradient Magnitude of Cross-Entropy', fontsize=14)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_xlim(0, 1)
    axes[1].set_ylim(0, 20)

    plt.tight_layout()
    plt.savefig('cross_entropy_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()


visualize_cross_entropy()

交差エントロピーの勾配

Softmax + Cross-Entropyの勾配

logits $\bm{z}$ に対する勾配を導出します。

$$ \frac{\partial \mathcal{L}}{\partial z_k} = \hat{y}_k – y_k $$

これは非常にシンプルな形です。正解クラス $c$ について:

$$ \frac{\partial \mathcal{L}}{\partial z_c} = \hat{y}_c – 1 $$

その他のクラス $k \neq c$ について:

$$ \frac{\partial \mathcal{L}}{\partial z_k} = \hat{y}_k $$

導出

損失関数を展開すると:

$$ \mathcal{L} = -\log \hat{y}_c = -z_c + \log \sum_{j=1}^{K} \exp(z_j) $$

$z_k$ で微分:

$$ \frac{\partial \mathcal{L}}{\partial z_k} = -\delta_{kc} + \frac{\exp(z_k)}{\sum_j \exp(z_j)} = -\delta_{kc} + \hat{y}_k $$

ここで $\delta_{kc}$ はクロネッカーのデルタ($k = c$ のとき1、それ以外は0)。

まとめ

本記事では、交差エントロピー損失とNLPの関係について解説しました。

  • 情報理論的背景: 交差エントロピーは、真の分布をモデル分布で符号化したときの平均符号長
  • 最尤推定との等価性: 交差エントロピーの最小化は、負の対数尤度の最小化と等価
  • NLPへの適用: 分類、言語モデル、系列ラベリングなど、あらゆるタスクの基盤
  • ラベル平滑化: 過信を防ぎ、汎化性能を向上させる正則化手法
  • 数値安定性: log-softmaxトリックにより、オーバーフローを防止
  • シンプルな勾配: $\hat{y}_k – y_k$ という直感的な形

交差エントロピー損失は、NLPにおいて最も基本的かつ重要な損失関数です。その数学的背景を理解することで、より深いモデル設計が可能になります。

次のステップとして、以下の記事も参考にしてください。