Batch Normalizationの順伝播・逆伝播を完全導出する

Batch Normalization(バッチ正規化)は、2015年に提案されて以来、深層学習の標準的なテクニックとなっています。学習の高速化と安定化を実現し、より大きな学習率の使用を可能にします。

本記事では、Batch Normalizationの理論から実装まで詳しく解説します。

本記事の内容

  • Batch Normalizationの動機と効果
  • 順伝播・逆伝播の数学的導出
  • 訓練時と推論時の違い
  • PyTorchでのスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

Batch Normalizationの動機

内部共変量シフト(Internal Covariate Shift)

深層ニューラルネットワークでは、各層の入力分布が学習中に変化します。これを「内部共変量シフト」と呼びます。

前の層のパラメータが更新されると、現在の層への入力分布が変化し: – 各層は常に変化する入力分布に適応し続ける必要がある – 学習が不安定になりやすい – 小さな学習率しか使えない

Batch Normalizationの効果

Batch Normalizationは各層の入力を正規化することで:

  1. 学習の高速化: より大きな学習率を使用可能
  2. 初期化への依存性低減: パラメータ初期化の影響を軽減
  3. 正則化効果: ミニバッチのノイズが正則化として機能
  4. 勾配の安定化: 勾配消失・爆発を軽減

Batch Normalizationのアルゴリズム

順伝播

ミニバッチ $\mathcal{B} = \{x_1, x_2, \ldots, x_m\}$ に対して:

ステップ1: ミニバッチの平均

$$ \mu_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} x_i $$

ステップ2: ミニバッチの分散

$$ \sigma_\mathcal{B}^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i – \mu_\mathcal{B})^2 $$

ステップ3: 正規化

$$ \hat{x}_i = \frac{x_i – \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}} $$

$\epsilon$ は数値安定性のための小さな定数(通常 $10^{-5}$)です。

ステップ4: スケールとシフト

$$ y_i = \gamma \hat{x}_i + \beta $$

$\gamma$(スケール)と $\beta$(シフト)は学習可能なパラメータです。

なぜスケールとシフトが必要か

正規化後の $\hat{x}$ は平均0、分散1に固定されます。しかし、これでは表現力が制限されてしまいます。

$\gamma$ と $\beta$ を導入することで: – 必要に応じて元の分布を復元可能($\gamma = \sigma$, $\beta = \mu$ のとき) – ネットワークが最適な分布を学習できる

逆伝播の導出

損失関数を $\mathcal{L}$ として、各パラメータの勾配を導出します。

表記の整理

  • 入力: $x_i$
  • 正規化後: $\hat{x}_i = \frac{x_i – \mu}{\sqrt{\sigma^2 + \epsilon}}$
  • 出力: $y_i = \gamma \hat{x}_i + \beta$
  • 上流からの勾配: $\frac{\partial \mathcal{L}}{\partial y_i}$

簡単のため $\sigma^2 + \epsilon$ を $\sigma^2$ と書きます。

$\gamma$ と $\beta$ の勾配

$$ \frac{\partial \mathcal{L}}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i} \cdot \hat{x}_i $$

$$ \frac{\partial \mathcal{L}}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i} $$

$\hat{x}_i$ の勾配

$$ \frac{\partial \mathcal{L}}{\partial \hat{x}_i} = \frac{\partial \mathcal{L}}{\partial y_i} \cdot \gamma $$

$\sigma^2$ の勾配

$\hat{x}_i = (x_i – \mu) \sigma^{-1}$ なので:

$$ \frac{\partial \hat{x}_i}{\partial \sigma^2} = -\frac{1}{2}(x_i – \mu)(\sigma^2)^{-3/2} $$

したがって:

$$ \frac{\partial \mathcal{L}}{\partial \sigma^2} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot (x_i – \mu) \cdot \left( -\frac{1}{2} \right) (\sigma^2)^{-3/2} $$

$\mu$ の勾配

$\hat{x}_i$ と $\sigma^2$ の両方が $\mu$ に依存することに注意:

$$ \frac{\partial \mathcal{L}}{\partial \mu} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2}} + \frac{\partial \mathcal{L}}{\partial \sigma^2} \cdot \frac{-2}{m} \sum_{i=1}^{m} (x_i – \mu) $$

$x_i$ の勾配

$$ \frac{\partial \mathcal{L}}{\partial x_i} = \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot \frac{1}{\sqrt{\sigma^2}} + \frac{\partial \mathcal{L}}{\partial \sigma^2} \cdot \frac{2(x_i – \mu)}{m} + \frac{\partial \mathcal{L}}{\partial \mu} \cdot \frac{1}{m} $$

訓練時と推論時の違い

訓練時

  • ミニバッチの統計量($\mu_\mathcal{B}$, $\sigma_\mathcal{B}^2$)を使用
  • 移動平均(Running Mean/Variance)を更新

$$ \begin{align} \mu_{\text{running}} &\leftarrow (1 – \alpha) \mu_{\text{running}} + \alpha \mu_\mathcal{B} \\ \sigma^2_{\text{running}} &\leftarrow (1 – \alpha) \sigma^2_{\text{running}} + \alpha \sigma^2_\mathcal{B} \end{align} $$

