チェックポイント管理のベストプラクティス

チェックポイント管理は、深層学習の訓練において重要な実践的スキルです。訓練中のモデル状態を保存し、障害からの復旧や最良モデルの保持を可能にします。

本記事では、PyTorchでのチェックポイント管理の基本から、分散学習での実装まで解説します。

本記事の内容

  • チェックポイントの基本概念
  • 保存すべき情報と形式
  • 訓練の再開と復元
  • 分散学習でのチェックポイント
  • ベストプラクティス

チェックポイントの基本概念

なぜチェックポイントが必要か

  1. 障害復旧: GPU故障やOOMで訓練が中断しても再開可能
  2. 最良モデルの保持: 検証性能が最良の時点のモデルを保存
  3. 長時間訓練の管理: 定期的に状態を保存して進捗を記録
  4. 実験の再現性: 訓練状態を完全に復元して再現実験

保存すべき情報

完全な訓練状態の復元には以下が必要です:

項目 内容 必須度
model_state_dict モデルのパラメータ 必須
optimizer_state_dict オプティマイザの状態(momentum等) 訓練再開に必須
scheduler_state_dict 学習率スケジューラの状態 訓練再開に必須
epoch 現在のエポック数 推奨
loss 訓練/検証損失 推奨
rng_state 乱数生成器の状態 完全再現に必要
scaler_state_dict GradScalerの状態(AMP使用時) AMP使用時必須

PyTorchでの基本実装

基本的なチェックポイント保存

import torch
import torch.nn as nn
import torch.optim as optim
import os
from datetime import datetime

def save_checkpoint(state, filename):
    """チェックポイントを保存"""
    torch.save(state, filename)
    print(f"Checkpoint saved: {filename}")


def load_checkpoint(filename, model, optimizer=None, scheduler=None):
    """チェックポイントを読み込み"""
    if not os.path.exists(filename):
        print(f"Checkpoint not found: {filename}")
        return None

    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print(f"Checkpoint loaded: {filename}")
    return checkpoint


# 使用例
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 保存
checkpoint = {
    'epoch': 50,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_loss': 0.15,
    'val_loss': 0.18,
}
save_checkpoint(checkpoint, 'checkpoint_epoch50.pt')

# 読み込み
loaded = load_checkpoint('checkpoint_epoch50.pt', model, optimizer, scheduler)
start_epoch = loaded['epoch'] + 1 if loaded else 0

完全な訓練ループ

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

class CheckpointManager:
    """チェックポイント管理クラス"""

    def __init__(self, checkpoint_dir='checkpoints', max_keep=5):
        self.checkpoint_dir = checkpoint_dir
        self.max_keep = max_keep
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoints = []

    def save(self, state, epoch, is_best=False):
        """チェックポイントを保存"""
        # 通常のチェックポイント
        filename = os.path.join(self.checkpoint_dir, f'checkpoint_epoch{epoch:04d}.pt')
        torch.save(state, filename)
        self.checkpoints.append(filename)

        # 最良モデルの保存
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(state, best_path)
            print(f"Best model saved at epoch {epoch}")

        # 古いチェックポイントを削除
        while len(self.checkpoints) > self.max_keep:
            old_ckpt = self.checkpoints.pop(0)
            if os.path.exists(old_ckpt) and 'best' not in old_ckpt:
                os.remove(old_ckpt)

        print(f"Checkpoint saved: {filename}")

    def load_latest(self, model, optimizer=None, scheduler=None):
        """最新のチェックポイントを読み込み"""
        # チェックポイントファイルを探す
        ckpt_files = sorted([
            f for f in os.listdir(self.checkpoint_dir)
            if f.startswith('checkpoint_epoch') and f.endswith('.pt')
        ])

        if not ckpt_files:
            print("No checkpoint found")
            return None

        latest = os.path.join(self.checkpoint_dir, ckpt_files[-1])
        return self.load(latest, model, optimizer, scheduler)

    def load_best(self, model):
        """最良のチェックポイントを読み込み"""
        best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
        if not os.path.exists(best_path):
            print("Best checkpoint not found")
            return None

        checkpoint = torch.load(best_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Best model loaded (epoch {checkpoint.get('epoch', 'N/A')})")
        return checkpoint

    def load(self, path, model, optimizer=None, scheduler=None):
        """指定されたチェックポイントを読み込み"""
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])

        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print(f"Loaded checkpoint: {path}")
        return checkpoint


