Transformerをゼロから実装する — PyTorchで全コンポーネントを構築

Transformerのコードを読んで「何となく分かった気になっている」状態にとどまっていませんか? 論文を読み、解説記事をいくつも読んだけれど、いざ白紙の状態からTransformerを書けと言われると手が止まる — これは多くの学習者が経験する壁です。

この壁を越える最も確実な方法は、全てのコンポーネントを自分の手で一行ずつ書くことです。nn.TransformerEncoder のような高レベルAPIに頼らず、Scaled Dot-Product Attentionの行列演算から始めて、Multi-Head Attention、Feed-Forward Network、Positional Encoding、そしてEncoder・Decoderの積み上げまで、すべてを自分で構築します。この過程で、各パーツの「なぜこうなっているのか」が体感として理解できます。

Transformerをゼロから実装できる力は、以下のような場面で直接的に役立ちます。

  • 論文の再現実装: 新しいAttention機構や改良型Transformerが提案されたとき、論文のコードを読み解き、自分で再現するための土台になります
  • モデルのカスタマイズ: 既存のTransformerを自分のタスクに合わせて改造するとき、各コンポーネントの役割と接続関係を正確に把握していなければ、適切な変更ができません
  • デバッグ力: 学習がうまくいかないとき、「どの層で勾配が消失しているのか」「Attentionの重みが一様になっていないか」を切り分ける力がつきます

本記事の内容

  • Transformerの全体設計と実装の順序
  • Scaled Dot-Product Attention の数式とPyTorch実装
  • Multi-Head Attention のヘッド分割と実装
  • Position-wise Feed-Forward Network の実装
  • Positional Encoding の sin/cos 方式と可視化
  • Encoder Layer・Decoder Layer の組み立て
  • Transformer全体の統合と合成データでの学習実験

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。各コンポーネントの理論的な背景を個別に扱っているので、実装中に「なぜこうするのか」が気になったときに立ち返る先として活用してください。

画像なし
Self-Attentionの理論と実装
Query・Key・Valueの線形射影とScaled Dot-Product Attentionの数理を解説します。
画像なし
Multi-Head Attentionの仕組み
複数のAttentionヘッドで異なる部分空間の情報を捉える仕組みを解説します。
画像なし
Positional Encodingの理論
sin/cosによる位置符号化の数式と、なぜこの設計が有効なのかを解説します。
画像なし
Layer Normalizationの理論と実装
Batch Normalizationとの違い、Transformerで使われる理由を解説します。
画像なし
Position-wise Feed-Forward Networkの役割
Transformerにおける2層の全結合ネットワークの役割と設計を解説します。
画像なし
Transformer Encoderの構造
EncoderブロックのSelf-Attention→Add&Norm→FFN→Add&Normの流れを解説します。
画像なし
Transformer Decoderの構造
DecoderブロックのMasked Self-Attention、Cross-Attention、FFNの3段構成を解説します。
画像なし
Transformerアーキテクチャの全体像
Encoder-Decoder構造、位置エンコーディング、マスク付きAttentionなどTransformerの全体像を解説します。

全体設計 — 実装の順序を把握する

コードを書き始める前に、Transformerの全体構造と、各コンポーネントの依存関係を把握しましょう。これにより、「今どこを作っていて、次に何が必要なのか」が常に明確になります。

Transformerは大きく分けて以下の部品で構成されています。

  1. Scaled Dot-Product Attention — 全てのAttentionの基本演算。Query、Key、Valueの3つの行列から、文脈を考慮した出力を計算します
  2. Multi-Head Attention — Scaled Dot-Product Attentionを複数の「ヘッド」で並列に実行し、異なる観点の情報を捉えます
  3. Position-wise Feed-Forward Network (FFN) — 各位置に独立に適用される2層の全結合ネットワーク。Attentionで得た情報を非線形変換します
  4. Positional Encoding — トークンの位置情報をsin/cos関数で符号化し、埋め込みに加算します
  5. Encoder Layer — Multi-Head Attention → Add & Norm → FFN → Add & Norm の1ブロック。これを $N$ 層積み重ねてEncoderを構成します
  6. Decoder Layer — Masked Self-Attention → Add & Norm → Cross-Attention → Add & Norm → FFN → Add & Norm の1ブロック。これを $N$ 層積み重ねてDecoderを構成します
  7. Transformer — Encoder + Decoder + 最終線形層を統合した完全なモデル

依存関係を見ると、上から順に作れば自然に組み上がることがわかります。Scaled Dot-Product AttentionがなければMulti-Head Attentionは作れませんし、Multi-Head AttentionとFFNがなければEncoder Layerは作れません。したがって、ボトムアップで小さい部品から順に作るのが最も効率的です。

また、Encoder LayerとDecoder Layerの中核には同じMulti-Head Attentionが使われています。違いはマスクの有無と、Cross-Attentionの有無だけです。この「同じ部品の再利用」がTransformerの設計の美しさであり、実装もシンプルにする要因です。

実装に使うPyTorchの基本方針は以下のとおりです。

  • 各コンポーネントを nn.Module のサブクラスとして定義する
  • forward メソッドに順伝播のロジックを書く
  • テンソルの形状を明示的にコメントで記述する(デバッグの生命線)

それでは、最小の部品であるScaled Dot-Product Attentionから実装を始めましょう。

Step 1: Scaled Dot-Product Attention

数式の復習

Attentionの計算は、図書館での情報検索に例えると分かりやすいです。あなたが「量子力学の入門書が欲しい」という検索クエリ(Query)を持って図書館に行きます。書棚には本ごとに分類ラベル(Key)が付いており、クエリとラベルの一致度が高い本ほど関連性が高いと判断されます。そして、見つかった本の中身(Value)を、関連度に応じて重み付けして読み取る — これがAttentionの本質です。

