バッチ正規化(Batch Normalization)の理論と実装を解説

バッチ正規化(Batch Normalization, BN)は、2015年にIoffeとSzegedyによって提案された深層学習の学習安定化テクニックです。各層の入力を正規化することで、学習の高速化と安定化を実現します。

現在のCNNやTransformerの多くのアーキテクチャで標準的に使用されており、深層学習を学ぶ上で必ず理解しておくべき手法です。

本記事の内容

  • バッチ正規化が解決する問題(内部共変量シフト)
  • アルゴリズムの数学的定式化
  • 学習時と推論時の違い
  • PyTorchでの実装と効果の確認

前提知識

この記事を読む前に、ニューラルネットワークの基礎(全結合層、活性化関数、誤差逆伝播法)を押さえておくと理解が深まります。

内部共変量シフト

深層ネットワークの学習では、各層のパラメータが更新されるたびに、次の層への入力分布が変化します。この現象を 内部共変量シフト(Internal Covariate Shift) と呼びます。

入力分布が絶えず変化すると、各層は常に新しい分布に適応する必要があり、学習が不安定になったり、収束が遅くなったりします。バッチ正規化はこの問題を軽減します。

アルゴリズムの定式化

ミニバッチ $\mathcal{B} = \{x_1, x_2, \dots, x_m\}$ が与えられたとき、バッチ正規化は以下の手順で行います。

Step 1: ミニバッチの平均

$$ \mu_{\mathcal{B}} = \frac{1}{m}\sum_{i=1}^{m} x_i $$

Step 2: ミニバッチの分散

$$ \sigma_{\mathcal{B}}^2 = \frac{1}{m}\sum_{i=1}^{m}(x_i – \mu_{\mathcal{B}})^2 $$

Step 3: 正規化

$$ \hat{x}_i = \frac{x_i – \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} $$

$\epsilon$ は数値安定性のための小さな定数(例: $10^{-5}$)です。

Step 4: スケールとシフト

$$ y_i = \gamma \hat{x}_i + \beta $$

$\gamma$(スケール)と $\beta$(シフト)は学習可能なパラメータです。

正規化によって表現力が失われることを防ぐために、$\gamma$ と $\beta$ で再スケール・再シフトを行います。$\gamma = \sigma_{\mathcal{B}}$, $\beta = \mu_{\mathcal{B}}$ と学習すれば元の分布に戻るため、恒等変換を表現でき、表現力は損なわれません。

学習時と推論時の違い

学習時

ミニバッチの統計量($\mu_{\mathcal{B}}$, $\sigma_{\mathcal{B}}^2$)を使って正規化します。同時に、移動平均(running mean/variance)を更新します。

$$ \mu_{\text{running}} \leftarrow (1 – \alpha)\mu_{\text{running}} + \alpha \mu_{\mathcal{B}} $$

$$ \sigma^2_{\text{running}} \leftarrow (1 – \alpha)\sigma^2_{\text{running}} + \alpha \sigma^2_{\mathcal{B}} $$

ここで $\alpha$(モメンタム)は通常0.1に設定されます。

推論時

ミニバッチの統計量ではなく、学習中に蓄積した移動平均を使います。

$$ \hat{x} = \frac{x – \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} $$

これにより、推論時の結果がバッチサイズや構成に依存しなくなります。

PyTorchでの実装

バッチ正規化の有無による学習の違いを確認します。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- データの準備 ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# --- BNなしのモデル ---
class ModelWithoutBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.layers(x.view(-1, 784))

# --- BNありのモデル ---
class ModelWithBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.layers(x.view(-1, 784))

# --- 学習関数 ---
def train_model(model, n_epochs=10):
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    test_accs = []

    for epoch in range(n_epochs):
        model.train()
        epoch_loss = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        train_losses.append(epoch_loss / len(train_loader))

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                output = model(batch_x)
                pred = output.argmax(dim=1)
                correct += (pred == batch_y).sum().item()
                total += batch_y.size(0)
        test_accs.append(correct / total)

    return train_losses, test_accs

# --- 学習と比較 ---
model_no_bn = ModelWithoutBN().to(device)
model_with_bn = ModelWithBN().to(device)

losses_no_bn, accs_no_bn = train_model(model_no_bn, n_epochs=15)
losses_with_bn, accs_with_bn = train_model(model_with_bn, n_epochs=15)

# --- 可視化 ---
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax1 = axes[0]
ax1.plot(losses_no_bn, 'r-o', markersize=4, label='Without BN')
ax1.plot(losses_with_bn, 'b-o', markersize=4, label='With BN')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2 = axes[1]
ax2.plot(accs_no_bn, 'r-o', markersize=4, label='Without BN')
ax2.plot(accs_with_bn, 'b-o', markersize=4, label='With BN')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"BNなし - 最終精度: {accs_no_bn[-1]:.4f}")
print(f"BNあり - 最終精度: {accs_with_bn[-1]:.4f}")

バッチ正規化を導入することで、学習の収束が速くなり、テスト精度も向上することが確認できます。

まとめ

本記事では、バッチ正規化(Batch Normalization)について解説しました。

  • バッチ正規化はミニバッチの統計量で各層の入力を正規化し、学習を安定化させる
  • 正規化後に学習可能なパラメータ $\gamma$, $\beta$ でスケール・シフトを行い表現力を維持する
  • 学習時はミニバッチ統計量、推論時は移動平均を使用する
  • 学習の高速化、より高い学習率の使用、正則化効果などの利点がある

次のステップとして、以下の記事も参考にしてください。