Mixed Precision Trainingで学習を高速化する方法

混合精度学習(Mixed Precision Training)は、FP32とFP16/BF16を組み合わせて訓練を効率化する手法です。メモリ使用量を削減し、計算を高速化しながら、モデルの精度を維持できます。

本記事では、浮動小数点数の基礎から、Loss Scaling、PyTorchでのAMP実装までを解説します。

本記事の内容

  • 浮動小数点数の精度とその影響
  • 混合精度学習の原理
  • Loss Scalingの必要性
  • PyTorchでのAMP実装
  • 実験と効果の検証

浮動小数点数の基礎

IEEE 754フォーマット

浮動小数点数は以下の形式で表現されます:

$$ (-1)^s \times 2^{e – \text{bias}} \times (1 + m) $$

精度 ビット 符号 指数 仮数 バイアス
FP32 32 1 8 23 127
FP16 16 1 5 10 15
BF16 16 1 8 7 127

数値範囲と精度

精度 最小正値 最大値 有効桁数
FP32 $\approx 10^{-38}$ $\approx 10^{38}$ 約7桁
FP16 $\approx 6 \times 10^{-8}$ $\approx 65504$ 約3桁
BF16 $\approx 10^{-38}$ $\approx 10^{38}$ 約2桁

FP16の問題点

アンダーフロー問題

勾配が小さい場合($< 6 \times 10^{-8}$)、FP16ではゼロに丸められます。

$$ \text{gradient} = 1 \times 10^{-8} \xrightarrow{\text{FP16}} 0 $$

オーバーフロー問題

値が大きい場合($> 65504$)、FP16では無限大になります。

BF16の特徴

BF16(Brain Floating Point)は、FP32と同じ指数部ビット数を持つため:

  • 数値範囲がFP32と同等
  • 精度はFP16より低いが、アンダーフロー/オーバーフローに強い
  • Loss Scalingが不要な場合が多い

混合精度学習の原理

基本アイデア

  • 順伝播・逆伝播: FP16/BF16で高速計算
  • パラメータ更新: FP32で精度を維持
  • マスターウェイト: FP32でパラメータのコピーを保持

メモリと速度の利点

項目 FP32 混合精度 改善率
パラメータメモリ 4N bytes 6N bytes*
活性化メモリ 4A bytes 2A bytes 50%削減
計算速度 1x 2-8x** GPU依存

マスターウェイト(FP32)とモデルウェイト(FP16)の合計 *Tensor Cores対応GPUで大きな効果

Tensor Cores

NVIDIA GPUのTensor Coresは、FP16行列演算を高速化する専用ハードウェアです。

$$ \bm{D} = \bm{A} \times \bm{B} + \bm{C} $$

A, Bは FP16、C, DはFP16またはFP32で、1サイクルで4×4行列演算を実行します。

Loss Scaling

必要性

FP16では小さな勾配がアンダーフローします。Loss Scalingは損失を大きくスケーリングして勾配を表現可能な範囲に持ち上げます。

静的Loss Scaling

固定のスケール係数 $s$ を使用:

$$ \mathcal{L}_{\text{scaled}} = s \cdot \mathcal{L} $$

勾配更新時にスケールを戻す:

$$ \bm{\theta} \leftarrow \bm{\theta} – \eta \cdot \frac{1}{s} \nabla_{\bm{\theta}} \mathcal{L}_{\text{scaled}} $$

動的Loss Scaling

訓練中にスケール係数を自動調整:

1. scale = 初期値(例:65536)
2. for each iteration:
     scaled_loss = loss * scale
     backward(scaled_loss)

     if grad contains inf or nan:
       scale = scale / 2
       skip update
     else:
       update weights with grad / scale
       every N iterations without overflow:
         scale = scale * 2

アルゴリズムの詳細

  • オーバーフロー検出: 勾配に infnan が含まれるかチェック
  • スケールダウン: オーバーフロー時にスケールを半減
  • スケールアップ: N回連続で成功したらスケールを倍増

PyTorchでのAMP実装

基本的な使い方

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import time

# シンプルなモデル
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


