残差接続とTransformerの学習安定性 — スキップ接続が深層モデルを支える仕組み

なぜ100層のTransformerは学習できるのでしょうか?

ニューラルネットワークの層を深くすればするほど、表現力は高まるはずです。しかし実際には、素朴に層を積み重ねると10層程度でも学習が破綻してしまうことがあります。勾配が層を逆伝播するたびに指数的に減衰(あるいは爆発)し、パラメータが全く更新されなくなるのです。

この問題を劇的に解決したのが残差接続(Residual Connection)です。2015年にHe et al.がResNetで導入したこの技術は、「入力をそのまま出力に足し合わせる」という驚くほど単純なアイデアでありながら、100層を超えるネットワークの学習を可能にしました。そしてこの技術はTransformerの中核的な構成要素として受け継がれ、GPT-4やLLaMAのような数十〜数百層にもなる大規模言語モデルの学習を支えています。

残差接続を理解すると、以下のことが見えてきます。

  • Transformerの設計原理: なぜ各サブレイヤーの入出力次元が $d_{\text{model}}$ で統一されているのか
  • 大規模言語モデルの学習安定性: GPT-2以降でPre-Norm構造が採用された理由
  • 深層ネットワーク全般の設計: ResNet、DenseNet、U-Netなどに共通する「ショートカット」の数学的意味

本記事の内容

  • 深層ネットワークにおける勾配消失・勾配爆発の問題
  • 残差接続の基本アイデアとResNetからの系譜
  • Transformerにおける残差接続の役割
  • Post-Norm vs Pre-Normの比較と数学的解析
  • 勾配の流れの数学的解析(残差接続あり/なし)
  • PyTorchでの実装と実験(3パターン比較)

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

画像なし
ResNetとスキップ接続
残差接続の元祖であるResNetのアーキテクチャと、スキップ接続が深層学習に与えた影響を解説
画像なし
Layer Normalization
Transformerで使われるLayer Normalizationの理論と実装を解説
画像なし
Transformerのアーキテクチャ
Transformerの全体構造と各コンポーネントの役割を解説

深層ネットワークの問題 — 勾配消失と勾配爆発

なぜ「深いネットワーク」が欲しいのか

ニューラルネットワークの理論では、2層(入力層 + 隠れ層 + 出力層)のネットワークで任意の連続関数を近似できる(万能近似定理)ことが知られています。しかし、2層で近似するには隠れ層のユニット数が指数的に必要になる場合があります。一方、層を深くすると、各層で少しずつ抽象的な特徴を抽出し、少ないパラメータ数でも複雑な関数を表現できます。

イメージとしては、人間が長い文章を理解するプロセスに似ています。文字列をいきなり意味に変換するのではなく、「文字 → 単語 → 句 → 文 → 段落 → 文書全体の意味」と段階的に抽象度を上げていきます。深いネットワークも同様に、低レベルの特徴から高レベルの概念まで段階的に構築していくのです。

したがって、層を深くすることは理論的にも実用的にも強い動機があります。しかし、ここで深刻な壁に突き当たります。

連鎖律と勾配の伝播

ニューラルネットワークの学習では、損失関数 $\mathcal{L}$ の各パラメータに対する勾配を誤差逆伝播法(Backpropagation)で計算します。層が $L$ 個ある場合、入力に近い層のパラメータに対する勾配は、連鎖律(Chain Rule)によって計算されます。

$L$ 層のネットワークで、第 $l$ 層の出力を $\bm{y}_l$ 、第 $l$ 層の変換を $f_l$ とすると、各層は次のように計算されます。

$$ \bm{y}_l = f_l(\bm{y}_{l-1}) $$

つまり、各層が前の層の出力に変換を適用し、それを次の層に渡します。損失 $\mathcal{L}$ の第0層の入力 $\bm{y}_0$ に対する勾配は、連鎖律から次のように展開されます。

$$ \frac{\partial \mathcal{L}}{\partial \bm{y}_0} = \frac{\partial \mathcal{L}}{\partial \bm{y}_L} \cdot \frac{\partial \bm{y}_L}{\partial \bm{y}_{L-1}} \cdot \frac{\partial \bm{y}_{L-1}}{\partial \bm{y}_{L-2}} \cdots \frac{\partial \bm{y}_1}{\partial \bm{y}_0} $$

右辺をヤコビ行列の積として書くと、次のようになります。

$$ \frac{\partial \mathcal{L}}{\partial \bm{y}_0} = \frac{\partial \mathcal{L}}{\partial \bm{y}_L} \prod_{l=0}^{L-1} \bm{J}_l $$

ここで $\bm{J}_l = \frac{\partial \bm{y}_{l+1}}{\partial \bm{y}_l}$ は第 $l$ 層のヤコビ行列です。

指数的な減衰と増大

