知識蒸留の理論 — 温度パラメータの数学的導出と実装

知識蒸留(Knowledge Distillation)は、大規模な教師モデルの知識を小規模な生徒モデルに転移する技術です。モデルの軽量化と推論速度の向上を実現しながら、性能の低下を最小限に抑えることができます。

本記事では、知識蒸留の理論的背景から実装まで詳しく解説します。

本記事の内容

  • 知識蒸留の基本概念と直感的理解
  • ソフトターゲットと温度パラメータの数学的意味
  • 蒸留損失の導出と実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

知識蒸留とは

基本的なアイデア

深層学習モデルは、層を深くしパラメータを増やすことで性能が向上しますが、推論時のコスト(計算量、メモリ、遅延)も増加します。知識蒸留は、この問題を解決する技術の一つです。

知識蒸留の基本的なアイデア:

  1. 教師モデル(Teacher): 大規模で高精度なモデル(事前学習済み)
  2. 生徒モデル(Student): 小規模で軽量なモデル(これから学習)
  3. 蒸留: 教師の「知識」を生徒に転移

なぜ蒸留が効くのか

通常の分類学習では、正解ラベル(ハードターゲット)のみを使用します:

$$ y_{\text{hard}} = [0, 0, 1, 0, 0]^\top \quad (\text{クラス3が正解}) $$

一方、教師モデルの出力(ソフトターゲット)には追加の情報が含まれています:

$$ y_{\text{soft}} = [0.02, 0.05, 0.85, 0.05, 0.03]^\top $$

このソフトターゲットには「クラス3が正解だが、クラス2や4も少し似ている」というクラス間の関係性が暗黙的に含まれています。

知識蒸留の数学的定式化

ソフトマックスと温度パラメータ

ニューラルネットワークの出力(ロジット)$\bm{z} = (z_1, z_2, \ldots, z_K)$ に対して、温度付きソフトマックスを定義します:

$$ q_i(T) = \frac{\exp(z_i / T)}{\sum_{j=1}^{K} \exp(z_j / T)} $$

ここで $T > 0$ は温度パラメータです。

温度の効果: – $T = 1$: 通常のソフトマックス – $T > 1$: 分布が平滑化(クラス間の差が小さくなる) – $T < 1$: 分布が尖鋭化(クラス間の差が大きくなる) - $T \to \infty$: 一様分布に近づく - $T \to 0$: ハードな選択(argmax)に近づく

温度による平滑化の導出

温度が高いとき、分布がどのように変化するか見てみましょう。

$T$ が大きいとき、$z_i / T$ が小さくなるため、テイラー展開を適用できます:

$$ \exp(z_i / T) \approx 1 + \frac{z_i}{T} + O\left(\frac{1}{T^2}\right) $$

したがって:

$$ \begin{align} q_i(T) &\approx \frac{1 + z_i / T}{\sum_j (1 + z_j / T)} \\ &= \frac{1 + z_i / T}{K + \sum_j z_j / T} \end{align} $$

$\sum_j z_j = 0$ と仮定すると:

$$ q_i(T) \approx \frac{1}{K} + \frac{z_i}{KT} $$

これは、$T$ が大きいほど分布が一様分布 $1/K$ に近づくことを示しています。

蒸留損失

知識蒸留では、以下の2つの損失を組み合わせます:

  1. ハードターゲット損失: 正解ラベルとの通常のクロスエントロピー
  2. 蒸留損失: 教師の出力分布とのKLダイバージェンス

ハードターゲット損失

$$ \mathcal{L}_{\text{hard}} = -\sum_{i=1}^{K} y_i \log p_i(T=1) $$

ここで $y_i$ は正解ラベル(one-hotベクトル)、$p_i$ は生徒モデルの出力です。

蒸留損失

$$ \mathcal{L}_{\text{soft}} = T^2 \cdot D_{\text{KL}}(q^{(T)} \| p^{(T)}) $$

ここで $q^{(T)}$ は教師モデルの温度 $T$ での出力、$p^{(T)}$ は生徒モデルの温度 $T$ での出力です。

$T^2$ の係数は、温度が高いときの勾配のスケールを調整するために必要です。

全体の損失

$$ \mathcal{L} = \alpha \mathcal{L}_{\text{hard}} + (1 – \alpha) \mathcal{L}_{\text{soft}} $$

$\alpha \in [0, 1]$ はハードターゲットとソフトターゲットのバランスを調整するハイパーパラメータです。