def train_with_amp(model, train_loader, epochs=5, device='cuda'):
    """AMPを使った訓練"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # GradScalerの初期化
    scaler = GradScaler()

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()

            # autocastコンテキスト内で順伝播
            with autocast():
                output = model(data)
                loss = criterion(output, target)

            # スケーリングされた勾配で逆伝播
            scaler.scale(loss).backward()

            # 勾配のアンスケーリングとクリッピング
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # パラメータ更新
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Scale = {scaler.get_scale():.1f}")


def train_without_amp(model, train_loader, epochs=5, device='cuda'):
    """AMPなしの訓練(比較用)"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")


# ダミーデータローダー
from torch.utils.data import DataLoader, TensorDataset

batch_size = 64
n_samples = 1000
X = torch.randn(n_samples, 3, 32, 32)
y = torch.randint(0, 10, (n_samples,))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# GPU利用可能な場合のみ実行
if torch.cuda.is_available():
    print("=== Training with AMP ===")
    model_amp = SimpleCNN()
    start = time.time()
    train_with_amp(model_amp, train_loader, epochs=5)
    amp_time = time.time() - start
    print(f"AMP Training Time: {amp_time:.2f}s")

    print("\n=== Training without AMP ===")
    model_fp32 = SimpleCNN()
    start = time.time()
    train_without_amp(model_fp32, train_loader, epochs=5)
    fp32_time = time.time() - start
    print(f"FP32 Training Time: {fp32_time:.2f}s")

    print(f"\nSpeedup: {fp32_time / amp_time:.2f}x")
else:
    print("CUDA is not available. Skipping GPU training.")

BF16の使用

# BF16を使用する場合
if torch.cuda.is_bf16_supported():
    with autocast(dtype=torch.bfloat16):
        output = model(data)
        loss = criterion(output, target)

    # BF16ではLoss Scalingが不要な場合が多い
    loss.backward()
    optimizer.step()

特定の演算の精度を制御

# 特定の演算でFP32を強制
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def custom_forward(x):
    # この演算はFP32で実行される
    return torch.log_softmax(x, dim=-1)


# autocast内で一時的にFP32を使用
with autocast():
    output = model(data)

    # 損失計算はFP32で
    with autocast(enabled=False):
        output_fp32 = output.float()
        loss = criterion(output_fp32, target)

メモリ使用量の比較

import torch
import torch.nn as nn
from torch.cuda.amp import autocast

def measure_memory(model, input_shape, use_amp=False):
    """メモリ使用量を測定"""
    if not torch.cuda.is_available():
        return None

    model = model.cuda()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    x = torch.randn(*input_shape).cuda()

    if use_amp:
        with autocast():
            y = model(x)
            loss = y.sum()
    else:
        y = model(x)
        loss = y.sum()

    loss.backward()

    peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return peak_memory


# 大きなモデルでテスト
class LargeModel(nn.Module):
    def __init__(self, hidden_size=2048, num_layers=12):
        super().__init__()
        layers = []
        for i in range(num_layers):
            layers.extend([
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
            ])
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


if torch.cuda.is_available():
    model = LargeModel()
    input_shape = (32, 2048)

    mem_fp32 = measure_memory(model, input_shape, use_amp=False)
    mem_amp = measure_memory(model, input_shape, use_amp=True)

    print(f"FP32 Peak Memory: {mem_fp32:.2f} GB")
    print(f"AMP Peak Memory: {mem_amp:.2f} GB")
    print(f"Memory Reduction: {(1 - mem_amp/mem_fp32) * 100:.1f}%")

GradScalerの詳細設定

from torch.cuda.amp import GradScaler

# カスタム設定のGradScaler
scaler = GradScaler(
    init_scale=65536.0,        # 初期スケール
    growth_factor=2.0,          # スケール増加係数
    backoff_factor=0.5,         # スケール減少係数
    growth_interval=2000,       # スケール増加の間隔
    enabled=True                # AMPの有効/無効
)

# スケールの状態を確認
print(f"Current scale: {scaler.get_scale()}")
print(f"Growth tracker: {scaler._get_growth_tracker()}")

# 状態の保存と復元
state_dict = scaler.state_dict()
new_scaler = GradScaler()
new_scaler.load_state_dict(state_dict)

混合精度学習の注意点

FP16で不安定になりやすい演算

演算 問題 対策
Softmax 数値安定性 FP32で計算
Loss計算 アンダーフロー FP32で計算
LayerNorm 統計量の精度 FP32で計算
累積和 丸め誤差の蓄積 FP32で計算

PyTorchのautocastの挙動

