勾配クリッピングの理論と実装

勾配クリッピング(Gradient Clipping)は、勾配爆発を防ぐための手法です。特にRNNやTransformerなどの深いネットワークで、訓練を安定させるために広く使われています。

本記事では、勾配爆発の原因から、各種クリッピング手法の理論と実装までを解説します。

本記事の内容

  • 勾配爆発とは何か
  • ノルムクリッピングと値クリッピング
  • 適応的勾配クリッピング(AGC)
  • PyTorchでの実装
  • 実験と可視化

勾配爆発とは

問題の定義

深層ネットワークにおいて、バックプロパゲーションで計算される勾配が指数的に大きくなる現象を勾配爆発と呼びます。

$L$ 層のネットワークを考えます:

$$ \bm{h}_l = f(\bm{W}_l \bm{h}_{l-1}) $$

チェーンルールにより、損失 $\mathcal{L}$ のパラメータ $\bm{W}_1$ に関する勾配は:

$$ \frac{\partial \mathcal{L}}{\partial \bm{W}_1} = \frac{\partial \mathcal{L}}{\partial \bm{h}_L} \cdot \prod_{l=2}^{L} \frac{\partial \bm{h}_l}{\partial \bm{h}_{l-1}} \cdot \frac{\partial \bm{h}_1}{\partial \bm{W}_1} $$

RNNでの勾配爆発

RNNでは、同じ重み行列 $\bm{W}$ が時間ステップごとに繰り返し適用されます:

$$ \bm{h}_t = \tanh(\bm{W}_h \bm{h}_{t-1} + \bm{W}_x \bm{x}_t + \bm{b}) $$

時刻 $t$ から時刻 $k$ への勾配は:

$$ \frac{\partial \bm{h}_t}{\partial \bm{h}_k} = \prod_{i=k+1}^{t} \frac{\partial \bm{h}_i}{\partial \bm{h}_{i-1}} = \prod_{i=k+1}^{t} \text{diag}(\tanh'(\cdot)) \bm{W}_h $$

$\bm{W}_h$ の最大特異値 $\sigma_{\max} > 1$ の場合、$(t – k)$ ステップで勾配は $\sigma_{\max}^{t-k}$ に比例して増大します。

勾配爆発の影響

  1. 数値オーバーフロー: 勾配が infnan になる
  2. 不安定な更新: 大きな勾配でパラメータが発散
  3. 学習の失敗: 損失が発散して学習が崩壊

ノルムクリッピング

理論

勾配全体のノルムを閾値 $\tau$ 以下に制限します。

$$ \tilde{\bm{g}} = \begin{cases} \bm{g} & \text{if } \|\bm{g}\| \leq \tau \\ \tau \cdot \frac{\bm{g}}{\|\bm{g}\|} & \text{if } \|\bm{g}\| > \tau \end{cases} $$

特徴: – 勾配の方向を保持 – ノルムのみをスケーリング – 最も広く使われる手法

グローバルノルム vs パラメータ別ノルム

グローバルノルム(推奨):

全パラメータの勾配を連結したベクトルのノルムを使用:

$$ \|\bm{g}\|_{\text{global}} = \sqrt{\sum_i \|\bm{g}_i\|^2} $$

パラメータ別ノルム

各パラメータの勾配を独立にクリッピング。パラメータ間の相対的な大きさが変わる可能性がある。

L2ノルム vs L1ノルム vs Linfノルム

ノルム 定義 特徴
L2 $\sqrt{\sum_i g_i^2}$ 最も一般的、回転不変
L1 $\sum_i \|g_i\|$ スパース勾配に有効
Linf $\max_i \|g_i\|$ 外れ値に敏感

値クリッピング

理論

各勾配成分を個別に閾値 $[-\tau, \tau]$ にクリッピングします。

$$ \tilde{g}_i = \text{clip}(g_i, -\tau, \tau) = \max(-\tau, \min(\tau, g_i)) $$

特徴: – 勾配の方向が変わる可能性がある – 実装が単純 – 一部のフレームワークでデフォルト

ノルムクリッピングとの比較

観点 ノルムクリッピング 値クリッピング
方向保持 保持する 変わりうる
スケール 一様にスケール 成分ごとに異なる
使用場面 一般的 RNN等の特定ケース

適応的勾配クリッピング(AGC)

理論

NFNet (Brock et al., 2021) で導入された手法です。パラメータのノルムに対する勾配のノルムの比率に基づいてクリッピングします。

$$ \tilde{\bm{g}}_i = \begin{cases} \bm{g}_i & \text{if } \frac{\|\bm{g}_i\|}{\|\bm{W}_i\|} \leq \lambda \\ \lambda \cdot \frac{\|\bm{W}_i\| \cdot \bm{g}_i}{\|\bm{g}_i\|} & \text{otherwise} \end{cases} $$

ここで、$\lambda$ は閾値(例:0.01)です。

利点: – Batch Normalizationなしで深いネットワークを訓練可能 – パラメータスケールに自動適応 – 層ごとに適切なクリッピング強度

PyTorchでの実装

基本的な勾配クリッピング

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# ノルムクリッピング(PyTorch組み込み)
def train_step_with_clip_norm(model, optimizer, criterion, x, y, max_norm=1.0):
    """ノルムクリッピングを使った訓練ステップ"""
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()

    # 勾配ノルムを計算(クリッピング前)
    total_norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

    optimizer.step()

    return loss.item(), total_norm_before.item()


# 値クリッピング(PyTorch組み込み)
def train_step_with_clip_value(model, optimizer, criterion, x, y, clip_value=1.0):
    """値クリッピングを使った訓練ステップ"""
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()

    # クリッピング前の最大勾配値
    max_grad_before = max(p.grad.abs().max().item() for p in model.parameters() if p.grad is not None)

    # 値クリッピング
    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)

    optimizer.step()

    return loss.item(), max_grad_before


