チェックポイント管理は、深層学習の訓練において重要な実践的スキルです。訓練中のモデル状態を保存し、障害からの復旧や最良モデルの保持を可能にします。
本記事では、PyTorchでのチェックポイント管理の基本から、分散学習での実装まで解説します。
本記事の内容
- チェックポイントの基本概念
- 保存すべき情報と形式
- 訓練の再開と復元
- 分散学習でのチェックポイント
- ベストプラクティス
チェックポイントの基本概念
なぜチェックポイントが必要か
- 障害復旧: GPU故障やOOMで訓練が中断しても再開可能
- 最良モデルの保持: 検証性能が最良の時点のモデルを保存
- 長時間訓練の管理: 定期的に状態を保存して進捗を記録
- 実験の再現性: 訓練状態を完全に復元して再現実験
保存すべき情報
完全な訓練状態の復元には以下が必要です:
| 項目 | 内容 | 必須度 |
|---|---|---|
| model_state_dict | モデルのパラメータ | 必須 |
| optimizer_state_dict | オプティマイザの状態(momentum等) | 訓練再開に必須 |
| scheduler_state_dict | 学習率スケジューラの状態 | 訓練再開に必須 |
| epoch | 現在のエポック数 | 推奨 |
| loss | 訓練/検証損失 | 推奨 |
| rng_state | 乱数生成器の状態 | 完全再現に必要 |
| scaler_state_dict | GradScalerの状態(AMP使用時) | AMP使用時必須 |
PyTorchでの基本実装
基本的なチェックポイント保存
import torch
import torch.nn as nn
import torch.optim as optim
import os
from datetime import datetime
def save_checkpoint(state, filename):
"""チェックポイントを保存"""
torch.save(state, filename)
print(f"Checkpoint saved: {filename}")
def load_checkpoint(filename, model, optimizer=None, scheduler=None):
"""チェックポイントを読み込み"""
if not os.path.exists(filename):
print(f"Checkpoint not found: {filename}")
return None
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"Checkpoint loaded: {filename}")
return checkpoint
# 使用例
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# 保存
checkpoint = {
'epoch': 50,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': 0.15,
'val_loss': 0.18,
}
save_checkpoint(checkpoint, 'checkpoint_epoch50.pt')
# 読み込み
loaded = load_checkpoint('checkpoint_epoch50.pt', model, optimizer, scheduler)
start_epoch = loaded['epoch'] + 1 if loaded else 0
完全な訓練ループ
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
class CheckpointManager:
"""チェックポイント管理クラス"""
def __init__(self, checkpoint_dir='checkpoints', max_keep=5):
self.checkpoint_dir = checkpoint_dir
self.max_keep = max_keep
os.makedirs(checkpoint_dir, exist_ok=True)
self.checkpoints = []
def save(self, state, epoch, is_best=False):
"""チェックポイントを保存"""
# 通常のチェックポイント
filename = os.path.join(self.checkpoint_dir, f'checkpoint_epoch{epoch:04d}.pt')
torch.save(state, filename)
self.checkpoints.append(filename)
# 最良モデルの保存
if is_best:
best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
torch.save(state, best_path)
print(f"Best model saved at epoch {epoch}")
# 古いチェックポイントを削除
while len(self.checkpoints) > self.max_keep:
old_ckpt = self.checkpoints.pop(0)
if os.path.exists(old_ckpt) and 'best' not in old_ckpt:
os.remove(old_ckpt)
print(f"Checkpoint saved: {filename}")
def load_latest(self, model, optimizer=None, scheduler=None):
"""最新のチェックポイントを読み込み"""
# チェックポイントファイルを探す
ckpt_files = sorted([
f for f in os.listdir(self.checkpoint_dir)
if f.startswith('checkpoint_epoch') and f.endswith('.pt')
])
if not ckpt_files:
print("No checkpoint found")
return None
latest = os.path.join(self.checkpoint_dir, ckpt_files[-1])
return self.load(latest, model, optimizer, scheduler)
def load_best(self, model):
"""最良のチェックポイントを読み込み"""
best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
if not os.path.exists(best_path):
print("Best checkpoint not found")
return None
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model loaded (epoch {checkpoint.get('epoch', 'N/A')})")
return checkpoint
def load(self, path, model, optimizer=None, scheduler=None):
"""指定されたチェックポイントを読み込み"""
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"Loaded checkpoint: {path}")
return checkpoint
def train_with_checkpointing(model, train_loader, val_loader, epochs=100,
checkpoint_dir='checkpoints', resume=False):
"""チェックポイント付きの訓練"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
ckpt_manager = CheckpointManager(checkpoint_dir)
best_val_loss = float('inf')
start_epoch = 0
# 訓練再開
if resume:
checkpoint = ckpt_manager.load_latest(model, optimizer, scheduler)
if checkpoint:
start_epoch = checkpoint['epoch'] + 1
best_val_loss = checkpoint.get('best_val_loss', float('inf'))
print(f"Resuming from epoch {start_epoch}")
for epoch in range(start_epoch, epochs):
# 訓練
model.train()
train_loss = 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# 検証
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
val_loss /= len(val_loader)
accuracy = correct / len(val_loader.dataset)
scheduler.step(val_loss)
# チェックポイント保存
is_best = val_loss < best_val_loss
if is_best:
best_val_loss = val_loss
state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_val_loss': best_val_loss,
'accuracy': accuracy,
}
# 10エポックごと、または最良モデル時に保存
if epoch % 10 == 0 or is_best:
ckpt_manager.save(state, epoch, is_best=is_best)
print(f"Epoch {epoch}: train_loss={train_loss:.4f}, "
f"val_loss={val_loss:.4f}, accuracy={accuracy:.4f}")
# ダミーデータ
train_data = TensorDataset(torch.randn(1000, 784), torch.randint(0, 10, (1000,)))
val_data = TensorDataset(torch.randn(200, 784), torch.randint(0, 10, (200,)))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)
# 訓練
model = SimpleModel()
# train_with_checkpointing(model, train_loader, val_loader, epochs=100, resume=True)
乱数状態の完全保存
import random
import numpy as np
import torch
def save_rng_states():
"""乱数生成器の状態を保存"""
return {
'python': random.getstate(),
'numpy': np.random.get_state(),
'torch': torch.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
}
def load_rng_states(rng_states):
"""乱数生成器の状態を復元"""
random.setstate(rng_states['python'])
np.random.set_state(rng_states['numpy'])
torch.set_rng_state(rng_states['torch'])
if rng_states['cuda'] is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng_states['cuda'])
# 使用例
checkpoint = {
'model_state_dict': model.state_dict(),
'rng_states': save_rng_states(),
# ...
}
torch.save(checkpoint, 'checkpoint_with_rng.pt')
# 復元
loaded = torch.load('checkpoint_with_rng.pt')
load_rng_states(loaded['rng_states'])
分散学習でのチェックポイント
DDPでのチェックポイント
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def save_ddp_checkpoint(model, optimizer, epoch, path, rank):
"""DDPモデルのチェックポイント保存"""
# rank 0のみが保存
if rank != 0:
return
# DDPモデルの場合、.moduleでアクセス
model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
checkpoint = {
'epoch': epoch,
'model_state_dict': model_state,
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, path)
print(f"Checkpoint saved: {path}")
def load_ddp_checkpoint(path, model, optimizer, map_location):
"""DDPモデルのチェックポイント読み込み"""
checkpoint = torch.load(path, map_location=map_location)
# DDPモデルの場合、.moduleにロード
if hasattr(model, 'module'):
model.module.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint
# 使用例(DDPの訓練ループ内)
def train_ddp_with_checkpoint(rank, world_size, epochs=100):
# ... 初期化 ...
model = SimpleModel().to(rank)
model = DDP(model, device_ids=[rank])
optimizer = optim.Adam(model.parameters())
for epoch in range(epochs):
# ... 訓練 ...
# チェックポイント保存(rank 0のみ)
if epoch % 10 == 0:
save_ddp_checkpoint(model, optimizer, epoch, f'checkpoint_{epoch}.pt', rank)
# 全プロセスで同期(保存完了を待つ)
dist.barrier()
FSDPでのチェックポイント
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig,
ShardedStateDictConfig,
)
from torch.distributed.checkpoint import (
save_state_dict,
load_state_dict,
)
def save_fsdp_full_state(model, optimizer, path, rank):
"""FSDP: 完全な状態辞書として保存(rank 0のみ)"""
# 全パラメータをrank 0に集約
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
if rank == 0:
checkpoint = {
'model_state_dict': state_dict,
'optimizer_state_dict': optim_state,
}
torch.save(checkpoint, path)
print(f"Full checkpoint saved: {path}")
def save_fsdp_sharded_state(model, optimizer, checkpoint_dir):
"""FSDP: シャード状態として保存(各GPUが自分の部分を保存)"""
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
'model': model.state_dict(),
'optimizer': FSDP.optim_state_dict(model, optimizer),
}
save_state_dict(state_dict, checkpoint_dir=checkpoint_dir)
print(f"Sharded checkpoint saved: {checkpoint_dir}")
def load_fsdp_sharded_state(model, optimizer, checkpoint_dir):
"""FSDP: シャード状態を読み込み"""
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
'model': model.state_dict(),
'optimizer': {},
}
load_state_dict(state_dict, checkpoint_dir=checkpoint_dir)
model.load_state_dict(state_dict['model'])
# オプティマイザの復元
optim_state = FSDP.optim_state_dict_to_load(
model, optimizer, state_dict['optimizer']
)
optimizer.load_state_dict(optim_state)
AMPでのチェックポイント
from torch.cuda.amp import GradScaler
def save_amp_checkpoint(model, optimizer, scaler, epoch, path):
"""AMPのチェックポイント保存"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
}
torch.save(checkpoint, path)
def load_amp_checkpoint(path, model, optimizer, scaler):
"""AMPのチェックポイント読み込み"""
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
return checkpoint
# 使用例
scaler = GradScaler()
# ... 訓練後 ...
save_amp_checkpoint(model, optimizer, scaler, epoch, 'amp_checkpoint.pt')
ベストプラクティス
1. 定期的な保存戦略
class CheckpointStrategy:
"""チェックポイント保存戦略"""
def __init__(self, checkpoint_dir, save_every_n_epochs=10,
save_every_n_steps=None, max_keep=5):
self.checkpoint_dir = checkpoint_dir
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_steps = save_every_n_steps
self.max_keep = max_keep
self.best_metric = float('inf')
os.makedirs(checkpoint_dir, exist_ok=True)
def should_save(self, epoch, step=None, metric=None):
"""保存すべきかを判定"""
reasons = []
# エポックベース
if self.save_every_n_epochs and epoch % self.save_every_n_epochs == 0:
reasons.append('periodic_epoch')
# ステップベース
if self.save_every_n_steps and step and step % self.save_every_n_steps == 0:
reasons.append('periodic_step')
# 最良モデル
if metric is not None and metric < self.best_metric:
self.best_metric = metric
reasons.append('best_model')
return reasons
2. アトミックな保存
import tempfile
import shutil
def save_checkpoint_atomic(state, path):
"""アトミックなチェックポイント保存(中間状態での破損を防ぐ)"""
# 一時ファイルに保存
dir_name = os.path.dirname(path)
with tempfile.NamedTemporaryFile(dir=dir_name, delete=False) as tmp:
torch.save(state, tmp.name)
tmp_path = tmp.name
# アトミックにリネーム
shutil.move(tmp_path, path)
3. チェックポイントの検証
def validate_checkpoint(path, model_class):
"""チェックポイントの整合性を検証"""
try:
checkpoint = torch.load(path, map_location='cpu')
# 必須キーの確認
required_keys = ['model_state_dict']
for key in required_keys:
if key not in checkpoint:
return False, f"Missing key: {key}"
# モデルとの互換性確認
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
return True, "Checkpoint is valid"
except Exception as e:
return False, f"Error: {str(e)}"
# 使用例
is_valid, message = validate_checkpoint('checkpoint.pt', SimpleModel)
print(f"Validation: {message}")
4. 軽量チェックポイント
def save_lightweight_checkpoint(model, path):
"""推論用の軽量チェックポイント(パラメータのみ)"""
torch.save(model.state_dict(), path)
def save_full_model(model, path):
"""モデル構造ごと保存(非推奨だが便利な場合も)"""
torch.save(model, path)
def export_to_torchscript(model, example_input, path):
"""TorchScriptにエクスポート"""
model.eval()
traced = torch.jit.trace(model, example_input)
traced.save(path)
まとめ
本記事では、チェックポイント管理について解説しました。
- 基本: model、optimizer、scheduler、epoch等を保存
- 完全復元: 乱数状態も含めて保存すると完全再現可能
- DDP: rank 0のみが保存、barrierで同期
- FSDP: Full StateとSharded Stateの2つの方式
- AMP: GradScalerの状態も保存が必要
チェックポイント管理は地味ですが、長時間の訓練を安全に実行するために不可欠な技術です。
次のステップとして、以下の記事も参考にしてください。