def train_with_checkpointing(model, train_loader, val_loader, epochs=100,
                             checkpoint_dir='checkpoints', resume=False):
    """チェックポイント付きの訓練"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)

    ckpt_manager = CheckpointManager(checkpoint_dir)
    best_val_loss = float('inf')
    start_epoch = 0

    # 訓練再開
    if resume:
        checkpoint = ckpt_manager.load_latest(model, optimizer, scheduler)
        if checkpoint:
            start_epoch = checkpoint['epoch'] + 1
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            print(f"Resuming from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs):
        # 訓練
        model.train()
        train_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # 検証
        model.eval()
        val_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += criterion(output, target).item()
                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()

        val_loss /= len(val_loader)
        accuracy = correct / len(val_loader.dataset)

        scheduler.step(val_loss)

        # チェックポイント保存
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss

        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'best_val_loss': best_val_loss,
            'accuracy': accuracy,
        }

        # 10エポックごと、または最良モデル時に保存
        if epoch % 10 == 0 or is_best:
            ckpt_manager.save(state, epoch, is_best=is_best)

        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, "
              f"val_loss={val_loss:.4f}, accuracy={accuracy:.4f}")


# ダミーデータ
train_data = TensorDataset(torch.randn(1000, 784), torch.randint(0, 10, (1000,)))
val_data = TensorDataset(torch.randn(200, 784), torch.randint(0, 10, (200,)))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)

# 訓練
model = SimpleModel()
# train_with_checkpointing(model, train_loader, val_loader, epochs=100, resume=True)

乱数状態の完全保存

import random
import numpy as np
import torch

def save_rng_states():
    """乱数生成器の状態を保存"""
    return {
        'python': random.getstate(),
        'numpy': np.random.get_state(),
        'torch': torch.get_rng_state(),
        'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    }


def load_rng_states(rng_states):
    """乱数生成器の状態を復元"""
    random.setstate(rng_states['python'])
    np.random.set_state(rng_states['numpy'])
    torch.set_rng_state(rng_states['torch'])
    if rng_states['cuda'] is not None and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(rng_states['cuda'])


# 使用例
checkpoint = {
    'model_state_dict': model.state_dict(),
    'rng_states': save_rng_states(),
    # ...
}
torch.save(checkpoint, 'checkpoint_with_rng.pt')

# 復元
loaded = torch.load('checkpoint_with_rng.pt')
load_rng_states(loaded['rng_states'])

分散学習でのチェックポイント

DDPでのチェックポイント

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def save_ddp_checkpoint(model, optimizer, epoch, path, rank):
    """DDPモデルのチェックポイント保存"""
    # rank 0のみが保存
    if rank != 0:
        return

    # DDPモデルの場合、.moduleでアクセス
    model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state,
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")


def load_ddp_checkpoint(path, model, optimizer, map_location):
    """DDPモデルのチェックポイント読み込み"""
    checkpoint = torch.load(path, map_location=map_location)

    # DDPモデルの場合、.moduleにロード
    if hasattr(model, 'module'):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint


# 使用例(DDPの訓練ループ内)
def train_ddp_with_checkpoint(rank, world_size, epochs=100):
    # ... 初期化 ...

    model = SimpleModel().to(rank)
    model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters())

    for epoch in range(epochs):
        # ... 訓練 ...

        # チェックポイント保存(rank 0のみ)
        if epoch % 10 == 0:
            save_ddp_checkpoint(model, optimizer, epoch, f'checkpoint_{epoch}.pt', rank)

        # 全プロセスで同期(保存完了を待つ)
        dist.barrier()

FSDPでのチェックポイント

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,
    ShardedStateDictConfig,
)
from torch.distributed.checkpoint import (
    save_state_dict,
    load_state_dict,
)

def save_fsdp_full_state(model, optimizer, path, rank):
    """FSDP: 完全な状態辞書として保存(rank 0のみ)"""
    # 全パラメータをrank 0に集約
    full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
        state_dict = model.state_dict()
        optim_state = FSDP.optim_state_dict(model, optimizer)

        if rank == 0:
            checkpoint = {
                'model_state_dict': state_dict,
                'optimizer_state_dict': optim_state,
            }
            torch.save(checkpoint, path)
            print(f"Full checkpoint saved: {path}")


def save_fsdp_sharded_state(model, optimizer, checkpoint_dir):
    """FSDP: シャード状態として保存(各GPUが自分の部分を保存)"""
    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
        state_dict = {
            'model': model.state_dict(),
            'optimizer': FSDP.optim_state_dict(model, optimizer),
        }
        save_state_dict(state_dict, checkpoint_dir=checkpoint_dir)
        print(f"Sharded checkpoint saved: {checkpoint_dir}")


def load_fsdp_sharded_state(model, optimizer, checkpoint_dir):
    """FSDP: シャード状態を読み込み"""
    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
        state_dict = {
            'model': model.state_dict(),
            'optimizer': {},
        }
        load_state_dict(state_dict, checkpoint_dir=checkpoint_dir)
        model.load_state_dict(state_dict['model'])
        # オプティマイザの復元
        optim_state = FSDP.optim_state_dict_to_load(
            model, optimizer, state_dict['optimizer']
        )
        optimizer.load_state_dict(optim_state)

AMPでのチェックポイント

from torch.cuda.amp import GradScaler

def save_amp_checkpoint(model, optimizer, scaler, epoch, path):
    """AMPのチェックポイント保存"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
    }
    torch.save(checkpoint, path)