# autocastが自動的にFP32にする演算
fp32_ops = [
    'batch_norm', 'layer_norm', 'group_norm', 'instance_norm',
    'softmax', 'log_softmax', 'nll_loss', 'cross_entropy',
    'binary_cross_entropy', 'binary_cross_entropy_with_logits'
]

# autocastがFP16にする演算
fp16_ops = [
    'linear', 'matmul', 'conv1d', 'conv2d', 'conv3d',
    'bmm', 'addmm', 'addbmm'
]

効果の可視化

import numpy as np
import matplotlib.pyplot as plt

def visualize_precision_effects():
    """浮動小数点精度の影響を可視化"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # FP16の表現可能範囲
    ax = axes[0, 0]
    x = np.logspace(-10, 5, 1000)
    fp16_min = 6e-8
    fp16_max = 65504

    y = np.ones_like(x)
    y[x < fp16_min] = 0
    y[x > fp16_max] = 0

    ax.semilogx(x, y, 'b-', linewidth=2)
    ax.axvline(x=fp16_min, color='r', linestyle='--', label=f'FP16 min: {fp16_min:.0e}')
    ax.axvline(x=fp16_max, color='g', linestyle='--', label=f'FP16 max: {fp16_max}')
    ax.fill_between(x, 0, y, alpha=0.3)
    ax.set_xlabel('Value')
    ax.set_ylabel('Representable in FP16')
    ax.set_title('FP16 Representable Range')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Loss Scalingの効果
    ax = axes[0, 1]
    scales = [1, 128, 1024, 8192, 65536]
    gradient = 1e-6

    for scale in scales:
        scaled_grad = gradient * scale
        can_represent = scaled_grad >= fp16_min
        ax.scatter(scale, scaled_grad, s=100,
                  c='green' if can_represent else 'red',
                  label=f'Scale={scale}')

    ax.axhline(y=fp16_min, color='r', linestyle='--', alpha=0.5)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Loss Scale')
    ax.set_ylabel('Scaled Gradient')
    ax.set_title(f'Loss Scaling Effect (original grad = {gradient:.0e})')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # FP16 vs BF16 vs FP32の精度
    ax = axes[1, 0]
    precisions = ['FP32', 'BF16', 'FP16']
    mantissa_bits = [23, 7, 10]
    exponent_bits = [8, 8, 5]

    x_pos = np.arange(len(precisions))
    width = 0.35

    ax.bar(x_pos - width/2, mantissa_bits, width, label='Mantissa bits', color='steelblue')
    ax.bar(x_pos + width/2, exponent_bits, width, label='Exponent bits', color='coral')

    ax.set_xticks(x_pos)
    ax.set_xticklabels(precisions)
    ax.set_ylabel('Number of Bits')
    ax.set_title('Floating Point Format Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

    # 典型的なメモリ/速度改善
    ax = axes[1, 1]
    metrics = ['Memory\n(Activations)', 'Throughput\n(Tensor Core)', 'Training\nTime']
    fp32_values = [1.0, 1.0, 1.0]
    amp_values = [0.5, 3.0, 0.5]  # 典型的な改善率

    x_pos = np.arange(len(metrics))
    width = 0.35

    ax.bar(x_pos - width/2, fp32_values, width, label='FP32', color='steelblue')
    ax.bar(x_pos + width/2, amp_values, width, label='Mixed Precision', color='coral')

    ax.set_xticks(x_pos)
    ax.set_xticklabels(metrics)
    ax.set_ylabel('Relative Value (FP32 = 1.0)')
    ax.set_title('Mixed Precision Benefits')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('mixed_precision_effects.png', dpi=150, bbox_inches='tight')
    plt.show()

visualize_precision_effects()

まとめ

本記事では、混合精度学習について解説しました。

  • 浮動小数点精度: FP16は範囲が狭く、アンダーフロー/オーバーフローに注意
  • 混合精度の原理: 計算はFP16、パラメータ更新はFP32
  • Loss Scaling: 小さな勾配のアンダーフローを防ぐ
  • BF16: 範囲が広くLoss Scalingが不要な場合が多い
  • PyTorch AMP: autocastとGradScalerで簡単に実装可能

混合精度学習は、メモリ使用量を削減し訓練を高速化する効果的な手法です。特にGPUのTensor Coresを活用できる場合、大きな性能向上が期待できます。

次のステップとして、以下の記事も参考にしてください。