# スクラッチ実装
def clip_grad_norm_manual(parameters, max_norm, norm_type=2.0):
    """勾配ノルムクリッピングのスクラッチ実装"""
    parameters = list(filter(lambda p: p.grad is not None, parameters))

    if len(parameters) == 0:
        return torch.tensor(0.0)

    # グローバルノルムを計算
    if norm_type == float('inf'):
        total_norm = max(p.grad.abs().max() for p in parameters)
    else:
        total_norm = torch.norm(
            torch.stack([torch.norm(p.grad, norm_type) for p in parameters]),
            norm_type
        )

    # クリッピング係数
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = min(clip_coef, 1.0)

    # 勾配をスケーリング
    for p in parameters:
        p.grad.mul_(clip_coef)

    return total_norm


# テスト用のRNN
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=3):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out


# 勾配爆発のデモンストレーション
torch.manual_seed(42)
np.random.seed(42)

input_size = 10
hidden_size = 32
output_size = 1
seq_length = 50
batch_size = 16

model = SimpleRNN(input_size, hidden_size, output_size, num_layers=5)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# ダミーデータ
X = torch.randn(batch_size, seq_length, input_size)
y = torch.randn(batch_size, output_size)

# クリッピングなしで勾配を計算
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()

# 勾配ノルムを確認
grad_norms = []
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        grad_norms.append((name, grad_norm))
        print(f"{name}: grad_norm = {grad_norm:.4f}")

print(f"\nTotal grad norm: {sum(n**2 for _, n in grad_norms)**0.5:.4f}")

適応的勾配クリッピング(AGC)

def adaptive_gradient_clipping(parameters, clip_factor=0.01, eps=1e-3):
    """
    適応的勾配クリッピング(AGC)の実装

    Parameters:
    -----------
    parameters : iterable
        モデルパラメータ
    clip_factor : float
        クリッピング係数 λ
    eps : float
        数値安定性のための小さな値
    """
    for p in parameters:
        if p.grad is None:
            continue

        # パラメータノルムと勾配ノルムを計算
        param_norm = p.data.norm()
        grad_norm = p.grad.norm()

        # ノルムが小さすぎる場合はスキップ
        if param_norm < eps or grad_norm < eps:
            continue

        # クリッピング判定
        max_norm = param_norm * clip_factor
        if grad_norm > max_norm:
            p.grad.mul_(max_norm / grad_norm)


# AGCを使った訓練ループ
def train_with_agc(model, train_loader, criterion, optimizer, epochs=10, clip_factor=0.01):
    """AGCを使った訓練"""
    losses = []
    grad_norms = []

    for epoch in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()

            # 勾配ノルムを記録(AGC前)
            total_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
                           if p.grad is not None)**0.5
            grad_norms.append(total_norm)

            # AGCを適用
            adaptive_gradient_clipping(model.parameters(), clip_factor=clip_factor)

            optimizer.step()
            losses.append(loss.item())

    return losses, grad_norms

