Self-Attention機構の理論と実装をわかりやすく解説

Self-Attention(自己注意機構)は、Transformerアーキテクチャの中核をなすメカニズムです。2017年にGoogleの研究チームが “Attention Is All You Need” で提案して以来、自然言語処理のみならず、画像認識・音声認識・時系列解析など、あらゆる分野の深層学習モデルに革命をもたらしました。

Self-Attentionを理解することは、GPTやBERTといった大規模言語モデル(LLM)の仕組みを本質的に理解するための第一歩です。本記事では、RNNの限界からAttentionが生まれた背景を押さえた上で、Scaled Dot-Product Attentionの数式を1行ずつ丁寧に導出し、最終的にPyTorchでスクラッチ実装するところまでを解説します。

本記事の内容

  • RNNの限界とAttentionが必要になった背景
  • Query / Key / Value の直感的理解
  • Scaled Dot-Product Attention の数式と導出
  • 小さな行列による手計算の具体例
  • マルチヘッドAttentionの仕組み
  • Self-Attentionの計算量 $O(n^2 d)$
  • PyTorchでのスクラッチ実装と可視化

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

なぜAttentionが必要か — RNNの限界

Self-Attentionの意義を理解するために、まずRNN(Recurrent Neural Network)が抱えていた構造的な限界を確認しましょう。

RNNは時刻 $t$ の隠れ状態 $\bm{h}_t$ を、1つ前の隠れ状態 $\bm{h}_{t-1}$ と現在の入力 $\bm{x}_t$ から逐次的に計算します。

$$ \bm{h}_t = f(\bm{W}_h \bm{h}_{t-1} + \bm{W}_x \bm{x}_t + \bm{b}) $$

ここで $f$ は活性化関数(tanhなど)です。この逐次計算には、次の2つの本質的な問題があります。

問題1: 長距離依存の困難

文の先頭の単語が文末の単語に影響するとき、情報は $\bm{h}_1 \to \bm{h}_2 \to \cdots \to \bm{h}_n$ と逐次伝搬する必要があります。系列が長くなると、初期の情報は勾配消失によって失われやすくなります。LSTMやGRUはゲート機構でこの問題を緩和しましたが、根本的な解決には至りませんでした。

例えば、「私は東京で生まれ、… (100単語の文脈)… だから日本語が得意です」という文において、「東京で生まれ」の情報は100回の再帰的な変換を経るうちに減衰してしまい、「日本語が得意」との関連を捉えることが困難になります。

問題2: 並列化の不可能性

$\bm{h}_t$ の計算には $\bm{h}_{t-1}$ が必要なため、系列方向の計算を並列化できません。GPUの性能を十分に活かすことができず、長い系列の学習には多大な時間を要します。

Self-Attentionはこの2つの問題を同時に解決します。系列中の任意の2つの位置を直接結びつけるため、長距離依存の問題がなくなり、さらに全位置の計算を同時に実行できるため、並列化が可能です。

Attentionの直感的理解 — 辞書検索のアナロジー

Attention機構を直感的に理解するには、辞書を引く操作にたとえるのが最もわかりやすいでしょう。

辞書を使うとき、私たちは次のステップを踏みます。

  1. Query(問い合わせ): 「調べたいこと」を頭に浮かべる
  2. Key(見出し語): 辞書の見出し語を順に見て、Queryとの関連度を測る
  3. Value(本文): 関連度の高い見出し語に対応する本文を読み取る

Attentionでも全く同じ構造です。

概念 辞書での役割 Attentionでの役割
Query ($\bm{Q}$) 調べたいこと 注目元のトークンの特徴ベクトル
Key ($\bm{K}$) 見出し語 注目先の各トークンの特徴ベクトル
Value ($\bm{V}$) 本文 注目先の各トークンが持つ情報

ここで重要なのは、Self-Attentionの場合、Query・Key・Valueのすべてが同一の入力系列から生成されるという点です。つまり、文中の各単語が「自分自身を含む文中の他の全単語との関連度」を計算し、関連度に応じて情報を集約します。これが「Self(自己)」と呼ばれる所以です。

