混合精度学習(Mixed Precision Training)は、FP32とFP16/BF16を組み合わせて訓練を効率化する手法です。メモリ使用量を削減し、計算を高速化しながら、モデルの精度を維持できます。
本記事では、浮動小数点数の基礎から、Loss Scaling、PyTorchでのAMP実装までを解説します。
本記事の内容
- 浮動小数点数の精度とその影響
- 混合精度学習の原理
- Loss Scalingの必要性
- PyTorchでのAMP実装
- 実験と効果の検証
浮動小数点数の基礎
IEEE 754フォーマット
浮動小数点数は以下の形式で表現されます:
$$ (-1)^s \times 2^{e – \text{bias}} \times (1 + m) $$
| 精度 | ビット | 符号 | 指数 | 仮数 | バイアス |
|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | 127 |
| FP16 | 16 | 1 | 5 | 10 | 15 |
| BF16 | 16 | 1 | 8 | 7 | 127 |
数値範囲と精度
| 精度 | 最小正値 | 最大値 | 有効桁数 |
|---|---|---|---|
| FP32 | $\approx 10^{-38}$ | $\approx 10^{38}$ | 約7桁 |
| FP16 | $\approx 6 \times 10^{-8}$ | $\approx 65504$ | 約3桁 |
| BF16 | $\approx 10^{-38}$ | $\approx 10^{38}$ | 約2桁 |
FP16の問題点
アンダーフロー問題
勾配が小さい場合($< 6 \times 10^{-8}$)、FP16ではゼロに丸められます。
$$ \text{gradient} = 1 \times 10^{-8} \xrightarrow{\text{FP16}} 0 $$
オーバーフロー問題
値が大きい場合($> 65504$)、FP16では無限大になります。
BF16の特徴
BF16(Brain Floating Point)は、FP32と同じ指数部ビット数を持つため:
- 数値範囲がFP32と同等
- 精度はFP16より低いが、アンダーフロー/オーバーフローに強い
- Loss Scalingが不要な場合が多い
混合精度学習の原理
基本アイデア
- 順伝播・逆伝播: FP16/BF16で高速計算
- パラメータ更新: FP32で精度を維持
- マスターウェイト: FP32でパラメータのコピーを保持
メモリと速度の利点
| 項目 | FP32 | 混合精度 | 改善率 |
|---|---|---|---|
| パラメータメモリ | 4N bytes | 6N bytes* | – |
| 活性化メモリ | 4A bytes | 2A bytes | 50%削減 |
| 計算速度 | 1x | 2-8x** | GPU依存 |
マスターウェイト(FP32)とモデルウェイト(FP16)の合計 *Tensor Cores対応GPUで大きな効果
Tensor Cores
NVIDIA GPUのTensor Coresは、FP16行列演算を高速化する専用ハードウェアです。
$$ \bm{D} = \bm{A} \times \bm{B} + \bm{C} $$
A, Bは FP16、C, DはFP16またはFP32で、1サイクルで4×4行列演算を実行します。
Loss Scaling
必要性
FP16では小さな勾配がアンダーフローします。Loss Scalingは損失を大きくスケーリングして勾配を表現可能な範囲に持ち上げます。
静的Loss Scaling
固定のスケール係数 $s$ を使用:
$$ \mathcal{L}_{\text{scaled}} = s \cdot \mathcal{L} $$
勾配更新時にスケールを戻す:
$$ \bm{\theta} \leftarrow \bm{\theta} – \eta \cdot \frac{1}{s} \nabla_{\bm{\theta}} \mathcal{L}_{\text{scaled}} $$
動的Loss Scaling
訓練中にスケール係数を自動調整:
1. scale = 初期値(例:65536)
2. for each iteration:
scaled_loss = loss * scale
backward(scaled_loss)
if grad contains inf or nan:
scale = scale / 2
skip update
else:
update weights with grad / scale
every N iterations without overflow:
scale = scale * 2
アルゴリズムの詳細:
- オーバーフロー検出: 勾配に
infやnanが含まれるかチェック - スケールダウン: オーバーフロー時にスケールを半減
- スケールアップ: N回連続で成功したらスケールを倍増
PyTorchでのAMP実装
基本的な使い方
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import time
# シンプルなモデル
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(256, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def train_with_amp(model, train_loader, epochs=5, device='cuda'):
"""AMPを使った訓練"""
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# GradScalerの初期化
scaler = GradScaler()
for epoch in range(epochs):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# autocastコンテキスト内で順伝播
with autocast():
output = model(data)
loss = criterion(output, target)
# スケーリングされた勾配で逆伝播
scaler.scale(loss).backward()
# 勾配のアンスケーリングとクリッピング
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# パラメータ更新
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Scale = {scaler.get_scale():.1f}")
def train_without_amp(model, train_loader, epochs=5, device='cuda'):
"""AMPなしの訓練(比較用)"""
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
# ダミーデータローダー
from torch.utils.data import DataLoader, TensorDataset
batch_size = 64
n_samples = 1000
X = torch.randn(n_samples, 3, 32, 32)
y = torch.randint(0, 10, (n_samples,))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# GPU利用可能な場合のみ実行
if torch.cuda.is_available():
print("=== Training with AMP ===")
model_amp = SimpleCNN()
start = time.time()
train_with_amp(model_amp, train_loader, epochs=5)
amp_time = time.time() - start
print(f"AMP Training Time: {amp_time:.2f}s")
print("\n=== Training without AMP ===")
model_fp32 = SimpleCNN()
start = time.time()
train_without_amp(model_fp32, train_loader, epochs=5)
fp32_time = time.time() - start
print(f"FP32 Training Time: {fp32_time:.2f}s")
print(f"\nSpeedup: {fp32_time / amp_time:.2f}x")
else:
print("CUDA is not available. Skipping GPU training.")
BF16の使用
# BF16を使用する場合
if torch.cuda.is_bf16_supported():
with autocast(dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)
# BF16ではLoss Scalingが不要な場合が多い
loss.backward()
optimizer.step()
特定の演算の精度を制御
# 特定の演算でFP32を強制
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def custom_forward(x):
# この演算はFP32で実行される
return torch.log_softmax(x, dim=-1)
# autocast内で一時的にFP32を使用
with autocast():
output = model(data)
# 損失計算はFP32で
with autocast(enabled=False):
output_fp32 = output.float()
loss = criterion(output_fp32, target)
メモリ使用量の比較
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
def measure_memory(model, input_shape, use_amp=False):
"""メモリ使用量を測定"""
if not torch.cuda.is_available():
return None
model = model.cuda()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
x = torch.randn(*input_shape).cuda()
if use_amp:
with autocast():
y = model(x)
loss = y.sum()
else:
y = model(x)
loss = y.sum()
loss.backward()
peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
return peak_memory
# 大きなモデルでテスト
class LargeModel(nn.Module):
def __init__(self, hidden_size=2048, num_layers=12):
super().__init__()
layers = []
for i in range(num_layers):
layers.extend([
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
])
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
if torch.cuda.is_available():
model = LargeModel()
input_shape = (32, 2048)
mem_fp32 = measure_memory(model, input_shape, use_amp=False)
mem_amp = measure_memory(model, input_shape, use_amp=True)
print(f"FP32 Peak Memory: {mem_fp32:.2f} GB")
print(f"AMP Peak Memory: {mem_amp:.2f} GB")
print(f"Memory Reduction: {(1 - mem_amp/mem_fp32) * 100:.1f}%")
GradScalerの詳細設定
from torch.cuda.amp import GradScaler
# カスタム設定のGradScaler
scaler = GradScaler(
init_scale=65536.0, # 初期スケール
growth_factor=2.0, # スケール増加係数
backoff_factor=0.5, # スケール減少係数
growth_interval=2000, # スケール増加の間隔
enabled=True # AMPの有効/無効
)
# スケールの状態を確認
print(f"Current scale: {scaler.get_scale()}")
print(f"Growth tracker: {scaler._get_growth_tracker()}")
# 状態の保存と復元
state_dict = scaler.state_dict()
new_scaler = GradScaler()
new_scaler.load_state_dict(state_dict)
混合精度学習の注意点
FP16で不安定になりやすい演算
| 演算 | 問題 | 対策 |
|---|---|---|
| Softmax | 数値安定性 | FP32で計算 |
| Loss計算 | アンダーフロー | FP32で計算 |
| LayerNorm | 統計量の精度 | FP32で計算 |
| 累積和 | 丸め誤差の蓄積 | FP32で計算 |
PyTorchのautocastの挙動
# autocastが自動的にFP32にする演算
fp32_ops = [
'batch_norm', 'layer_norm', 'group_norm', 'instance_norm',
'softmax', 'log_softmax', 'nll_loss', 'cross_entropy',
'binary_cross_entropy', 'binary_cross_entropy_with_logits'
]
# autocastがFP16にする演算
fp16_ops = [
'linear', 'matmul', 'conv1d', 'conv2d', 'conv3d',
'bmm', 'addmm', 'addbmm'
]
効果の可視化
import numpy as np
import matplotlib.pyplot as plt
def visualize_precision_effects():
"""浮動小数点精度の影響を可視化"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# FP16の表現可能範囲
ax = axes[0, 0]
x = np.logspace(-10, 5, 1000)
fp16_min = 6e-8
fp16_max = 65504
y = np.ones_like(x)
y[x < fp16_min] = 0
y[x > fp16_max] = 0
ax.semilogx(x, y, 'b-', linewidth=2)
ax.axvline(x=fp16_min, color='r', linestyle='--', label=f'FP16 min: {fp16_min:.0e}')
ax.axvline(x=fp16_max, color='g', linestyle='--', label=f'FP16 max: {fp16_max}')
ax.fill_between(x, 0, y, alpha=0.3)
ax.set_xlabel('Value')
ax.set_ylabel('Representable in FP16')
ax.set_title('FP16 Representable Range')
ax.legend()
ax.grid(True, alpha=0.3)
# Loss Scalingの効果
ax = axes[0, 1]
scales = [1, 128, 1024, 8192, 65536]
gradient = 1e-6
for scale in scales:
scaled_grad = gradient * scale
can_represent = scaled_grad >= fp16_min
ax.scatter(scale, scaled_grad, s=100,
c='green' if can_represent else 'red',
label=f'Scale={scale}')
ax.axhline(y=fp16_min, color='r', linestyle='--', alpha=0.5)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Loss Scale')
ax.set_ylabel('Scaled Gradient')
ax.set_title(f'Loss Scaling Effect (original grad = {gradient:.0e})')
ax.legend()
ax.grid(True, alpha=0.3)
# FP16 vs BF16 vs FP32の精度
ax = axes[1, 0]
precisions = ['FP32', 'BF16', 'FP16']
mantissa_bits = [23, 7, 10]
exponent_bits = [8, 8, 5]
x_pos = np.arange(len(precisions))
width = 0.35
ax.bar(x_pos - width/2, mantissa_bits, width, label='Mantissa bits', color='steelblue')
ax.bar(x_pos + width/2, exponent_bits, width, label='Exponent bits', color='coral')
ax.set_xticks(x_pos)
ax.set_xticklabels(precisions)
ax.set_ylabel('Number of Bits')
ax.set_title('Floating Point Format Comparison')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
# 典型的なメモリ/速度改善
ax = axes[1, 1]
metrics = ['Memory\n(Activations)', 'Throughput\n(Tensor Core)', 'Training\nTime']
fp32_values = [1.0, 1.0, 1.0]
amp_values = [0.5, 3.0, 0.5] # 典型的な改善率
x_pos = np.arange(len(metrics))
width = 0.35
ax.bar(x_pos - width/2, fp32_values, width, label='FP32', color='steelblue')
ax.bar(x_pos + width/2, amp_values, width, label='Mixed Precision', color='coral')
ax.set_xticks(x_pos)
ax.set_xticklabels(metrics)
ax.set_ylabel('Relative Value (FP32 = 1.0)')
ax.set_title('Mixed Precision Benefits')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig('mixed_precision_effects.png', dpi=150, bbox_inches='tight')
plt.show()
visualize_precision_effects()
まとめ
本記事では、混合精度学習について解説しました。
- 浮動小数点精度: FP16は範囲が狭く、アンダーフロー/オーバーフローに注意
- 混合精度の原理: 計算はFP16、パラメータ更新はFP32
- Loss Scaling: 小さな勾配のアンダーフローを防ぐ
- BF16: 範囲が広くLoss Scalingが不要な場合が多い
- PyTorch AMP: autocastとGradScalerで簡単に実装可能
混合精度学習は、メモリ使用量を削減し訓練を高速化する効果的な手法です。特にGPUのTensor Coresを活用できる場合、大きな性能向上が期待できます。
次のステップとして、以下の記事も参考にしてください。