数式で表すと、Query行列 $\bm{Q} \in \mathbb{R}^{n \times d_k}$、Key行列 $\bm{K} \in \mathbb{R}^{m \times d_k}$、Value行列 $\bm{V} \in \mathbb{R}^{m \times d_v}$ に対して、Scaled Dot-Product Attentionは次のように定義されます。

$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$

計算を3ステップに分解します。

ステップ1: スコア行列の計算

$$ \bm{S} = \bm{Q}\bm{K}^\top \in \mathbb{R}^{n \times m} $$

$\bm{S}_{ij}$ はQuery $i$ とKey $j$ の内積であり、この値が大きいほど「Query $i$ はKey $j$ に強く注目している」ことを意味します。

ステップ2: スケーリング

$$ \bm{S}_{\text{scaled}} = \frac{\bm{S}}{\sqrt{d_k}} $$

$\sqrt{d_k}$ で割る理由は、内積の値が次元数 $d_k$ に比例して大きくなるためです。$d_k$ が大きいとスコアの絶対値が大きくなり、softmaxの出力が一つの要素に集中してしまいます(勾配消失)。スケーリングにより、softmaxの入力を適切な範囲に収めます。

ステップ3: softmaxとValueの重み付き和

$$ \bm{A} = \text{softmax}(\bm{S}_{\text{scaled}}) \in \mathbb{R}^{n \times m} $$

$$ \text{Output} = \bm{A}\bm{V} \in \mathbb{R}^{n \times d_v} $$

softmaxにより各行が確率分布(合計1)になり、その重みでValueを混合します。

マスクがある場合は、softmaxの前にマスク位置のスコアを $-\infty$(実装上は $-10^9$ などの大きな負の値)に設定します。これにより、softmax後の重みが0になり、未来のトークンを参照できなくなります。

PyTorch実装

