Early Stoppingの理論 — L2正則化との等価性を導出する

早期終了(Early Stopping)は、検証誤差の推移を監視し、過学習が始まる前に訓練を停止する正則化テクニックです。実装が簡単でありながら、過学習を効果的に防ぐことができます。

本記事では、早期終了の理論的背景、様々な戦略、そしてPyTorchでの実装を解説します。

本記事の内容

  • 早期終了の基本概念
  • 正則化としての理論的解釈
  • patience戦略の設計
  • PyTorchでの実装
  • 実験と可視化

早期終了とは

基本概念

深層学習では、訓練を続けると訓練誤差は減少し続けますが、ある時点から検証誤差は増加に転じます。この現象が過学習です。

早期終了は、検証誤差が改善しなくなったら訓練を停止する手法です。

$$ \text{停止条件}: \mathcal{L}_{\text{val}}^{(t)} > \mathcal{L}_{\text{val}}^{(t-1)} > \cdots > \mathcal{L}_{\text{val}}^{(t-p)} $$

ここで、$p$ はpatience(我慢回数)と呼ばれるハイパーパラメータです。

アルゴリズム

1. best_val_loss = infinity
2. counter = 0
3. for each epoch:
     train(model)
     val_loss = evaluate(model)

     if val_loss < best_val_loss:
       best_val_loss = val_loss
       save(model)
       counter = 0
     else:
       counter += 1

     if counter >= patience:
       load(best_model)
       break

正則化としての理論的解釈

暗黙的な正則化

早期終了は、明示的な正則化項を加えなくても、正則化と同等の効果を持ちます。

線形回帰を勾配降下法で解く場合を考えます:

$$ \bm{\theta}^{(t+1)} = \bm{\theta}^{(t)} – \eta \nabla_{\bm{\theta}} \mathcal{L}(\bm{\theta}^{(t)}) $$

$t$ ステップ後のパラメータは:

$$ \bm{\theta}^{(t)} = \sum_{k=0}^{t-1} \eta (\bm{I} – \eta \bm{X}^\top \bm{X})^k \bm{X}^\top \bm{y} $$

L2正則化との等価性

早期終了後のパラメータは、ある $\lambda$ に対するL2正則化解と近似的に等価です。

L2正則化解:

$$ \bm{\theta}_{\text{ridge}} = (\bm{X}^\top \bm{X} + \lambda \bm{I})^{-1} \bm{X}^\top \bm{y} $$

学習率 $\eta$、反復回数 $t$ に対して、暗黙的な正則化係数は近似的に:

$$ \lambda_{\text{eff}} \approx \frac{1}{\eta t} $$

つまり、早く停止するほど($t$ が小さいほど)強い正則化効果が得られます。

スペクトル分解による解釈

特異値分解 $\bm{X} = \bm{U} \bm{\Sigma} \bm{V}^\top$ を用いると、早期終了の効果は特異値のフィルタリングとして解釈できます。

$t$ ステップ後の解は、特異値 $\sigma_i$ に対するフィルタ係数:

$$ f_i^{(t)} = 1 – (1 – \eta \sigma_i^2)^t $$

小さい特異値(ノイズに対応)は抑制され、大きい特異値(信号に対応)は保持されます。

patience戦略

基本戦略

固定patience

最もシンプルな戦略。patience回連続で改善がなければ停止。

$$ \text{stop if } \mathcal{L}_{\text{val}}^{(t)} \geq \mathcal{L}_{\text{val}}^{*} \text{ for } p \text{ consecutive epochs} $$

相対改善threshold

絶対的な改善だけでなく、相対的な改善率も考慮:

$$ \text{improved} = \mathcal{L}_{\text{val}}^{(t)} < \mathcal{L}_{\text{val}}^{*} \cdot (1 - \delta) $$

ここで、$\delta$ は最小改善率(例:0.001 = 0.1%)。

学習率スケジューリングとの組み合わせ

Reduce on Plateau戦略では、改善がない場合にまず学習率を下げ、それでも改善しなければ停止します。

if no improvement for patience epochs:
    if lr > min_lr:
        lr = lr * factor
        counter = 0
    else:
        stop

PyTorchでの実装

基本的な早期終了

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