問題はこのヤコビ行列の積にあります。仮に各層のヤコビ行列のスペクトル半径(最大固有値の絶対値)が $\gamma$ であるとすると、積のスペクトル半径はおおよそ $\gamma^L$ のスケールになります。

  • $\gamma < 1$ の場合: $\gamma^L \to 0$ — 勾配消失(Vanishing Gradient)
  • $\gamma > 1$ の場合: $\gamma^L \to \infty$ — 勾配爆発(Exploding Gradient)
  • $\gamma = 1$ の場合のみ安定 — しかし実際にこれを維持するのは非常に困難

たとえば $\gamma = 0.9$ で $L = 50$ とすると、$0.9^{50} \approx 0.0052$ です。最終層の勾配の0.5%しか入力層に届かないことになります。$\gamma = 0.8$ なら $0.8^{50} \approx 1.4 \times 10^{-5}$ であり、実質的にゼロです。

具体的な数値で見てみましょう。活性化関数にシグモイド関数 $\sigma(x) = \frac{1}{1+e^{-x}}$ を使う場合、その導関数の最大値は $\sigma'(x) \leq 0.25$ です。重みの影響を考慮しても、各層でヤコビ行列のノルムが1未満になりやすく、50層を超えると勾配はほぼ消滅します。

ReLU関数 $\text{ReLU}(x) = \max(0, x)$ はこの問題を緩和しましたが、完全には解決できません。正の領域では導関数が1ですが、重み行列との積を考えると、やはり層が深くなるほど勾配は不安定になります。

劣化問題(Degradation Problem)

勾配消失は「学習が遅い」だけでなく、もっと奇妙な現象を引き起こします。He et al. (2015) は、深いネットワークは浅いネットワークよりも訓練誤差が大きくなることを実験で示しました。

理論的には、20層のネットワークは56層のネットワークの部分集合です(追加の36層を恒等写像にすれば良い)。したがって、56層のネットワークの訓練誤差は20層以下になるはずです。しかし実際にはそうならない — これは最適化の問題であり、SGDが恒等写像に近い解を見つけられないことを意味しています。

この劣化問題は、単にネットワークを深くしても性能が向上しないどころか悪化することを示しており、深層学習の発展にとって大きな壁でした。

では、この壁をどう乗り越えるのでしょうか。アイデアは驚くほどシンプルで、「恒等写像をショートカットとして明示的にネットワークに組み込む」というものです。

残差接続の基本アイデア

恒等写像のバイパス

残差接続の発想は、日常的なアナロジーで理解できます。

高速道路の迂回路を想像してください。メインの道路($F(\bm{x})$)が渋滞していても、バイパス($\bm{x}$)を通れば確実に目的地に着けます。残差接続はこの「バイパス」に相当し、情報と勾配が確実に流れるルートを保証します。

もう少し技術的に言えば、通常のネットワーク層は入力 $\bm{x}$ に対して出力 $\bm{y} = F(\bm{x})$ を直接学習します。一方、残差接続付きの層は次のように定義されます。

$$ \begin{equation} \bm{y} = F(\bm{x}) + \bm{x} \end{equation} $$

ここで $F(\bm{x})$ は残差関数(Residual Function)と呼ばれます。ネットワークは出力 $\bm{y}$ を直接学習するのではなく、入力からの「差分」 $F(\bm{x}) = \bm{y} – \bm{x}$ を学習するのです。

なぜ「差分の学習」が有利なのか

直感的に考えてみましょう。もし最適な変換がほぼ恒等写像に近いなら(つまり $\bm{y} \approx \bm{x}$)、通常のネットワークでは $F(\bm{x}) \approx \bm{x}$ という非自明な関数を学習する必要があります。しかし残差接続があれば、$F(\bm{x}) \approx \bm{0}$ を学習するだけで済みます。重みをゼロに近づけるだけなので、はるかに容易です。

実際、深い層ほど「ほとんど何もしない」ことが最適解に近い場合が多く、残差接続はこの性質と非常に相性が良いのです。ResNetの論文では、残差ブロックの重みが実際にゼロ付近に集中していることが確認されています。

勾配が直接流れるパスの確保

残差接続の最も重要な効果は、勾配の流れに対するものです。残差接続付きの層 $\bm{y} = F(\bm{x}) + \bm{x}$ の勾配を計算すると、次のようになります。

$$ \frac{\partial \bm{y}}{\partial \bm{x}} = \frac{\partial F(\bm{x})}{\partial \bm{x}} + \bm{I} $$

$\bm{I}$ は恒等行列です。$F$ の勾配がどれだけ小さくなっても、$\bm{I}$ の項が常に存在するため、勾配が完全にゼロにはなりません。

これは先ほどの高速道路のアナロジーに完全に対応します。メインの道路($\frac{\partial F}{\partial \bm{x}}$)がどんなに渋滞(勾配消失)していても、バイパス($\bm{I}$)が常に通行可能なのです。