3ステップの数式をそのままコードに落とし込みます。テンソルの形状をコメントで追跡することで、次元の不整合を防ぎます。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """Scaled Dot-Product Attention
    Args:
        Q: (batch, heads, seq_len, d_k)
        K: (batch, heads, seq_len, d_k)
        V: (batch, heads, seq_len, d_v)
        mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
    Returns:
        output: (batch, heads, seq_len, d_v)
        attn_weights: (batch, heads, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    # スコア行列: (batch, heads, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # マスク適用(未来のトークンやパディングを無視)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # softmaxで重みを正規化
    attn_weights = F.softmax(scores, dim=-1)
    # 重み付き和
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

この関数は、バッチ次元とヘッド次元を含む4次元テンソルを受け取る設計にしています。これは後でMulti-Head Attentionと組み合わせるときに自然に接続できるようにするためです。mask 引数を None にすればマスクなしのAttention、適切なマスクを渡せばMasked Attentionとして機能します。

簡単な動作確認をしてみましょう。

import torch
import math
import torch.nn.functional as F

# テスト: バッチ2、ヘッド1、系列長4、次元8
batch, heads, seq_len, d_k = 2, 1, 4, 8
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)

# マスクなし
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn_weights = F.softmax(scores, dim=-1)
print("Attention weights shape:", attn_weights.shape)
print("各行の合計(1になるはず):", attn_weights[0, 0].sum(dim=-1))

出力のAttention weightsの形状は (2, 1, 4, 4) となり、各行の合計が1.0になることが確認できます。これはsoftmaxが正しく適用されている証拠です。各行が「そのQueryがどのKeyにどれだけ注目しているか」を表す確率分布になっています。

ここまでで、Attentionの最小単位が完成しました。しかし、1つのAttentionでは「1つの観点」でしか情報を集約できません。たとえば、あるトークンに対して「構文的な関係」と「意味的な関係」を同時に捉えたい場合、1つのAttentionでは不十分です。次のステップでは、この制限を解決するMulti-Head Attentionを実装します。

Step 2: Multi-Head Attention

ヘッド分割の仕組み

Multi-Head Attentionの発想は、「1つの大きなAttentionより、複数の小さなAttentionを並列に走らせた方が豊かな表現を学習できる」というものです。

写真を撮るときのアナロジーで考えてみましょう。1台のカメラで1枚の写真を撮るより、複数のカメラを異なる角度に設置して同時に撮影した方が、被写体の全体像をより正確に捉えられます。Multi-Head Attentionでは、入力を複数の「ヘッド」に分割し、それぞれが異なる部分空間でAttentionを計算します。

具体的には、モデルの次元 $d_{\text{model}}$ をヘッド数 $h$ で分割し、各ヘッドの次元を $d_k = d_v = d_{\text{model}} / h$ とします。

$$ \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\bm{W}_O $$

各ヘッドは次のように計算されます。

$$ \text{head}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q, \bm{K}\bm{W}_i^K, \bm{V}\bm{W}_i^V) $$

ここで $\bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$、$\bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$、$\bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$、$\bm{W}_O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$ です。

実装上のポイント: reshape によるヘッド分割

理論上は各ヘッドごとに別々の重み行列を持ちますが、実装では1つの大きな線形層で射影してから reshape でヘッドに分割する方が効率的です。

たとえば $d_{\text{model}} = 512$、$h = 8$ の場合、入力 $(B, n, 512)$ を1つの線形層 $\bm{W}_Q \in \mathbb{R}^{512 \times 512}$ で射影して $(B, n, 512)$ を得た後、$(B, n, 8, 64)$ に reshape し、$(B, 8, n, 64)$ に転置します。これにより、8つのヘッドがバッチ次元と同様に並列に処理できます。

PyTorch実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_modelはnum_headsで割り切れる必要があります"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q, K, V の線形射影(1つの大きな行列で一括計算)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        # 出力の線形射影
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 線形射影: (batch, seq_len, d_model)
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        # ヘッドに分割: (batch, seq_len, d_model)
        #   → (batch, seq_len, num_heads, d_k)
        #   → (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        # ヘッドを結合: (batch, num_heads, seq_len, d_k)
        #   → (batch, seq_len, num_heads, d_k)
        #   → (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, -1, self.d_model)

        # 出力射影
        output = self.W_o(context)
        return output

実装の核心は viewtranspose の組み合わせです。view(batch_size, -1, self.num_heads, self.d_k) で最後の次元をヘッド数と各ヘッドの次元に分割し、transpose(1, 2) でヘッド次元を前に持ってきます。こうすることで、torch.matmul がバッチ次元とヘッド次元の両方を並列に処理してくれます。.contiguous()transpose 後にメモリレイアウトが不連続になるため、view の前に必要です。

動作確認をしましょう。

import torch
import torch.nn as nn

# Multi-Head Attentionのテスト
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

# ダミー入力: (batch=2, seq_len=10, d_model=512)
x = torch.randn(2, 10, 512)
output = mha(x, x, x)  # Self-Attentionの場合、Q=K=V=x
print("入力形状:", x.shape)
print("出力形状:", output.shape)
print("形状が一致:", x.shape == output.shape)

入力と出力の形状が (2, 10, 512) で一致することが確認できます。これはMulti-Head Attentionの重要な性質で、入力と出力の次元が同じであるため、何層でも積み重ねることができます。この性質があるからこそ、Transformerは6層や12層のEncoder・Decoderを構成できるのです。

Multi-Head Attentionは入力の「どこに注目するか」を学習しますが、各位置の情報を個別に変換する仕組みがまだありません。次のステップでは、この役割を担うFeed-Forward Networkを実装します。

Step 3: Position-wise Feed-Forward Network

2層の全結合ネットワーク

Multi-Head Attentionが「系列全体を見て、どの情報を集めるか」を決める役割だとすると、Feed-Forward Network(FFN)は「集めた情報を各位置で独立に加工する」役割を担います。

料理に例えると、Attentionは食材を集めてくる段階(市場で良い食材を選ぶ)、FFNはそれを調理する段階(切る・焼く・味付けする)です。食材をどう組み合わせるか(Attention)と、個々の食材をどう加工するか(FFN)は別の操作であり、両方が必要です。

FFNの数式は非常にシンプルです。

$$ \text{FFN}(\bm{x}) = \text{Activation}(\bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2 $$

ここで $\bm{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}$、$\bm{W}_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$ です。典型的には $d_{ff} = 4 \times d_{\text{model}}$ とします($d_{\text{model}} = 512$ なら $d_{ff} = 2048$)。

中間層の次元を4倍に拡大してから元に戻す「ボトルネック構造の逆」が特徴的です。一旦高次元空間に射影することで、より豊かな非線形変換が可能になります。

活性化関数について、原論文ではReLUを使用していますが、最近のTransformerではGELU(Gaussian Error Linear Unit)が広く使われています。GELUはReLUと違い、入力が0付近で滑らかに遷移するため、勾配の流れがより安定します。

PyTorch実装

import torch
import torch.nn as nn

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = self.linear1(x)    # → (batch, seq_len, d_ff)
        x = self.activation(x) # GELU活性化
        x = self.dropout(x)
        x = self.linear2(x)    # → (batch, seq_len, d_model)
        return x

「Position-wise」という名前が示すとおり、このFFNは系列の各位置に対して同じパラメータで独立に適用されます。位置1のトークンも位置10のトークンも、同じ $\bm{W}_1, \bm{W}_2$ で変換されます。位置間の情報交換はAttentionの仕事であり、FFNは各位置の表現を個別に豊かにするという明確な役割分担があります。

import torch

# FFNのテスト
ffn = PositionwiseFeedForward(d_model=512, d_ff=2048)
x = torch.randn(2, 10, 512)
output = ffn(x)
print("入力形状:", x.shape)
print("出力形状:", output.shape)
print("パラメータ数:", sum(p.numel() for p in ffn.parameters()))

出力形状は入力と同じ (2, 10, 512) です。パラメータ数は $512 \times 2048 + 2048 + 2048 \times 512 + 512 = 2{,}099{,}712$ 個(約210万)となります。これはMulti-Head Attention(約105万パラメータ)の約2倍であり、Transformer全体のパラメータの約3分の2がFFNに集中していることがわかります。つまり、Transformerの「知識」の大部分はFFNに格納されているのです。

AttentionとFFNという2つの演算ブロックが揃いました。しかし、Transformerには系列の順序情報を扱う仕組みがまだありません。Attentionは本質的に順序に無関心(置換不変)なので、「1番目のトークン」と「10番目のトークン」の区別がつきません。次のステップでは、この問題を解決するPositional Encodingを実装します。

Step 4: Positional Encoding

sin/cosによる位置符号化

RNNは逐次処理を行うため、トークンの順序情報が構造的に組み込まれています。一方、Self-Attentionは全ペアの内積を並列に計算するため、入力の順序を入れ替えても出力(の集合としては)変わりません。「I love you」と「you love I」を区別できないのは困ります。

この問題を解決するために、各位置に固有の「位置ベクトル」を加算します。原論文(Vaswani et al., 2017)では、sin関数とcos関数を使った以下の符号化を提案しています。

$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$

$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$

ここで $pos$ はトークンの位置(0, 1, 2, …)、$i$ は次元のインデックス(0, 1, …, $d_{\text{model}}/2 – 1$)です。

この設計の直感的な意味を説明します。$10000^{2i/d_{\text{model}}}$ は次元ごとに異なる「周期」を生み出します。低い次元($i$ が小さい)では周期が短く、隣接する位置を細かく区別します。高い次元($i$ が大きい)では周期が長く、大まかな位置関係を捉えます。これは、二進数で整数を表現するのに似ています。下位ビットは0と1が頻繁に切り替わり(短い周期)、上位ビットはゆっくり変化します(長い周期)。

もう一つの重要な性質は、$PE_{pos+k}$ を $PE_{pos}$ の線形変換で表せることです。これにより、モデルは相対的な位置関係を容易に学習できます。

PyTorch実装

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 位置エンコーディング行列を事前計算
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 分母の計算: 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数次元
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数次元
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)

        # 学習対象ではないのでbufferとして登録
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

実装上のポイントが2つあります。まず、div_term の計算では $10000^{2i/d_{\text{model}}}$ を直接計算する代わりに、対数空間で計算しています。$10000^{2i/d_{\text{model}}} = \exp(2i \cdot \ln(10000) / d_{\text{model}})$ という関係を使い、数値的により安定な計算を実現しています。次に、register_buffer を使うことで、位置エンコーディング行列を学習パラメータではなくバッファとして登録しています。これにより、model.parameters() には含まれませんが、model.to(device) でGPUに転送されます。

可視化

位置エンコーディングの値がどのようなパターンになるか、可視化して確認しましょう。

import torch
import math
import matplotlib.pyplot as plt
import numpy as np

# 位置エンコーディングの計算
d_model = 128
max_len = 100
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
    torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# ヒートマップ
plt.figure(figsize=(12, 5))
plt.imshow(pe.numpy(), cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
plt.colorbar(label='Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding (sin/cos)')
plt.tight_layout()
plt.savefig('positional_encoding_heatmap.png', dpi=150)
plt.show()

ヒートマップを見ると、左側(低次元)では値が位置に応じて高速に振動し、右側(高次元)に行くほど振動がゆっくりになることが確認できます。これはまさに先ほど説明した「二進数のビット」のアナロジーに対応しています。各位置は異なるパターンの組み合わせで一意に表現されるため、モデルは位置の区別が可能になります。

特定の次元の振る舞いも見てみましょう。

import torch
import math
import matplotlib.pyplot as plt

d_model = 128
max_len = 100
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
    torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

plt.figure(figsize=(12, 4))
for dim in [0, 10, 30, 60]:
    plt.plot(pe[:, dim].numpy(), label=f'dim {dim}')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.title('Positional Encoding for Different Dimensions')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('positional_encoding_dims.png', dpi=150)
plt.show()

グラフから、dim 0(青)は非常に速い周期でsin波が振動しているのに対し、dim 60(赤)はほぼ直線的にゆっくり変化していることがわかります。この多スケールの周期構造により、モデルは「隣接する位置」と「離れた位置」の両方を区別できるようになっています。dim 0の情報だけでは離れた位置のsin値が周期的に繰り返して区別しにくくなりますが、dim 60の長い周期がそれを補完します。

ここまでで、Transformerの4つの基本部品(Attention、Multi-Head Attention、FFN、Positional Encoding)が全て揃いました。次のステップでは、これらを組み合わせてEncoder Layerを構築します。

Step 5: Encoder Layer と Encoder

1層の構成

Encoder Layerは、Multi-Head AttentionとFFNを「残差接続(Residual Connection)+ Layer Normalization」で包んだ構造です。原論文ではこの構造を「サブレイヤー」と呼んでいます。

1つのEncoder Layerの処理フローは以下のとおりです。

  1. Multi-Head Self-Attention: 入力 $\bm{X}$ をQuery、Key、Valueの全てに使う
  2. Add & Norm: 残差接続(入力を出力に足す)+ Layer Normalization
  3. Feed-Forward Network: 各位置を独立に非線形変換
  4. Add & Norm: 再び残差接続 + Layer Normalization

数式で書くと次のようになります。

$$ \bm{Z} = \text{LayerNorm}(\bm{X} + \text{MultiHead}(\bm{X}, \bm{X}, \bm{X})) $$

$$ \bm{H} = \text{LayerNorm}(\bm{Z} + \text{FFN}(\bm{Z})) $$

ここで重要なのは残差接続です。「入力をそのまま出力に足す」というシンプルな仕組みですが、深いネットワークの学習を劇的に安定させます。残差接続がなければ、勾配がAttentionやFFNの何層もの変換を逆伝播する際に消失しやすくなります。入力を直接足すことで、勾配が「近道」を通って流れ、深い層のパラメータも効果的に更新できます。

Layer Normalizationは、各サンプルの各位置で特徴量を正規化します。Batch Normalizationと異なり、バッチサイズに依存しないため、可変長の系列を扱うTransformerに適しています。

なお、原論文のLayer Normalizationの位置(Post-LN)に対して、最近の実装ではAttention/FFNの前にLayer Normalizationを適用する「Pre-LN」方式も広く使われています。Pre-LNの方が学習が安定しやすいことが知られていますが、本記事では原論文に忠実なPost-LNで実装します。

PyTorch実装

import torch
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        # Multi-Head Self-Attention + Add & Norm
        attn_output = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # Feed-Forward Network + Add & Norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        return x

self.self_attn(x, x, x, src_mask) の3つの x は、Query、Key、Valueの全てが同じ入力であることを意味しています。これが「Self-Attention」と呼ばれる所以です。Decoderでは、Cross-AttentionとしてQueryとKey/Valueが異なるソースから来る場合もあります。

Encoderのスタック

Encoder Layerを $N$ 層積み重ねてEncoderを構成します。原論文では $N = 6$ です。

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, src_mask=None):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)

nn.ModuleList を使うことで、各層のパラメータがPyTorchに正しく登録されます。通常のPythonリストではパラメータが追跡されないため、model.parameters() に含まれず、最適化が行われません。最後に self.norm を適用していますが、これはPre-LN方式との整合性を取るためのオプションです。Post-LN方式では省略しても大きな影響はありませんが、多くの実装で採用されている慣習に従います。

import torch

# Encoderのテスト
encoder = Encoder(d_model=512, num_heads=8, d_ff=2048, num_layers=6)
x = torch.randn(2, 20, 512)  # (batch=2, seq_len=20, d_model=512)
output = encoder(x)
print("Encoder入力形状:", x.shape)
print("Encoder出力形状:", output.shape)
total_params = sum(p.numel() for p in encoder.parameters())
print(f"Encoderパラメータ数: {total_params:,}")

Encoderの出力形状は入力と同じ (2, 20, 512) です。6層のEncoderのパラメータ数は約1,890万個になります。入力系列の各位置について、6回のAttention + FFNの変換を経ることで、文脈を深く反映した表現が得られます。

Encoderが完成しました。次は、Encoderの出力を参照しながら出力系列を生成するDecoderを実装します。Decoder LayerはEncoder Layerより複雑で、Masked Self-AttentionとCross-Attentionという2種類のAttentionを含みます。

Step 6: Decoder Layer と Decoder

Decoder Layerの構造

Decoder Layerは、Encoder Layerの構成に加えて、Masked Self-AttentionCross-Attentionという2つの追加的な仕組みを持ちます。

翻訳タスクを例に考えましょう。英語の文「I love cats」をフランス語「J’aime les chats」に翻訳するとき、「les」を生成する時点では、まだ「chats」は生成されていません。Decoderは「J’aime」までの情報と、Encoderが処理した「I love cats」の文脈情報だけを使って「les」を予測しなければなりません。

この制約を実現するために、Decoder Layerは以下の3つのサブレイヤーで構成されます。

サブレイヤー1: Masked Self-Attention

出力系列内でのSelf-Attentionですが、未来のトークンを参照できないようにマスクをかけます。位置 $t$ のトークンは、位置 $1, 2, \ldots, t$ までしか見ることができません。これは上三角行列を $-\infty$ にすることで実現します。

サブレイヤー2: Cross-Attention(Encoder-Decoder Attention)

DecoderのQueryがEncoderの出力をKey/Valueとして参照するAttentionです。ここでDecoderは「入力系列のどの部分に注目すべきか」を学習します。翻訳タスクでは、「les」を出力するときに「cats」(複数形)に注目するような対応関係を学びます。

サブレイヤー3: Feed-Forward Network

Encoder Layerと同じPosition-wise FFNです。

数式で書くと以下のようになります。

$$ \bm{Z}_1 = \text{LayerNorm}(\bm{Y} + \text{MaskedMultiHead}(\bm{Y}, \bm{Y}, \bm{Y})) $$

$$ \bm{Z}_2 = \text{LayerNorm}(\bm{Z}_1 + \text{MultiHead}(\bm{Z}_1, \bm{M}_{\text{enc}}, \bm{M}_{\text{enc}})) $$

$$ \bm{H} = \text{LayerNorm}(\bm{Z}_2 + \text{FFN}(\bm{Z}_2)) $$

ここで $\bm{Y}$ はDecoder入力、$\bm{M}_{\text{enc}}$ はEncoder出力です。Cross-Attentionでは、Queryが $\bm{Z}_1$(Decoderの中間表現)、Key/Valueが $\bm{M}_{\text{enc}}$(Encoder出力)である点に注目してください。

因果マスクの作成

Masked Self-Attentionで使う因果マスク(causal mask)は、上三角部分が0(マスクされる)、下三角部分と対角が1(参照可能)の行列です。

import torch

def make_causal_mask(size):
    """因果マスク: 未来のトークンを参照できないようにする
    Args:
        size: 系列長
    Returns:
        mask: (1, 1, size, size) — 参照可能な位置が1、マスク位置が0
    """
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
    return mask  # (1, 1, size, size)

# 確認: 系列長5の因果マスク
print(make_causal_mask(5).squeeze())

出力は下三角行列になります。1行目は位置0のトークンで、自分自身(位置0)だけを参照できます。3行目は位置2のトークンで、位置0、1、2を参照できます。この「過去と現在だけ見える、未来は見えない」という制約が、自己回帰的な生成を可能にします。

PyTorch実装

import torch
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # 3つのサブレイヤー
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)

        # 3つのLayerNorm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # 3つのDropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # サブレイヤー1: Masked Self-Attention
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # サブレイヤー2: Cross-Attention
        # Query=x (Decoderの中間表現), Key=Value=enc_output
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout2(attn_output))

        # サブレイヤー3: Feed-Forward Network
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_output))
        return x

Cross-Attentionの呼び出し self.cross_attn(x, enc_output, enc_output, src_mask) で、第1引数(Query)がDecoderの中間表現、第2・第3引数(Key, Value)がEncoder出力であることに注目してください。これにより、Decoderは「自分が今生成しようとしているトークンに関連する、入力系列のどの部分を見るべきか」を動的に決定します。

Decoderのスタック

import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.norm(x)

構造はEncoderとほぼ同じですが、各層が enc_output を受け取る点と、因果マスク tgt_mask を使う点が異なります。

import torch

# Decoderのテスト
decoder = Decoder(d_model=512, num_heads=8, d_ff=2048, num_layers=6)
# Decoder入力とEncoder出力
tgt = torch.randn(2, 15, 512)   # 出力系列(長さ15)
enc_out = torch.randn(2, 20, 512)  # Encoder出力(長さ20)
# 因果マスク
tgt_mask = make_causal_mask(15)

output = decoder(tgt, enc_out, tgt_mask=tgt_mask)
print("Decoder入力形状:", tgt.shape)
print("Encoder出力形状:", enc_out.shape)
print("Decoder出力形状:", output.shape)

Decoder入力 (2, 15, 512) とEncoder出力 (2, 20, 512) の系列長が異なっていても正しく動作することが確認できます。Cross-Attentionがこの「異なる長さの系列」を橋渡ししています。翻訳タスクでは入力文と出力文の長さが異なるのが普通なので、この柔軟性は不可欠です。

EncoderとDecoderの両方が完成しました。最後のステップでは、これらを統合し、入力トークン列から出力トークン列の確率分布を予測するTransformerモデル全体を組み立てます。

Step 7: Transformer全体の組み立て

全体構成

Transformer全体は以下の要素で構成されます。

  1. 入力埋め込み(Source Embedding): 入力トークンID → 埋め込みベクトル
  2. 出力埋め込み(Target Embedding): 出力トークンID → 埋め込みベクトル
  3. Positional Encoding: 位置情報の付加
  4. Encoder: 入力系列の文脈表現を計算
  5. Decoder: Encoder出力を参照しながら出力系列を生成
  6. 最終線形層 + Softmax: Decoderの出力を語彙サイズの確率分布に変換

ここで一つ重要な設計判断があります。原論文では、入力埋め込みと出力埋め込みの重み、さらに最終線形層の重みを共有(Weight Tying)しています。語彙が同じ場合(例: 同一言語内のタスク)、これによりパラメータ数を削減しつつ性能を維持できます。本実装ではシンプルさのために独立した重みを使いますが、共有する場合は self.generator.weight = self.tgt_embedding.weight とするだけです。

また、埋め込みベクトルに $\sqrt{d_{\text{model}}}$ を掛ける処理があります。これは、埋め込みの値のスケールをPositional Encodingと揃えるためです。Positional Encodingのsin/cosは値域が $[-1, 1]$ ですが、埋め込みの初期値は通常もっと小さいため、そのままでは位置情報が支配的になってしまいます。

PyTorch実装

import torch
import torch.nn as nn
import math

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
                 num_heads=8, d_ff=2048, num_layers=6,
                 max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # 埋め込み層
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # 位置エンコーディング
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)

        # Encoder と Decoder
        self.encoder = Encoder(d_model, num_heads, d_ff, num_layers, dropout)
        self.decoder = Decoder(d_model, num_heads, d_ff, num_layers, dropout)

        # 最終出力層(語彙サイズへの射影)
        self.generator = nn.Linear(d_model, tgt_vocab_size)

        # パラメータの初期化
        self._init_parameters()

    def _init_parameters(self):
        """Xavier初期化で学習を安定させる"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