勾配クリッピングの効果の可視化

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 勾配爆発を起こしやすいRNN
class DeepRNN(nn.Module):
    def __init__(self, input_size=10, hidden_size=64, num_layers=10):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

        # 重みを大きめに初期化(勾配爆発を誘発)
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param, std=1.0)

    def forward(self, x):
        out, _ = self.rnn(x)
        return self.fc(out[:, -1, :])


def train_and_record(model, max_norm=None, clip_value=None, epochs=100):
    """訓練しながら勾配ノルムを記録"""
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    losses = []
    grad_norms = []

    for epoch in range(epochs):
        # ダミーデータ
        x = torch.randn(16, 30, 10)
        y = torch.randn(16, 1)

        optimizer.zero_grad()
        try:
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
        except RuntimeError:
            losses.append(float('nan'))
            grad_norms.append(float('nan'))
            continue

        # 勾配ノルムを記録
        total_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
                        if p.grad is not None)**0.5
        grad_norms.append(total_norm)

        # クリッピング
        if max_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        elif clip_value is not None:
            torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)

        optimizer.step()
        losses.append(loss.item())

    return losses, grad_norms


# 3つの条件で訓練
torch.manual_seed(42)
model_no_clip = DeepRNN()
losses_no_clip, norms_no_clip = train_and_record(model_no_clip, max_norm=None)

torch.manual_seed(42)
model_norm_clip = DeepRNN()
losses_norm_clip, norms_norm_clip = train_and_record(model_norm_clip, max_norm=1.0)

torch.manual_seed(42)
model_value_clip = DeepRNN()
losses_value_clip, norms_value_clip = train_and_record(model_value_clip, clip_value=0.5)

# 可視化
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 損失の推移
ax = axes[0, 0]
ax.plot(losses_no_clip, label='No Clipping', alpha=0.7)
ax.plot(losses_norm_clip, label='Norm Clipping (max=1.0)', alpha=0.7)
ax.plot(losses_value_clip, label='Value Clipping (max=0.5)', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 10)