複数の残差ブロックを積み重ねた場合の効果は、次のセクションで数学的に詳しく解析します。まずは、Transformerでこの残差接続がどのように使われているかを見ていきましょう。

Transformerにおける残差接続

Transformerの各サブレイヤーと残差接続

Transformerのエンコーダブロックは、主に2つのサブレイヤーで構成されています。

  1. Multi-Head Attention(MHA): 入力系列の各位置が、他の全位置との関係を計算する
  2. Feed-Forward Network(FFN): 各位置に独立に適用される2層の全結合ネットワーク

そして、各サブレイヤーには残差接続が適用されます。原論文(Vaswani et al., 2017)の記法では、次のようになります。

$$ \begin{align} \bm{z} &= \text{LayerNorm}(\bm{x} + \text{MHA}(\bm{x})) \\ \bm{y} &= \text{LayerNorm}(\bm{z} + \text{FFN}(\bm{z})) \end{align} $$

1つ目の式では、入力 $\bm{x}$ がMulti-Head Attentionの出力に足し合わされ、その結果にLayer Normalizationが適用されます。2つ目も同様に、$\bm{z}$ がFFNの出力に足し合わされます。

これにより、入力信号は「Attentionを通るルート」と「そのまま通り抜けるルート」の2つの経路を持ちます。仮にAttentionの学習が不十分であっても、入力がそのまま次の層に伝わるため、ネットワーク全体の学習が破綻することはありません。

次元の一致 — $d_{\text{model}}$ の統一

残差接続 $\bm{y} = F(\bm{x}) + \bm{x}$ が成立するためには、$F(\bm{x})$ と $\bm{x}$ のベクトル次元が一致しなければなりません。これがTransformerで全サブレイヤーの入出力次元が $d_{\text{model}}$ に統一されている理由です。

具体的には、以下の設計上の制約が生まれます。

  • Multi-Head Attention: 入力 $\bm{x} \in \mathbb{R}^{n \times d_{\text{model}}}$ に対して、出力も $\mathbb{R}^{n \times d_{\text{model}}}$ でなければなりません。各ヘッドの次元 $d_k = d_v = d_{\text{model}} / h$($h$ はヘッド数)とし、最後の線形射影 $\bm{W}_O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$ で元の次元に戻します。
  • FFN: 内部の隠れ層は $d_{\text{ff}}$(通常 $4d_{\text{model}}$)次元に拡張しますが、出力層で必ず $d_{\text{model}}$ 次元に射影し直します。

$$ \text{FFN}(\bm{x}) = \bm{W}_2 \, \text{ReLU}(\bm{W}_1 \bm{x} + \bm{b}_1) + \bm{b}_2 $$

ここで $\bm{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$、$\bm{W}_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}$ です。

この設計は一見すると制約に思えますが、残差接続を可能にすることで得られる学習安定性の恩恵は絶大です。Transformerの「全てを $d_{\text{model}}$ 次元に揃える」という設計思想は、残差接続から自然に導かれるものなのです。

デコーダにおける残差接続

デコーダブロックは3つのサブレイヤーを持ちますが、全てに残差接続が適用されます。

$$ \begin{align} \bm{z}_1 &= \text{LayerNorm}(\bm{x} + \text{MaskedMHA}(\bm{x})) \\ \bm{z}_2 &= \text{LayerNorm}(\bm{z}_1 + \text{CrossMHA}(\bm{z}_1, \bm{m})) \\ \bm{y} &= \text{LayerNorm}(\bm{z}_2 + \text{FFN}(\bm{z}_2)) \end{align} $$

ここで $\bm{m}$ はエンコーダの出力です。Masked Multi-Head Attention、Cross Attention、FFNの全てに残差接続があることで、デコーダ内でも勾配が安定して流れます。

Transformerにおける残差接続の配置がわかったところで、次に重要な問題に取り組みましょう。Layer Normalizationを残差接続のに置くかに置くかで、学習の安定性が大きく変わるのです。

Post-Norm vs Pre-Norm

2つの構造の定義

残差接続とLayer Normalizationの組み合わせ方には、主に2つのバリエーションがあります。

Post-Norm(原論文の構造):

$$ \bm{y} = \text{LayerNorm}(\bm{x} + F(\bm{x})) $$

Layer Normalizationが残差接続の(Post)に適用されます。

Pre-Norm(GPT-2以降の主流構造):

$$ \bm{y} = \bm{x} + F(\text{LayerNorm}(\bm{x})) $$

Layer Normalizationが残差接続の(Pre)に適用されます。

一見すると些細な違いに見えますが、この順序の違いが学習の安定性に決定的な影響を与えます。その理由を、勾配の流れの観点から理解しましょう。

Post-Normにおける勾配の流れ

Post-Norm構造で $L$ ブロックを積み重ねたとき、第 $l$ ブロックの出力は次のように再帰的に定義されます。