$T^2$ の導出

なぜ蒸留損失に $T^2$ が必要なのかを導出します。

温度 $T$ でのソフトマックスの勾配を考えます。生徒のロジット $z_i$ に関する勾配:

$$ \frac{\partial q_i(T)}{\partial z_i} = \frac{q_i(T)(1 – q_i(T))}{T} $$

温度が高いとき $q_i(T) \approx 1/K$ なので:

$$ \frac{\partial q_i(T)}{\partial z_i} \approx \frac{1}{KT} \cdot \frac{K-1}{K} \approx \frac{1}{T} $$

したがって、$T = 1$ のときと同程度の勾配を得るには $T^2$ でスケーリングする必要があります。

Pythonでの実装

温度付きソフトマックスの実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def softmax_with_temperature(logits, temperature=1.0):
    """温度付きソフトマックス"""
    return F.softmax(logits / temperature, dim=-1)

# 温度の効果を可視化
logits = torch.tensor([2.0, 1.0, 0.5, 0.2, 0.1])

temperatures = [0.5, 1.0, 2.0, 5.0, 10.0]

plt.figure(figsize=(10, 6))

for T in temperatures:
    probs = softmax_with_temperature(logits, T).numpy()
    plt.plot(range(len(probs)), probs, 'o-', label=f'T = {T}')