class EarlyStopping:
    """早期終了を実装するクラス"""

    def __init__(self, patience=10, min_delta=0, restore_best=True, verbose=True):
        """
        Parameters:
        -----------
        patience : int
            改善がない場合に待つエポック数
        min_delta : float
            改善と見なす最小変化量
        restore_best : bool
            停止時に最良のモデルを復元するか
        verbose : bool
            進捗を表示するか
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.verbose = verbose

        self.best_loss = float('inf')
        self.counter = 0
        self.best_epoch = 0
        self.best_state = None

    def __call__(self, val_loss, model, epoch):
        """
        検証損失を監視し、早期終了の判定を行う

        Returns:
        --------
        should_stop : bool
            訓練を停止すべきかどうか
        """
        if val_loss < self.best_loss - self.min_delta:
            # 改善があった
            self.best_loss = val_loss
            self.counter = 0
            self.best_epoch = epoch
            self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

            if self.verbose:
                print(f"  -> New best: {val_loss:.6f}")

            return False

        else:
            # 改善がない
            self.counter += 1

            if self.verbose:
                print(f"  -> No improvement: {self.counter}/{self.patience}")

            if self.counter >= self.patience:
                if self.restore_best and self.best_state is not None:
                    model.load_state_dict(self.best_state)
                    if self.verbose:
                        print(f"Restored best model from epoch {self.best_epoch}")
                return True

            return False


# シンプルなMLPモデル
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, output_dim=1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)


# データ生成(過学習しやすい設定)
np.random.seed(42)
torch.manual_seed(42)

n_samples = 200
n_features = 50

X = np.random.randn(n_samples, n_features)
y = X[:, 0] + 0.5 * X[:, 1] - 0.3 * X[:, 2] + np.random.randn(n_samples) * 0.5

# 訓練・検証分割
split = int(0.8 * n_samples)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train).unsqueeze(1)
X_val = torch.FloatTensor(X_val)
y_val = torch.FloatTensor(y_val).unsqueeze(1)


def train_with_early_stopping(patience=10, max_epochs=500):
    """早期終了を使った訓練"""
    model = MLP(n_features)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    early_stopping = EarlyStopping(patience=patience, min_delta=1e-4, verbose=False)

    train_losses = []
    val_losses = []

    for epoch in range(max_epochs):
        # 訓練
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

        # 検証
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val)
            val_losses.append(val_loss.item())

        # 早期終了チェック
        if early_stopping(val_loss.item(), model, epoch):
            print(f"Early stopping at epoch {epoch}")
            break

    return train_losses, val_losses, early_stopping.best_epoch


def train_without_early_stopping(max_epochs=500):
    """早期終了なしの訓練"""
    model = MLP(n_features)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    train_losses = []
    val_losses = []

    for epoch in range(max_epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val)
            val_losses.append(val_loss.item())

    return train_losses, val_losses


# 両方の方法で訓練
train_es, val_es, best_epoch = train_with_early_stopping(patience=20)
train_no_es, val_no_es = train_without_early_stopping(max_epochs=200)

# 可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 早期終了あり
ax = axes[0]
ax.plot(train_es, label='Train Loss', color='blue')
ax.plot(val_es, label='Validation Loss', color='red')
ax.axvline(x=best_epoch, color='green', linestyle='--', label=f'Best Epoch ({best_epoch})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('With Early Stopping')
ax.legend()
ax.grid(True, alpha=0.3)

# 早期終了なし
ax = axes[1]
ax.plot(train_no_es, label='Train Loss', color='blue')
ax.plot(val_no_es, label='Validation Loss', color='red')
best_no_es = np.argmin(val_no_es)
ax.axvline(x=best_no_es, color='green', linestyle='--', label=f'Best Epoch ({best_no_es})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Without Early Stopping')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('early_stopping_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n早期終了あり: 停止エポック={len(train_es)}, 最良検証損失={min(val_es):.4f}")
print(f"早期終了なし: 最良検証損失={min(val_no_es):.4f} (エポック{best_no_es})")

学習率スケジューリングとの組み合わせ

class EarlyStoppingWithScheduler:
    """学習率スケジューリングと組み合わせた早期終了"""

    def __init__(self, patience=10, min_delta=0, lr_patience=5,
                 lr_factor=0.5, min_lr=1e-6, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.lr_patience = lr_patience
        self.lr_factor = lr_factor
        self.min_lr = min_lr
        self.verbose = verbose

        self.best_loss = float('inf')
        self.counter = 0
        self.lr_counter = 0
        self.best_state = None

    def __call__(self, val_loss, model, optimizer, epoch):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.lr_counter = 0
            self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            return False

        self.counter += 1
        self.lr_counter += 1

        # 学習率を下げる
        if self.lr_counter >= self.lr_patience:
            current_lr = optimizer.param_groups[0]['lr']
            if current_lr > self.min_lr:
                new_lr = max(current_lr * self.lr_factor, self.min_lr)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lr
                self.lr_counter = 0
                if self.verbose:
                    print(f"  -> Reduced LR to {new_lr:.2e}")

        # 早期終了判定
        if self.counter >= self.patience:
            if self.best_state is not None:
                model.load_state_dict(self.best_state)
            return True

        return False


# 使用例
model = MLP(n_features)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
early_stopping = EarlyStoppingWithScheduler(
    patience=30,
    lr_patience=10,
    lr_factor=0.5,
    min_lr=1e-5,
    verbose=True
)

train_losses = []
val_losses = []
learning_rates = []

for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())
    learning_rates.append(optimizer.param_groups[0]['lr'])

    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, y_val)
        val_losses.append(val_loss.item())

    if epoch % 20 == 0:
        print(f"Epoch {epoch}: train={loss.item():.4f}, val={val_loss.item():.4f}")

    if early_stopping(val_loss.item(), model, optimizer, epoch):
        print(f"Early stopping at epoch {epoch}")
        break

# 学習率の推移を可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.plot(train_losses, label='Train Loss')
ax.plot(val_losses, label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.plot(learning_rates)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('early_stopping_with_scheduler.png', dpi=150, bbox_inches='tight')
plt.show()

PyTorch Lightningでの実装

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset

class LitModel(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=64, lr=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.model = MLP(input_dim, hidden_dim)
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }


# データローダー
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# コールバック
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=20,
    mode='min',
    verbose=True
)

checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    filename='best-{epoch:02d}-{val_loss:.4f}'
)

# 訓練
model = LitModel(n_features)
trainer = pl.Trainer(
    max_epochs=500,
    callbacks=[early_stopping, checkpoint],
    enable_progress_bar=True
)
trainer.fit(model, train_loader, val_loader)

print(f"Best model path: {checkpoint.best_model_path}")
print(f"Best val_loss: {checkpoint.best_model_score:.4f}")

正則化効果の可視化

import numpy as np
import matplotlib.pyplot as plt

# 早期終了の正則化効果を可視化
def visualize_regularization_effect():
    """特異値フィルタリングによる早期終了の効果を可視化"""
    # 特異値の範囲
    sigma = np.linspace(0.1, 2, 100)
    eta = 0.1  # 学習率

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 左:フィルタ係数の推移
    ax = axes[0]
    for t in [10, 50, 100, 500, 1000]:
        f = 1 - (1 - eta * sigma**2)**t
        ax.plot(sigma, f, label=f't = {t}')

    ax.set_xlabel('Singular Value σ')
    ax.set_ylabel('Filter Coefficient f(σ, t)')
    ax.set_title('Early Stopping as Spectral Filtering')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 右:等価L2正則化との比較
    ax = axes[1]
    t_values = np.array([10, 50, 100, 500, 1000])
    lambda_eff = 1 / (eta * t_values)

    # 早期終了フィルタ
    t = 100
    f_early = 1 - (1 - eta * sigma**2)**t
    ax.plot(sigma, f_early, 'b-', linewidth=2, label=f'Early Stopping (t={t})')

    # 等価L2フィルタ
    lam = 1 / (eta * t)
    f_ridge = sigma**2 / (sigma**2 + lam)
    ax.plot(sigma, f_ridge, 'r--', linewidth=2, label=f'Ridge (λ={lam:.3f})')

    ax.set_xlabel('Singular Value σ')
    ax.set_ylabel('Filter Coefficient')
    ax.set_title('Early Stopping ≈ L2 Regularization')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('early_stopping_regularization.png', dpi=150, bbox_inches='tight')
    plt.show()

visualize_regularization_effect()

まとめ

本記事では、早期終了(Early Stopping)について解説しました。

  • 基本概念: 検証誤差が改善しなくなったら訓練を停止
  • 正則化としての解釈: 暗黙的にL2正則化と同等の効果
  • patience戦略: 固定patience、相対改善threshold、学習率スケジューリングとの組み合わせ
  • 実装: PyTorch/PyTorch Lightningでの実装

早期終了は、追加のハイパーパラメータをほとんど導入せずに過学習を防ぐ、実用的で効果的なテクニックです。

次のステップとして、以下の記事も参考にしてください。