$$ \bm{y}_l = \text{LN}(\bm{y}_{l-1} + F_l(\bm{y}_{l-1})) $$

ここで $\text{LN}$ はLayer Normalizationの略記です。この構造で最終出力 $\bm{y}_L$ の入力 $\bm{y}_0$ に対する勾配を考えると、各層のLayer Normalizationのヤコビ行列が積として掛かることになります。

第 $l$ 層の勾配は、連鎖律を適用すると次のように展開されます。

$$ \frac{\partial \bm{y}_l}{\partial \bm{y}_{l-1}} = \bm{J}_{\text{LN},l} \left(\bm{I} + \frac{\partial F_l}{\partial \bm{y}_{l-1}}\right) $$

ここで $\bm{J}_{\text{LN},l}$ はLayer Normalizationのヤコビ行列です。

$L$ 層を通した全体の勾配は、この積になります。

$$ \frac{\partial \bm{y}_L}{\partial \bm{y}_0} = \prod_{l=1}^{L} \bm{J}_{\text{LN},l} \left(\bm{I} + \frac{\partial F_l}{\partial \bm{y}_{l-1}}\right) $$

注目すべきは、Layer Normalizationのヤコビ行列 $\bm{J}_{\text{LN},l}$ が全ての層にわたって積として掛かることです。Layer Normalizationは入力の平均を引いて標準偏差で割る操作なので、そのヤコビ行列は恒等行列とは異なります。この積が深い層で不安定になりうるのです。

Pre-Normにおける勾配の流れ

Pre-Norm構造では、各ブロックの出力は次のように定義されます。

$$ \bm{y}_l = \bm{y}_{l-1} + F_l(\text{LN}(\bm{y}_{l-1})) $$

この構造のポイントは、Layer Normalizationが $F_l$ の内部に組み込まれ、残差接続のメインパスにはLayer Normalizationが介在しないことです。

$L$ 層を通した全体の勾配を展開しましょう。まず $\bm{y}_L$ を再帰的に展開すると、次のようになります。

$$ \bm{y}_L = \bm{y}_0 + \sum_{l=1}^{L} F_l(\text{LN}(\bm{y}_{l-1})) $$

この式の右辺を $\bm{y}_0$ で微分すると、次のように得られます。

$$ \frac{\partial \bm{y}_L}{\partial \bm{y}_0} = \bm{I} + \sum_{l=1}^{L} \frac{\partial F_l(\text{LN}(\bm{y}_{l-1}))}{\partial \bm{y}_0} $$

ここで最も重要なのは、恒等行列 $\bm{I}$ が他の項と独立に存在することです。Post-Normでは恒等行列がLayer Normalizationのヤコビ行列と積になって「吸収」されてしまいますが、Pre-Normでは恒等行列がクリーンな形で残ります。

Pre-Normが学習安定性に優れる理由

上の解析から、Pre-Normの優位性は明確です。

  1. 勾配のベースライン: Pre-Normでは $\frac{\partial \bm{y}_L}{\partial \bm{y}_0}$ に常に $\bm{I}$ が含まれるため、$F_l$ の勾配が小さくても、勾配が完全にゼロにはなりません。
  2. Layer Normalizationの非干渉: Post-Normでは $L$ 個のLayer Normalizationのヤコビ行列が積として蓄積しますが、Pre-Normではメインパスにこの干渉がありません。
  3. 勾配の分散の安定性: Pre-Normでは各ブロックの勾配の寄与が和として足し合わさるため、特定の層で勾配が消失しても他の層が補償できます。

Warmup学習率スケジューリングとの関係

Post-Norm構造のTransformerを学習する際には、学習率Warmup(最初の数千ステップで学習率を徐々に上げる)が必須であることが知られています。これは、学習初期にPost-Normの勾配が不安定であるためです。

Pre-Norm構造では、勾配が初期から安定しているため、Warmupなしでも学習が収束することが多くの研究で報告されています(Xiong et al., 2020)。ただし、実用上はPre-Normでも適度なWarmupを使うことが多いです。

一方で、Post-Normにも利点があります。学習が安定しさえすれば、Post-Normの方が最終的な性能が高いという報告もあります。これは、Layer Normalizationがサブレイヤーの出力を正規化することで、表現力の観点では有利に働く場合があるためです。

Post-NormとPre-Normの違いが直感的に理解できたところで、次に残差接続がもたらす勾配の流れの改善をより厳密に数学的に解析しましょう。

勾配の流れの数学的解析

残差接続なしの場合

まず、残差接続のない $L$ 層ネットワークを考えます。

$$ \bm{y}_l = F_l(\bm{y}_{l-1}), \quad l = 1, 2, \ldots, L $$

最終出力 $\bm{y}_L$ の入力 $\bm{y}_0$ に対するヤコビ行列は、連鎖律から次のように得られます。