パラメータの初期化に xavier_uniform_ を使っています。Xavierの初期化は、各層の入出力の分散を揃えることで、勾配の消失・爆発を防ぎます。Transformerのような深いネットワークでは、適切な初期化が学習の成否を左右します。

続いて、forward メソッドとマスク生成のヘルパーメソッドを実装します。

    def make_src_mask(self, src, pad_idx=0):
        """パディングマスク: パディングトークンを無視する"""
        # src: (batch, src_len)
        src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask  # (batch, 1, 1, src_len)

    def make_tgt_mask(self, tgt, pad_idx=0):
        """パディングマスク + 因果マスクの組み合わせ"""
        batch_size, tgt_len = tgt.size()
        # パディングマスク
        pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
        # 因果マスク
        causal_mask = torch.tril(
            torch.ones(tgt_len, tgt_len, device=tgt.device)
        ).unsqueeze(0).unsqueeze(0)
        # 両方のマスクを結合(論理AND)
        tgt_mask = pad_mask & (causal_mask.bool())
        return tgt_mask  # (batch, 1, tgt_len, tgt_len)

    def forward(self, src, tgt, pad_idx=0):
        # マスク生成
        src_mask = self.make_src_mask(src, pad_idx)
        tgt_mask = self.make_tgt_mask(tgt, pad_idx)

        # Encoder側: 埋め込み → 位置エンコーディング → Encoder
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.pos_encoding(src_emb)
        enc_output = self.encoder(src_emb, src_mask)

        # Decoder側: 埋め込み → 位置エンコーディング → Decoder
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoding(tgt_emb)
        dec_output = self.decoder(tgt_emb, enc_output, src_mask, tgt_mask)

        # 最終出力: 語彙サイズへの射影
        output = self.generator(dec_output)
        return output  # (batch, tgt_len, tgt_vocab_size)

