マルチヘッドアテンション(Multi-Head Attention)の仕組みと実装

マルチヘッドアテンション(Multi-Head Attention)は、Transformerアーキテクチャの中核をなす機構です。単一のAttentionヘッドでは捉えきれない多様な関係性を、複数のヘッドで並列に学習することで、モデルの表現力を飛躍的に高めます。

GPTやBERTといった大規模言語モデル(LLM)の成功は、このマルチヘッドアテンションの設計に大きく依存しています。本記事では、なぜ単一ヘッドでは不十分なのかを明らかにした上で、マルチヘッドアテンションの数式を丁寧に導出し、Pythonでスクラッチ実装するところまでを解説します。

本記事の内容

  • なぜマルチヘッドが必要か
  • Scaled Dot-Product Attentionの復習
  • マルチヘッドアテンションの数式と図解
  • Query, Key, Valueの役割と射影の意味
  • Pythonでのスクラッチ実装と可視化

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

なぜマルチヘッドが必要か

単一ヘッドAttentionの限界

Self-Attentionの基本形であるScaled Dot-Product Attentionは、入力 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ から Query $\bm{Q}$、Key $\bm{K}$、Value $\bm{V}$ を生成し、次の式で出力を計算します。

$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\!\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$

この単一ヘッドのAttentionには、本質的な限界があります。

限界1: 単一の関係性パターンしか捉えられない

1組の射影行列 $(\bm{W}^Q, \bm{W}^K, \bm{W}^V)$ は、1種類の「類似度の測り方」しか学習できません。しかし、自然言語の文には複数の関係性が同時に存在します。

例えば、”The cat that sat on the mat was black.” という文を考えましょう。

  • 構文的関係: “cat” と “was” は主語と述語の関係
  • 修飾関係: “cat” と “sat” は関係代名詞節による修飾関係
  • 位置的近接: “on” と “the” や “mat” は隣接した位置関係
  • 意味的関係: “cat” と “black” は属性の関係

単一ヘッドでこれらすべてを同時に捉えることは困難です。

限界2: 表現部分空間の制約

単一ヘッドでは、$d_k$ 次元の部分空間で類似度を計算します。この空間が小さすぎると表現力が不足し、大きすぎると計算コストが増大します。

マルチヘッドによる解決

マルチヘッドアテンションは、複数の異なる部分空間で並列にAttentionを計算することで、これらの限界を克服します。

$$ \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \, \bm{W}^O $$

$h$ 個のヘッドがそれぞれ異なる関係性パターンを学習し、それらを統合することで、豊かな表現を獲得できます。

Scaled Dot-Product Attentionの復習

マルチヘッドアテンションの構成要素であるScaled Dot-Product Attentionを復習しましょう。

入力と出力

  • 入力: Query $\bm{Q} \in \mathbb{R}^{n \times d_k}$、Key $\bm{K} \in \mathbb{R}^{n \times d_k}$、Value $\bm{V} \in \mathbb{R}^{n \times d_v}$
  • 出力: $\text{Attention}(\bm{Q}, \bm{K}, \bm{V}) \in \mathbb{R}^{n \times d_v}$

計算ステップ

ステップ1: 類似度行列の計算

$$ \bm{S} = \bm{Q}\bm{K}^\top \in \mathbb{R}^{n \times n} $$

ステップ2: スケーリング

$$ \bm{S}’ = \frac{\bm{S}}{\sqrt{d_k}} $$

ステップ3: softmaxによる正規化

$$ \bm{A} = \text{softmax}(\bm{S}’) \in \mathbb{R}^{n \times n} $$

ステップ4: 加重和

$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \bm{A}\bm{V} \in \mathbb{R}^{n \times d_v} $$

マルチヘッドアテンションの数式

基本構造

マルチヘッドアテンションでは、$h$ 個のAttentionヘッドを並列に計算します。各ヘッドは、入力を異なる部分空間に射影し、独立にAttentionを計算します。

入力 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ に対して、各ヘッド $i$ ($i = 1, \dots, h$) は固有の射影行列を持ちます。

$$ \bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad \bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad \bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v} $$

各ヘッドの計算

各ヘッドでは、まず入力を射影してQuery, Key, Valueを生成します。

$$ \bm{Q}_i = \bm{X}\bm{W}_i^Q, \quad \bm{K}_i = \bm{X}\bm{W}_i^K, \quad \bm{V}_i = \bm{X}\bm{W}_i^V $$

次に、Scaled Dot-Product Attentionを計算します。

$$ \begin{equation} \text{head}_i = \text{Attention}(\bm{Q}_i, \bm{K}_i, \bm{V}_i) = \text{softmax}\!\left(\frac{\bm{Q}_i\bm{K}_i^\top}{\sqrt{d_k}}\right)\bm{V}_i \end{equation} $$