$$ \frac{\partial \bm{y}_L}{\partial \bm{y}_0} = \prod_{l=1}^{L} \frac{\partial F_l}{\partial \bm{y}_{l-1}} = \prod_{l=1}^{L} \bm{J}_l $$

各ヤコビ行列 $\bm{J}_l$ のスペクトルノルム(最大特異値)を $\sigma_l$ とすると、全体のヤコビ行列のスペクトルノルムは次のように上界が与えられます。

$$ \left\| \frac{\partial \bm{y}_L}{\partial \bm{y}_0} \right\| \leq \prod_{l=1}^{L} \sigma_l $$

もし全ての $\sigma_l < 1$ なら、この積は $L$ の増大とともに指数的にゼロに近づきます。逆に全ての $\sigma_l > 1$ なら指数的に増大します。これが勾配消失と勾配爆発の数学的な正体です。

残差接続ありの場合

残差接続を導入すると、各層は次のようになります。

$$ \bm{y}_l = \bm{y}_{l-1} + F_l(\bm{y}_{l-1}), \quad l = 1, 2, \ldots, L $$

この再帰式を展開すると、最終出力は入力と全ての残差項の和として書けます。

$$ \bm{y}_L = \bm{y}_0 + \sum_{l=1}^{L} F_l(\bm{y}_{l-1}) $$

この式の $\bm{y}_0$ に対する勾配を計算しましょう。$\bm{y}_0$ は直接第1項に現れるだけでなく、$\bm{y}_1, \bm{y}_2, \ldots$ を通じて間接的にも $F_l$ に影響を与えます。

各層のヤコビ行列は、残差接続によって次のようになります。

$$ \frac{\partial \bm{y}_l}{\partial \bm{y}_{l-1}} = \bm{I} + \frac{\partial F_l}{\partial \bm{y}_{l-1}} $$

$L$ 層全体のヤコビ行列は次の積になります。

$$ \frac{\partial \bm{y}_L}{\partial \bm{y}_0} = \prod_{l=1}^{L} \left(\bm{I} + \frac{\partial F_l}{\partial \bm{y}_{l-1}}\right) $$

この積を展開すると、次のような形になります。

$$ \frac{\partial \bm{y}_L}{\partial \bm{y}_0} = \bm{I} + \sum_{l=1}^{L} \frac{\partial F_l}{\partial \bm{y}_{l-1}} + \sum_{l_1 < l_2} \frac{\partial F_{l_2}}{\partial \bm{y}_{l_2-1}} \frac{\partial F_{l_1}}{\partial \bm{y}_{l_1-1}} + \cdots $$

右辺の第1項が恒等行列 $\bm{I}$ であることが決定的に重要です。

恒等行列の項が保証するもの

展開式の第1項 $\bm{I}$ は、以下の2つのことを保証します。

1. 勾配の下限の保証

恒等行列は常にノルム1を持ちます。したがって、他の項 $\frac{\partial F_l}{\partial \bm{y}_{l-1}}$ がどれだけ小さくても(たとえゼロでも)、全体のヤコビ行列のノルムは少なくとも1を下回りにくくなります。厳密には、交差項が負の方向に寄与する可能性もありますが、ランダムに初期化されたネットワークでは、これらの項が系統的に打ち消し合う確率は低いです。

2. 任意の2層間の直接的なパス

展開式を見ると、$\bm{I}$ は「全ての $F_l$ をバイパスするパス」に対応します。同様に、$\frac{\partial F_l}{\partial \bm{y}_{l-1}}$ の項は「第 $l$ 層のみを通るパス」に対応します。つまり、残差接続は $2^L$ 個の異なるパスを作り出し、勾配はそのうちの短いパス(少数の $F_l$ のみを通るパス)を通じて効率的に流れることができます。

He et al. (2016) は、この $2^L$ パスの解釈をさらに発展させ、ResNetは多くの浅いネットワークのアンサンブルとして機能していると論じました。各パスは異なる深さのサブネットワークに対応し、実効的には浅いネットワークが支配的に寄与するため、深さに伴う勾配の問題が緩和されるのです。

残差接続なしとの比較のまとめ

性質 残差接続なし 残差接続あり
ヤコビ行列 $\prod_{l} \bm{J}_l$ $\prod_{l} (\bm{I} + \bm{J}_l)$
展開形 単一の積のみ $\bm{I}$ + 和 + 交差項
勾配消失 $L$ に対して指数的に減衰 $\bm{I}$ が下限を提供
勾配爆発 $L$ に対して指数的に増大 依然として起こりうるが緩和
情報パス 1本のみ $2^L$ 本

数学的な解析で残差接続の効果が明確になりました。それでは、この理論をPyTorchで実装し、実際に学習曲線と勾配の挙動を比較してみましょう。

PyTorchでの実装と実験