make_tgt_mask では、パディングマスクと因果マスクを論理ANDで結合しています。パディングマスクは「パディングトークンを無視する」ためのもので、因果マスクは「未来のトークンを見ない」ためのものです。両方を満たす位置だけが参照可能になります。

import torch

# Transformer全体のテスト
model = Transformer(
    src_vocab_size=1000,
    tgt_vocab_size=1000,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_layers=6
)

# ダミー入力
src = torch.randint(1, 1000, (2, 20))  # (batch=2, src_len=20)
tgt = torch.randint(1, 1000, (2, 15))  # (batch=2, tgt_len=15)

output = model(src, tgt)
print("入力(src)形状:", src.shape)
print("入力(tgt)形状:", tgt.shape)
print("出力形状:", output.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}")

出力形状は (2, 15, 1000) で、バッチ内の各サンプルの各出力位置について、語彙サイズ1000の確率分布(のlogits)が得られています。総パラメータ数は約4,500万で、原論文の “base” モデル(約6,500万パラメータ)よりやや少ないですが、これは語彙サイズ(1000 vs 37000)の違いによるものです。埋め込み層と最終線形層のパラメータは語彙サイズに比例するため、実用的なモデルではパラメータの大部分をこれらの層が占めます。

全てのコンポーネントが統合されたTransformerが完成しました。しかし、モデルを作っただけでは動作が正しいかわかりません。次のステップでは、合成データを使った学習実験により、このモデルが実際に学習できることを確認します。