ヘッドの結合と出力射影

$h$ 個のヘッドの出力を横方向に結合(concatenate)します。

$$ \text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} $$

結合した出力に、出力射影行列 $\bm{W}^O \in \mathbb{R}^{(h \cdot d_v) \times d_{\text{model}}}$ を適用して、最終出力を得ます。

$$ \begin{equation} \text{MultiHead}(\bm{X}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \, \bm{W}^O \end{equation} $$

次元設計の典型例

パラメータ 説明
$d_{\text{model}}$ 512 モデルの埋め込み次元
$h$ 8 ヘッド数
$d_k = d_v$ 64 各ヘッドの次元($d_{\text{model}} / h$)

Query, Key, Valueの役割

直感的理解: 辞書検索のアナロジー

概念 辞書での役割 Attentionでの役割
Query ($\bm{Q}$) 調べたいこと 注目する側のトークンの特徴
Key ($\bm{K}$) 見出し語 注目される側のトークンの特徴
Value ($\bm{V}$) 本文 注目される側が持つ情報

各ヘッドが学習するパターン

学習が進むと、各ヘッドは異なる関係性パターンを獲得することが実験的に観察されています。

  • ヘッドA: 隣接する単語への注目(ローカルな文脈)
  • ヘッドB: 動詞と目的語の関係
  • ヘッドC: 代名詞とその先行詞の関係(共参照)
  • ヘッドD: 文の末尾への注目(文全体の要約)

Pythonでのスクラッチ実装

Scaled Dot-Product Attention

import numpy as np
import matplotlib.pyplot as plt


def softmax(x, axis=-1):
    """数値的に安定なsoftmax"""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attention

    Args:
        Q: Query行列 (batch_size, seq_len, d_k)
        K: Key行列 (batch_size, seq_len, d_k)
        V: Value行列 (batch_size, seq_len, d_v)
        mask: マスク行列(オプション)

    Returns:
        output: Attention出力
        attention_weights: Attention重み
    """
    d_k = Q.shape[-1]

    # ステップ1: QK^T を計算
    scores = np.matmul(Q, K.swapaxes(-2, -1))

    # ステップ2: スケーリング
    scores = scores / np.sqrt(d_k)

    # マスク適用(オプション)
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)

    # ステップ3: softmax
    attention_weights = softmax(scores, axis=-1)

    # ステップ4: Valueとの加重和
    output = np.matmul(attention_weights, V)

    return output, attention_weights

MultiHeadAttentionクラス

class MultiHeadAttention:
    """マルチヘッドアテンション(NumPy実装)"""

    def __init__(self, d_model, num_heads, seed=42):
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 重みの初期化(Xavier初期化)
        np.random.seed(seed)
        scale = np.sqrt(2.0 / (d_model + self.d_k))

        self.W_q = np.random.randn(d_model, d_model) * scale
        self.W_k = np.random.randn(d_model, d_model) * scale
        self.W_v = np.random.randn(d_model, d_model) * scale
        self.W_o = np.random.randn(d_model, d_model) * scale

    def split_heads(self, x):
        batch_size, seq_len, _ = x.shape
        x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(0, 2, 1, 3)

    def combine_heads(self, x):
        batch_size, _, seq_len, _ = x.shape
        x = x.transpose(0, 2, 1, 3)
        return x.reshape(batch_size, seq_len, self.d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = np.matmul(x, self.W_q)
        K = np.matmul(x, self.W_k)
        V = np.matmul(x, self.W_v)

        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        context, attention_weights = scaled_dot_product_attention(Q, K, V, mask)

        context = self.combine_heads(context)
        output = np.matmul(context, self.W_o)

        return output, attention_weights


# 動作確認
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
np.random.seed(0)
x = np.random.randn(batch_size, seq_len, d_model)

output, attention_weights = mha.forward(x)

print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
print(f"Attention重み形状: {attention_weights.shape}")

まとめ

本記事では、マルチヘッドアテンションの理論と実装について解説しました。

  • 単一ヘッドの限界: 1種類の関係性パターンしか捉えられない制約を、複数ヘッドで克服
  • Scaled Dot-Product Attention: $\text{softmax}(\bm{Q}\bm{K}^\top / \sqrt{d_k})\bm{V}$ の各ステップを復習
  • マルチヘッドの数式: $h$ 個のヘッドを並列計算し、結合と出力射影を適用
  • 次元設計: $d_k = d_{\text{model}} / h$ とすることで、計算量を維持しながら多様性を獲得
  • Query, Key, Valueの役割: 辞書検索のアナロジーと、射影による役割分離の意味
  • Pythonでのスクラッチ実装: NumPyでマルチヘッドアテンションをゼロから構築

次のステップとして、以下の記事も参考にしてください。