【PyTorch】分散学習(DDP/FSDP)の仕組みと実装

分散学習は、複数のGPUやノードを使ってモデルの訓練を並列化する技術です。大規模言語モデル(LLM)の訓練には不可欠であり、Data Parallel、DDP、FSDPなど様々な手法が存在します。

本記事では、各手法の理論と特徴、PyTorchでの実装を解説します。

本記事の内容

  • 分散学習の基本概念
  • Data ParallelとDDP
  • Fully Sharded Data Parallel (FSDP)
  • PyTorchでの実装
  • 大規模モデル訓練のテクニック

分散学習の基本概念

並列化の種類

手法 分割対象 特徴
Data Parallel データ 各GPUで同じモデル、異なるデータ
Model Parallel モデル モデルの異なる層を異なるGPUに配置
Pipeline Parallel モデル 層をパイプライン的に処理
Tensor Parallel モデル 単一の層を複数GPUに分割

Data Parallelの原理

$N$ 個のGPUで訓練する場合、バッチサイズ $B$ を各GPUに $B/N$ ずつ分配します。

各GPU $i$ で: 1. ローカルバッチで順伝播・逆伝播 2. 勾配を計算:$\bm{g}_i = \nabla_{\bm{\theta}} \mathcal{L}_i$ 3. 全GPUで勾配を集約:$\bm{g} = \frac{1}{N} \sum_{i=1}^{N} \bm{g}_i$ 4. パラメータを更新

通信パターン

All-Reduce

全GPUの勾配を合計し、結果を全GPUに配布:

$$ \bm{g}_{\text{all}} = \sum_{i=1}^{N} \bm{g}_i \xrightarrow{\text{broadcast}} \text{all GPUs} $$

Ring All-Reduce

効率的なAll-Reduce実装。$N$ GPUでリング状に勾配を集約: – 通信量:$2(N-1)/N \times \text{data size}$ – 帯域効率が最適

DataParallel vs DistributedDataParallel

DataParallel (DP)

PyTorchの nn.DataParallel は最もシンプルな並列化手法です。

アーキテクチャ

GPU 0 (Master): Model copy + Gradient aggregation
GPU 1-N: Model copies

Forward: GPU 0 → scatter data → all GPUs → gather outputs → GPU 0
Backward: GPU 0 → scatter grads → all GPUs → reduce grads → GPU 0

問題点: – GPU 0がボトルネック(メモリ、通信) – GILによる並列化の制限 – 効率が低い(特にGPU数が多い場合)

DistributedDataParallel (DDP)

torch.nn.parallel.DistributedDataParallel はより効率的な実装です。

アーキテクチャ

Process 0 (GPU 0): Model copy
Process 1 (GPU 1): Model copy
...
Process N-1 (GPU N-1): Model copy

Each process: Forward → Backward → All-Reduce gradients

利点: – マスターGPUのボトルネックなし – マルチプロセス(GILを回避) – 勾配の集約と逆伝播を重ねられる(オーバーラップ)

性能比較

項目 DataParallel DDP
GPU効率 低い(GPU 0負荷大) 高い(均等分散)
通信 GPU 0経由 直接All-Reduce
GIL 制限あり なし(マルチプロセス)
スケーラビリティ 低い 高い

Fully Sharded Data Parallel (FSDP)

理論

FSDPは、モデルパラメータ自体も分割(シャード)することで、メモリ効率を大幅に改善します。

ZeRO(Zero Redundancy Optimizer)の概念

Stage シャード対象 メモリ削減
ZeRO-1 Optimizer states ~4x
ZeRO-2 + Gradients ~8x
ZeRO-3 + Parameters ~$N$x(GPU数)

FSDPはZeRO-3相当の機能を提供します。

動作原理

各GPU: パラメータの1/N、勾配の1/N、Optimizer状態の1/N

Forward pass:
  for each layer:
    All-Gather: 全パラメータを一時的に収集
    Compute forward
    Discard gathered parameters

Backward pass:
  for each layer (reverse):
    All-Gather: 全パラメータを収集
    Compute backward
    Reduce-Scatter: 勾配を分散
    Discard gathered parameters

Update:
  各GPUは自分のシャードのみを更新

メモリ分析

$P$ パラメータ、$N$ GPU、FP16訓練の場合:

項目 DDP FSDP
パラメータ $2P$ $2P/N$
勾配 $2P$ $2P/N$
Optimizer (Adam) $8P$ $8P/N$
合計 $12P$ $12P/N$

例:7B パラメータモデル – DDP: 各GPUに84GB必要 – FSDP (8 GPU): 各GPUに10.5GB

PyTorchでの実装

DDPの実装

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp

def setup(rank, world_size):
    """分散環境の初期化"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    """分散環境のクリーンアップ"""
    dist.destroy_process_group()


class SimpleModel(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


def train_ddp(rank, world_size, epochs=5):
    """DDPを使った訓練"""
    setup(rank, world_size)

    # モデル、オプティマイザの設定
    model = SimpleModel().to(rank)
    model = DDP(model, device_ids=[rank])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # ダミーデータセット
    dataset = torch.utils.data.TensorDataset(
        torch.randn(1000, 784),
        torch.randint(0, 10, (1000,))
    )
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # シャッフルのために必要
        total_loss = 0

        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(rank), target.to(rank)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if rank == 0:
            print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")

    cleanup()


def main():
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print("DDPには2つ以上のGPUが必要です")
        return

    mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)


# 実行
# if __name__ == "__main__":
#     main()

FSDPの実装

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy,
)
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
import torch.multiprocessing as mp
from functools import partial


class TransformerBlock(nn.Module):
    """Transformerブロック"""
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Self-Attention
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)

        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x


class SimpleTransformer(nn.Module):
    """シンプルなTransformerモデル"""
    def __init__(self, vocab_size=10000, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        return self.fc_out(x)


def train_fsdp(rank, world_size, epochs=5):
    """FSDPを使った訓練"""
    # 分散環境の初期化
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # モデルの作成
    model = SimpleTransformer()

    # FSDPのラップポリシー
    # TransformerBlockごとにシャーディング
    auto_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    # 混合精度設定
    mixed_precision_policy = MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16,
    )

    # FSDPでラップ
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mixed_precision_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=rank,
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)

    # ダミーデータ
    for epoch in range(epochs):
        total_loss = 0
        n_batches = 10

        for _ in range(n_batches):
            # ランダムな入力
            x = torch.randint(0, 10000, (8, 128)).to(rank)
            y = torch.randint(0, 10000, (8, 128)).to(rank)

            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output.view(-1, 10000), y.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if rank == 0:
            print(f"Epoch {epoch+1}: Loss = {total_loss / n_batches:.4f}")

    dist.destroy_process_group()


def main_fsdp():
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print("FSDPには2つ以上のGPUが必要です")
        return

    mp.spawn(train_fsdp, args=(world_size,), nprocs=world_size, join=True)

torchrunを使った実行

# 単一ノード、4 GPU
torchrun --nproc_per_node=4 train.py

# マルチノード(2ノード、各4 GPU)
# ノード1
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=<master_ip> --master_port=12355 train.py

# ノード2
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
    --master_addr=<master_ip> --master_port=12355 train.py

torchrunに対応したスクリプト

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def main():
    # 環境変数から情報を取得
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))

    # 分散環境の初期化
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)

    # モデルの作成とDDPラップ
    model = SimpleModel().to(local_rank)
    model = DDP(model, device_ids=[local_rank])

    # 訓練ループ
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # ダミーデータ
    dataset = torch.utils.data.TensorDataset(
        torch.randn(1000, 784),
        torch.randint(0, 10, (1000,))
    )
    sampler = DistributedSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

    for epoch in range(5):
        sampler.set_epoch(epoch)
        total_loss = 0

        for data, target in dataloader:
            data, target = data.to(local_rank), target.to(local_rank)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # rank 0のみログ出力
        if rank == 0:
            print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

大規模モデル訓練のテクニック

勾配蓄積

実効バッチサイズを大きくしながらメモリを節約:

accumulation_steps = 4
optimizer.zero_grad()

for i, (data, target) in enumerate(dataloader):
    output = model(data)
    loss = criterion(output, target) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

勾配チェックポイント

中間活性化を再計算してメモリを節約:

from torch.utils.checkpoint import checkpoint

class CheckpointedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 1024)

    def forward(self, x):
        # チェックポイント:順伝播時に活性化を保存せず、逆伝播時に再計算
        x = checkpoint(self.layer1, x, use_reentrant=False)
        x = checkpoint(self.layer2, x, use_reentrant=False)
        x = self.layer3(x)
        return x

混合精度 + DDP

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data, target in dataloader:
    optimizer.zero_grad()

    with autocast():
        output = model(data)
        loss = criterion(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

分散学習のデバッグ

def debug_distributed_info():
    """分散環境の情報を出力"""
    if dist.is_initialized():
        print(f"Rank: {dist.get_rank()}")
        print(f"World Size: {dist.get_world_size()}")
        print(f"Backend: {dist.get_backend()}")
        print(f"Local Rank: {os.environ.get('LOCAL_RANK', 'N/A')}")

        # CUDA情報
        if torch.cuda.is_available():
            print(f"CUDA Device: {torch.cuda.current_device()}")
            print(f"Device Name: {torch.cuda.get_device_name()}")
            print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    else:
        print("Distributed not initialized")


# 同期のテスト
def test_all_reduce():
    """All-Reduceのテスト"""
    tensor = torch.tensor([dist.get_rank()], device='cuda')
    print(f"Before All-Reduce (rank {dist.get_rank()}): {tensor}")

    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f"After All-Reduce (rank {dist.get_rank()}): {tensor}")

まとめ

本記事では、分散学習について解説しました。

  • Data Parallel: シンプルだが効率が低い
  • DDP: マルチプロセスで効率的な並列化
  • FSDP: パラメータもシャードしてメモリ効率を最大化
  • 通信: All-Reduce、Ring All-Reduceで勾配を集約
  • テクニック: 勾配蓄積、勾配チェックポイント、混合精度

大規模モデルの訓練には、FSDPと混合精度、勾配チェックポイントの組み合わせが効果的です。

次のステップとして、以下の記事も参考にしてください。