# 勾配ノルムの推移
ax = axes[0, 1]
ax.plot(norms_no_clip, label='No Clipping', alpha=0.7)
ax.plot(norms_norm_clip, label='Norm Clipping', alpha=0.7)
ax.plot(norms_value_clip, label='Value Clipping', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Gradient Norm')
ax.set_title('Gradient Norm (before clipping)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# クリッピングの効果を図解
ax = axes[1, 0]
theta = np.linspace(0, 2*np.pi, 100)
r_original = 3  # 元の勾配ノルム
r_clipped = 1   # クリッピング後

# 元の勾配(大きい)
ax.arrow(0, 0, r_original*np.cos(np.pi/4), r_original*np.sin(np.pi/4),
         head_width=0.15, head_length=0.1, fc='blue', ec='blue', alpha=0.5,
         label='Original Gradient')

# クリッピング後(方向は同じ、ノルムは制限)
ax.arrow(0, 0, r_clipped*np.cos(np.pi/4), r_clipped*np.sin(np.pi/4),
         head_width=0.15, head_length=0.1, fc='red', ec='red',
         label='Clipped Gradient')

# クリッピング境界(円)
ax.plot(r_clipped*np.cos(theta), r_clipped*np.sin(theta), 'g--', label='Clip Boundary')

ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.set_aspect('equal')
ax.set_xlabel('Gradient Component 1')
ax.set_ylabel('Gradient Component 2')
ax.set_title('Norm Clipping Illustration')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)

# 値クリッピングの図解
ax = axes[1, 1]
g = np.linspace(-5, 5, 100)
clip_val = 1.0
g_clipped = np.clip(g, -clip_val, clip_val)

ax.plot(g, g, 'b-', label='Original', alpha=0.7)
ax.plot(g, g_clipped, 'r-', linewidth=2, label='Value Clipped')
ax.axhline(y=clip_val, color='g', linestyle='--', alpha=0.5)
ax.axhline(y=-clip_val, color='g', linestyle='--', alpha=0.5)
ax.set_xlabel('Original Gradient Value')
ax.set_ylabel('Clipped Gradient Value')
ax.set_title('Value Clipping Function')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('gradient_clipping_effect.png', dpi=150, bbox_inches='tight')
plt.show()

Transformerでの勾配クリッピング

import torch
import torch.nn as nn
import torch.optim as optim

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(512, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_encoder(pos)
        x = self.transformer(x)
        return self.fc(x)


def train_transformer_with_clipping(model, train_loader, epochs=10, max_norm=1.0, lr=1e-4):
    """Transformerの訓練(勾配クリッピング付き)"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    history = {'loss': [], 'grad_norm': [], 'clipped_norm': []}

    for epoch in range(epochs):
        epoch_loss = 0
        epoch_grad_norm = 0
        epoch_clipped_norm = 0
        n_batches = 0

        for x, y in train_loader:
            optimizer.zero_grad()
            output = model(x)
            output = output.view(-1, output.size(-1))
            y = y.view(-1)
            loss = criterion(output, y)
            loss.backward()

            # 勾配ノルム(クリッピング前)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

            # 勾配ノルム(クリッピング後)
            clipped_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
                              if p.grad is not None)**0.5

            optimizer.step()

            epoch_loss += loss.item()
            epoch_grad_norm += grad_norm.item()
            epoch_clipped_norm += clipped_norm
            n_batches += 1

        scheduler.step()

        history['loss'].append(epoch_loss / n_batches)
        history['grad_norm'].append(epoch_grad_norm / n_batches)
        history['clipped_norm'].append(epoch_clipped_norm / n_batches)

        print(f"Epoch {epoch+1}: Loss={epoch_loss/n_batches:.4f}, "
              f"GradNorm={epoch_grad_norm/n_batches:.4f}, "
              f"ClippedNorm={epoch_clipped_norm/n_batches:.4f}")

    return history


# ダミーデータローダー
from torch.utils.data import DataLoader, TensorDataset

vocab_size = 1000
seq_len = 32
batch_size = 16
n_samples = 160

X = torch.randint(0, vocab_size, (n_samples, seq_len))
y = torch.randint(0, vocab_size, (n_samples, seq_len))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 訓練
model = TransformerModel(vocab_size)
history = train_transformer_with_clipping(model, train_loader, epochs=20, max_norm=1.0)

クリッピング閾値の選択

import numpy as np
import matplotlib.pyplot as plt

def analyze_gradient_distribution(model, train_loader, n_batches=50):
    """勾配の分布を分析してクリッピング閾値を決定"""
    criterion = nn.MSELoss()
    grad_norms = []

    for i, (x, y) in enumerate(train_loader):
        if i >= n_batches:
            break

        model.zero_grad()
        output = model(x)
        loss = criterion(output, y.unsqueeze(1))
        loss.backward()

        total_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
                        if p.grad is not None)**0.5
        grad_norms.append(total_norm)

    grad_norms = np.array(grad_norms)

    # 統計量
    print(f"Gradient Norm Statistics:")
    print(f"  Mean: {grad_norms.mean():.4f}")
    print(f"  Std: {grad_norms.std():.4f}")
    print(f"  Median: {np.median(grad_norms):.4f}")
    print(f"  95th percentile: {np.percentile(grad_norms, 95):.4f}")
    print(f"  99th percentile: {np.percentile(grad_norms, 99):.4f}")

    # 推奨閾値(95パーセンタイル付近)
    recommended_threshold = np.percentile(grad_norms, 95)
    print(f"\nRecommended clip threshold: {recommended_threshold:.4f}")

    # 可視化
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.hist(grad_norms, bins=30, density=True, alpha=0.7, edgecolor='black')
    ax.axvline(x=recommended_threshold, color='r', linestyle='--',
               label=f'Recommended threshold: {recommended_threshold:.2f}')
    ax.set_xlabel('Gradient Norm')
    ax.set_ylabel('Density')
    ax.set_title('Gradient Norm Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.savefig('gradient_norm_distribution.png', dpi=150, bbox_inches='tight')
    plt.show()

    return recommended_threshold

まとめ

本記事では、勾配クリッピングについて解説しました。

  • 勾配爆発: 深いネットワークで勾配が指数的に増大する問題
  • ノルムクリッピング: 勾配の方向を保持しつつノルムを制限(最も一般的)
  • 値クリッピング: 各成分を独立に制限
  • AGC: パラメータノルムに適応したクリッピング

勾配クリッピングは、特にRNNやTransformerの訓練で重要なテクニックです。閾値は経験的に決定することが多いですが、勾配ノルムの分布を分析して適切な値を選ぶことが推奨されます。

次のステップとして、以下の記事も参考にしてください。