バッチ正規化(Batch Normalization, BN)は、2015年にIoffeとSzegedyによって提案された深層学習の学習安定化テクニックです。各層の入力を正規化することで、学習の高速化と安定化を実現します。
現在のCNNやTransformerの多くのアーキテクチャで標準的に使用されており、深層学習を学ぶ上で必ず理解しておくべき手法です。
本記事の内容
- バッチ正規化が解決する問題(内部共変量シフト)
- アルゴリズムの数学的定式化
- 学習時と推論時の違い
- PyTorchでの実装と効果の確認
前提知識
この記事を読む前に、ニューラルネットワークの基礎(全結合層、活性化関数、誤差逆伝播法)を押さえておくと理解が深まります。
内部共変量シフト
深層ネットワークの学習では、各層のパラメータが更新されるたびに、次の層への入力分布が変化します。この現象を 内部共変量シフト(Internal Covariate Shift) と呼びます。
入力分布が絶えず変化すると、各層は常に新しい分布に適応する必要があり、学習が不安定になったり、収束が遅くなったりします。バッチ正規化はこの問題を軽減します。
アルゴリズムの定式化
ミニバッチ $\mathcal{B} = \{x_1, x_2, \dots, x_m\}$ が与えられたとき、バッチ正規化は以下の手順で行います。
Step 1: ミニバッチの平均
$$ \mu_{\mathcal{B}} = \frac{1}{m}\sum_{i=1}^{m} x_i $$
Step 2: ミニバッチの分散
$$ \sigma_{\mathcal{B}}^2 = \frac{1}{m}\sum_{i=1}^{m}(x_i – \mu_{\mathcal{B}})^2 $$
Step 3: 正規化
$$ \hat{x}_i = \frac{x_i – \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} $$
$\epsilon$ は数値安定性のための小さな定数(例: $10^{-5}$)です。
Step 4: スケールとシフト
$$ y_i = \gamma \hat{x}_i + \beta $$
$\gamma$(スケール)と $\beta$(シフト)は学習可能なパラメータです。
正規化によって表現力が失われることを防ぐために、$\gamma$ と $\beta$ で再スケール・再シフトを行います。$\gamma = \sigma_{\mathcal{B}}$, $\beta = \mu_{\mathcal{B}}$ と学習すれば元の分布に戻るため、恒等変換を表現でき、表現力は損なわれません。
学習時と推論時の違い
学習時
ミニバッチの統計量($\mu_{\mathcal{B}}$, $\sigma_{\mathcal{B}}^2$)を使って正規化します。同時に、移動平均(running mean/variance)を更新します。
$$ \mu_{\text{running}} \leftarrow (1 – \alpha)\mu_{\text{running}} + \alpha \mu_{\mathcal{B}} $$
$$ \sigma^2_{\text{running}} \leftarrow (1 – \alpha)\sigma^2_{\text{running}} + \alpha \sigma^2_{\mathcal{B}} $$
ここで $\alpha$(モメンタム)は通常0.1に設定されます。
推論時
ミニバッチの統計量ではなく、学習中に蓄積した移動平均を使います。
$$ \hat{x} = \frac{x – \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} $$
これにより、推論時の結果がバッチサイズや構成に依存しなくなります。
PyTorchでの実装
バッチ正規化の有無による学習の違いを確認します。
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- データの準備 ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# --- BNなしのモデル ---
class ModelWithoutBN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
return self.layers(x.view(-1, 784))
# --- BNありのモデル ---
class ModelWithBN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
return self.layers(x.view(-1, 784))
# --- 学習関数 ---
def train_model(model, n_epochs=10):
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train_losses = []
test_accs = []
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
train_losses.append(epoch_loss / len(train_loader))
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in test_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
output = model(batch_x)
pred = output.argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += batch_y.size(0)
test_accs.append(correct / total)
return train_losses, test_accs
# --- 学習と比較 ---
model_no_bn = ModelWithoutBN().to(device)
model_with_bn = ModelWithBN().to(device)
losses_no_bn, accs_no_bn = train_model(model_no_bn, n_epochs=15)
losses_with_bn, accs_with_bn = train_model(model_with_bn, n_epochs=15)
# --- 可視化 ---
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ax1 = axes[0]
ax1.plot(losses_no_bn, 'r-o', markersize=4, label='Without BN')
ax1.plot(losses_with_bn, 'b-o', markersize=4, label='With BN')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2 = axes[1]
ax2.plot(accs_no_bn, 'r-o', markersize=4, label='Without BN')
ax2.plot(accs_with_bn, 'b-o', markersize=4, label='With BN')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"BNなし - 最終精度: {accs_no_bn[-1]:.4f}")
print(f"BNあり - 最終精度: {accs_with_bn[-1]:.4f}")
バッチ正規化を導入することで、学習の収束が速くなり、テスト精度も向上することが確認できます。
まとめ
本記事では、バッチ正規化(Batch Normalization)について解説しました。
- バッチ正規化はミニバッチの統計量で各層の入力を正規化し、学習を安定化させる
- 正規化後に学習可能なパラメータ $\gamma$, $\beta$ でスケール・シフトを行い表現力を維持する
- 学習時はミニバッチ統計量、推論時は移動平均を使用する
- 学習の高速化、より高い学習率の使用、正則化効果などの利点がある
次のステップとして、以下の記事も参考にしてください。