Batch Normalization(バッチ正規化)は、2015年に提案されて以来、深層学習の標準的なテクニックとなっています。学習の高速化と安定化を実現し、より大きな学習率の使用を可能にします。
本記事では、Batch Normalizationの理論から実装まで詳しく解説します。
本記事の内容
- Batch Normalizationの動機と効果
- 順伝播・逆伝播の数学的導出
- 訓練時と推論時の違い
- PyTorchでのスクラッチ実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
Batch Normalizationの動機
内部共変量シフト(Internal Covariate Shift)
深層ニューラルネットワークでは、各層の入力分布が学習中に変化します。これを「内部共変量シフト」と呼びます。
前の層のパラメータが更新されると、現在の層への入力分布が変化し: – 各層は常に変化する入力分布に適応し続ける必要がある – 学習が不安定になりやすい – 小さな学習率しか使えない
Batch Normalizationの効果
Batch Normalizationは各層の入力を正規化することで:
- 学習の高速化: より大きな学習率を使用可能
- 初期化への依存性低減: パラメータ初期化の影響を軽減
- 正則化効果: ミニバッチのノイズが正則化として機能
- 勾配の安定化: 勾配消失・爆発を軽減
Batch Normalizationのアルゴリズム
順伝播
ミニバッチ $\mathcal{B} = \{x_1, x_2, \ldots, x_m\}$ に対して:
ステップ1: ミニバッチの平均
$$ \mu_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} x_i $$
ステップ2: ミニバッチの分散
$$ \sigma_\mathcal{B}^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i – \mu_\mathcal{B})^2 $$
ステップ3: 正規化
$$ \hat{x}_i = \frac{x_i – \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}} $$
$\epsilon$ は数値安定性のための小さな定数(通常 $10^{-5}$)です。
ステップ4: スケールとシフト
$$ y_i = \gamma \hat{x}_i + \beta $$
$\gamma$(スケール)と $\beta$(シフト)は学習可能なパラメータです。
なぜスケールとシフトが必要か
正規化後の $\hat{x}$ は平均0、分散1に固定されます。しかし、これでは表現力が制限されてしまいます。
$\gamma$ と $\beta$ を導入することで: – 必要に応じて元の分布を復元可能($\gamma = \sigma$, $\beta = \mu$ のとき) – ネットワークが最適な分布を学習できる
逆伝播の導出
損失関数を $\mathcal{L}$ として、各パラメータの勾配を導出します。
表記の整理
- 入力: $x_i$
- 正規化後: $\hat{x}_i = \frac{x_i – \mu}{\sqrt{\sigma^2 + \epsilon}}$
- 出力: $y_i = \gamma \hat{x}_i + \beta$
- 上流からの勾配: $\frac{\partial \mathcal{L}}{\partial y_i}$
簡単のため $\sigma^2 + \epsilon$ を $\sigma^2$ と書きます。
$\gamma$ と $\beta$ の勾配
$$ \frac{\partial \mathcal{L}}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i} \cdot \hat{x}_i $$
$$ \frac{\partial \mathcal{L}}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i} $$
$\hat{x}_i$ の勾配
$$ \frac{\partial \mathcal{L}}{\partial \hat{x}_i} = \frac{\partial \mathcal{L}}{\partial y_i} \cdot \gamma $$
$\sigma^2$ の勾配
$\hat{x}_i = (x_i – \mu) \sigma^{-1}$ なので:
$$ \frac{\partial \hat{x}_i}{\partial \sigma^2} = -\frac{1}{2}(x_i – \mu)(\sigma^2)^{-3/2} $$
したがって:
$$ \frac{\partial \mathcal{L}}{\partial \sigma^2} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot (x_i – \mu) \cdot \left( -\frac{1}{2} \right) (\sigma^2)^{-3/2} $$
$\mu$ の勾配
$\hat{x}_i$ と $\sigma^2$ の両方が $\mu$ に依存することに注意:
$$ \frac{\partial \mathcal{L}}{\partial \mu} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2}} + \frac{\partial \mathcal{L}}{\partial \sigma^2} \cdot \frac{-2}{m} \sum_{i=1}^{m} (x_i – \mu) $$
$x_i$ の勾配
$$ \frac{\partial \mathcal{L}}{\partial x_i} = \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot \frac{1}{\sqrt{\sigma^2}} + \frac{\partial \mathcal{L}}{\partial \sigma^2} \cdot \frac{2(x_i – \mu)}{m} + \frac{\partial \mathcal{L}}{\partial \mu} \cdot \frac{1}{m} $$
訓練時と推論時の違い
訓練時
- ミニバッチの統計量($\mu_\mathcal{B}$, $\sigma_\mathcal{B}^2$)を使用
- 移動平均(Running Mean/Variance)を更新
$$ \begin{align} \mu_{\text{running}} &\leftarrow (1 – \alpha) \mu_{\text{running}} + \alpha \mu_\mathcal{B} \\ \sigma^2_{\text{running}} &\leftarrow (1 – \alpha) \sigma^2_{\text{running}} + \alpha \sigma^2_\mathcal{B} \end{align} $$
$\alpha$ はモメンタム(通常0.1)です。
推論時
- 訓練中に計算した移動平均を使用
- バッチサイズに依存しない予測が可能
$$ \hat{x} = \frac{x – \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} $$
$$ y = \gamma \hat{x} + \beta $$
なぜ移動平均を使うのか
推論時には: 1. バッチサイズ1の入力が来る可能性がある 2. 入力ごとに結果が変わるのは望ましくない 3. 訓練データ全体の統計量を代表する値が必要
他の正規化手法との比較
Layer Normalization
特徴マップの空間方向ではなく、チャネル方向で正規化:
$$ \mu = \frac{1}{D} \sum_{d=1}^{D} x_d, \quad \sigma^2 = \frac{1}{D} \sum_{d=1}^{D} (x_d – \mu)^2 $$
- バッチサイズに依存しない
- Transformerで広く使用
- RNNに適している
Instance Normalization
各サンプルの各チャネルで独立に正規化:
- 画像のスタイル変換で使用
- バッチ内の他のサンプルに依存しない
Group Normalization
チャネルをグループに分けて正規化:
- バッチサイズが小さいときに有効
- 物体検出、セグメンテーションで使用
| 手法 | 正規化の軸 | 主な用途 |
|---|---|---|
| Batch Norm | バッチ & 空間 | CNN |
| Layer Norm | チャネル | Transformer, RNN |
| Instance Norm | 空間 | スタイル変換 |
| Group Norm | チャネルグループ | 小バッチサイズ |
Pythonでの実装
スクラッチ実装(NumPy)
import numpy as np
class BatchNormalization:
"""Batch Normalizationのスクラッチ実装"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# 学習可能パラメータ
self.gamma = np.ones(num_features)
self.beta = np.zeros(num_features)
# 移動平均
self.running_mean = np.zeros(num_features)
self.running_var = np.ones(num_features)
# 逆伝播用のキャッシュ
self.cache = None
# 訓練モード
self.training = True
def forward(self, x):
"""
順伝播
Args:
x: 入力 (batch_size, num_features)
Returns:
out: 正規化された出力
"""
if self.training:
# ミニバッチ統計量
mean = np.mean(x, axis=0)
var = np.var(x, axis=0)
# 正規化
x_norm = (x - mean) / np.sqrt(var + self.eps)
# スケールとシフト
out = self.gamma * x_norm + self.beta
# 移動平均の更新
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
# キャッシュ
self.cache = (x, x_norm, mean, var)
else:
# 推論時は移動平均を使用
x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
out = self.gamma * x_norm + self.beta
return out
def backward(self, dout):
"""
逆伝播
Args:
dout: 上流からの勾配 (batch_size, num_features)
Returns:
dx: 入力に対する勾配
dgamma: gammaに対する勾配
dbeta: betaに対する勾配
"""
x, x_norm, mean, var = self.cache
m = x.shape[0]
# gamma, betaの勾配
dgamma = np.sum(dout * x_norm, axis=0)
dbeta = np.sum(dout, axis=0)
# x_normの勾配
dx_norm = dout * self.gamma
# varの勾配
dvar = np.sum(dx_norm * (x - mean) * (-0.5) * (var + self.eps)**(-1.5), axis=0)
# meanの勾配
dmean = np.sum(dx_norm * (-1 / np.sqrt(var + self.eps)), axis=0)
dmean += dvar * np.mean(-2 * (x - mean), axis=0)
# xの勾配
dx = dx_norm / np.sqrt(var + self.eps)
dx += dvar * 2 * (x - mean) / m
dx += dmean / m
return dx, dgamma, dbeta
def train(self):
self.training = True
def eval(self):
self.training = False
# 動作確認
np.random.seed(42)
bn = BatchNormalization(num_features=10)
# 順伝播
x = np.random.randn(32, 10) * 3 + 5 # 平均5, 標準偏差3のデータ
out = bn.forward(x)
print("Input statistics:")
print(f" Mean: {x.mean(axis=0)[:3].round(3)}")
print(f" Std: {x.std(axis=0)[:3].round(3)}")
print("\nOutput statistics (after BatchNorm):")
print(f" Mean: {out.mean(axis=0)[:3].round(3)}")
print(f" Std: {out.std(axis=0)[:3].round(3)}")
PyTorchでの実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class CustomBatchNorm1d(nn.Module):
"""PyTorchでのBatch Normalization実装"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# 学習可能パラメータ
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# 移動平均(学習対象外)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
# 移動平均の更新
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mean = self.running_mean
var = self.running_var
x_norm = (x - mean) / torch.sqrt(var + self.eps)
out = self.gamma * x_norm + self.beta
return out
# 動作確認
torch.manual_seed(42)
# 自作BNとPyTorch標準の比較
custom_bn = CustomBatchNorm1d(num_features=10)
pytorch_bn = nn.BatchNorm1d(num_features=10, momentum=0.1)
x = torch.randn(32, 10) * 3 + 5
custom_bn.train()
pytorch_bn.train()
out_custom = custom_bn(x)
out_pytorch = pytorch_bn(x)
print("Custom BatchNorm output mean:", out_custom.mean(dim=0)[:3].detach().numpy().round(4))
print("PyTorch BatchNorm output mean:", out_pytorch.mean(dim=0)[:3].detach().numpy().round(4))
学習の安定化効果の実験
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
class MLPWithoutBN(nn.Module):
"""Batch Normalizationなしのネットワーク"""
def __init__(self, input_dim=20, hidden_dim=256, num_layers=5, num_classes=10):
super().__init__()
layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
for _ in range(num_layers - 1):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class MLPWithBN(nn.Module):
"""Batch Normalization付きのネットワーク"""
def __init__(self, input_dim=20, hidden_dim=256, num_layers=5, num_classes=10):
super().__init__()
layers = [nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()]
for _ in range(num_layers - 1):
layers.extend([
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU()
])
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def create_dataset(n_samples=5000, input_dim=20, num_classes=10):
np.random.seed(42)
X = np.random.randn(n_samples, input_dim).astype(np.float32)
W = np.random.randn(input_dim, num_classes).astype(np.float32)
logits = X @ W
y = np.argmax(logits, axis=1)
split = int(0.8 * n_samples)
return (X[:split], y[:split]), (X[split:], y[split:])
def train_model(model, train_loader, test_loader, n_epochs=100, lr=0.01):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
train_losses = []
test_accuracies = []
gradient_norms = []
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
epoch_grad_norm = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
# 勾配のノルムを記録
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
epoch_grad_norm += total_norm ** 0.5
optimizer.step()
epoch_loss += loss.item()
train_losses.append(epoch_loss / len(train_loader))
gradient_norms.append(epoch_grad_norm / len(train_loader))
# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
outputs = model(X_batch)
_, predicted = torch.max(outputs, 1)
total += y_batch.size(0)
correct += (predicted == y_batch).sum().item()
test_accuracies.append(correct / total)
return train_losses, test_accuracies, gradient_norms
# データ準備
(X_train, y_train), (X_test, y_test) = create_dataset()
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 学習
torch.manual_seed(42)
model_no_bn = MLPWithoutBN()
results_no_bn = train_model(model_no_bn, train_loader, test_loader, n_epochs=100, lr=0.01)
torch.manual_seed(42)
model_with_bn = MLPWithBN()
results_with_bn = train_model(model_with_bn, train_loader, test_loader, n_epochs=100, lr=0.01)
# より大きな学習率でも試す
torch.manual_seed(42)
model_with_bn_high_lr = MLPWithBN()
results_with_bn_high = train_model(model_with_bn_high_lr, train_loader, test_loader, n_epochs=100, lr=0.1)
# 可視化
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# 訓練損失
ax1 = axes[0]
ax1.plot(results_no_bn[0], label='Without BN (lr=0.01)', alpha=0.8)
ax1.plot(results_with_bn[0], label='With BN (lr=0.01)', alpha=0.8)
ax1.plot(results_with_bn_high[0], label='With BN (lr=0.1)', alpha=0.8)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
# テスト精度
ax2 = axes[1]
ax2.plot(results_no_bn[1], label='Without BN (lr=0.01)', alpha=0.8)
ax2.plot(results_with_bn[1], label='With BN (lr=0.01)', alpha=0.8)
ax2.plot(results_with_bn_high[1], label='With BN (lr=0.1)', alpha=0.8)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test Accuracy')
ax2.set_title('Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 勾配ノルム
ax3 = axes[2]
ax3.plot(results_no_bn[2], label='Without BN (lr=0.01)', alpha=0.8)
ax3.plot(results_with_bn[2], label='With BN (lr=0.01)', alpha=0.8)
ax3.plot(results_with_bn_high[2], label='With BN (lr=0.1)', alpha=0.8)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Gradient Norm')
ax3.set_title('Gradient Norm')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_yscale('log')
plt.tight_layout()
plt.show()
print("\nFinal Test Accuracy:")
print(f" Without BN (lr=0.01): {results_no_bn[1][-1]:.4f}")
print(f" With BN (lr=0.01): {results_with_bn[1][-1]:.4f}")
print(f" With BN (lr=0.1): {results_with_bn_high[1][-1]:.4f}")
Batch Normalizationの注意点
バッチサイズへの依存
- 小さいバッチサイズでは統計量が不安定
- バッチサイズ16未満では性能が低下する傾向
- 対策: Group Normalization や Layer Normalization を使用
訓練・推論モードの切り替え
# 訓練時
model.train()
# 推論時(重要!)
model.eval()
model.eval() を忘れると、推論時もミニバッチ統計量を使ってしまい、結果が不安定になります。
配置位置
一般的な配置:
Linear -> BatchNorm -> Activation
または:
Conv -> BatchNorm -> Activation
まとめ
本記事では、Batch Normalizationについて解説しました。
- Batch Normalizationは各層の入力を正規化し、学習を安定化させる
- $\gamma$(スケール)と $\beta$(シフト)により表現力を維持
- 訓練時はミニバッチ統計量、推論時は移動平均を使用
- より大きな学習率を使用でき、学習が高速化される
次のステップとして、以下の記事も参考にしてください。