$\alpha$ はモメンタム(通常0.1)です。

推論時

  • 訓練中に計算した移動平均を使用
  • バッチサイズに依存しない予測が可能

$$ \hat{x} = \frac{x – \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} $$

$$ y = \gamma \hat{x} + \beta $$

なぜ移動平均を使うのか

推論時には: 1. バッチサイズ1の入力が来る可能性がある 2. 入力ごとに結果が変わるのは望ましくない 3. 訓練データ全体の統計量を代表する値が必要

他の正規化手法との比較

Layer Normalization

特徴マップの空間方向ではなく、チャネル方向で正規化:

$$ \mu = \frac{1}{D} \sum_{d=1}^{D} x_d, \quad \sigma^2 = \frac{1}{D} \sum_{d=1}^{D} (x_d – \mu)^2 $$

  • バッチサイズに依存しない
  • Transformerで広く使用
  • RNNに適している

Instance Normalization

各サンプルの各チャネルで独立に正規化:

  • 画像のスタイル変換で使用
  • バッチ内の他のサンプルに依存しない

Group Normalization

チャネルをグループに分けて正規化:

  • バッチサイズが小さいときに有効
  • 物体検出、セグメンテーションで使用
手法 正規化の軸 主な用途
Batch Norm バッチ & 空間 CNN
Layer Norm チャネル Transformer, RNN
Instance Norm 空間 スタイル変換
Group Norm チャネルグループ 小バッチサイズ

Pythonでの実装

スクラッチ実装(NumPy)

import numpy as np

class BatchNormalization:
    """Batch Normalizationのスクラッチ実装"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # 学習可能パラメータ
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)

        # 移動平均
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)

        # 逆伝播用のキャッシュ
        self.cache = None

        # 訓練モード
        self.training = True

    def forward(self, x):
        """
        順伝播
        Args:
            x: 入力 (batch_size, num_features)
        Returns:
            out: 正規化された出力
        """
        if self.training:
            # ミニバッチ統計量
            mean = np.mean(x, axis=0)
            var = np.var(x, axis=0)

            # 正規化
            x_norm = (x - mean) / np.sqrt(var + self.eps)

            # スケールとシフト
            out = self.gamma * x_norm + self.beta

            # 移動平均の更新
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var

            # キャッシュ
            self.cache = (x, x_norm, mean, var)

        else:
            # 推論時は移動平均を使用
            x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
            out = self.gamma * x_norm + self.beta

        return out

    def backward(self, dout):
        """
        逆伝播
        Args:
            dout: 上流からの勾配 (batch_size, num_features)
        Returns:
            dx: 入力に対する勾配
            dgamma: gammaに対する勾配
            dbeta: betaに対する勾配
        """
        x, x_norm, mean, var = self.cache
        m = x.shape[0]

        # gamma, betaの勾配
        dgamma = np.sum(dout * x_norm, axis=0)
        dbeta = np.sum(dout, axis=0)

        # x_normの勾配
        dx_norm = dout * self.gamma

        # varの勾配
        dvar = np.sum(dx_norm * (x - mean) * (-0.5) * (var + self.eps)**(-1.5), axis=0)

        # meanの勾配
        dmean = np.sum(dx_norm * (-1 / np.sqrt(var + self.eps)), axis=0)
        dmean += dvar * np.mean(-2 * (x - mean), axis=0)

        # xの勾配
        dx = dx_norm / np.sqrt(var + self.eps)
        dx += dvar * 2 * (x - mean) / m
        dx += dmean / m

        return dx, dgamma, dbeta

    def train(self):
        self.training = True

    def eval(self):
        self.training = False

# 動作確認
np.random.seed(42)
bn = BatchNormalization(num_features=10)

# 順伝播
x = np.random.randn(32, 10) * 3 + 5  # 平均5, 標準偏差3のデータ
out = bn.forward(x)

print("Input statistics:")
print(f"  Mean: {x.mean(axis=0)[:3].round(3)}")
print(f"  Std:  {x.std(axis=0)[:3].round(3)}")

print("\nOutput statistics (after BatchNorm):")
print(f"  Mean: {out.mean(axis=0)[:3].round(3)}")
print(f"  Std:  {out.std(axis=0)[:3].round(3)}")

PyTorchでの実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class CustomBatchNorm1d(nn.Module):
    """PyTorchでのBatch Normalization実装"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # 学習可能パラメータ
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # 移動平均(学習対象外)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)

            # 移動平均の更新
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var

        else:
            mean = self.running_mean
            var = self.running_var

        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * x_norm + self.beta

        return out

# 動作確認
torch.manual_seed(42)

# 自作BNとPyTorch標準の比較
custom_bn = CustomBatchNorm1d(num_features=10)
pytorch_bn = nn.BatchNorm1d(num_features=10, momentum=0.1)

x = torch.randn(32, 10) * 3 + 5

custom_bn.train()
pytorch_bn.train()

out_custom = custom_bn(x)
out_pytorch = pytorch_bn(x)