動作確認: コピータスクでの学習実験

コピータスクとは

Transformerが正しく実装できているかを確認する最もシンプルな方法は、コピータスクを使うことです。コピータスクとは、入力系列をそのまま出力系列としてコピーするタスクです。

たとえば、入力が [5, 3, 8, 2, 7] なら、出力も [5, 3, 8, 2, 7] です。一見簡単に見えますが、Transformer(Encoder-Decoder構造)がこれを解くには、Encoderが入力系列の情報を適切に符号化し、DecoderがCross-Attentionを通じてその情報を正確に取り出す必要があります。つまり、Attentionのメカニズムが正しく機能しなければ、このシンプルなタスクですら解けません。

学習が成功すれば、損失が急速に0に近づき、推論時にも入力を正確にコピーできるようになります。逆に損失が下がらない場合は、実装のどこかにバグがあることを意味します。

データ生成とモデル設定

コピータスクでは、特殊トークンとして <pad>=0<bos>=1(系列の開始)、<eos>=2(系列の終了)を定義し、データトークンは3以上の整数を使います。Decoder入力は <bos> から始まり、正解ラベルは <eos> で終わるように構成します。

import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np

# 再現性のためのシード固定
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# 特殊トークン
PAD_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
VOCAB_SIZE = 20   # 0-19のトークン(0,1,2は特殊トークン)
SEQ_LEN = 10      # コピーする系列の長さ