実験の設計

ここでは、残差接続の効果を実験的に検証するため、以下の3つの構造を比較します。

  1. 残差接続なし(Plain): $\bm{y}_l = F_l(\bm{y}_{l-1})$
  2. Post-Norm: $\bm{y}_l = \text{LN}(\bm{y}_{l-1} + F_l(\bm{y}_{l-1}))$
  3. Pre-Norm: $\bm{y}_l = \bm{y}_{l-1} + F_l(\text{LN}(\bm{y}_{l-1}))$

単純な系列分類タスク(合成データ)に対して、24層と48層のネットワークで学習曲線を比較します。

3つのブロック構造の実装

まず、3つの構造のTransformer風ブロックをPyTorchで実装します。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# ===== 3つのブロック構造の定義 =====

class PlainBlock(nn.Module):
    """残差接続なしのブロック"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.ffn(x)


class PostNormBlock(nn.Module):
    """Post-Norm: LayerNorm(x + F(x))"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self.ffn(x))


class PreNormBlock(nn.Module):
    """Pre-Norm: x + F(LayerNorm(x))"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return x + self.ffn(self.norm(x))

各ブロックの構造は非常にシンプルです。PlainBlockは入力をFFNに通すだけ、PostNormBlockは残差接続の後にLayer Normalizationを適用し、PreNormBlockはLayer Normalizationの後にFFNを通した結果を残差接続で足します。

次に、これらのブロックを積み重ねたネットワーク全体を定義します。

class DeepNetwork(nn.Module):
    """深いネットワーク(分類タスク用)"""
    def __init__(self, d_model, d_ff, n_layers, n_classes, block_type="pre_norm"):
        super().__init__()
        # 入力射影
        self.input_proj = nn.Linear(d_model, d_model)

        # ブロックの選択
        block_cls = {
            "plain": PlainBlock,
            "post_norm": PostNormBlock,
            "pre_norm": PreNormBlock
        }[block_type]

        # ブロックを積み重ねる
        self.blocks = nn.ModuleList([
            block_cls(d_model, d_ff) for _ in range(n_layers)
        ])

        # Pre-Normの場合、最終的なLayerNormを追加
        self.final_norm = nn.LayerNorm(d_model) if block_type == "pre_norm" else nn.Identity()

        # 分類ヘッド
        self.classifier = nn.Linear(d_model, n_classes)

    def forward(self, x):
        x = self.input_proj(x)
        for block in self.blocks:
            x = block(x)
        x = self.final_norm(x)
        # 系列の平均をプーリング
        x = x.mean(dim=1)
        return self.classifier(x)

DeepNetworkクラスは、指定されたブロックタイプを n_layers 層分積み重ね、最後にグローバル平均プーリングと分類ヘッドを適用します。Pre-Normの場合は最終出力の前にLayer Normalizationを追加しています。これはPre-Norm構造では最終ブロックの出力が正規化されていないため、安定した分類を行うために必要です。

合成データの生成と学習

実験用の合成データを生成し、3つの構造で学習を行います。

def generate_data(n_samples, seq_len, d_model, n_classes, seed=42):
    """分類用の合成データを生成"""
    rng = np.random.RandomState(seed)

    X = rng.randn(n_samples, seq_len, d_model).astype(np.float32)
    # クラスラベル: 系列の平均の符号に基づく単純な分類
    means = X.mean(axis=(1, 2))
    # n_classes個のビンに分割
    percentiles = np.linspace(0, 100, n_classes + 1)[1:-1]
    thresholds = np.percentile(means, percentiles)
    y = np.digitize(means, thresholds).astype(np.int64)

    return torch.tensor(X), torch.tensor(y)


def train_model(model, X_train, y_train, n_epochs=100, lr=1e-3):
    """モデルを学習してロス履歴を返す"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    losses = []
    for epoch in range(n_epochs):
        model.train()
        optimizer.zero_grad()
        logits = model(X_train)
        loss = criterion(logits, y_train)
        loss.backward()

        # 勾配クリッピング(爆発対策)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        losses.append(loss.item())

    return losses

generate_data 関数は、ランダムな系列データを生成し、系列全体の平均値に基づいて分類ラベルを割り当てます。train_model 関数はAdamオプティマイザで学習を行い、各エポックの損失を記録します。勾配クリッピングを適用していますが、これは残差接続なしの場合に勾配爆発が起きた際のフェイルセーフです。

学習曲線の比較(24層)

まず24層のネットワークで3構造を比較します。

# 実験パラメータ
d_model = 64
d_ff = 128
n_classes = 4
seq_len = 16
n_samples = 500
n_epochs = 200

# データ生成
X_train, y_train = generate_data(n_samples, seq_len, d_model, n_classes)

# 再現性のためシードを固定
torch.manual_seed(42)

# 24層での比較
n_layers = 24
results_24 = {}

