深層学習において、正規化(Normalization)は学習の安定化と高速化に欠かせない技術です。特にTransformerアーキテクチャでは、Batch Normalization(BN)ではなくLayer Normalization(LN)が標準的に使われています。
本記事では、Layer Normalizationの数学的な仕組みを丁寧に導出し、Batch Normalizationとの違いを明確にした上で、なぜTransformerではLayer Normが採用されているのかを解説します。
本記事の内容
- 正規化がなぜ重要か
- Batch Normalizationの復習
- Layer Normalizationの数式と導出
- なぜTransformerでLayer Normが使われるか
- Pre-LN vs Post-LN
- Pythonでの実装と可視化
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
正規化の重要性
深層ニューラルネットワークでは、層を重ねるごとに各層への入力の分布が変動する問題が知られています。この現象は内部共変量シフト(Internal Covariate Shift)と呼ばれ、学習を不安定にする原因となります。
正規化は、各層への入力を平均0・分散1に近づけることで、この問題を緩和します。
- 学習の安定化:入力分布が安定し、勾配の流れが改善される
- 学習の高速化:より大きな学習率を使用できる
- 正則化効果:ある程度の正則化効果があり、過学習を抑制する
Batch Normalizationの復習
Batch Normalization(BN)は2015年に提案された正規化手法です。ミニバッチ内のサンプルを使って、各特徴量(チャネル)ごとに正規化を行います。
入力テンソルを $\bm{X} \in \mathbb{R}^{B \times D}$ とします($B$: バッチサイズ、$D$: 特徴量次元)。各特徴量次元 $d$ について、
$$ \mu_d = \frac{1}{B}\sum_{b=1}^{B} x_{b,d} $$
$$ \sigma_d^2 = \frac{1}{B}\sum_{b=1}^{B} (x_{b,d} – \mu_d)^2 $$
$$ \hat{x}_{b,d} = \frac{x_{b,d} – \mu_d}{\sqrt{\sigma_d^2 + \epsilon}} $$
$$ y_{b,d} = \gamma_d \hat{x}_{b,d} + \beta_d $$
Batch Normalizationの問題点
- バッチサイズ依存性:バッチサイズが小さいと統計量の推定が不安定
- 推論時の扱い:学習時とは異なり、移動平均で保存した統計量を使う必要がある
- 系列データへの適用が困難:系列長が可変のデータでは正規化が複雑になる
Layer Normalizationの数式と導出
Layer Normalizationの定義
Layer Normalization(LN)は各サンプル内の特徴量方向で正規化します。
入力テンソルを $\bm{X} \in \mathbb{R}^{B \times D}$ とします。各サンプル $b$ について、
$$ \mu_b = \frac{1}{D}\sum_{d=1}^{D} x_{b,d} $$
$$ \sigma_b^2 = \frac{1}{D}\sum_{d=1}^{D} (x_{b,d} – \mu_b)^2 $$
$$ \hat{x}_{b,d} = \frac{x_{b,d} – \mu_b}{\sqrt{\sigma_b^2 + \epsilon}} $$
$$ y_{b,d} = \gamma_d \hat{x}_{b,d} + \beta_d $$
ベクトル表記
$$ \text{LayerNorm}(\bm{x}) = \frac{\bm{x} – \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \bm{\gamma} + \bm{\beta} $$
ここで、$\mu$ と $\sigma^2$ はスカラー(そのサンプルの平均と分散)、$\bm{\gamma}, \bm{\beta} \in \mathbb{R}^D$ は学習可能なパラメータです。
Batch NormalizationとLayer Normalizationの比較
正規化の方向の違い
入力テンソル (B, D):
特徴量 d →
┌─────────────────┐
サ b │ x₁₁ x₁₂ ... x₁D │ ← LN: この行で平均・分散を計算
ン ↓ │ x₂₁ x₂₂ ... x₂D │
プ │ : : : │
ル │ xB₁ xB₂ ... xBD │
└─────────────────┘
↑
BN: この列で平均・分散を計算
比較表
| 項目 | Batch Normalization | Layer Normalization |
|---|---|---|
| 正規化の方向 | バッチ方向 | 特徴量方向 |
| バッチサイズ依存性 | あり | なし |
| 系列長依存性 | あり | なし |
| 学習/推論の違い | 移動平均を使用 | なし(同一処理) |
| 主な用途 | CNN | RNN, Transformer |
なぜTransformerでLayer Normが使われるか
- 系列長の可変性: 各トークンを独立に正規化するため、系列長の違いやパディングの影響を受けない
- バッチサイズへの非依存性: 小さなバッチサイズでも安定した学習が可能
- 推論時の一貫性: 学習時も推論時も同じ計算を行う
- Self-Attentionとの相性: 各トークンの表現を独立に正規化するため、トークン間相互作用を阻害しない
Pre-LN vs Post-LN
Post-LN(原論文の構成)
$$ \bm{h} = \text{LayerNorm}(\bm{x} + \text{SubLayer}(\bm{x})) $$
Pre-LN
$$ \bm{h} = \bm{x} + \text{SubLayer}(\text{LayerNorm}(\bm{x})) $$
Pre-LNの利点: – 勾配の流れが安定し、学習が容易 – ウォームアップなしでも学習可能な場合がある – GPT-2、GPT-3などの大規模モデルで採用
RMSNorm
最近のLLM(LLaMA、Mistralなど)では、平均の減算を省略した軽量な正規化手法が使われることがあります。
$$ \text{RMSNorm}(\bm{x}) = \frac{\bm{x}}{\sqrt{\frac{1}{D}\sum_{d=1}^{D} x_d^2 + \epsilon}} \odot \bm{\gamma} $$
Pythonでの実装
Layer Normalizationのスクラッチ実装
import numpy as np
import torch
import torch.nn as nn
class LayerNormalization:
"""Layer Normalizationのスクラッチ実装(NumPy)"""
def __init__(self, normalized_shape, eps=1e-5):
self.normalized_shape = normalized_shape
self.eps = eps
self.gamma = np.ones(normalized_shape)
self.beta = np.zeros(normalized_shape)
def forward(self, x):
self.mean = np.mean(x, axis=-1, keepdims=True)
self.var = np.var(x, axis=-1, keepdims=True)
self.x_norm = (x - self.mean) / np.sqrt(self.var + self.eps)
out = self.gamma * self.x_norm + self.beta
return out
def __call__(self, x):
return self.forward(x)
# 動作確認
np.random.seed(42)
batch_size = 4
seq_len = 10
d_model = 64
x = np.random.randn(batch_size, seq_len, d_model)
print(f"入力形状: {x.shape}")
print(f"入力の統計量(最初のトークン):")
print(f" 平均: {x[0, 0].mean():.4f}, 分散: {x[0, 0].var():.4f}")
ln = LayerNormalization(d_model)
y = ln(x)
print(f"\n出力形状: {y.shape}")
print(f"出力の統計量(最初のトークン):")
print(f" 平均: {y[0, 0].mean():.4f}, 分散: {y[0, 0].var():.4f}")
PyTorchによる実装と比較
import torch
import torch.nn as nn
class LayerNormPyTorch(nn.Module):
"""Layer Normalizationのスクラッチ実装(PyTorch)"""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps
self.gamma = nn.Parameter(torch.ones(normalized_shape))
self.beta = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# PyTorch公式実装との比較
torch.manual_seed(42)
x = torch.randn(4, 10, 64)
my_ln = LayerNormPyTorch(64)
official_ln = nn.LayerNorm(64)
my_output = my_ln(x)
official_output = official_ln(x)
diff = (my_output - official_output).abs().max().item()
print(f"自作実装とPyTorch公式実装の最大差: {diff:.2e}")
RMSNormの実装
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
まとめ
本記事では、Layer Normalizationの仕組みとBatch Normalizationとの違いについて解説しました。
- 正規化の方向の違い: Batch Normalizationはバッチ方向、Layer Normalizationは特徴量方向で正規化を行う
- Layer Normalizationの計算: 各サンプル内で平均と分散を計算し、正規化後に学習可能なスケール・シフトパラメータを適用する
- TransformerでLNが使われる理由: 系列長の可変性、バッチサイズ非依存性、推論時の一貫性、Self-Attentionとの相性
- Pre-LN vs Post-LN: Pre-LNはLayerNormを先に適用し、学習が安定する。大規模LLMではPre-LNが主流
- RMSNorm: 平均の減算を省略した軽量な正規化手法で、最近のLLMで採用されている
次のステップとして、以下の記事も参考にしてください。