マルチヘッドアテンション(Multi-Head Attention)は、Transformerアーキテクチャの中核をなす機構です。単一のAttentionヘッドでは捉えきれない多様な関係性を、複数のヘッドで並列に学習することで、モデルの表現力を飛躍的に高めます。
GPTやBERTといった大規模言語モデル(LLM)の成功は、このマルチヘッドアテンションの設計に大きく依存しています。本記事では、なぜ単一ヘッドでは不十分なのかを明らかにした上で、マルチヘッドアテンションの数式を丁寧に導出し、Pythonでスクラッチ実装するところまでを解説します。
本記事の内容
- なぜマルチヘッドが必要か
- Scaled Dot-Product Attentionの復習
- マルチヘッドアテンションの数式と図解
- Query, Key, Valueの役割と射影の意味
- Pythonでのスクラッチ実装と可視化
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
なぜマルチヘッドが必要か
単一ヘッドAttentionの限界
Self-Attentionの基本形であるScaled Dot-Product Attentionは、入力 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ から Query $\bm{Q}$、Key $\bm{K}$、Value $\bm{V}$ を生成し、次の式で出力を計算します。
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\!\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$
この単一ヘッドのAttentionには、本質的な限界があります。
限界1: 単一の関係性パターンしか捉えられない
1組の射影行列 $(\bm{W}^Q, \bm{W}^K, \bm{W}^V)$ は、1種類の「類似度の測り方」しか学習できません。しかし、自然言語の文には複数の関係性が同時に存在します。
例えば、”The cat that sat on the mat was black.” という文を考えましょう。
- 構文的関係: “cat” と “was” は主語と述語の関係
- 修飾関係: “cat” と “sat” は関係代名詞節による修飾関係
- 位置的近接: “on” と “the” や “mat” は隣接した位置関係
- 意味的関係: “cat” と “black” は属性の関係
単一ヘッドでこれらすべてを同時に捉えることは困難です。
限界2: 表現部分空間の制約
単一ヘッドでは、$d_k$ 次元の部分空間で類似度を計算します。この空間が小さすぎると表現力が不足し、大きすぎると計算コストが増大します。
マルチヘッドによる解決
マルチヘッドアテンションは、複数の異なる部分空間で並列にAttentionを計算することで、これらの限界を克服します。
$$ \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \, \bm{W}^O $$
$h$ 個のヘッドがそれぞれ異なる関係性パターンを学習し、それらを統合することで、豊かな表現を獲得できます。
Scaled Dot-Product Attentionの復習
マルチヘッドアテンションの構成要素であるScaled Dot-Product Attentionを復習しましょう。
入力と出力
- 入力: Query $\bm{Q} \in \mathbb{R}^{n \times d_k}$、Key $\bm{K} \in \mathbb{R}^{n \times d_k}$、Value $\bm{V} \in \mathbb{R}^{n \times d_v}$
- 出力: $\text{Attention}(\bm{Q}, \bm{K}, \bm{V}) \in \mathbb{R}^{n \times d_v}$
計算ステップ
ステップ1: 類似度行列の計算
$$ \bm{S} = \bm{Q}\bm{K}^\top \in \mathbb{R}^{n \times n} $$
ステップ2: スケーリング
$$ \bm{S}’ = \frac{\bm{S}}{\sqrt{d_k}} $$
ステップ3: softmaxによる正規化
$$ \bm{A} = \text{softmax}(\bm{S}’) \in \mathbb{R}^{n \times n} $$
ステップ4: 加重和
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \bm{A}\bm{V} \in \mathbb{R}^{n \times d_v} $$
マルチヘッドアテンションの数式
基本構造
マルチヘッドアテンションでは、$h$ 個のAttentionヘッドを並列に計算します。各ヘッドは、入力を異なる部分空間に射影し、独立にAttentionを計算します。
入力 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ に対して、各ヘッド $i$ ($i = 1, \dots, h$) は固有の射影行列を持ちます。
$$ \bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad \bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad \bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v} $$
各ヘッドの計算
各ヘッドでは、まず入力を射影してQuery, Key, Valueを生成します。
$$ \bm{Q}_i = \bm{X}\bm{W}_i^Q, \quad \bm{K}_i = \bm{X}\bm{W}_i^K, \quad \bm{V}_i = \bm{X}\bm{W}_i^V $$
次に、Scaled Dot-Product Attentionを計算します。
$$ \begin{equation} \text{head}_i = \text{Attention}(\bm{Q}_i, \bm{K}_i, \bm{V}_i) = \text{softmax}\!\left(\frac{\bm{Q}_i\bm{K}_i^\top}{\sqrt{d_k}}\right)\bm{V}_i \end{equation} $$
ヘッドの結合と出力射影
$h$ 個のヘッドの出力を横方向に結合(concatenate)します。
$$ \text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} $$
結合した出力に、出力射影行列 $\bm{W}^O \in \mathbb{R}^{(h \cdot d_v) \times d_{\text{model}}}$ を適用して、最終出力を得ます。
$$ \begin{equation} \text{MultiHead}(\bm{X}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \, \bm{W}^O \end{equation} $$
次元設計の典型例
| パラメータ | 値 | 説明 |
|---|---|---|
| $d_{\text{model}}$ | 512 | モデルの埋め込み次元 |
| $h$ | 8 | ヘッド数 |
| $d_k = d_v$ | 64 | 各ヘッドの次元($d_{\text{model}} / h$) |
Query, Key, Valueの役割
直感的理解: 辞書検索のアナロジー
| 概念 | 辞書での役割 | Attentionでの役割 |
|---|---|---|
| Query ($\bm{Q}$) | 調べたいこと | 注目する側のトークンの特徴 |
| Key ($\bm{K}$) | 見出し語 | 注目される側のトークンの特徴 |
| Value ($\bm{V}$) | 本文 | 注目される側が持つ情報 |
各ヘッドが学習するパターン
学習が進むと、各ヘッドは異なる関係性パターンを獲得することが実験的に観察されています。
- ヘッドA: 隣接する単語への注目(ローカルな文脈)
- ヘッドB: 動詞と目的語の関係
- ヘッドC: 代名詞とその先行詞の関係(共参照)
- ヘッドD: 文の末尾への注目(文全体の要約)
Pythonでのスクラッチ実装
Scaled Dot-Product Attention
import numpy as np
import matplotlib.pyplot as plt
def softmax(x, axis=-1):
"""数値的に安定なsoftmax"""
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled Dot-Product Attention
Args:
Q: Query行列 (batch_size, seq_len, d_k)
K: Key行列 (batch_size, seq_len, d_k)
V: Value行列 (batch_size, seq_len, d_v)
mask: マスク行列(オプション)
Returns:
output: Attention出力
attention_weights: Attention重み
"""
d_k = Q.shape[-1]
# ステップ1: QK^T を計算
scores = np.matmul(Q, K.swapaxes(-2, -1))
# ステップ2: スケーリング
scores = scores / np.sqrt(d_k)
# マスク適用(オプション)
if mask is not None:
scores = np.where(mask == 0, -1e9, scores)
# ステップ3: softmax
attention_weights = softmax(scores, axis=-1)
# ステップ4: Valueとの加重和
output = np.matmul(attention_weights, V)
return output, attention_weights
MultiHeadAttentionクラス
class MultiHeadAttention:
"""マルチヘッドアテンション(NumPy実装)"""
def __init__(self, d_model, num_heads, seed=42):
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 重みの初期化(Xavier初期化)
np.random.seed(seed)
scale = np.sqrt(2.0 / (d_model + self.d_k))
self.W_q = np.random.randn(d_model, d_model) * scale
self.W_k = np.random.randn(d_model, d_model) * scale
self.W_v = np.random.randn(d_model, d_model) * scale
self.W_o = np.random.randn(d_model, d_model) * scale
def split_heads(self, x):
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(0, 2, 1, 3)
def combine_heads(self, x):
batch_size, _, seq_len, _ = x.shape
x = x.transpose(0, 2, 1, 3)
return x.reshape(batch_size, seq_len, self.d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = np.matmul(x, self.W_q)
K = np.matmul(x, self.W_k)
V = np.matmul(x, self.W_v)
Q = self.split_heads(Q)
K = self.split_heads(K)
V = self.split_heads(V)
context, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
context = self.combine_heads(context)
output = np.matmul(context, self.W_o)
return output, attention_weights
# 動作確認
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
np.random.seed(0)
x = np.random.randn(batch_size, seq_len, d_model)
output, attention_weights = mha.forward(x)
print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
print(f"Attention重み形状: {attention_weights.shape}")
まとめ
本記事では、マルチヘッドアテンションの理論と実装について解説しました。
- 単一ヘッドの限界: 1種類の関係性パターンしか捉えられない制約を、複数ヘッドで克服
- Scaled Dot-Product Attention: $\text{softmax}(\bm{Q}\bm{K}^\top / \sqrt{d_k})\bm{V}$ の各ステップを復習
- マルチヘッドの数式: $h$ 個のヘッドを並列計算し、結合と出力射影を適用
- 次元設計: $d_k = d_{\text{model}} / h$ とすることで、計算量を維持しながら多様性を獲得
- Query, Key, Valueの役割: 辞書検索のアナロジーと、射影による役割分離の意味
- Pythonでのスクラッチ実装: NumPyでマルチヘッドアテンションをゼロから構築
次のステップとして、以下の記事も参考にしてください。