def load_amp_checkpoint(path, model, optimizer, scaler):
    """AMPのチェックポイント読み込み"""
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    return checkpoint


# 使用例
scaler = GradScaler()
# ... 訓練後 ...
save_amp_checkpoint(model, optimizer, scaler, epoch, 'amp_checkpoint.pt')

ベストプラクティス

1. 定期的な保存戦略

class CheckpointStrategy:
    """チェックポイント保存戦略"""

    def __init__(self, checkpoint_dir, save_every_n_epochs=10,
                 save_every_n_steps=None, max_keep=5):
        self.checkpoint_dir = checkpoint_dir
        self.save_every_n_epochs = save_every_n_epochs
        self.save_every_n_steps = save_every_n_steps
        self.max_keep = max_keep
        self.best_metric = float('inf')
        os.makedirs(checkpoint_dir, exist_ok=True)

    def should_save(self, epoch, step=None, metric=None):
        """保存すべきかを判定"""
        reasons = []

        # エポックベース
        if self.save_every_n_epochs and epoch % self.save_every_n_epochs == 0:
            reasons.append('periodic_epoch')

        # ステップベース
        if self.save_every_n_steps and step and step % self.save_every_n_steps == 0:
            reasons.append('periodic_step')

        # 最良モデル
        if metric is not None and metric < self.best_metric:
            self.best_metric = metric
            reasons.append('best_model')

        return reasons

2. アトミックな保存

import tempfile
import shutil

def save_checkpoint_atomic(state, path):
    """アトミックなチェックポイント保存(中間状態での破損を防ぐ)"""
    # 一時ファイルに保存
    dir_name = os.path.dirname(path)
    with tempfile.NamedTemporaryFile(dir=dir_name, delete=False) as tmp:
        torch.save(state, tmp.name)
        tmp_path = tmp.name

    # アトミックにリネーム
    shutil.move(tmp_path, path)

3. チェックポイントの検証

def validate_checkpoint(path, model_class):
    """チェックポイントの整合性を検証"""
    try:
        checkpoint = torch.load(path, map_location='cpu')

        # 必須キーの確認
        required_keys = ['model_state_dict']
        for key in required_keys:
            if key not in checkpoint:
                return False, f"Missing key: {key}"

        # モデルとの互換性確認
        model = model_class()
        model.load_state_dict(checkpoint['model_state_dict'])

        return True, "Checkpoint is valid"

    except Exception as e:
        return False, f"Error: {str(e)}"


# 使用例
is_valid, message = validate_checkpoint('checkpoint.pt', SimpleModel)
print(f"Validation: {message}")

4. 軽量チェックポイント

def save_lightweight_checkpoint(model, path):
    """推論用の軽量チェックポイント(パラメータのみ)"""
    torch.save(model.state_dict(), path)


def save_full_model(model, path):
    """モデル構造ごと保存(非推奨だが便利な場合も)"""
    torch.save(model, path)


def export_to_torchscript(model, example_input, path):
    """TorchScriptにエクスポート"""
    model.eval()
    traced = torch.jit.trace(model, example_input)
    traced.save(path)

まとめ

本記事では、チェックポイント管理について解説しました。

  • 基本: model、optimizer、scheduler、epoch等を保存
  • 完全復元: 乱数状態も含めて保存すると完全再現可能
  • DDP: rank 0のみが保存、barrierで同期
  • FSDP: Full StateとSharded Stateの2つの方式
  • AMP: GradScalerの状態も保存が必要

チェックポイント管理は地味ですが、長時間の訓練を安全に実行するために不可欠な技術です。

次のステップとして、以下の記事も参考にしてください。