plt.xlabel('Class Index')
plt.ylabel('Probability')
plt.title('Effect of Temperature on Softmax Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

蒸留損失の実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """知識蒸留の損失関数"""

    def __init__(self, temperature=4.0, alpha=0.5):
        """
        Args:
            temperature: 蒸留の温度パラメータ
            alpha: ハードターゲット損失の重み (0-1)
        """
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        """
        Args:
            student_logits: 生徒モデルのロジット (batch_size, num_classes)
            teacher_logits: 教師モデルのロジット (batch_size, num_classes)
            labels: 正解ラベル (batch_size,)
        Returns:
            loss: 蒸留損失
        """
        T = self.temperature

        # ハードターゲット損失(通常のクロスエントロピー)
        hard_loss = self.ce_loss(student_logits, labels)

        # ソフトターゲット損失(KLダイバージェンス)
        soft_student = F.log_softmax(student_logits / T, dim=-1)
        soft_teacher = F.softmax(teacher_logits / T, dim=-1)

        # KLダイバージェンス(T^2でスケーリング)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)

        # 全体の損失
        loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss

        return loss

# 動作確認
torch.manual_seed(42)
batch_size = 32
num_classes = 10

student_logits = torch.randn(batch_size, num_classes)
teacher_logits = torch.randn(batch_size, num_classes) * 2  # 教師はより自信を持った出力
labels = torch.randint(0, num_classes, (batch_size,))

criterion = DistillationLoss(temperature=4.0, alpha=0.5)
loss = criterion(student_logits, teacher_logits, labels)
print(f"Distillation Loss: {loss.item():.4f}")

完全な知識蒸留の例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# シンプルなニューラルネットワーク
class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_classes):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def create_synthetic_dataset(n_samples=2000, input_dim=20, num_classes=5):
    """合成データセットの作成"""
    np.random.seed(42)

    X = []
    y = []

    for class_idx in range(num_classes):
        center = np.random.randn(input_dim) * 3
        samples = center + np.random.randn(n_samples // num_classes, input_dim) * 0.5
        X.append(samples)
        y.extend([class_idx] * (n_samples // num_classes))

    X = np.vstack(X).astype(np.float32)
    y = np.array(y)

    # シャッフル
    indices = np.random.permutation(len(X))
    X, y = X[indices], y[indices]

    # 訓練/テスト分割
    split = int(0.8 * len(X))
    X_train, X_test = X[:split], X[split:]
    y_train, y_test = y[:split], y[split:]

    return X_train, y_train, X_test, y_test

def train_model(model, train_loader, criterion, optimizer, n_epochs):
    """モデルの学習"""
    model.train()
    losses = []

    for epoch in range(n_epochs):
        epoch_loss = 0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)

    return losses

def evaluate_model(model, test_loader):
    """モデルの評価"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    return correct / total

def distill_knowledge(teacher, student, train_loader, n_epochs, temperature=4.0, alpha=0.5, lr=0.001):
    """知識蒸留の実行"""
    teacher.eval()
    student.train()

    criterion = DistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = optim.Adam(student.parameters(), lr=lr)

    losses = []

    for epoch in range(n_epochs):
        epoch_loss = 0

        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()

            # 教師の出力(勾配不要)
            with torch.no_grad():
                teacher_logits = teacher(X_batch)

            # 生徒の出力
            student_logits = student(X_batch)

            # 蒸留損失
            loss = criterion(student_logits, teacher_logits, y_batch)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {avg_loss:.4f}")

    return losses

# データ準備
X_train, y_train, X_test, y_test = create_synthetic_dataset()

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

input_dim = X_train.shape[1]
num_classes = len(np.unique(y_train))

# 教師モデル(大きなモデル)の学習
print("Training Teacher Model...")
teacher = SimpleNet(input_dim, hidden_dims=[256, 128, 64], num_classes=num_classes)
teacher_optimizer = optim.Adam(teacher.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()
teacher_losses = train_model(teacher, train_loader, teacher_criterion, teacher_optimizer, n_epochs=50)
teacher_acc = evaluate_model(teacher, test_loader)
print(f"Teacher Accuracy: {teacher_acc:.4f}")

# 生徒モデル(小さなモデル)を通常学習
print("\nTraining Student (No Distillation)...")
student_no_distill = SimpleNet(input_dim, hidden_dims=[32], num_classes=num_classes)
student_optimizer = optim.Adam(student_no_distill.parameters(), lr=0.001)
student_criterion = nn.CrossEntropyLoss()
student_losses = train_model(student_no_distill, train_loader, student_criterion, student_optimizer, n_epochs=50)
student_no_distill_acc = evaluate_model(student_no_distill, test_loader)
print(f"Student (No Distillation) Accuracy: {student_no_distill_acc:.4f}")

# 生徒モデルを知識蒸留で学習
print("\nTraining Student (With Distillation)...")
student_distilled = SimpleNet(input_dim, hidden_dims=[32], num_classes=num_classes)
distill_losses = distill_knowledge(teacher, student_distilled, train_loader, n_epochs=50, temperature=4.0, alpha=0.3)
student_distilled_acc = evaluate_model(student_distilled, test_loader)
print(f"Student (With Distillation) Accuracy: {student_distilled_acc:.4f}")

# 結果の可視化
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(teacher_losses, label='Teacher')
plt.plot(student_losses, label='Student (No Distill)')
plt.plot(distill_losses, label='Student (Distilled)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
models = ['Teacher\n(Large)', 'Student\n(No Distill)', 'Student\n(Distilled)']
accuracies = [teacher_acc, student_no_distill_acc, student_distilled_acc]
colors = ['green', 'orange', 'blue']
bars = plt.bar(models, accuracies, color=colors, alpha=0.7)
plt.ylabel('Accuracy')
plt.title('Model Accuracy Comparison')
plt.ylim(0, 1)
for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{acc:.3f}', ha='center')

plt.tight_layout()
plt.show()

# モデルサイズの比較
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Size Comparison:")
print(f"Teacher: {count_parameters(teacher):,} parameters")
print(f"Student: {count_parameters(student_distilled):,} parameters")
print(f"Compression Ratio: {count_parameters(teacher) / count_parameters(student_distilled):.1f}x")

知識蒸留の発展形

様々な蒸留手法

手法 説明
Response-based 出力層のロジット/確率を蒸留(本記事で解説)
Feature-based 中間層の特徴マップを蒸留
Relation-based サンプル間の関係性を蒸留
Self-distillation 同じモデルの過去の状態から蒸留

Feature-based Distillation

中間層の特徴を蒸留する場合、追加の損失項を加えます:

$$ \mathcal{L}_{\text{feature}} = \sum_{l} \| \phi(\bm{h}_l^{(S)}) – \bm{h}_l^{(T)} \|^2 $$

ここで $\bm{h}_l^{(S)}, \bm{h}_l^{(T)}$ はそれぞれ生徒と教師の第 $l$ 層の特徴、$\phi$ は次元を合わせる変換関数です。

まとめ

本記事では、知識蒸留(Knowledge Distillation)について解説しました。

  • 知識蒸留は、大規模な教師モデルから小規模な生徒モデルへ知識を転移する技術
  • ソフトターゲット(教師の出力分布)にはクラス間の関係性という「暗黙知」が含まれる
  • 温度パラメータにより分布の平滑化を制御(典型的には $T = 3 \sim 5$)
  • 蒸留損失は $T^2$ でスケーリングして勾配のバランスを取る

次のステップとして、以下の記事も参考にしてください。