for block_type in ["plain", "post_norm", "pre_norm"]:
    torch.manual_seed(42)
    model = DeepNetwork(d_model, d_ff, n_layers, n_classes, block_type)
    losses = train_model(model, X_train, y_train, n_epochs=n_epochs, lr=1e-3)
    results_24[block_type] = losses

# 学習曲線のプロット
plt.figure(figsize=(10, 6))
labels = {"plain": "残差接続なし (Plain)", "post_norm": "Post-Norm", "pre_norm": "Pre-Norm"}
colors = {"plain": "#e74c3c", "post_norm": "#3498db", "pre_norm": "#2ecc71"}

for block_type, losses in results_24.items():
    plt.plot(losses, label=labels[block_type], color=colors[block_type], linewidth=2)

plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Cross-Entropy Loss", fontsize=12)
plt.title("Training Loss Comparison (24 Layers)", fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.ylim(0, 3.0)
plt.tight_layout()
plt.show()

このグラフからは、3つの構造の学習挙動の違いが明確に見て取れます。Pre-Norm(緑)は最も速やかに損失が減少し、安定した学習曲線を描きます。Post-Norm(青)も学習は進みますが、初期の数十エポックでPre-Normよりも不安定な振動を示すことがあります。残差接続なし(赤)は24層でも学習が非常に困難で、損失がほとんど減少しないか、極めて緩やかにしか減少しません。これは、勾配消失により入力に近い層のパラメータがほとんど更新されないためです。

学習曲線の比較(48層)

より深い48層ではどうなるか見てみましょう。

# 48層での比較
n_layers = 48
results_48 = {}

for block_type in ["plain", "post_norm", "pre_norm"]:
    torch.manual_seed(42)
    model = DeepNetwork(d_model, d_ff, n_layers, n_classes, block_type)
    losses = train_model(model, X_train, y_train, n_epochs=n_epochs, lr=1e-3)
    results_48[block_type] = losses

plt.figure(figsize=(10, 6))
for block_type, losses in results_48.items():
    plt.plot(losses, label=labels[block_type], color=colors[block_type], linewidth=2)

plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Cross-Entropy Loss", fontsize=12)
plt.title("Training Loss Comparison (48 Layers)", fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.ylim(0, 3.0)
plt.tight_layout()
plt.show()

48層になると、構造の違いがさらに顕著になります。Pre-Normは48層でもなお安定した学習が可能であり、24層の場合とほぼ同等の学習曲線を示します。これは恒等行列 $\bm{I}$ がメインパスで保存されているためです。Post-Normは48層では学習初期の不安定性が増し、収束が遅くなることがあります。残差接続なしは48層では完全に学習不能となり、損失がほぼ初期値のまま推移します。

勾配ノルムの可視化

学習曲線の違いの原因を直接確認するため、各層における勾配のノルムを可視化します。

def compute_gradient_norms(model, X, y, block_type):
    """各ブロックの勾配ノルムを計算"""
    model.train()
    criterion = nn.CrossEntropyLoss()

    # フォワードパス
    logits = model(X)
    loss = criterion(logits, y)
    loss.backward()

    # 各ブロックのFFNの第1層の重みの勾配ノルムを記録
    grad_norms = []
    for block in model.blocks:
        if block_type == "plain":
            grad = block.ffn[0].weight.grad
        elif block_type == "post_norm":
            grad = block.ffn[0].weight.grad
        else:  # pre_norm
            grad = block.ffn[0].weight.grad

        if grad is not None:
            grad_norms.append(grad.norm().item())
        else:
            grad_norms.append(0.0)

    return grad_norms


# 24層での勾配ノルム比較
n_layers = 24
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for idx, block_type in enumerate(["plain", "post_norm", "pre_norm"]):
    torch.manual_seed(42)
    model = DeepNetwork(d_model, d_ff, n_layers, n_classes, block_type)
    grad_norms = compute_gradient_norms(model, X_train, y_train, block_type)

    axes[idx].bar(range(len(grad_norms)), grad_norms, color=colors[block_type], alpha=0.8)
    axes[idx].set_xlabel("Layer Index", fontsize=11)
    axes[idx].set_ylabel("Gradient Norm", fontsize=11)
    axes[idx].set_title(labels[block_type], fontsize=12)
    axes[idx].set_yscale("log")
    axes[idx].grid(True, alpha=0.3)

plt.suptitle("Gradient Norms per Layer (24 Layers, Initial State)", fontsize=14)
plt.tight_layout()
plt.show()

この勾配ノルムの可視化から、理論と完全に一致する結果が得られます。残差接続なし(左)では、出力に近い層(右側)の勾配は大きいのに対し、入力に近い層(左側)の勾配は桁違いに小さくなっています。対数スケールで見ると、層のインデックスに対してほぼ線形(つまり指数的な減衰)であることが確認できます。Post-Norm(中央)では勾配のばらつきが残差接続なしよりも大幅に改善されていますが、層間で若干の変動があります。Pre-Norm(右)では全ての層にわたって勾配のノルムがほぼ均一に保たれており、これが安定した学習を可能にする直接的な証拠です。

勾配ノルムの深さ依存性

最後に、ネットワークの深さを変えたときの勾配ノルムの変化を定量的に比較します。

# 深さを変えたときの入力層付近の勾配ノルム
depths = [4, 8, 12, 16, 24, 32, 48]
gradient_at_first_layer = {"plain": [], "post_norm": [], "pre_norm": []}

for n_layers in depths:
    for block_type in ["plain", "post_norm", "pre_norm"]:
        torch.manual_seed(42)
        model = DeepNetwork(d_model, d_ff, n_layers, n_classes, block_type)
        grad_norms = compute_gradient_norms(model, X_train, y_train, block_type)
        # 最初の層の勾配ノルムを記録
        gradient_at_first_layer[block_type].append(grad_norms[0])

plt.figure(figsize=(10, 6))
for block_type in ["plain", "post_norm", "pre_norm"]:
    plt.plot(depths, gradient_at_first_layer[block_type],
             "o-", label=labels[block_type], color=colors[block_type],
             linewidth=2, markersize=8)

plt.xlabel("Number of Layers", fontsize=12)
plt.ylabel("Gradient Norm at First Layer", fontsize=12)
plt.title("Gradient Norm at Input Layer vs Network Depth", fontsize=14)
plt.legend(fontsize=12)
plt.yscale("log")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

このグラフは、ネットワークの深さに対する入力層付近の勾配ノルムの変化を示しています。残差接続なし(赤)では、深さの増加に対して勾配が指数的に減衰していることがはっきりと見て取れます。4層で十分な大きさだった勾配が、48層ではほぼゼロになっています。Post-Norm(青)は残差接続のおかげで減衰が大幅に緩和されていますが、深さが48層を超えると徐々に不安定になり始めます。Pre-Norm(緑)は驚くほど安定しており、4層から48層まで勾配ノルムがほぼ同じ水準に保たれています。この結果は、理論解析で示した「Pre-Normでは $\bm{I}$ がメインパスに保存される」という性質を実験的に裏付けるものです。

実装上の注意点

実験から得られた知見に基づき、実装時の注意点をまとめます。

初期化の重要性: 残差ブロック内のFFNの最終層の重みをゼロに近い値で初期化することで、初期状態でブロックが恒等写像に近くなり、学習の序盤が安定します。GPT-2では、残差ブロックの出力層を $1/\sqrt{N}$($N$ はブロック数)でスケーリングする手法が使われています。

勾配クリッピングとの併用: 残差接続があっても勾配爆発は完全には防げません。実用上は、勾配クリッピング(通常は最大ノルム1.0程度)を併用することが標準的です。

学習率の選択: Pre-Normでは比較的大きな学習率が使えますが、Post-Normでは小さめの初期学習率とWarmupの組み合わせが必要です。

まとめ

本記事では、残差接続がTransformerの学習安定性に果たす役割を、理論と実験の両面から解説しました。

  • 勾配消失・爆発の問題: 深いネットワークではヤコビ行列の積が指数的に減衰/増大し、素朴な構造では10層程度でも学習が困難になります。
  • 残差接続の効果: $\bm{y} = F(\bm{x}) + \bm{x}$ という単純な構造が、勾配のメインパスに恒等行列 $\bm{I}$ を保証し、$2^L$ 本の情報パスを作り出します。
  • Transformerでの活用: 各サブレイヤー(MHA, FFN)に残差接続を適用し、$d_{\text{model}}$ の次元統一によって加算を可能にしています。
  • Pre-Norm vs Post-Norm: Pre-Norm構造は恒等行列をメインパスにクリーンに保存するため、学習が安定し、Warmupなしでも収束可能です。Post-Normは条件付きで高い最終性能を達成しうるものの、学習の安定化にはWarmupが必須です。
  • 実験による検証: 24層・48層のネットワークで、残差接続なし/Post-Norm/Pre-Normの学習曲線と勾配ノルムを比較し、理論的予測と一致する結果を確認しました。

残差接続は現代のTransformerベースモデルの基盤であり、GPT、BERT、LLaMA、Vision Transformerなど、あらゆるモデルでこの技術が使われています。残差接続の原理を理解することは、これらのモデルの設計思想を深く理解するための重要なステップです。

次のステップとして、以下の記事も参考にしてください。

画像なし
Transformer Encoderの実装
Transformer Encoderの各コンポーネントをPyTorchで実装し、動作を確認します
画像なし
LLaMAのアーキテクチャ
Meta社のLLaMAモデルの設計思想と、Pre-Normの採用をはじめとする実装上の工夫を解説します