TransformerのFeed-Forward Network(FFN)の役割と実装

TransformerアーキテクチャにおいてSelf-Attentionと並んで重要な役割を果たすのが、Feed-Forward Network(FFN)です。Transformerの各ブロックは「Self-Attention層」と「FFN層」の2つのサブレイヤーで構成されており、FFNはモデルの表現力を高めるために不可欠な要素です。

本記事では、FFNの数学的な定義から、活性化関数の選択、隠れ層の次元設計、そしてPyTorchでの実装までを解説します。

本記事の内容

  • FFNのTransformerにおける位置づけ
  • 2層全結合ネットワークの数式
  • なぜFFNが必要か(表現力の観点から)
  • 活性化関数の選択(ReLU, GELU, SwiGLU)
  • 隠れ層の次元設計($d_{ff} = 4 \times d_{model}$)
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

TransformerにおけるFFNの位置づけ

Transformerの各Encoder/Decoderブロックは、以下の2つのサブレイヤーで構成されています。

  1. Multi-Head Self-Attention: 系列内の全トークン間の関連度を計算し、情報を集約
  2. Position-wise Feed-Forward Network (FFN): 各位置のベクトルを独立に非線形変換

FFNは Position-wise(位置ごと)に適用されます。つまり、系列中の各トークンに対して、同じ重みパラメータを持つFFNが独立に適用されます。

FFNの数式

基本形(原論文の定義)

原論文 “Attention Is All You Need” では、FFNは以下のように定義されています。

$$ \begin{equation} \text{FFN}(\bm{x}) = \max(0, \bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2 \end{equation} $$

より一般的に書くと、

$$ \begin{equation} \text{FFN}(\bm{x}) = \sigma(\bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2 \end{equation} $$

パラメータの次元

パラメータ 次元 説明
$\bm{x}$ $\mathbb{R}^{d_{model}}$ 入力ベクトル
$\bm{W}_1$ $\mathbb{R}^{d_{model} \times d_{ff}}$ 第1層の重み行列
$\bm{W}_2$ $\mathbb{R}^{d_{ff} \times d_{model}}$ 第2層の重み行列
出力 $\mathbb{R}^{d_{model}}$ 出力ベクトル

原論文では $d_{model} = 512$、$d_{ff} = 2048$(4倍)が使用されています。

なぜFFNが必要か

Self-Attentionの限界

Self-Attentionは強力なメカニズムですが、出力は入力Valueベクトルの加重和(線形結合)であり、入力を非線形に変換する能力が限られています

FFNの役割

  1. 非線形変換の導入: 活性化関数により、モデルに非線形性が加わる
  2. 高次元空間での特徴抽出: $d_{model} \to d_{ff} \to d_{model}$ という「拡張-縮小」構造
  3. 位置ごとの独立した変換: Self-Attentionが「どの情報を集めるか」を決め、FFNが「集めた情報をどう処理するか」を決める

活性化関数の選択

ReLU

原論文で使用されたのがReLUです。

$$ \text{ReLU}(x) = \max(0, x) $$

GELU

BERTやGPTで採用され、現在最も広く使われているのがGELUです。

$$ \text{GELU}(x) = x \cdot \Phi(x) \approx 0.5x\left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right)\right] $$

GELUは「入力 $x$ を確率的にゲーティングする」と解釈できます。

SwiGLU

LLaMAやPaLMで採用されたのがSwiGLUです。

$$ \text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} $$

$$ \text{FFN}_{\text{SwiGLU}}(\bm{x}) = (\text{Swish}(\bm{x}\bm{W}_1) \odot \bm{x}\bm{W}_3)\bm{W}_2 $$

隠れ層の次元設計

なぜ $d_{ff} = 4 \times d_{model}$ なのか

  1. 表現力と計算コストのトレードオフ: $d_{ff}$ を大きくするほど表現力は上がるが、計算量も増加
  2. 経験的な知見: 2〜4倍の範囲が一般的に有効
  3. Self-Attentionとのバランス: FFNがAttention層の約2倍のパラメータを持つ設計

PyTorchでの実装

基本的なFFN

import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    """Position-wise Feed-Forward Network(ReLU版)"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class FeedForwardGELU(nn.Module):
    """Position-wise Feed-Forward Network(GELU版)"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.gelu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class SwiGLU(nn.Module):
    """SwiGLU活性化関数"""
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return x * F.silu(gate)


class FeedForwardSwiGLU(nn.Module):
    """Position-wise Feed-Forward Network(SwiGLU版)"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff * 2, bias=False)
        self.linear2 = nn.Linear(d_ff, d_model, bias=False)
        self.swiglu = SwiGLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.swiglu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


# 動作確認
d_model = 512
d_ff = 2048
batch_size = 2
seq_len = 10

x = torch.randn(batch_size, seq_len, d_model)

ffn_relu = FeedForward(d_model, d_ff)
ffn_gelu = FeedForwardGELU(d_model, d_ff)
ffn_swiglu = FeedForwardSwiGLU(d_model, d_ff)

print(f"入力形状: {x.shape}")
print(f"ReLU FFN 出力形状: {ffn_relu(x).shape}")
print(f"GELU FFN 出力形状: {ffn_gelu(x).shape}")
print(f"SwiGLU FFN 出力形状: {ffn_swiglu(x).shape}")

まとめ

本記事では、TransformerにおけるFeed-Forward Network(FFN)について解説しました。

  • 位置づけ: FFNはSelf-Attentionと並ぶ重要なサブレイヤーで、各位置に独立に適用される
  • 数式: 2層の全結合ネットワーク $\text{FFN}(\bm{x}) = \sigma(\bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2$
  • 役割: 非線形変換の導入、高次元空間での特徴抽出、位置ごとの独立した変換
  • 活性化関数: ReLU(原論文)、GELU(BERT/GPT)、SwiGLU(LLaMA/PaLM)が主流
  • 次元設計: $d_{ff} = 4 \times d_{model}$ が標準。表現力と計算コストのバランス

次のステップとして、以下の記事も参考にしてください。