数学的には、入力行列 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$($n$:系列長、$d_{\text{model}}$:埋め込み次元)に対して、学習可能な重み行列を用いて3つの射影を生成します。

$$ \bm{Q} = \bm{X} \bm{W}^Q, \quad \bm{K} = \bm{X} \bm{W}^K, \quad \bm{V} = \bm{X} \bm{W}^V $$

ここで $\bm{W}^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $\bm{W}^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $\bm{W}^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$ は学習可能なパラメータです。同じ入力 $\bm{X}$ から異なる線形変換で $\bm{Q}, \bm{K}, \bm{V}$ を生成することにより、同じ単語でも「質問する側」「検索される側」「情報を提供する側」として異なる表現を持つことができます。

Scaled Dot-Product Attention の数式

Attention関数の定義

Scaled Dot-Product Attentionは次の式で定義されます。

$$ \begin{equation} \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\!\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} \end{equation} $$

この式を段階的に分解して見ていきましょう。

ステップ1: 類似度の計算 $\bm{Q}\bm{K}^\top$

QueryとKeyの内積を計算します。$\bm{Q} \in \mathbb{R}^{n \times d_k}$ と $\bm{K}^\top \in \mathbb{R}^{d_k \times n}$ の積なので、結果は $\bm{S} \in \mathbb{R}^{n \times n}$ の正方行列になります。

$$ S_{ij} = \bm{q}_i \cdot \bm{k}_j = \sum_{l=1}^{d_k} q_{il} \, k_{jl} $$

$S_{ij}$ は $i$ 番目のQueryと $j$ 番目のKeyの類似度を表します。内積が大きいほど、2つのベクトルの方向が近い、つまり関連度が高いことを意味します。

ステップ2: スケーリング $\frac{1}{\sqrt{d_k}}$

内積の値を $\sqrt{d_k}$ で割ります。なぜこのスケーリングが必要なのでしょうか。

QueryとKeyの各成分が平均0、分散1の独立な確率変数だと仮定します。このとき、内積 $S_{ij} = \sum_{l=1}^{d_k} q_{il} k_{jl}$ の期待値と分散を計算してみましょう。

期待値の計算:

$$ \begin{align} \mathbb{E}[S_{ij}] &= \mathbb{E}\!\left[\sum_{l=1}^{d_k} q_{il} k_{jl}\right] \\ &= \sum_{l=1}^{d_k} \mathbb{E}[q_{il}] \mathbb{E}[k_{jl}] \quad (\because q_{il} \text{ と } k_{jl} \text{ は独立}) \\ &= \sum_{l=1}^{d_k} 0 \cdot 0 = 0 \end{align} $$

分散の計算:

独立な確率変数の積の分散公式を用いると、

$$ \begin{align} \text{Var}[q_{il} k_{jl}] &= \mathbb{E}[q_{il}^2 k_{jl}^2] – (\mathbb{E}[q_{il} k_{jl}])^2 \\ &= \mathbb{E}[q_{il}^2] \cdot \mathbb{E}[k_{jl}^2] – (\mathbb{E}[q_{il}] \cdot \mathbb{E}[k_{jl}])^2 \quad (\because \text{独立}) \\ &= (\text{Var}[q_{il}] + (\mathbb{E}[q_{il}])^2)(\text{Var}[k_{jl}] + (\mathbb{E}[k_{jl}])^2) – 0 \\ &= (1 + 0)(1 + 0) = 1 \end{align} $$

各項が独立なので、

$$ \begin{align} \text{Var}[S_{ij}] &= \sum_{l=1}^{d_k} \text{Var}[q_{il} k_{jl}] = \sum_{l=1}^{d_k} 1 = d_k \end{align} $$

つまり、内積の分散は次元 $d_k$ に比例して大きくなります。$d_k$ が大きいとき(例えば $d_k = 512$)、内積の値は非常に大きな絶対値を取り得ます。この大きな値がsoftmaxに入力されると、softmaxの出力がone-hotに近い極端な分布になり、勾配が極めて小さくなってしまいます(勾配消失)。

$\sqrt{d_k}$ で割ることにより、スケーリング後の値の分散を1に正規化できます。

