私たちが文章を読むとき、無意識のうちに複数の視点から情報を処理しています。たとえば「彼女は銀行の前で友人と待ち合わせをしていた」という文を読んだとき、頭の中では少なくとも3つの処理が同時に走っています。1つ目は文法的な構造の把握(「彼女は」が主語で「待ち合わせをしていた」が述語)、2つ目は単語の意味的な関係(「銀行」は金融機関か川岸か、文脈から判断する)、そして3つ目は位置的な近接関係(「前で」が「銀行」にかかる)です。
人間がこのような多面的な言語処理を自然にこなしているのに対し、Single-Head Attentionは1つの「視点」しか持ちません。1つの注意パターンで文法・意味・位置関係をすべてカバーしようとするのは、1つの目で奥行き・色・動きを同時に知覚しようとするようなものです。この制約を突破するために生まれたのが、Multi-Head Attention(MHA)です。
Multi-Head Attentionを理解することで、以下のような幅広い応用への道が開けます。
- 自然言語処理: GPTアーキテクチャやBERTアーキテクチャの中核機構として、機械翻訳・文章生成・質問応答など、現代のNLPモデルのほぼ全てがMHAに依存しています
- コンピュータビジョン: Vision Transformer(ViT)では画像パッチ間の関係をMHAで捉え、CNNに匹敵する画像認識性能を実現しています
- 音声処理・タンパク質構造予測: Whisper(音声認識)やAlphaFold2(タンパク質構造予測)など、シーケンスデータを扱うあらゆる分野でMHAが活躍しています
本記事の内容
- Single-Head Attentionの復習と限界
- Multi-Head Attentionの直感的理解
- 部分空間への射影の幾何学的意味
- MHAの数学的定義と数式の読み方
- 計算量分析 — なぜコストが増えないのか
- NumPyによるスクラッチ実装と可視化
- PyTorchによる実装
- ヘッド数による表現力の違いの実験
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
- Self-Attention機構 — Scaled Dot-Product Attentionの仕組み
- Transformerアーキテクチャ — MHAがどこで使われるかの全体像
- 行列の基本演算 — 行列積・転置の基本操作
Single-Head Attention(復習)
Multi-Head Attentionを理解するためには、まずSingle-Head Attention(Scaled Dot-Product Attention)の仕組みをしっかり押さえておく必要があります。詳細はSelf-Attention機構の記事で解説していますが、ここでは要点を振り返ります。
直感的な理解
Attentionの本質は「検索エンジン」のようなものです。あなたが図書館で「量子力学の入門書」を探しているとしましょう。あなたの頭の中にある「量子力学・入門」というキーワードがクエリ($\bm{Q}$)です。本棚に並んでいる各書籍のタイトル・キーワードがキー($\bm{K}$)に対応します。そして、各書籍の中身そのものがバリュー($\bm{V}$)です。
Attentionは、クエリとキーの類似度を計算し、類似度が高いバリューを重点的に取り出す仕組みです。これにより、シーケンス中のどの位置の情報が「今注目すべきか」を動的に決定できます。
数学的な定義
Scaled Dot-Product Attentionは、以下の式で定義されます。
$$ \begin{equation} \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} \end{equation} $$
この式を段階的に読み解いていきましょう。
まず、$\bm{Q}\bm{K}^\top$ はクエリとキーの内積を計算しています。内積が大きいほど、2つのベクトルの方向が近い、つまり「関連性が高い」ことを意味します。ただし、次元数 $d_k$ が大きくなると内積の値も大きくなり、ソフトマックス関数の出力が極端な値(ほぼ0かほぼ1)に偏ってしまいます。これを防ぐために $\sqrt{d_k}$ で割るスケーリングを行います。
各変数の意味は以下のとおりです。
- $\bm{Q} \in \mathbb{R}^{n \times d_k}$: クエリ行列($n$ 個のトークンそれぞれの「検索キーワード」を $d_k$ 次元ベクトルで表現)
- $\bm{K} \in \mathbb{R}^{n \times d_k}$: キー行列(各トークンが持つ「索引ラベル」)
- $\bm{V} \in \mathbb{R}^{n \times d_v}$: バリュー行列(各トークンが持つ「実際の情報」)
- $d_k$: キー・クエリの次元数(スケーリング係数の計算に使用)
Single-Head Attentionの限界
Single-Head Attentionは強力ですが、1つの重大な制約があります。それは、softmaxが1つの注意分布しか生成しないということです。
softmax関数は確率分布を出力するため、その出力値の合計は必ず1になります。これは、あるトークンが注意を向ける「総量」が固定されていることを意味します。すると、たとえば「文法構造に注意したい」と「意味的類似性に注意したい」という2つの需要がある場合、1つのsoftmax出力で両方を同時に表現することは困難です。
具体的に言えば、「The cat sat on the mat because it was tired」という文で「it」が何を指すかを理解するには、「it」→「cat」への意味的注意と、「it」→「sat」への構文的注意の両方が必要です。Single-Head Attentionでは、これらを1つの注意パターンに詰め込まなければなりません。
この制約を克服するために、「複数のヘッドを並列に走らせる」というアイデアが生まれました。それがMulti-Head Attentionです。
なぜマルチヘッドにするのか — 直感的理解
アナロジー: 騒がしいカフェでの聴覚処理
Multi-Head Attentionの直感を掴むために、騒がしいカフェにいる場面を想像してみましょう。
あなたはカフェで友人と会話しています。このとき、あなたの脳は同時に複数の「聴覚チャンネル」を走らせています。
- チャンネル1(会話追跡): 友人の声を選択的に追跡し、周囲のノイズから分離する
- チャンネル2(環境モニタリング): BGMのリズムや周囲の雰囲気を無意識に感じ取る
- チャンネル3(危険検知): 皿が割れる音やドアが急に開く音など、異常な音に瞬時に反応する準備をする
これらのチャンネルは、同じ音の入力を受け取りながら、それぞれ異なる「フィルター」を通して情報を抽出しています。会話追跡チャンネルは人間の声の周波数帯に感度が高く、危険検知チャンネルは突発的な大きな音に反応するように調整されています。
Multi-Head Attentionはまさにこの仕組みを模しています。同じ入力シーケンスに対して、各ヘッドが異なる「フィルター」(射影行列)を通して情報を見ることで、多面的な関係性を同時に捉えるのです。
自然言語処理における具体例
Attention Is All You Needの原論文で使われた8ヘッドのTransformerでは、学習後に各ヘッドが以下のような異なる言語的パターンを学習することが観察されています。
- ヘッドA(構文ヘッド): 主語と述語動詞の対応を追跡します。「The students who passed the exam were happy」のような文では、「students」と「were」の間の長距離依存関係を捉えます
- ヘッドB(共参照ヘッド): 代名詞とその先行詞の関係を学習します。「Alice told Bob that she would be late」で「she」→「Alice」の対応を見つけます
- ヘッドC(近傍ヘッド): 隣接するトークン間の局所的なパターン(形容詞-名詞、冠詞-名詞などのバイグラム的関係)に注目します
- ヘッドD(位置ヘッド): 特定の相対位置(例えば2つ前のトークン)に固定的に注目する、位置ベースの注意パターンを形成します
このように、各ヘッドが異なる言語的「役割」を自己組織的に獲得することで、モデル全体として豊かな言語理解を実現しています。これは人間が文章を読むときに文法・意味・文脈を同時処理するのと驚くほど似ています。
数学的な直感: 表現力の向上
もう少し厳密に考えてみましょう。Single-Head Attentionのsoftmax出力 $\bm{A} \in \mathbb{R}^{n \times n}$ は、各行の合計が1になるという制約を持ちます。これはランク1に近い行列になりやすいことを意味します。つまり、注意パターンが「特定の少数のトークンに集中する」か「全体に均等に分散する」かの二択に偏りがちです。
一方、Multi-Head Attentionでは $h$ 個の異なる注意行列 $\bm{A}_1, \bm{A}_2, \dots, \bm{A}_h$ が独立に学習されます。これらをそれぞれ異なる部分空間に適用して連結し、出力射影行列 $\bm{W}^O$ で混合することで、はるかに表現力の高い出力を生成できます。直感的には、1枚のフィルターで写真を撮るより、複数のフィルターで撮った画像を合成するほうが豊かな情報を得られるのと同じです。
では、「異なる部分空間」とは具体的に何を意味するのでしょうか。次のセクションでは、射影行列が果たす幾何学的な役割を詳しく見ていきます。
部分空間への射影 — 幾何学的な直感
高次元空間を「複数の窓」から覗く
Multi-Head Attentionの中核にあるのは、入力ベクトルを複数の低次元部分空間に射影するという操作です。この操作の意味を幾何学的に理解しましょう。
たとえば、3次元空間にある物体を考えてください。この物体を正面から見た図($xy$平面への射影)と、上から見た図($xz$平面への射影)と、横から見た図($yz$平面への射影)を合わせると、元の3次元の形状をかなり正確に復元できます。1つの視点(1つの平面への射影)だけでは見えない特徴が、別の視点からは明確に見えることがあります。
Multi-Head Attentionでも同じことが起きています。$d_{\text{model}}$ 次元の入力ベクトルを、$h$ 個の $d_k$ 次元部分空間に線形変換(射影)します。各部分空間では、元の高次元空間での異なる「側面」が強調されます。
射影行列の役割
具体的に、第 $i$ ヘッドのクエリ射影行列 $\bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$ が何をしているかを考えます。
入力ベクトル $\bm{x} \in \mathbb{R}^{d_{\text{model}}}$(あるトークンの埋め込み表現)に $\bm{W}_i^Q$ を掛けると、結果は $d_k$ 次元のベクトル $\bm{q}_i = \bm{x}\bm{W}_i^Q \in \mathbb{R}^{d_k}$ になります。
この操作は、$d_{\text{model}}$ 次元空間から $d_k$ 次元の部分空間への射影です。射影行列 $\bm{W}_i^Q$ の各列は、射影先の部分空間の基底ベクトルを(暗黙的に)定義しています。学習によってこの基底ベクトルが最適化されることで、各ヘッドが「見るべき方向」を自動的に獲得するのです。
重要なのは、異なるヘッドの射影行列は独立に学習されるということです。ヘッド1の $\bm{W}_1^Q$ とヘッド2の $\bm{W}_2^Q$ は一般に異なる部分空間への射影を表し、それぞれが入力の異なる側面を切り出します。
直感的なまとめ
射影の意味をまとめると、以下のようになります。
| 操作 | 幾何学的意味 | アナロジー |
|---|---|---|
| $\bm{Q}_i = \bm{X}\bm{W}_i^Q$ | 入力を第 $i$ 部分空間のクエリ表現に射影 | 「第 $i$ の視点」から物体を見る |
| $\bm{K}_i = \bm{X}\bm{W}_i^K$ | 入力を第 $i$ 部分空間のキー表現に射影 | その視点での「索引」を作る |
| $\bm{V}_i = \bm{X}\bm{W}_i^V$ | 入力を第 $i$ 部分空間のバリュー表現に射影 | その視点での「情報」を取り出す |
| $\text{Concat} \cdot \bm{W}^O$ | 全ヘッドの出力を統合 | 複数の視点を合成して全体像を復元 |
各ヘッドは「異なる窓」から情報を覗き見て、最後にそれらを統合することで、元の高次元空間での豊かな関係性を捉えます。
この幾何学的直感を踏まえたうえで、次のセクションではMulti-Head Attentionを数式で厳密に定義します。
Multi-Head Attentionの数学的定義
全体の流れ
Multi-Head Attentionの計算は、大きく3つのステップで構成されます。
- 射影: 入力を $h$ 個の部分空間にそれぞれ射影する
- 並列Attention: 各部分空間で独立にScaled Dot-Product Attentionを計算する
- 統合: 全ヘッドの出力を連結し、出力射影行列で元の次元に戻す
定義式
この3ステップを数式で表すと、Multi-Head Attentionは以下のように定義されます。
$$ \begin{equation} \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\bm{W}^O \end{equation} $$
ここで、各ヘッドの出力 $\text{head}_i$ は、入力をヘッド固有の射影行列で変換したうえでAttentionを適用して得られます。
$$ \begin{equation} \text{head}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q,\; \bm{K}\bm{W}_i^K,\; \bm{V}\bm{W}_i^V) \end{equation} $$
式(2)を展開すると、各ヘッドの内部でやっていることが見えてきます。
まず、入力行列にヘッド固有の射影行列を掛けて、部分空間上のクエリ・キー・バリューを作ります。
$$ \bm{Q}_i = \bm{Q}\bm{W}_i^Q, \quad \bm{K}_i = \bm{K}\bm{W}_i^K, \quad \bm{V}_i = \bm{V}\bm{W}_i^V $$
次に、この部分空間上でScaled Dot-Product Attentionを計算します。
$$ \text{head}_i = \text{softmax}\left(\frac{\bm{Q}_i \bm{K}_i^\top}{\sqrt{d_k}}\right)\bm{V}_i $$
最後に、全ヘッドの出力を連結して出力射影行列を掛けます。
$$ \text{output} = [\text{head}_1 ; \text{head}_2 ; \dots ; \text{head}_h] \bm{W}^O $$
パラメータの形状
各パラメータの次元を整理しておきましょう。
- $\bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$: 第 $i$ ヘッドのクエリ射影行列
- $\bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$: 第 $i$ ヘッドのキー射影行列
- $\bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$: 第 $i$ ヘッドのバリュー射影行列
- $\bm{W}^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$: 出力射影行列
通常、$d_k = d_v = d_{\text{model}} / h$ と設定します。これにより、各ヘッドの次元は元の次元の $1/h$ になります。たとえば $d_{\text{model}} = 512$、$h = 8$ のとき、各ヘッドは $d_k = 64$ 次元の部分空間で動作します。
連結と出力射影の意味
式(1)の $\text{Concat}(\cdot)\bm{W}^O$ の部分について、もう少し詳しく考えてみましょう。
各ヘッドの出力 $\text{head}_i \in \mathbb{R}^{n \times d_v}$ を連結すると、$\mathbb{R}^{n \times hd_v}$ のテンソルが得られます。$hd_v = d_{\text{model}}$ なので、連結後のテンソルは $\mathbb{R}^{n \times d_{\text{model}}}$ です。
しかし、単に連結しただけでは、各ヘッドの情報が独立したまま並んでいるだけです。出力射影行列 $\bm{W}^O$ を掛けることで、異なるヘッドが捉えた情報を混合します。これは、前述の「複数の視点を合成して全体像を復元する」操作に対応します。
$\bm{W}^O$ も学習可能なパラメータなので、どのヘッドの情報をどの程度重視するかはデータから自動的に最適化されます。
この数式の定義から、Multi-Head Attentionは単にAttentionを並列化しているだけでなく、射影→並列処理→統合という巧みな構造を持っていることがわかります。次に、この構造が計算コストの面でどのような利点を持つかを分析します。
計算量の分析 — なぜコストが同じなのか
直感的な理解
Multi-Head Attentionは「ヘッド数分だけ計算が増えるのでは?」と思われがちですが、実はSingle-Head Attentionと総計算量は同じです。これは非常に巧妙な設計です。
たとえ話で説明しましょう。512ページの本を1人で読むか、8人で64ページずつ分担して読むかを考えてください。1人で512ページ読む作業量と、8人が各64ページ読む総作業量は同じです。しかし、8人が並列に読めば、時間は $1/8$ で済みます。さらに、それぞれが異なる観点(文法チェック担当、事実確認担当、文体チェック担当、…)で読めば、1人で全観点をカバーしようとするより質の高いレビューが可能です。
Multi-Head Attentionは、まさにこの「分担読み」の仕組みです。各ヘッドの次元を $d_k = d_{\text{model}} / h$ と設定することで、計算量を維持しながら多様性を獲得しています。
数式による計算量比較
これを数式で確認しましょう。
Single-Head Attention($d_{\text{model}}$ 次元でのAttention)の場合を考えます。
Attentionの計算の主要部分は $\bm{Q}\bm{K}^\top$($n \times d_{\text{model}}$ と $d_{\text{model}} \times n$ の行列積)です。この計算量は以下のとおりです。
$$ O(n^2 \cdot d_{\text{model}}) $$
Multi-Head Attention($h$ ヘッド、各 $d_k = d_{\text{model}} / h$)の場合、各ヘッドの計算量は以下です。
$$ O(n^2 \cdot d_k) = O\left(n^2 \cdot \frac{d_{\text{model}}}{h}\right) $$
$h$ 個のヘッドの計算量を合計すると、ヘッド数 $h$ とヘッドあたりの次元 $d_{\text{model}}/h$ が打ち消し合います。
$$ h \times O\left(n^2 \cdot \frac{d_{\text{model}}}{h}\right) = O(n^2 \cdot d_{\text{model}}) $$
このように、ヘッド数を増やしても総計算量は変わりません。これがMulti-Head Attentionの設計の鍵です。
射影行列の計算コスト
厳密には、射影行列 $\bm{W}_i^Q, \bm{W}_i^K, \bm{W}_i^V$ による線形変換の計算コストも考慮する必要があります。
全ヘッドのクエリ射影行列をまとめると $\bm{W}^Q = [\bm{W}_1^Q ; \bm{W}_2^Q ; \dots ; \bm{W}_h^Q] \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$ となり、これは $d_{\text{model}} \times d_{\text{model}}$ の行列1つと同じです。同様に $\bm{W}^K$、$\bm{W}^V$、$\bm{W}^O$ もそれぞれ $d_{\text{model}} \times d_{\text{model}}$ の行列です。
したがって、射影に要する計算量は以下のとおりです。
$$ 4 \times O(n \cdot d_{\text{model}}^2) $$
Single-Head Attentionで同等の射影を行う場合も同じオーダーの計算が必要なので、射影のコストも増えていません。
パラメータ数の比較
パラメータ数についても確認しておきましょう。
| 構成要素 | パラメータ数 |
|---|---|
| クエリ射影 $\bm{W}^Q$(全ヘッド合計) | $d_{\text{model}}^2$ |
| キー射影 $\bm{W}^K$(全ヘッド合計) | $d_{\text{model}}^2$ |
| バリュー射影 $\bm{W}^V$(全ヘッド合計) | $d_{\text{model}}^2$ |
| 出力射影 $\bm{W}^O$ | $d_{\text{model}}^2$ |
| 合計 | $4 d_{\text{model}}^2$ |
ヘッド数 $h$ に依存しないことに注目してください。ヘッド数を変えても、学習すべきパラメータの総数は変わりません。変わるのは「パラメータをどう分割するか」だけです。
計算量とパラメータ数が同じなら、Multi-Head Attentionは「タダで多様性を手に入れている」ことになります。では、実際のTransformerではどのような次元設定が使われているのでしょうか。次のセクションで具体的な数値を確認します。
具体的な次元の例
原論文の設定
Attention Is All You Needの原論文で使われた設定を確認しましょう。
| パラメータ | 値 | 意味 |
|---|---|---|
| $d_{\text{model}}$ | 512 | モデルの隠れ層の次元数 |
| $h$ | 8 | ヘッド数 |
| $d_k = d_v$ | 64 | 各ヘッドの次元数($512 / 8$) |
各ヘッドは64次元の部分空間でAttentionを計算します。512次元の空間を8つの64次元部分空間に分割して、それぞれ異なる「視点」からAttentionを計算するわけです。
他のモデルの設定
比較のために、代表的なモデルの設定も見てみましょう。
| モデル | $d_{\text{model}}$ | $h$ | $d_k$ |
|---|---|---|---|
| Transformer Base | 512 | 8 | 64 |
| Transformer Big | 1024 | 16 | 64 |
| BERT-Base | 768 | 12 | 64 |
| BERT-Large | 1024 | 16 | 64 |
| GPT-2 | 768 | 12 | 64 |
| GPT-3 (175B) | 12288 | 96 | 128 |
興味深いことに、$d_k = 64$ という値が多くのモデルで共通しています。これは、64次元が「1つのヘッドが意味のある注意パターンを学習するのに十分な次元数」であることを示唆しています。モデルの規模が大きくなると、$d_k$ を増やすのではなく、ヘッド数 $h$ を増やす(つまり「視点の数を増やす」)方向でスケールアップする傾向があります。
GPT-3のように超大規模なモデルでは $d_k = 128$ とやや大きくなっていますが、それでもヘッド数96という多数のヘッドを使うことで、入力を非常に多くの「観点」から分析しています。
具体的な数値のイメージが掴めたところで、いよいよMulti-Head Attentionをゼロから実装してみましょう。まずはNumPyによるスクラッチ実装で、各ステップの処理を細かく追いかけます。
NumPyによるスクラッチ実装
実装の方針
まずはNumPyを使って、Multi-Head Attentionの各ステップを明示的に実装します。ループでヘッドごとに処理を書くことで、「射影→Attention→連結→出力射影」の流れを一つ一つ確認できます。
最初に、基本部品であるsoftmax関数とScaled Dot-Product Attentionを実装します。
import numpy as np
import matplotlib.pyplot as plt
def softmax(x, axis=-1):
"""数値的に安定なsoftmax関数
最大値を引くことでオーバーフローを防ぎます。
softmax(x) = softmax(x - max(x)) の性質を利用しています。
"""
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V):
"""Scaled Dot-Product Attention
Parameters
----------
Q : ndarray, shape (seq_len, d_k) — クエリ行列
K : ndarray, shape (seq_len, d_k) — キー行列
V : ndarray, shape (seq_len, d_v) — バリュー行列
Returns
-------
output : ndarray, shape (seq_len, d_v) — Attention適用後の出力
weights : ndarray, shape (seq_len, seq_len) — Attention重み行列
"""
d_k = Q.shape[-1]
# 1. クエリとキーの内積でスコアを計算
scores = Q @ K.T # (seq_len, seq_len)
# 2. sqrt(d_k)でスケーリング(勾配消失を防ぐ)
scores = scores / np.sqrt(d_k)
# 3. softmaxで確率分布に変換
weights = softmax(scores)
# 4. Attention重みでバリューの加重和を計算
output = weights @ V # (seq_len, d_v)
return output, weights
このコードでは、Scaled Dot-Product Attentionを4つのステップに分解しています。ステップ1で類似度スコアを計算し、ステップ2でスケーリングを行い、ステップ3で正規化して確率分布にし、ステップ4でバリューの加重和を求めます。softmax関数では数値安定性のために最大値を引く工夫を入れています。これは $\text{softmax}(\bm{x}) = \text{softmax}(\bm{x} – \max(\bm{x}))$ という性質を利用したもので、大きな指数関数のオーバーフローを防ぎます。
次に、Multi-Head Attentionクラスを実装します。
class MultiHeadAttention:
"""Multi-Head Attention(NumPyスクラッチ実装)
射影 → 並列Attention → 連結 → 出力射影 の4ステップを
明示的にループで実装しています。
"""
def __init__(self, d_model, n_heads):
assert d_model % n_heads == 0, \
f"d_model({d_model})はn_heads({n_heads})で割り切れる必要があります"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 各ヘッドの次元
# Xavier初期化(分散を適切に保つ)
np.random.seed(42)
scale = np.sqrt(2.0 / d_model)
# 各ヘッドの射影行列を個別に作成
self.W_Q = [np.random.randn(d_model, self.d_k) * scale
for _ in range(n_heads)]
self.W_K = [np.random.randn(d_model, self.d_k) * scale
for _ in range(n_heads)]
self.W_V = [np.random.randn(d_model, self.d_k) * scale
for _ in range(n_heads)]
# 出力射影行列: (n_heads * d_k, d_model) = (d_model, d_model)
self.W_O = np.random.randn(n_heads * self.d_k, d_model) * scale
def forward(self, X):
"""Multi-Head Attentionの順伝播
Parameters
----------
X : ndarray, shape (seq_len, d_model)
入力シーケンス(各トークンのd_model次元埋め込み)
Returns
-------
output : ndarray, shape (seq_len, d_model)
MHAの出力
attention_weights : list of ndarray
各ヘッドのAttention重み行列
"""
head_outputs = []
attention_weights = []
for i in range(self.n_heads):
# ステップ1: 部分空間への射影
Q_i = X @ self.W_Q[i] # (seq_len, d_model) @ (d_model, d_k) = (seq_len, d_k)
K_i = X @ self.W_K[i] # 同上
V_i = X @ self.W_V[i] # 同上
# ステップ2: 部分空間でのAttention計算
head_out, weights = scaled_dot_product_attention(Q_i, K_i, V_i)
head_outputs.append(head_out) # (seq_len, d_k)
attention_weights.append(weights) # (seq_len, seq_len)
# ステップ3: 全ヘッドの出力を連結
concat = np.concatenate(head_outputs, axis=-1) # (seq_len, n_heads * d_k)
# ステップ4: 出力射影で元の次元に戻す
output = concat @ self.W_O # (seq_len, d_model)
return output, attention_weights
このクラスのforwardメソッドは、先ほど説明した4つのステップを忠実に実装しています。ステップ1では入力行列 $\bm{X}$ にヘッド固有の射影行列を掛けて、$d_k$ 次元の部分空間にクエリ・キー・バリューを射影しています。ステップ2では各部分空間で独立にScaled Dot-Product Attentionを計算します。ステップ3で全ヘッドの出力を横方向に連結し、ステップ4で出力射影行列 $\bm{W}^O$ を掛けて元の $d_{\text{model}}$ 次元に戻しています。
実装が正しいことを確認するために、具体的なトークン列でMHAを実行し、各ヘッドのAttention重みを可視化してみましょう。
# デモ: 6つのトークンに対するMulti-Head Attention
tokens = ["I", "love", "machine", "learning", "very", "much"]
seq_len = len(tokens)
d_model = 16 # 小さめの次元で実験
n_heads = 4 # 4つのヘッド
# ランダムな埋め込みベクトルを生成(実際にはWord2VecやBPEで得る)
np.random.seed(0)
X = np.random.randn(seq_len, d_model)
# Multi-Head Attentionの実行
mha = MultiHeadAttention(d_model, n_heads)
output, attn_weights = mha.forward(X)
# 入出力の形状を確認
print(f"入力形状: {X.shape}") # (6, 16)
print(f"出力形状: {output.shape}") # (6, 16)
print(f"ヘッド数: {n_heads}")
print(f"各ヘッドの次元: {d_model // n_heads}")
print(f"Attention重みの形状(各ヘッド): {attn_weights[0].shape}") # (6, 6)
出力から、入力と出力の形状が同じ $(6, 16)$ であることが確認できます。これはMulti-Head Attentionが入力の次元を変えない、つまり残差接続と組み合わせやすい設計になっていることを意味しています。Transformer EncoderやTransformer Decoderでは、MHAの出力を入力に足し合わせる残差接続が使われており、この次元の一致は不可欠です。
続いて、各ヘッドのAttention重みをヒートマップで可視化します。
# 各ヘッドのAttention重みをヒートマップで可視化
fig, axes = plt.subplots(1, n_heads, figsize=(4 * n_heads, 4))
for i in range(n_heads):
im = axes[i].imshow(attn_weights[i], cmap='Blues', vmin=0, vmax=1)
axes[i].set_xticks(range(seq_len))
axes[i].set_yticks(range(seq_len))
axes[i].set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
axes[i].set_yticklabels(tokens, fontsize=9)
axes[i].set_title(f'Head {i+1}', fontsize=14)
axes[i].set_xlabel('Key (注目先)', fontsize=10)
if i == 0:
axes[i].set_ylabel('Query (注目元)', fontsize=10)
plt.colorbar(im, ax=axes[i], fraction=0.046)
plt.suptitle('Multi-Head Attention Weights', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig("mha_attention_weights.png", dpi=150, bbox_inches='tight')
plt.show()
このヒートマップは、各ヘッドがどのトークンからどのトークンに注意を向けているかを示しています。行が「注目元」(クエリ)、列が「注目先」(キー)に対応し、色が濃いほど強い注意を意味します。
ここではランダムな重みを使っているため、学習済みモデルのように「構文ヘッド」「意味ヘッド」といった明確な役割分担は見えません。しかし、4つのヘッドがそれぞれ異なる注意パターンを生成していることが確認できます。あるヘッドは対角線付近(自分自身や隣接トークン)に注意が集中し、別のヘッドはより広範囲に注意を分散させています。この多様性こそが、Multi-Head Attentionの表現力の源泉です。
学習が進むにつれて、各ヘッドは入力データの統計的構造に合わせて射影行列を最適化し、先に述べたような言語的に意味のある注意パターンを獲得していきます。
NumPyでの実装は各ステップの処理を明示的に追えるという利点がありますが、実際の深層学習ではGPU並列計算に対応したフレームワークが必要です。次のセクションでは、PyTorchを使った効率的な実装を見ていきます。
PyTorchによる実装
効率的な実装のポイント
NumPy版ではヘッドごとにループで処理しましたが、PyTorch版では全ヘッドの射影を1つの大きな行列積としてまとめて計算します。これにより、GPU上での並列計算が効率的に行えます。
具体的には、$h$ 個の射影行列 $\bm{W}_1^Q, \bm{W}_2^Q, \dots, \bm{W}_h^Q$(それぞれ $d_{\text{model}} \times d_k$)を横に並べて1つの $d_{\text{model}} \times d_{\text{model}}$ 行列としてまとめます。こうすると、1回の行列積で全ヘッドの射影を同時に計算できます。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttentionPyTorch(nn.Module):
"""Multi-Head Attention(PyTorch実装)
全ヘッドの射影を1つのnn.Linearでまとめて計算し、
その後view + transposeでヘッドごとに分割する効率的な実装です。
"""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, \
f"d_model({d_model})はn_heads({n_heads})で割り切れる必要があります"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 全ヘッドの射影をまとめた線形層
# 内部的には (d_model, d_model) の重み行列を持つ
self.W_Q = nn.Linear(d_model, d_model) # 全ヘッドのQ射影
self.W_K = nn.Linear(d_model, d_model) # 全ヘッドのK射影
self.W_V = nn.Linear(d_model, d_model) # 全ヘッドのV射影
self.W_O = nn.Linear(d_model, d_model) # 出力射影
def forward(self, x):
"""
Parameters
----------
x : Tensor, shape (batch_size, seq_len, d_model)
Returns
-------
output : Tensor, shape (batch_size, seq_len, d_model)
attn_weights : Tensor, shape (batch_size, n_heads, seq_len, seq_len)
"""
batch_size, seq_len, _ = x.shape
# ---- ステップ1: 全ヘッドの射影を一括計算し、ヘッドに分割 ----
# (batch, seq, d_model) → (batch, seq, n_heads, d_k) → (batch, n_heads, seq, d_k)
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# ---- ステップ2: Scaled Dot-Product Attention ----
# (batch, n_heads, seq, d_k) @ (batch, n_heads, d_k, seq) = (batch, n_heads, seq, seq)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
# (batch, n_heads, seq, seq) @ (batch, n_heads, seq, d_k) = (batch, n_heads, seq, d_k)
attn_output = torch.matmul(attn_weights, V)
# ---- ステップ3: ヘッドを連結 ----
# (batch, n_heads, seq, d_k) → (batch, seq, n_heads, d_k) → (batch, seq, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
# ---- ステップ4: 出力射影 ----
output = self.W_O(attn_output) # (batch, seq, d_model)
return output, attn_weights
このPyTorch実装のキーポイントは、view と transpose 操作によるテンソルの形状変換です。self.W_Q(x) は (batch, seq, d_model) 形状のテンソルを出力しますが、これは実質的に全ヘッドのクエリ射影が横に連結されたものです。.view(batch_size, seq_len, self.n_heads, self.d_k) で最後の次元を「ヘッド数 x ヘッド次元」に分割し、.transpose(1, 2) でヘッド次元をバッチ次元の隣に持ってくることで、(batch, n_heads, seq, d_k) という形状に変換しています。この形状であれば、torch.matmul が自動的にバッチ次元とヘッド次元にまたがってブロードキャストし、全ヘッドのAttentionを1回の演算で並列計算できます。
ステップ3の .contiguous() は、transpose 後にメモリ上のデータ配置が不連続になっている可能性があるため、.view() の前に連続配置を保証するために呼んでいます。
それでは、この実装を実際に動かして動作確認をしてみましょう。
# 動作確認: Transformer Baseの設定で実行
d_model, n_heads = 512, 8
mha_pt = MultiHeadAttentionPyTorch(d_model, n_heads)
# ダミー入力: バッチサイズ2, シーケンス長10, 次元512
x = torch.randn(2, 10, d_model)
# 推論モードで実行
with torch.no_grad():
output, weights = mha_pt(x)
print(f"入力形状: {x.shape}") # (2, 10, 512)
print(f"出力形状: {output.shape}") # (2, 10, 512)
print(f"Attention重み形状: {weights.shape}") # (2, 8, 10, 10)
print(f"パラメータ数: {sum(p.numel() for p in mha_pt.parameters()):,}")
出力形状が入力と同じ (2, 10, 512) であることが確認できます。Attention重みの形状は (2, 8, 10, 10) で、これは「バッチサイズ2 x 8ヘッド x シーケンス長10 x シーケンス長10」を表しています。各ヘッドが独立に $10 \times 10$ の注意行列を持っていることがわかります。
パラメータ数は $4 \times 512^2 + 4 \times 512 = 1,050,624$ 個(バイアス項含む)です。先ほどの計算量分析で見た $4d_{\text{model}}^2$ に、バイアス項の $4d_{\text{model}}$ を加えた値と一致しています。
PyTorch版の実装を確認できたところで、最後にヘッド数が注意パターンの多様性にどう影響するかを実験で確かめてみましょう。
ヘッド数による注意パターンの違い
実験の目的
Multi-Head Attentionの理論的な利点は「各ヘッドが異なる注意パターンを学習できる」ことでした。ここでは、ヘッド数を変えたときに注意パターンの多様性がどう変化するかを、可視化を通じて観察します。
ランダム初期化の段階でも、射影行列が異なるためヘッドごとに異なるパターンが生まれます。ヘッド数が多いほど多様なパターンが出現するかを確認しましょう。
def visualize_heads_comparison(d_model=32, head_configs=[1, 2, 4, 8]):
"""ヘッド数を変えたときの注意パターンを比較する
Parameters
----------
d_model : int
モデルの次元数
head_configs : list of int
比較するヘッド数のリスト
"""
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(tokens)
# 同じ入力を使用
np.random.seed(123)
X = np.random.randn(seq_len, d_model)
fig, axes = plt.subplots(
len(head_configs), max(head_configs),
figsize=(3 * max(head_configs), 3 * len(head_configs))
)
for row, n_heads in enumerate(head_configs):
mha = MultiHeadAttention(d_model, n_heads)
_, attn_weights = mha.forward(X)
for col in range(max(head_configs)):
ax = axes[row][col] if len(head_configs) > 1 else axes[col]
if col < n_heads:
im = ax.imshow(attn_weights[col], cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=7)
ax.set_yticklabels(tokens, fontsize=7)
ax.set_title(f'h={n_heads}, Head {col+1}', fontsize=10)
else:
ax.axis('off')
plt.suptitle('Attention Patterns with Different Head Counts',
fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("mha_head_comparison.png", dpi=150, bbox_inches='tight')
plt.show()
visualize_heads_comparison(d_model=32, head_configs=[1, 2, 4, 8])
この可視化から、いくつかの重要な特徴が観察できます。
1つ目に、ヘッド数が1(Single-Head)のとき、1つの注意パターンしかないため、全ての言語的関係を1枚の注意マップに押し込む必要があります。結果として、特定のトークンペアへの注意が支配的になりがちです。
2つ目に、ヘッド数を増やすにつれて、注意パターンの多様性が増すことが見て取れます。あるヘッドは対角成分(自分自身への注意)が強く、別のヘッドは特定のトークンペアに集中し、また別のヘッドはより均等に分散した注意パターンを示します。
3つ目に、各ヘッドの注意パターンは独立に生成されていることが確認できます。同じ入力に対しても、射影行列が異なるため、異なるヘッドはまったく異なるパターンを生成しています。
注意パターンのエントロピー分析
注意パターンの「集中度」を定量的に評価するために、各ヘッドのAttention重みのエントロピーを計算してみましょう。エントロピーが高いほど注意が分散しており、低いほど特定のトークンに集中していることを意味します。
def attention_entropy(weights):
"""Attention重み行列の平均エントロピーを計算する
各行(クエリ)ごとのエントロピーを計算し、平均を返します。
エントロピーが高い = 注意が分散, 低い = 注意が集中
"""
# ゼロ除算を避けるための微小値
eps = 1e-10
# 各行のエントロピー: -Σ p_i * log(p_i)
entropy = -np.sum(weights * np.log(weights + eps), axis=-1)
return np.mean(entropy)
# ヘッド数ごとのエントロピーを計算
np.random.seed(42)
d_model = 64
tokens = ["I", "love", "machine", "learning", "very", "much"]
seq_len = len(tokens)
X = np.random.randn(seq_len, d_model)
head_counts = [1, 2, 4, 8, 16]
fig, ax = plt.subplots(figsize=(8, 5))
for n_heads in head_counts:
mha = MultiHeadAttention(d_model, n_heads)
_, attn_weights = mha.forward(X)
entropies = [attention_entropy(w) for w in attn_weights]
ax.scatter([n_heads] * len(entropies), entropies, alpha=0.7, s=50)
ax.plot(n_heads, np.mean(entropies), 'k_', markersize=15, markeredgewidth=2)
ax.set_xlabel('Number of Heads', fontsize=12)
ax.set_ylabel('Average Entropy of Attention Weights', fontsize=12)
ax.set_title('Attention Entropy vs. Number of Heads', fontsize=14)
ax.set_xticks(head_counts)
# 最大エントロピー(一様分布のとき)
max_entropy = np.log(seq_len)
ax.axhline(y=max_entropy, color='r', linestyle='--', alpha=0.5,
label=f'Maximum entropy (uniform): {max_entropy:.2f}')
ax.legend(fontsize=10)
plt.tight_layout()
plt.savefig("mha_entropy_analysis.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"最大エントロピー(一様分布): {max_entropy:.4f}")
print(f"各ヘッド数でのエントロピーの範囲:")
for n_heads in head_counts:
mha = MultiHeadAttention(d_model, n_heads)
_, attn_weights = mha.forward(X)
entropies = [attention_entropy(w) for w in attn_weights]
print(f" h={n_heads:2d}: min={min(entropies):.4f}, "
f"max={max(entropies):.4f}, "
f"range={max(entropies)-min(entropies):.4f}")
この散布図と数値結果から、いくつかの傾向が読み取れます。
まず、ヘッド数が増えるにつれて、各ヘッドのエントロピー値のばらつき(レンジ)が大きくなる傾向があります。これは、ヘッド数が多いほど、あるヘッドは集中的な注意を、別のヘッドは分散的な注意をというように、異なるタイプの注意パターンが共存しやすくなることを意味します。
また、赤い点線は一様分布のエントロピー($\ln(6) \approx 1.79$)を示しています。多くのヘッドのエントロピーはこの最大値よりも低く、各ヘッドが何らかの構造を持った注意パターン(完全に均等ではない分布)を生成していることがわかります。
なお、ここでの実験はランダム初期化によるものなので、学習後のモデルではこの多様性がさらに顕著になります。学習によって各ヘッドの射影行列が最適化されると、構文ヘッド、意味ヘッド、位置ヘッドなどの明確な役割分担が生じ、エントロピーの分布にも明確なクラスター構造が現れることが知られています。
Self-Attentionの中でのMHAの位置づけ
ここまでMulti-Head Attentionの理論と実装を詳しく見てきましたが、最後にTransformerアーキテクチャ全体の中でMHAがどのような役割を果たしているかを整理しておきましょう。
Transformerアーキテクチャでは、MHAは以下の場所で使われています。
- Encoderの自己注意層: 入力シーケンスの各トークンが、同じシーケンス内の他の全トークンとの関係を学習します。ここでは $\bm{Q} = \bm{K} = \bm{V} = \bm{X}$(入力自身)です
- Decoderの自己注意層(マスク付き): 出力シーケンスの生成時に、未来のトークンを参照しないようマスクを適用します
- Decoderのクロス注意層: Decoderの各トークン($\bm{Q}$)がEncoderの出力($\bm{K}, \bm{V}$)に注意を向けます。これにより、翻訳タスクなどで入力文の情報を参照しながら出力を生成できます
MHAの後には必ずLayer Normalizationと残差接続が適用され、さらにTransformer FFN(Feed-Forward Network)が続きます。MHAが「トークン間の関係性」を捉えるのに対し、FFNは「各トークンの表現を非線形変換で豊かにする」役割を担っています。
また、入力トークンには位置エンコーディングが加算されます。Attention自体には位置情報が含まれないため、位置エンコーディングがなければ「語順」を区別できません。位置エンコーディングによってトークンの位置情報が埋め込みベクトルに加えられ、MHAの各ヘッドがその位置情報を利用できるようになります。
まとめ
本記事では、Multi-Head Attentionの直感的理解から数学的定義、計算量分析、そしてNumPyとPyTorchによるスクラッチ実装まで解説しました。
- Multi-Head Attentionの本質: 入力を $h$ 個の低次元部分空間にそれぞれ射影し、各部分空間で独立にAttentionを計算することで、多様な依存関係(構文的・意味的・位置的)を同時に捉えます
- 数式: $\text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\bm{W}^O$、各ヘッドは $\text{head}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q, \bm{K}\bm{W}_i^K, \bm{V}\bm{W}_i^V)$
- 幾何学的直感: 各ヘッドの射影行列は「異なる窓から高次元空間を覗く」操作であり、異なる側面の情報を切り出します
- 計算量の不変性: $d_k = d_{\text{model}} / h$ の設計により、ヘッド数を増やしても総計算量 $O(n^2 d_{\text{model}})$ は変わりません
- 実装のポイント: NumPy版はヘッドごとのループで教育的、PyTorch版は全ヘッドを1つの行列積にまとめてGPU並列化に対応
- 注意パターンの多様性: ヘッド数が増えるほど、異なるタイプの注意パターン(集中型・分散型)が共存し、表現力が向上します
Multi-Head Attentionは、Transformerファミリーの全モデルに共通する最も重要な構成要素です。この機構を深く理解することは、GPTアーキテクチャやBERTアーキテクチャ、Vision Transformerなど、現代の深層学習モデルを理解するための土台となります。
次のステップとして、以下の記事も参考にしてください。
- Transformer Encoder — MHAを含むEncoderブロックの全体構造
- Transformer Decoder — マスク付きMHAとクロスAttentionの仕組み
- Layer Normalization — MHAの後に適用される正規化手法
- Transformer FFN — MHAと対をなすFeed-Forward Networkの役割