def generate_copy_data(batch_size, seq_len, vocab_size):
    """コピータスクのデータを生成する"""
    # データトークン: 3 ~ vocab_size-1 の範囲
    data = torch.randint(3, vocab_size, (batch_size, seq_len))

    # Encoder入力: [data tokens]
    src = data

    # Decoder入力: [BOS, data tokens](右シフト)
    tgt_input = torch.cat([
        torch.full((batch_size, 1), BOS_IDX, dtype=torch.long),
        data
    ], dim=1)

    # 正解ラベル: [data tokens, EOS]
    tgt_output = torch.cat([
        data,
        torch.full((batch_size, 1), EOS_IDX, dtype=torch.long)
    ], dim=1)

    return src, tgt_input, tgt_output

Decoderの入力が [BOS, x1, x2, ..., xn] で、対応する正解ラベルが [x1, x2, ..., xn, EOS] であることに注意してください。これはTeacher Forcingと呼ばれる学習手法で、Decoderの各位置に「正解の前のトークン」を入力として与え、次のトークンを予測させます。つまり、位置0では BOS を見て x1 を予測し、位置1では x1 を見て x2 を予測する、というように1ステップずつずれた予測をします。

学習ループ

小さなモデルで効率的に学習するため、ハイパーパラメータを控えめに設定します。

import torch
import torch.nn as nn
import torch.optim as optim

# ハイパーパラメータ(コピータスク用の小さな設定)
D_MODEL = 64
NUM_HEADS = 4
D_FF = 128
NUM_LAYERS = 2
DROPOUT = 0.1
LR = 1e-3
NUM_EPOCHS = 50
BATCH_SIZE = 64

# モデル、損失関数、最適化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(
    src_vocab_size=VOCAB_SIZE,
    tgt_vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    d_ff=D_FF,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)

print(f"デバイス: {device}")
total_params = sum(p.numel() for p in model.parameters())
print(f"パラメータ数: {total_params:,}")

CrossEntropyLossignore_index=PAD_IDX により、パディングトークンに対する損失は計算されません。Adamのハイパーパラメータ betas=(0.9, 0.98)eps=1e-9 は原論文の設定に倣っています。

import torch

# 学習ループ
losses = []
for epoch in range(NUM_EPOCHS):
    model.train()
    # バッチごとに新しいデータを生成
    src, tgt_input, tgt_output = generate_copy_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
    src = src.to(device)
    tgt_input = tgt_input.to(device)
    tgt_output = tgt_output.to(device)

    # 順伝播
    output = model(src, tgt_input, pad_idx=PAD_IDX)
    # output: (batch, tgt_len, vocab_size)
    # CrossEntropyLossは (N, C) と (N,) を期待するのでreshape
    output = output.view(-1, VOCAB_SIZE)
    tgt_output = tgt_output.view(-1)

    # 損失計算と逆伝播
    loss = criterion(output, tgt_output)
    optimizer.zero_grad()
    loss.backward()
    # 勾配クリッピング(勾配爆発を防ぐ)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    losses.append(loss.item())
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {loss.item():.4f}")

勾配クリッピング clip_grad_norm_ は、勾配のノルムが閾値(ここでは1.0)を超えた場合にスケーリングして抑える処理です。Transformerでは、特に学習初期に勾配が大きくなりやすいため、この処理が安定した学習に寄与します。

損失の推移を可視化

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Copy Task Training Loss')
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.tight_layout()
plt.savefig('copy_task_loss.png', dpi=150)
plt.show()

損失曲線が指数的に減少していることが確認できます。最初の数エポックで損失が急激に下がり、その後緩やかに収束していくパターンは、モデルがタスクを学習している典型的な挙動です。対数スケールで直線的に減少しているということは、損失が指数関数的に減少していることを意味します。もし損失が全く下がらない場合は、マスクの実装やテンソルの形状に問題がある可能性が高いです。

推論テスト

学習したモデルを使って、実際にコピータスクが解けるか確認しましょう。推論時はTeacher Forcingは使わず、自己回帰的に1トークンずつ生成します。<BOS> から始めて、モデルの出力する最も確率の高いトークンを次の入力として追加し、<EOS> が出るか最大長に達するまで繰り返します。

import torch

def greedy_decode(model, src, max_len, bos_idx, eos_idx, pad_idx, device):
    """貪欲法による自己回帰デコード"""
    model.eval()
    src = src.to(device)

    # Decoder入力を<BOS>で初期化
    tgt = torch.full((src.size(0), 1), bos_idx, dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len):
            output = model(src, tgt, pad_idx=pad_idx)
            # 最後の位置の予測を取得
            next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)
            # 全バッチでEOSが出たら終了
            if (next_token == eos_idx).all():
                break
    return tgt
import torch

# 推論テスト
src_test, _, _ = generate_copy_data(5, SEQ_LEN, VOCAB_SIZE)
print("入力系列:")
print(src_test.numpy())

predicted = greedy_decode(
    model, src_test, max_len=SEQ_LEN + 2,
    bos_idx=BOS_IDX, eos_idx=EOS_IDX, pad_idx=PAD_IDX, device=device
)
# <BOS>を除いた予測結果
pred_tokens = predicted[:, 1:SEQ_LEN+1].cpu().numpy()
print("\n予測系列:")
print(pred_tokens)