$$ \text{Var}\!\left[\frac{S_{ij}}{\sqrt{d_k}}\right] = \frac{1}{d_k} \text{Var}[S_{ij}] = \frac{d_k}{d_k} = 1 $$

これにより、$d_k$ の大きさに依存せず、softmaxが適切に動作する温和な値域に収まります。

ステップ3: softmaxによる正規化

スケーリング後のスコアに対して、行方向にsoftmaxを適用します。

$$ \alpha_{ij} = \frac{\exp\!\left(\dfrac{S_{ij}}{\sqrt{d_k}}\right)}{\displaystyle\sum_{l=1}^{n} \exp\!\left(\dfrac{S_{il}}{\sqrt{d_k}}\right)} $$

$\alpha_{ij}$ はAttention重みと呼ばれ、$i$ 番目のトークンが $j$ 番目のトークンにどれだけ注目するかを表す確率値です。各行について $\sum_{j=1}^{n} \alpha_{ij} = 1$ かつ $\alpha_{ij} \geq 0$ が成り立ちます。

ステップ4: Valueの加重和

最後に、Attention重みを用いてValueの加重和を計算します。

$$ \bm{z}_i = \sum_{j=1}^{n} \alpha_{ij} \bm{v}_j $$

$i$ 番目のトークンの出力は、全トークンのValueベクトルを、Attention重みで重み付けした加重平均です。関連度の高いトークンの情報が強く反映されます。

全体をまとめた導出

以上を行列表記でまとめると、次のように整理できます。

$$ \begin{align} \bm{S} &= \bm{Q}\bm{K}^\top \quad &\in \mathbb{R}^{n \times n} \\ \bm{A} &= \text{softmax}\!\left(\frac{\bm{S}}{\sqrt{d_k}}\right) \quad &\in \mathbb{R}^{n \times n} \\ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) &= \bm{A}\bm{V} \quad &\in \mathbb{R}^{n \times d_v} \end{align} $$

Attention重みの計算例

小さな行列を用いて、手計算でAttentionの動作を確認しましょう。系列長 $n = 3$、次元 $d_k = 2$ とします。

$$ \bm{Q} = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{pmatrix}, \quad \bm{K} = \begin{pmatrix} 1 & 1 \\ 0 & 1 \\ 1 & 0 \end{pmatrix}, \quad \bm{V} = \begin{pmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{pmatrix} $$

ステップ1: $\bm{Q}\bm{K}^\top$ を計算

$$ \bm{Q}\bm{K}^\top = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{pmatrix} \begin{pmatrix} 1 & 0 & 1 \\ 1 & 1 & 0 \end{pmatrix} = \begin{pmatrix} 1 & 0 & 1 \\ 1 & 1 & 0 \\ 2 & 1 & 1 \end{pmatrix} $$

例えば、$(1,1)$ 成分は $\bm{q}_1 \cdot \bm{k}_1 = 1 \times 1 + 0 \times 1 = 1$ です。

ステップ2: $\sqrt{d_k} = \sqrt{2} \approx 1.414$ で割る

$$ \frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} = \begin{pmatrix} 0.707 & 0.000 & 0.707 \\ 0.707 & 0.707 & 0.000 \\ 1.414 & 0.707 & 0.707 \end{pmatrix} $$

ステップ3: 行ごとにsoftmaxを適用

1行目を例にとると、

$$ \begin{align} \alpha_{11} &= \frac{e^{0.707}}{e^{0.707} + e^{0.000} + e^{0.707}} = \frac{2.028}{2.028 + 1.000 + 2.028} = \frac{2.028}{5.056} \approx 0.401 \\ \alpha_{12} &= \frac{e^{0.000}}{5.056} = \frac{1.000}{5.056} \approx 0.198 \\ \alpha_{13} &= \frac{e^{0.707}}{5.056} = \frac{2.028}{5.056} \approx 0.401 \end{align} $$

同様に全行を計算すると、Attention重み行列 $\bm{A}$ は以下のようになります。

$$ \bm{A} \approx \begin{pmatrix} 0.401 & 0.198 & 0.401 \\ 0.365 & 0.365 & 0.269 \\ 0.464 & 0.232 & 0.304 \end{pmatrix} $$

各行の和が1になっていることを確認してください(例: $0.401 + 0.198 + 0.401 = 1.000$)。

