早期終了(Early Stopping)は、検証誤差の推移を監視し、過学習が始まる前に訓練を停止する正則化テクニックです。実装が簡単でありながら、過学習を効果的に防ぐことができます。
本記事では、早期終了の理論的背景、様々な戦略、そしてPyTorchでの実装を解説します。
本記事の内容
- 早期終了の基本概念
- 正則化としての理論的解釈
- patience戦略の設計
- PyTorchでの実装
- 実験と可視化
早期終了とは
基本概念
深層学習では、訓練を続けると訓練誤差は減少し続けますが、ある時点から検証誤差は増加に転じます。この現象が過学習です。
早期終了は、検証誤差が改善しなくなったら訓練を停止する手法です。
$$ \text{停止条件}: \mathcal{L}_{\text{val}}^{(t)} > \mathcal{L}_{\text{val}}^{(t-1)} > \cdots > \mathcal{L}_{\text{val}}^{(t-p)} $$
ここで、$p$ はpatience(我慢回数)と呼ばれるハイパーパラメータです。
アルゴリズム
1. best_val_loss = infinity
2. counter = 0
3. for each epoch:
train(model)
val_loss = evaluate(model)
if val_loss < best_val_loss:
best_val_loss = val_loss
save(model)
counter = 0
else:
counter += 1
if counter >= patience:
load(best_model)
break
正則化としての理論的解釈
暗黙的な正則化
早期終了は、明示的な正則化項を加えなくても、正則化と同等の効果を持ちます。
線形回帰を勾配降下法で解く場合を考えます:
$$ \bm{\theta}^{(t+1)} = \bm{\theta}^{(t)} – \eta \nabla_{\bm{\theta}} \mathcal{L}(\bm{\theta}^{(t)}) $$
$t$ ステップ後のパラメータは:
$$ \bm{\theta}^{(t)} = \sum_{k=0}^{t-1} \eta (\bm{I} – \eta \bm{X}^\top \bm{X})^k \bm{X}^\top \bm{y} $$
L2正則化との等価性
早期終了後のパラメータは、ある $\lambda$ に対するL2正則化解と近似的に等価です。
L2正則化解:
$$ \bm{\theta}_{\text{ridge}} = (\bm{X}^\top \bm{X} + \lambda \bm{I})^{-1} \bm{X}^\top \bm{y} $$
学習率 $\eta$、反復回数 $t$ に対して、暗黙的な正則化係数は近似的に:
$$ \lambda_{\text{eff}} \approx \frac{1}{\eta t} $$
つまり、早く停止するほど($t$ が小さいほど)強い正則化効果が得られます。
スペクトル分解による解釈
特異値分解 $\bm{X} = \bm{U} \bm{\Sigma} \bm{V}^\top$ を用いると、早期終了の効果は特異値のフィルタリングとして解釈できます。
$t$ ステップ後の解は、特異値 $\sigma_i$ に対するフィルタ係数:
$$ f_i^{(t)} = 1 – (1 – \eta \sigma_i^2)^t $$
小さい特異値(ノイズに対応)は抑制され、大きい特異値(信号に対応)は保持されます。
patience戦略
基本戦略
固定patience
最もシンプルな戦略。patience回連続で改善がなければ停止。
$$ \text{stop if } \mathcal{L}_{\text{val}}^{(t)} \geq \mathcal{L}_{\text{val}}^{*} \text{ for } p \text{ consecutive epochs} $$
相対改善threshold
絶対的な改善だけでなく、相対的な改善率も考慮:
$$ \text{improved} = \mathcal{L}_{\text{val}}^{(t)} < \mathcal{L}_{\text{val}}^{*} \cdot (1 - \delta) $$
ここで、$\delta$ は最小改善率(例:0.001 = 0.1%)。
学習率スケジューリングとの組み合わせ
Reduce on Plateau戦略では、改善がない場合にまず学習率を下げ、それでも改善しなければ停止します。
if no improvement for patience epochs:
if lr > min_lr:
lr = lr * factor
counter = 0
else:
stop
PyTorchでの実装
基本的な早期終了
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
class EarlyStopping:
"""早期終了を実装するクラス"""
def __init__(self, patience=10, min_delta=0, restore_best=True, verbose=True):
"""
Parameters:
-----------
patience : int
改善がない場合に待つエポック数
min_delta : float
改善と見なす最小変化量
restore_best : bool
停止時に最良のモデルを復元するか
verbose : bool
進捗を表示するか
"""
self.patience = patience
self.min_delta = min_delta
self.restore_best = restore_best
self.verbose = verbose
self.best_loss = float('inf')
self.counter = 0
self.best_epoch = 0
self.best_state = None
def __call__(self, val_loss, model, epoch):
"""
検証損失を監視し、早期終了の判定を行う
Returns:
--------
should_stop : bool
訓練を停止すべきかどうか
"""
if val_loss < self.best_loss - self.min_delta:
# 改善があった
self.best_loss = val_loss
self.counter = 0
self.best_epoch = epoch
self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
if self.verbose:
print(f" -> New best: {val_loss:.6f}")
return False
else:
# 改善がない
self.counter += 1
if self.verbose:
print(f" -> No improvement: {self.counter}/{self.patience}")
if self.counter >= self.patience:
if self.restore_best and self.best_state is not None:
model.load_state_dict(self.best_state)
if self.verbose:
print(f"Restored best model from epoch {self.best_epoch}")
return True
return False
# シンプルなMLPモデル
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim=64, output_dim=1):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = torch.relu(self.fc2(x))
x = self.dropout(x)
return self.fc3(x)
# データ生成(過学習しやすい設定)
np.random.seed(42)
torch.manual_seed(42)
n_samples = 200
n_features = 50
X = np.random.randn(n_samples, n_features)
y = X[:, 0] + 0.5 * X[:, 1] - 0.3 * X[:, 2] + np.random.randn(n_samples) * 0.5
# 訓練・検証分割
split = int(0.8 * n_samples)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train).unsqueeze(1)
X_val = torch.FloatTensor(X_val)
y_val = torch.FloatTensor(y_val).unsqueeze(1)
def train_with_early_stopping(patience=10, max_epochs=500):
"""早期終了を使った訓練"""
model = MLP(n_features)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
early_stopping = EarlyStopping(patience=patience, min_delta=1e-4, verbose=False)
train_losses = []
val_losses = []
for epoch in range(max_epochs):
# 訓練
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# 検証
model.eval()
with torch.no_grad():
val_outputs = model(X_val)
val_loss = criterion(val_outputs, y_val)
val_losses.append(val_loss.item())
# 早期終了チェック
if early_stopping(val_loss.item(), model, epoch):
print(f"Early stopping at epoch {epoch}")
break
return train_losses, val_losses, early_stopping.best_epoch
def train_without_early_stopping(max_epochs=500):
"""早期終了なしの訓練"""
model = MLP(n_features)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
train_losses = []
val_losses = []
for epoch in range(max_epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
model.eval()
with torch.no_grad():
val_outputs = model(X_val)
val_loss = criterion(val_outputs, y_val)
val_losses.append(val_loss.item())
return train_losses, val_losses
# 両方の方法で訓練
train_es, val_es, best_epoch = train_with_early_stopping(patience=20)
train_no_es, val_no_es = train_without_early_stopping(max_epochs=200)
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 早期終了あり
ax = axes[0]
ax.plot(train_es, label='Train Loss', color='blue')
ax.plot(val_es, label='Validation Loss', color='red')
ax.axvline(x=best_epoch, color='green', linestyle='--', label=f'Best Epoch ({best_epoch})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('With Early Stopping')
ax.legend()
ax.grid(True, alpha=0.3)
# 早期終了なし
ax = axes[1]
ax.plot(train_no_es, label='Train Loss', color='blue')
ax.plot(val_no_es, label='Validation Loss', color='red')
best_no_es = np.argmin(val_no_es)
ax.axvline(x=best_no_es, color='green', linestyle='--', label=f'Best Epoch ({best_no_es})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Without Early Stopping')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('early_stopping_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\n早期終了あり: 停止エポック={len(train_es)}, 最良検証損失={min(val_es):.4f}")
print(f"早期終了なし: 最良検証損失={min(val_no_es):.4f} (エポック{best_no_es})")
学習率スケジューリングとの組み合わせ
class EarlyStoppingWithScheduler:
"""学習率スケジューリングと組み合わせた早期終了"""
def __init__(self, patience=10, min_delta=0, lr_patience=5,
lr_factor=0.5, min_lr=1e-6, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.lr_patience = lr_patience
self.lr_factor = lr_factor
self.min_lr = min_lr
self.verbose = verbose
self.best_loss = float('inf')
self.counter = 0
self.lr_counter = 0
self.best_state = None
def __call__(self, val_loss, model, optimizer, epoch):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.lr_counter = 0
self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
return False
self.counter += 1
self.lr_counter += 1
# 学習率を下げる
if self.lr_counter >= self.lr_patience:
current_lr = optimizer.param_groups[0]['lr']
if current_lr > self.min_lr:
new_lr = max(current_lr * self.lr_factor, self.min_lr)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
self.lr_counter = 0
if self.verbose:
print(f" -> Reduced LR to {new_lr:.2e}")
# 早期終了判定
if self.counter >= self.patience:
if self.best_state is not None:
model.load_state_dict(self.best_state)
return True
return False
# 使用例
model = MLP(n_features)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
early_stopping = EarlyStoppingWithScheduler(
patience=30,
lr_patience=10,
lr_factor=0.5,
min_lr=1e-5,
verbose=True
)
train_losses = []
val_losses = []
learning_rates = []
for epoch in range(500):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
learning_rates.append(optimizer.param_groups[0]['lr'])
model.eval()
with torch.no_grad():
val_outputs = model(X_val)
val_loss = criterion(val_outputs, y_val)
val_losses.append(val_loss.item())
if epoch % 20 == 0:
print(f"Epoch {epoch}: train={loss.item():.4f}, val={val_loss.item():.4f}")
if early_stopping(val_loss.item(), model, optimizer, epoch):
print(f"Early stopping at epoch {epoch}")
break
# 学習率の推移を可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
ax.plot(train_losses, label='Train Loss')
ax.plot(val_losses, label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)
ax = axes[1]
ax.plot(learning_rates)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('early_stopping_with_scheduler.png', dpi=150, bbox_inches='tight')
plt.show()
PyTorch Lightningでの実装
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset
class LitModel(pl.LightningModule):
def __init__(self, input_dim, hidden_dim=64, lr=0.01):
super().__init__()
self.save_hyperparameters()
self.model = MLP(input_dim, hidden_dim)
self.criterion = nn.MSELoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss'
}
}
# データローダー
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# コールバック
early_stopping = EarlyStopping(
monitor='val_loss',
patience=20,
mode='min',
verbose=True
)
checkpoint = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
filename='best-{epoch:02d}-{val_loss:.4f}'
)
# 訓練
model = LitModel(n_features)
trainer = pl.Trainer(
max_epochs=500,
callbacks=[early_stopping, checkpoint],
enable_progress_bar=True
)
trainer.fit(model, train_loader, val_loader)
print(f"Best model path: {checkpoint.best_model_path}")
print(f"Best val_loss: {checkpoint.best_model_score:.4f}")
正則化効果の可視化
import numpy as np
import matplotlib.pyplot as plt
# 早期終了の正則化効果を可視化
def visualize_regularization_effect():
"""特異値フィルタリングによる早期終了の効果を可視化"""
# 特異値の範囲
sigma = np.linspace(0.1, 2, 100)
eta = 0.1 # 学習率
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 左:フィルタ係数の推移
ax = axes[0]
for t in [10, 50, 100, 500, 1000]:
f = 1 - (1 - eta * sigma**2)**t
ax.plot(sigma, f, label=f't = {t}')
ax.set_xlabel('Singular Value σ')
ax.set_ylabel('Filter Coefficient f(σ, t)')
ax.set_title('Early Stopping as Spectral Filtering')
ax.legend()
ax.grid(True, alpha=0.3)
# 右:等価L2正則化との比較
ax = axes[1]
t_values = np.array([10, 50, 100, 500, 1000])
lambda_eff = 1 / (eta * t_values)
# 早期終了フィルタ
t = 100
f_early = 1 - (1 - eta * sigma**2)**t
ax.plot(sigma, f_early, 'b-', linewidth=2, label=f'Early Stopping (t={t})')
# 等価L2フィルタ
lam = 1 / (eta * t)
f_ridge = sigma**2 / (sigma**2 + lam)
ax.plot(sigma, f_ridge, 'r--', linewidth=2, label=f'Ridge (λ={lam:.3f})')
ax.set_xlabel('Singular Value σ')
ax.set_ylabel('Filter Coefficient')
ax.set_title('Early Stopping ≈ L2 Regularization')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('early_stopping_regularization.png', dpi=150, bbox_inches='tight')
plt.show()
visualize_regularization_effect()
まとめ
本記事では、早期終了(Early Stopping)について解説しました。
- 基本概念: 検証誤差が改善しなくなったら訓練を停止
- 正則化としての解釈: 暗黙的にL2正則化と同等の効果
- patience戦略: 固定patience、相対改善threshold、学習率スケジューリングとの組み合わせ
- 実装: PyTorch/PyTorch Lightningでの実装
早期終了は、追加のハイパーパラメータをほとんど導入せずに過学習を防ぐ、実用的で効果的なテクニックです。
次のステップとして、以下の記事も参考にしてください。