print("Custom BatchNorm output mean:", out_custom.mean(dim=0)[:3].detach().numpy().round(4))
print("PyTorch BatchNorm output mean:", out_pytorch.mean(dim=0)[:3].detach().numpy().round(4))

学習の安定化効果の実験

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

class MLPWithoutBN(nn.Module):
    """Batch Normalizationなしのネットワーク"""

    def __init__(self, input_dim=20, hidden_dim=256, num_layers=5, num_classes=10):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class MLPWithBN(nn.Module):
    """Batch Normalization付きのネットワーク"""

    def __init__(self, input_dim=20, hidden_dim=256, num_layers=5, num_classes=10):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU()
            ])
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def create_dataset(n_samples=5000, input_dim=20, num_classes=10):
    np.random.seed(42)
    X = np.random.randn(n_samples, input_dim).astype(np.float32)
    W = np.random.randn(input_dim, num_classes).astype(np.float32)
    logits = X @ W
    y = np.argmax(logits, axis=1)
    split = int(0.8 * n_samples)
    return (X[:split], y[:split]), (X[split:], y[split:])

def train_model(model, train_loader, test_loader, n_epochs=100, lr=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    train_losses = []
    test_accuracies = []
    gradient_norms = []

    for epoch in range(n_epochs):
        model.train()
        epoch_loss = 0
        epoch_grad_norm = 0

        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()

            # 勾配のノルムを記録
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    total_norm += p.grad.data.norm(2).item() ** 2
            epoch_grad_norm += total_norm ** 0.5

            optimizer.step()
            epoch_loss += loss.item()

        train_losses.append(epoch_loss / len(train_loader))
        gradient_norms.append(epoch_grad_norm / len(train_loader))

        # Evaluation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
        test_accuracies.append(correct / total)

    return train_losses, test_accuracies, gradient_norms

# データ準備
(X_train, y_train), (X_test, y_test) = create_dataset()
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 学習
torch.manual_seed(42)
model_no_bn = MLPWithoutBN()
results_no_bn = train_model(model_no_bn, train_loader, test_loader, n_epochs=100, lr=0.01)

torch.manual_seed(42)
model_with_bn = MLPWithBN()
results_with_bn = train_model(model_with_bn, train_loader, test_loader, n_epochs=100, lr=0.01)

# より大きな学習率でも試す
torch.manual_seed(42)
model_with_bn_high_lr = MLPWithBN()
results_with_bn_high = train_model(model_with_bn_high_lr, train_loader, test_loader, n_epochs=100, lr=0.1)

# 可視化
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 訓練損失
ax1 = axes[0]
ax1.plot(results_no_bn[0], label='Without BN (lr=0.01)', alpha=0.8)
ax1.plot(results_with_bn[0], label='With BN (lr=0.01)', alpha=0.8)
ax1.plot(results_with_bn_high[0], label='With BN (lr=0.1)', alpha=0.8)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# テスト精度
ax2 = axes[1]
ax2.plot(results_no_bn[1], label='Without BN (lr=0.01)', alpha=0.8)
ax2.plot(results_with_bn[1], label='With BN (lr=0.01)', alpha=0.8)
ax2.plot(results_with_bn_high[1], label='With BN (lr=0.1)', alpha=0.8)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test Accuracy')
ax2.set_title('Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 勾配ノルム
ax3 = axes[2]
ax3.plot(results_no_bn[2], label='Without BN (lr=0.01)', alpha=0.8)
ax3.plot(results_with_bn[2], label='With BN (lr=0.01)', alpha=0.8)
ax3.plot(results_with_bn_high[2], label='With BN (lr=0.1)', alpha=0.8)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Gradient Norm')
ax3.set_title('Gradient Norm')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_yscale('log')

plt.tight_layout()
plt.show()

print("\nFinal Test Accuracy:")
print(f"  Without BN (lr=0.01): {results_no_bn[1][-1]:.4f}")
print(f"  With BN (lr=0.01):    {results_with_bn[1][-1]:.4f}")
print(f"  With BN (lr=0.1):     {results_with_bn_high[1][-1]:.4f}")

Batch Normalizationの注意点

バッチサイズへの依存

  • 小さいバッチサイズでは統計量が不安定
  • バッチサイズ16未満では性能が低下する傾向
  • 対策: Group Normalization や Layer Normalization を使用

訓練・推論モードの切り替え

# 訓練時
model.train()

# 推論時(重要!)
model.eval()

model.eval() を忘れると、推論時もミニバッチ統計量を使ってしまい、結果が不安定になります。

配置位置

一般的な配置:

Linear -> BatchNorm -> Activation

または:

Conv -> BatchNorm -> Activation

まとめ

本記事では、Batch Normalizationについて解説しました。

  • Batch Normalizationは各層の入力を正規化し、学習を安定化させる
  • $\gamma$(スケール)と $\beta$(シフト)により表現力を維持
  • 訓練時はミニバッチ統計量、推論時は移動平均を使用
  • より大きな学習率を使用でき、学習が高速化される

次のステップとして、以下の記事も参考にしてください。