# 一致率の確認
correct = (src_test.numpy() == pred_tokens).mean()
print(f"\nトークン一致率: {correct * 100:.1f}%")

学習が十分であれば、トークン一致率は100%(またはそれに近い値)になります。入力系列と予測系列を比較すると、モデルが各位置のトークンを正確にコピーできていることが確認できます。これは、Encoderが入力系列の情報を正しく符号化し、DecoderのCross-Attentionがその情報を適切に取り出し、因果マスクの下での自己回帰生成が正しく機能していることの証拠です。

もし一致率が低い場合、以下の点をデバッグしてください。

  • マスクの向き: 因果マスクの上三角と下三角が逆になっていないか
  • Teacher Forcingのずれ: Decoder入力と正解ラベルが1ステップずれているか
  • 学習不足: エポック数を増やす、またはバッチサイズを大きくする
  • 学習率: 高すぎると発散、低すぎると学習が遅い

Attention重みの可視化

最後に、学習済みモデルのAttention重みを可視化して、モデルが何を学習したか確認しましょう。コピータスクでは、Cross-Attentionが「対応する位置に強く注目する」対角パターンを学習するはずです。

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math

# Attention重みを取得するための順伝播
model.eval()
src_vis = src_test[:1].to(device)
tgt_vis = torch.cat([
    torch.full((1, 1), BOS_IDX, dtype=torch.long),
    src_test[:1]
], dim=1).to(device)

with torch.no_grad():
    # Encoder
    src_mask = model.make_src_mask(src_vis, PAD_IDX)
    src_emb = model.src_embedding(src_vis) * math.sqrt(model.d_model)
    src_emb = model.pos_encoding(src_emb)
    enc_out = model.encoder(src_emb, src_mask)

    # Decoder(最初の層のCross-Attentionの重みを取得)
    tgt_mask = model.make_tgt_mask(tgt_vis, PAD_IDX)
    tgt_emb = model.tgt_embedding(tgt_vis) * math.sqrt(model.d_model)
    tgt_emb = model.pos_encoding(tgt_emb)

    # Decoder第1層のCross-Attentionの重みを手動計算
    dec_layer = model.decoder.layers[0]
    # Masked Self-Attention
    z1 = dec_layer.self_attn(tgt_emb, tgt_emb, tgt_emb, tgt_mask)
    z1 = dec_layer.norm1(tgt_emb + z1)
    # Cross-Attentionのスコアを直接計算
    Q = dec_layer.cross_attn.W_q(z1)
    K = dec_layer.cross_attn.W_k(enc_out)
    d_k = dec_layer.cross_attn.d_k
    num_heads = dec_layer.cross_attn.num_heads
    bs = Q.size(0)
    Q = Q.view(bs, -1, num_heads, d_k).transpose(1, 2)
    K = K.view(bs, -1, num_heads, d_k).transpose(1, 2)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    attn_w = F.softmax(scores, dim=-1)
import matplotlib.pyplot as plt

# 4つのヘッドのAttention重みを可視化
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for h in range(4):
    ax = axes[h]
    im = ax.imshow(
        attn_w[0, h].cpu().numpy(), cmap='Blues', aspect='auto'
    )
    ax.set_title(f'Head {h}')
    ax.set_xlabel('Source Position')
    ax.set_ylabel('Target Position')
fig.suptitle('Cross-Attention Weights (Decoder Layer 0)', fontsize=14)
plt.tight_layout()
plt.savefig('cross_attention_weights.png', dpi=150)
plt.show()

コピータスクを十分に学習したモデルでは、少なくとも一部のヘッドで対角線上に高い重みが集中するパターンが見られます。これは「出力位置 $i$ が入力位置 $i$ に注目している」ことを意味し、コピーという操作を正確に学習できた証拠です。他のヘッドは異なるパターンを示すかもしれませんが、それはMulti-Head Attentionの「複数の観点で情報を捉える」性質の表れです。

まとめ

本記事では、Transformerの全コンポーネントをPyTorchでゼロから実装しました。以下に、各ステップで構築した要素と、その役割を振り返ります。

  • Scaled Dot-Product Attention: Query、Key、Valueの3つの行列から、文脈を考慮した出力を計算する基本演算。$\sqrt{d_k}$ によるスケーリングが勾配の安定性に寄与する
  • Multi-Head Attention: Attentionを複数のヘッドで並列実行し、異なる部分空間の情報を捉える。viewtranspose による効率的なヘッド分割が実装の鍵
  • Position-wise FFN: 各位置に独立に適用される2層全結合ネットワーク。中間層の拡大($4 \times d_{\text{model}}$)により豊かな非線形変換を実現する
  • Positional Encoding: sin/cosの多スケール周期構造により、トークンの位置情報を符号化する
  • Encoder: Self-Attention + FFN を残差接続とLayer Normalizationで包んだ層を $N$ 段積み重ねる
  • Decoder: Masked Self-Attention + Cross-Attention + FFN の3サブレイヤー構成。因果マスクにより自己回帰的な生成を保証する
  • Transformer全体: 埋め込み → Positional Encoding → Encoder/Decoder → 線形射影の一気通貫構造

コピータスクでの実験により、全てのメカニズム(Attention、マスク、残差接続、層の積み重ね)が正しく機能し、学習・推論が行えることを確認しました。

ゼロからの実装を通じて得た理解は、既存ライブラリの nn.Transformer を使う際にも活きます。各引数の意味、マスクの仕様、出力の形状が「なぜそうなっているか」を体感として知っていることで、デバッグや拡張が格段に楽になります。

次のステップとして、以下の記事も参考にしてください。

画像なし
Transformerアーキテクチャの全体像
原論文の設計思想や学習テクニック(warmup、ラベル平滑化)など、本記事でカバーしなかった側面を解説します。