ステップ4: $\bm{A}\bm{V}$ を計算

$$ \bm{A}\bm{V} \approx \begin{pmatrix} 0.401 & 0.198 & 0.401 \\ 0.365 & 0.365 & 0.269 \\ 0.464 & 0.232 & 0.304 \end{pmatrix} \begin{pmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{pmatrix} = \begin{pmatrix} 3.000 & 4.000 \\ 2.690 & 3.884 \\ 2.680 & 3.752 \end{pmatrix} $$

1行目の出力を確認すると、$0.401 \times 1 + 0.198 \times 3 + 0.401 \times 5 = 3.000$ です。1番目のトークンは、1番目と3番目のトークン(Attention重み $\approx 0.401$)に強く注目し、2番目のトークン(重み $\approx 0.198$)への注目は弱いことがわかります。これは、Queryベクトル $(1, 0)$ が Keyベクトル $(1, 1)$ や $(1, 0)$ と高い内積を持つ一方で、$(0, 1)$ とは内積が0になることと整合しています。

マルチヘッドAttention

単一ヘッドの限界

単一のAttentionヘッドでは、1組の $(\bm{W}^Q, \bm{W}^K, \bm{W}^V)$ で1種類の関連度パターンしか捉えられません。しかし、自然言語では「文法的な係り受け」「意味的な類似性」「共参照関係」など、複数の関連度パターンが同時に存在します。

マルチヘッドの仕組み

マルチヘッドAttentionは、$h$ 個のAttentionヘッドを並列に走らせ、それぞれ異なる部分空間で関連度を学習します。

各ヘッド $i$ ($i = 1, \dots, h$) に対して、

$$ \text{head}_i = \text{Attention}(\bm{X}\bm{W}_i^Q, \, \bm{X}\bm{W}_i^K, \, \bm{X}\bm{W}_i^V) $$

ここで、$\bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $\bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $\bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$ は各ヘッド固有の射影行列です。

全ヘッドの出力を結合(concatenate)し、出力射影行列 $\bm{W}^O$ で変換します。

$$ \begin{equation} \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \, \bm{W}^O \end{equation} $$

ここで、$\bm{W}^O \in \mathbb{R}^{h \cdot d_v \times d_{\text{model}}}$ です。

次元の設計

典型的には $d_k = d_v = d_{\text{model}} / h$ と設定します。元論文 “Attention Is All You Need” では $d_{\text{model}} = 512$, $h = 8$ として、各ヘッドの次元は $d_k = d_v = 512 / 8 = 64$ です。

各ヘッドの次元を小さくすることで、マルチヘッドの総計算量は単一ヘッド($d_k = d_{\text{model}}$)の場合とほぼ同等に保たれます。これは次のように確認できます。

  • 単一ヘッド($d_k = d_{\text{model}}$): $O(n^2 d_{\text{model}})$
  • $h$ ヘッド($d_k = d_{\text{model}} / h$): $h \times O(n^2 \cdot d_{\text{model}} / h) = O(n^2 d_{\text{model}})$

計算コストは同等でありながら、$h$ 個の異なる部分空間で多様な関連度パターンを並列に学習できるのが、マルチヘッドAttentionの大きな利点です。

Self-Attentionの計算量

Self-Attentionの計算量を分析しましょう。系列長を $n$、次元を $d$ とします。

主要な計算:

  1. $\bm{Q}\bm{K}^\top$ の計算: $\bm{Q} \in \mathbb{R}^{n \times d}$ と $\bm{K}^\top \in \mathbb{R}^{d \times n}$ の行列積なので $O(n^2 d)$
  2. softmaxの適用: $n \times n$ 行列に対して $O(n^2)$
  3. $\bm{A}\bm{V}$ の計算: $\bm{A} \in \mathbb{R}^{n \times n}$ と $\bm{V} \in \mathbb{R}^{n \times d}$ の行列積なので $O(n^2 d)$

したがって、全体の計算量は以下のようになります。

$$ \begin{equation} O(n^2 d) \end{equation} $$

モデル 1層あたりの計算量 系列方向の逐次計算 最大パス長
RNN $O(n d^2)$ $O(n)$ $O(n)$
Self-Attention $O(n^2 d)$ $O(1)$ $O(1)$

RNNと比較すると、Self-Attentionは系列長 $n$ に対して2乗で計算量が増加するため、非常に長い系列($n \gg d$)では計算コストが高くなります。一方で、系列方向の逐次計算が $O(1)$ であるため完全に並列化が可能であり、任意の2トークン間の最大パス長も $O(1)$ であるため長距離依存を直接的に捉えられます。

$n^2$ のメモリ・計算コストに対処するために、Linformer($O(nd)$)、Performer($O(nd)$)、Flash Attention(IO効率の改善)などの効率的なAttention手法が提案されています。

PyTorchでのスクラッチ実装

Scaled Dot-Product Attention

まず、Scaled Dot-Product Attentionを実装します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention"""

    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: (batch_size, ..., seq_len, d_k)
            K: (batch_size, ..., seq_len, d_k)
            V: (batch_size, ..., seq_len, d_v)
            mask: マスク(オプション)
        Returns:
            output: (batch_size, ..., seq_len, d_v)
            attn_weights: (batch_size, ..., seq_len, seq_len)
        """
        d_k = Q.size(-1)

        # ステップ1, 2: 類似度計算 + スケーリング
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        # マスク適用(デコーダの因果マスクなどに使用)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # ステップ3: softmaxで正規化
        attn_weights = F.softmax(scores, dim=-1)

        # ステップ4: Valueの加重和
        output = torch.matmul(attn_weights, V)

        return output, attn_weights

SingleHeadAttention クラス

入力 $\bm{X}$ からQ, K, Vを生成し、Attention計算を行う単一ヘッドのクラスです。

class SingleHeadAttention(nn.Module):
    """単一ヘッドのSelf-Attention"""

    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_v, bias=False)
        self.attention = ScaledDotProductAttention()

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            output: (batch_size, seq_len, d_v)
            attn_weights: (batch_size, seq_len, seq_len)
        """
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        return self.attention(Q, K, V)

MultiHeadAttention クラス

複数のヘッドを効率的に計算するクラスです。全ヘッド分のQ, K, Vを1つの線形変換で一括計算し、viewtranspose でヘッド次元に分割します。

class MultiHeadAttention(nn.Module):
    """マルチヘッドAttention"""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 全ヘッド分を一括で線形変換
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        # 出力射影
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        # Attention計算
        self.attention = ScaledDotProductAttention()

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, d_model)
            mask: マスク(オプション)
        Returns:
            output: (batch_size, seq_len, d_model)
            attn_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.size()

        # 線形変換
        Q = self.W_q(x)  # (batch_size, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # ヘッド分割: (batch, seq, d_model) -> (batch, heads, seq, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 各ヘッドでAttention計算
        context, attn_weights = self.attention(Q, K, V, mask)

        # ヘッド結合: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        # 出力射影
        output = self.W_o(context)

        return output, attn_weights

実装のポイントを整理します。

  • nn.Linear(d_model, d_model) は、$h$ 個のヘッド分の射影を1つの行列にまとめて計算しています。これにより、ヘッドごとに別々の行列積を計算する必要がなく、効率的です。
  • viewtranspose を使い、$(B, n, d_{\text{model}})$ のテンソルを $(B, h, n, d_k)$ に変形しています。これにより、全ヘッドの計算をバッチ次元の一部として並列実行できます。
  • 出力時には transposecontiguous().view() で元の形状に戻し、出力射影 $\bm{W}^O$ を適用します。contiguous() は、transpose によりメモリ上で不連続になったテンソルを連続化するために必要です。

動作確認

# パラメータ設定
d_model = 32    # 埋め込み次元
num_heads = 4   # ヘッド数
seq_len = 6     # 系列長
batch_size = 2  # バッチサイズ

# ランダムな入力
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, d_model)

# 単一ヘッドAttention
single_attn = SingleHeadAttention(d_model, d_k=32, d_v=32)
out_s, weights_s = single_attn(x)
print(f"SingleHead 出力形状: {out_s.shape}")      # (2, 6, 32)
print(f"SingleHead 重み形状: {weights_s.shape}")   # (2, 6, 6)
print(f"Attention重みの行和: {weights_s[0, 0].sum().item():.4f}")  # 1.0000

# マルチヘッドAttention
multi_attn = MultiHeadAttention(d_model, num_heads)
out_m, weights_m = multi_attn(x)
print(f"MultiHead  出力形状: {out_m.shape}")       # (2, 6, 32)
print(f"MultiHead  重み形状: {weights_m.shape}")   # (2, 4, 6, 6)
print(f"Attention重みの行和: {weights_m[0, 0, 0].sum().item():.4f}")  # 1.0000

入力と出力の形状が同じ $(B, n, d_{\text{model}})$ であること、Attention重みの各行の和が1であること(softmaxの性質)が確認できます。

Attention重みの可視化

Attention重みをヒートマップとして可視化することで、モデルがどのトークン間の関連を捉えているかを確認できます。

import numpy as np
import matplotlib.pyplot as plt

# トークン列を定義
tokens = ["The", "cat", "sat", "on", "the", "mat"]

# 再現性のためシードを固定
torch.manual_seed(0)

# 入力を生成(バッチサイズ1)
x = torch.randn(1, len(tokens), d_model)

# マルチヘッドAttentionでAttention重みを取得
mha = MultiHeadAttention(d_model=32, num_heads=4)
with torch.no_grad():
    _, attn_weights = mha(x)

# Attention重みを NumPy に変換
attn = attn_weights[0].detach().numpy()  # (num_heads, seq_len, seq_len)

# 各ヘッドのAttention重みをヒートマップで表示
fig, axes = plt.subplots(1, 4, figsize=(18, 4))
for i, ax in enumerate(axes):
    im = ax.imshow(attn[i], cmap="Blues", vmin=0, vmax=attn[i].max())
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha="right")
    ax.set_yticklabels(tokens)
    ax.set_title(f"Head {i+1}")
    ax.set_xlabel("Key")
    ax.set_ylabel("Query")
fig.colorbar(im, ax=axes, shrink=0.8, label="Attention Weight")
fig.suptitle("Multi-Head Self-Attention Weights", fontsize=14)
plt.tight_layout()
plt.savefig("attention_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()

このヒートマップの読み方は以下の通りです。

  • 縦軸(Query): 注目する側のトークン
  • 横軸(Key): 注目される側のトークン
  • 色の濃さ: Attention重みの大きさ(濃いほど強い注目)

ランダム初期化の段階では各ヘッドのAttentionパターンはほぼ均一ですが、学習が進むにつれて、各ヘッドが異なるパターンを獲得していきます。例えば、学習済みのTransformerでは、あるヘッドは構文関係(主語-述語のペア)に、別のヘッドは位置的な近接関係に、また別のヘッドは共参照関係に注目するパターンが観察されることが知られています。

まとめ

本記事では、Self-Attention(自己注意)機構の理論と実装について解説しました。

  • RNNの限界: 逐次計算による長距離依存の困難と並列化の不可能性が、Attention機構の誕生を促した
  • Query / Key / Value: 辞書検索のアナロジーで理解でき、Self-Attentionでは同一の入力系列からQ, K, Vすべてが生成される
  • Scaled Dot-Product Attention: $\text{softmax}(\bm{Q}\bm{K}^\top / \sqrt{d_k})\bm{V}$ の各ステップの意味と、$\sqrt{d_k}$ によるスケーリングの数学的根拠を導出した
  • 手計算の具体例: 小さな行列で全ステップを追うことで、Attention重みの意味を確認した
  • マルチヘッドAttention: 複数の部分空間で異なる関連度パターンを並列に学習し、結合して出力する。計算コストは単一ヘッドと同等
  • 計算量: $O(n^2 d)$ であり、系列長の2乗に比例するが、完全な並列化が可能

Self-Attentionは、Transformerアーキテクチャの心臓部です。次のステップとして、Self-Attentionに加えてPositional Encoding、Feed-Forward Network、残差結合、Layer Normalizationを組み合わせたTransformerアーキテクチャ全体を学ぶことで、GPTやBERTの構造を包括的に理解できるようになります。

次のステップとして、以下の記事も参考にしてください。