Cross-Attention(クロスアテンション)の理論と実装を完全解説

機械翻訳モデルに「I love cats」を入力したとき、デコーダが「猫」という単語を出力するためには、ソース文の「cats」に注目する必要があります。では、デコーダはどのようにしてエンコーダが処理した情報を「参照」しているのでしょうか?

答えがCross-Attention(クロスアテンション)です。Self-Attentionが「自分自身の系列内」で各要素間の関係を計算するのに対し、Cross-Attentionは異なる2つの系列間の関係を計算します。翻訳モデルであれば、ターゲット言語の生成中にソース言語の情報を参照する橋渡し役を担います。

Self-AttentionとCross-Attentionの違い

図が両者の違いです。Self-Attention(左)はQuery・Key・Valueをすべて1つの系列から作るのに対し、Cross-Attention(右)はQueryを出力系列から、Key・Valueを別の入力系列から作ります。この「別系列を参照する」点だけが本質的な違いで、計算式そのものは共通です。

Cross-Attentionはもはや機械翻訳だけの技術ではありません。現代の深層学習において広く使われています。

  • Stable Diffusion: テキストプロンプトの情報を画像生成プロセスに注入するためにCross-Attentionを使用
  • マルチモーダルモデル: 画像の特徴量とテキストの特徴量を結びつけるためにCross-Attentionが活躍
  • 音声認識: 音声のエンコーダ出力をテキストデコーダが参照する際にCross-Attentionを利用
  • Vision Transformer(ViT)系列のモデル: 異なるモダリティ間の融合にCross-Attentionが不可欠

本記事の内容

  • Self-Attentionの復習とCross-Attentionへの自然な拡張
  • Cross-Attentionの数学的定式化と直感的理解
  • Self-Attention vs Cross-Attention の明確な比較
  • Transformerデコーダ内でのCross-Attentionの位置と役割
  • Multi-Head Cross-Attentionへの拡張
  • Cross-Attentionの応用例(機械翻訳、Stable Diffusion、マルチモーダル)
  • PyTorchによるスクラッチ実装と注意重みの可視化

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

画像なし
Self-Attentionの理論と実装
Query・Key・Valueの線形射影とScaled Dot-Product Attentionの導出
画像なし
Multi-Head Attentionの理論と実装
複数のAttentionヘッドによる多角的な特徴抽出の仕組み
画像なし
Transformerのアーキテクチャ
Encoder-Decoderの全体構成と各コンポーネントの役割

Self-Attentionの復習

Cross-Attentionを理解するための最も自然な出発点は、Self-Attentionの仕組みを改めて確認することです。Self-Attentionでは、1つの系列が「自分自身に質問し、自分自身から答えを得る」という構造になっています。

Query・Key・Valueの生成

入力系列 $\bm{X} \in \mathbb{R}^{n \times d}$ ($n$ はトークン数、$d$ は埋め込み次元)が与えられたとき、Self-Attentionでは以下の3つの行列を同じ入力 $\bm{X}$ から生成します。

$$ \begin{align} \bm{Q} &= \bm{X}\bm{W}_Q \quad \in \mathbb{R}^{n \times d_k} \\ \bm{K} &= \bm{X}\bm{W}_K \quad \in \mathbb{R}^{n \times d_k} \\ \bm{V} &= \bm{X}\bm{W}_V \quad \in \mathbb{R}^{n \times d_v} \end{align} $$

ここで $\bm{W}_Q \in \mathbb{R}^{d \times d_k}$、$\bm{W}_K \in \mathbb{R}^{d \times d_k}$、$\bm{W}_V \in \mathbb{R}^{d \times d_v}$ は学習可能な重み行列です。

直感的に言えば、各トークンは以下の3つの役割を同時に担っています。

  • Query(質問): 「自分はどんな情報を必要としているか」を表現するベクトル
  • Key(鍵): 「自分はどんな情報を持っているか」を表現するベクトル
  • Value(値): 「自分が実際に渡す情報の中身」を表現するベクトル

Scaled Dot-Product Attention

Query、Key、Valueが生成されたら、Attentionは次の式で計算されます。

$$ \begin{equation} \text{SelfAttn}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} \end{equation} $$

この計算の流れを分解すると、まず $\bm{Q}\bm{K}^\top$ で全ペア間のスコア(内積による類似度)を計算します。次に $\sqrt{d_k}$ で割ることで、内積の値が大きくなりすぎるのを防ぎます($d_k$ が大きいと内積の分散が $d_k$ に比例して大きくなるため)。その後 $\text{softmax}$ で行ごとに正規化し、注意重み $\bm{A} \in \mathbb{R}^{n \times n}$ を得ます。最後に注意重みで $\bm{V}$ を加重平均し、文脈を考慮した出力を得ます。

Self-Attentionの重要なポイントは、Query、Key、Valueが全て同一の系列 $\bm{X}$ から生成されるという点です。つまり、系列内の各トークンが他の全トークンとの関係を計算し、文脈を加味した表現を作ります。

しかし、翻訳や要約、画像キャプションのように、入力と出力が異なる系列であるタスクでは、デコーダが生成中に「エンコーダ側の情報」を参照する必要があります。Self-Attentionだけではエンコーダとデコーダの間に橋を架けることができません。ここでCross-Attentionが登場します。

Cross-Attentionとは

直感的な理解 — 「質問する人」と「答えを持つ人」の対話

Cross-Attentionを理解するための最もわかりやすいアナロジーは、図書館での調べものです。

あなた(デコーダ)が論文を書いているとします。ある段落を書くために、どんな情報が必要かはあなたの頭の中(デコーダの隠れ状態)にあります。そこで図書館に行き、棚に並んだ本(エンコーダの出力)の中から関連するものを探します。本の背表紙(Key)を見て、あなたの質問(Query)に合う本を見つけ、その本の中身(Value)を読んで自分の論文に活かします。

このアナロジーにおけるポイントは以下の通りです。

  • 質問する人(Query)はデコーダ側から来る — 「今の出力生成に必要な情報は何か」
  • 答えを持つ人(Key, Value)はエンコーダ側から来る — 「入力系列にはどんな情報があるか」
  • 質問と答えが異なるソースに属する — これがSelf-Attentionとの決定的な違い

もう1つのアナロジーを挙げましょう。機械翻訳では、Cross-Attentionは「通訳者」のような役割を果たします。通訳者(デコーダ)が日本語の文を組み立てているとき、原文(エンコーダ出力)のどの部分を今参照すべきかを判断するメカニズムがCross-Attentionです。「猫」を出力するタイミングでは「cats」に強く注目し、「好きです」を出力するタイミングでは「love」に強く注目します。

Cross-Attentionの情報の流れ

図がこのアナロジーを表しています。質問する側(Query=出力系列)が、参照される側(Key・Value=入力系列)のどこを見るかを各位置で決め、Value の加重和として「文脈ベクトル」を取り出します。これがまさに「図書館で本を探して中身を読む」流れに対応します。

数学的定式化

Cross-Attentionの数式は、Self-Attentionの形とほとんど同じです。違いは、Query と Key/Value の出所が異なることだけです。

エンコーダの出力を $\bm{H}_\text{enc} \in \mathbb{R}^{m \times d}$ とします。ここで $m$ はソース系列の長さ、$d$ は隠れ状態の次元です。デコーダのある層の隠れ状態を $\bm{H}_\text{dec} \in \mathbb{R}^{n \times d}$ とします。$n$ はターゲット系列の(現在までの)長さです。

Cross-Attentionでは、Query はデコーダ側から、Key と Value はエンコーダ側から生成します。

Cross-AttentionのQ・K・Vの生成(2系列から)

図の通り、Query は出力系列(デコーダ)から、Key と Value は入力系列(エンコーダ)から生成されます。Self-Attentionとの違いはこの「出所」だけで、線形射影 $\bm{W}_Q,\bm{W}_K,\bm{W}_V$ を掛ける操作も、続く計算式もまったく同じです。

$$ \begin{align} \bm{Q} &= \bm{H}_\text{dec}\bm{W}_Q \quad \in \mathbb{R}^{n \times d_k} \\ \bm{K} &= \bm{H}_\text{enc}\bm{W}_K \quad \in \mathbb{R}^{m \times d_k} \\ \bm{V} &= \bm{H}_\text{enc}\bm{W}_V \quad \in \mathbb{R}^{m \times d_v} \end{align} $$

ここで $\bm{W}_Q \in \mathbb{R}^{d \times d_k}$、$\bm{W}_K \in \mathbb{R}^{d \times d_k}$、$\bm{W}_V \in \mathbb{R}^{d \times d_v}$ は学習可能な重み行列です。

Cross-Attentionの計算式はScaled Dot-Product Attentionと同じ形です。

$$ \begin{equation} \text{CrossAttn}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} \end{equation} $$

Scaled Dot-Product Attentionの計算フロー

計算は図のように左から順に進みます。$\bm{Q}\bm{K}^\top$ で類似度を求め、$\sqrt{d_k}$ でスケールし、softmaxで重み化し、最後に $\bm{V}$ を加重和する——この4ステップです。Self-AttentionでもCross-Attentionでもこのフロー自体は同一で、入る $\bm{Q},\bm{K},\bm{V}$ の出所だけが違います。

注意重み行列の形状

Self-Attentionとの重要な違いは、注意重み行列の形状にあります。

Self-Attentionでは、$\bm{Q}$ と $\bm{K}$ が同じ系列長 $n$ を持つため、スコア行列は $\bm{Q}\bm{K}^\top \in \mathbb{R}^{n \times n}$ の正方行列です。一方、Cross-Attentionでは、$\bm{Q} \in \mathbb{R}^{n \times d_k}$ と $\bm{K} \in \mathbb{R}^{m \times d_k}$ の系列長が異なるため、スコア行列は $\bm{Q}\bm{K}^\top \in \mathbb{R}^{n \times m}$ の長方形行列になります。

この $n \times m$ の注意重み行列 $\bm{A}$ の各要素 $A_{ij}$ は、「デコーダの位置 $i$ がエンコーダの位置 $j$ にどれだけ注目しているか」を表します。行方向に $\text{softmax}$ をとるので、各デコーダ位置 $i$ について $\sum_{j=1}^{m} A_{ij} = 1$ が成り立ちます。

注意重み行列の形状(query長×key長)

図のように、注意重み行列の形状は「(queryの長さ)×(keyの長さ)」です。Self-Attentionでは同一系列なので正方行列 $n\times n$ ですが、Cross-Attentionでは入力と出力の長さが違うため長方形 $n\times m$ になります。各行が1に正規化され、出力位置 $i$ が入力位置 $j$ をどれだけ参照するかを表します。

出力の意味

Cross-Attentionの出力 $\bm{C} \in \mathbb{R}^{n \times d_v}$ の $i$ 行目は次のようになります。

$$ \bm{c}_i = \sum_{j=1}^{m} A_{ij} \bm{v}_j $$

これはエンコーダ出力のValue $\bm{v}_1, \ldots, \bm{v}_m$ の重み付き平均であり、デコーダの位置 $i$ にとって「最も関連性の高いソース情報」を集約したベクトルです。この集約されたコンテキストベクトルが、デコーダの後続の層に渡され、次のトークンの予測に活用されます。

ここまでで、Cross-Attentionの基本的な仕組みを理解しました。では、Self-Attentionと具体的にどのような点が異なるのか、表形式で整理してみましょう。

Self-Attention vs Cross-Attention の比較

Self-AttentionとCross-Attentionは数式の「骨格」はまったく同じScaled Dot-Product Attentionです。違いはQuery、Key、Valueがどこから来るかだけです。しかし、この違いが役割や使われ方に大きな差をもたらします。

比較項目 Self-Attention Cross-Attention
Queryの出所 入力系列 $\bm{X}$(自分自身) デコーダの隠れ状態 $\bm{H}_\text{dec}$
Key/Valueの出所 入力系列 $\bm{X}$(自分自身) エンコーダの出力 $\bm{H}_\text{enc}$
注意重みの形状 $n \times n$(正方行列) $n \times m$(長方形行列)
主な役割 系列内の各トークン間の関係を捉える 異なる系列間の対応関係を捉える
使われる場所 エンコーダ層、デコーダの第1サブ層 デコーダの第2サブ層
マスク 因果マスク(デコーダ側のSelf-Attention) 通常は不要(パディングマスクのみ)
直感的な解釈 「自問自答」 「異なるソースへの問い合わせ」
典型的な応用 BERT、GPTの内部表現学習 機械翻訳、画像キャプション、Stable Diffusion

この表からわかるように、Self-Attentionは1つの系列の「内部構造」を理解するための機構であり、Cross-Attentionは2つの系列を「結びつける」ための機構です。

もう少し具体的に考えてみましょう。機械翻訳で “I love cats” → “私は猫が好きです” を翻訳する場合を考えます。

  • エンコーダのSelf-Attention: “I”, “love”, “cats” の3トークン間の関係を計算。たとえば “love” が “I”(主語)と “cats”(目的語)の両方に注目
  • デコーダのSelf-Attention(Masked): “私は”, “猫が” など、これまでに生成したターゲットトークン間の関係を計算。未来のトークンは参照不可
  • Cross-Attention: デコーダの各位置がエンコーダ出力を参照。”猫が” を生成するときは “cats” に強く注目し、”好きです” を生成するときは “love” に強く注目

このように、Self-AttentionとCross-Attentionはそれぞれ異なる役割を担いながら、協調してTransformerの生成能力を実現しています。では、Transformerのデコーダ内でこれらがどのように配置され、連携しているのかを詳しく見ていきましょう。

Transformerデコーダ内でのCross-Attentionの位置と役割

Transformerデコーダ内のCross-Attentionの位置

図はデコーダ層の3つのサブ層です。①Masked Self-Attentionで出力系列の内部を処理し、②Cross-Attentionでエンコーダ出力(Key・Value)を参照し、③Feed-Forwardで位置ごとに変換します。Cross-Attentionは②に位置し、エンコーダとデコーダをつなぐ橋の役割を担います。

デコーダ層の3つのサブ層

Transformerの各デコーダ層は、以下の3つのサブ層から構成されています。

サブ層1: Masked Self-Attention

デコーダの入力トークン(これまでに生成されたトークン系列)に対してSelf-Attentionを計算します。ただし、自己回帰的な生成を保証するために、因果マスク(Causal Mask)を適用し、各位置が未来の位置を参照できないようにします。

このサブ層の役割は、「ターゲット系列内部の依存関係を捉えること」です。たとえば日本語の「猫が好きです」を生成する際、「好きです」は「猫が」に依存しているため、このような関係をモデリングします。

サブ層2: Cross-Attention(Encoder-Decoder Attention)

Masked Self-Attentionの出力をQueryとし、エンコーダの最終出力をKey/Valueとして、Cross-Attentionを計算します。

このサブ層の役割は、「デコーダの各位置がソース系列のどの部分を参照すべきかを決定すること」です。翻訳であれば、ターゲット言語の各単語がソース言語のどの単語に対応するかを学習します。

サブ層3: Feed-Forward Network(FFN)

位置ごとに独立した2層のニューラルネットワークを適用します。

$$ \text{FFN}(\bm{x}) = \text{ReLU}(\bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2 $$

Attentionで集約した情報を、位置ごとに非線形変換して表現を豊かにする役割を担います。

残差接続とLayer Normalization

各サブ層の周りには残差接続(Residual Connection)Layer Normalizationが配置されています。具体的には、各サブ層の出力は次のように計算されます。

$$ \bm{H}_\text{out} = \text{LayerNorm}(\bm{H}_\text{in} + \text{Sublayer}(\bm{H}_\text{in})) $$

この構造により、2つの重要な効果が得られます。

  1. 勾配の安定化: 残差接続により、誤差逆伝播時に勾配が直接前の層に伝わるパスが確保され、勾配消失を防ぎます
  2. 学習の安定性: Layer Normalizationにより、各層の入力の分布が正規化され、学習が安定します

デコーダ1層の処理フロー

デコーダ1層の処理フローを数式で整理しましょう。デコーダ層への入力を $\bm{H}_\text{dec}^{(l-1)}$、エンコーダの最終出力を $\bm{H}_\text{enc}$ とします。

まずMasked Self-Attentionを適用します。

$$ \bm{H}_1 = \text{LayerNorm}\left(\bm{H}_\text{dec}^{(l-1)} + \text{MaskedSelfAttn}(\bm{H}_\text{dec}^{(l-1)})\right) $$

次にCross-Attentionを適用します。ここで $\bm{H}_1$ からQueryを生成し、$\bm{H}_\text{enc}$ からKey/Valueを生成します。

$$ \bm{H}_2 = \text{LayerNorm}\left(\bm{H}_1 + \text{CrossAttn}(\bm{Q}=\bm{H}_1, \bm{K}=\bm{H}_\text{enc}, \bm{V}=\bm{H}_\text{enc})\right) $$

最後にFFNを適用します。

$$ \bm{H}_\text{dec}^{(l)} = \text{LayerNorm}\left(\bm{H}_2 + \text{FFN}(\bm{H}_2)\right) $$

この処理を $L$ 層繰り返すことで、デコーダは入力系列の文脈を段階的に取り込みながら、ターゲット系列の表現を洗練していきます。

Cross-Attentionの「情報ゲート」としての役割

Cross-Attentionを別の角度から見ると、エンコーダ側の情報をデコーダに流入させる「ゲート」の役割を果たしています。

学習が進むと、Cross-Attentionの注意重みは以下のようなパターンを示すことが知られています。

  • 整列パターン(Alignment): 機械翻訳では、ソースとターゲットの単語対応関係にほぼ沿った対角的な注意パターンが現れる
  • 広域参照パターン: 要約タスクでは、1つの出力トークンが入力の広い範囲に注目するパターンが現れる
  • 条件付けパターン: Stable Diffusionでは、テキストトークンのうち特定のキーワード(色、形状、オブジェクト名)への注意が特に強くなる

このように、Cross-Attentionはタスクに応じて柔軟に情報の流れを制御しています。

ここまでで1つのAttentionヘッドによるCross-Attentionを理解しました。しかし、1つのヘッドでは1つの「視点」からしか情報を集約できません。異なる側面の情報を同時に捉えるためには、Multi-Head構造が必要です。

Multi-Head Cross-Attention

Multi-Head Cross-Attention

図のように、Multi-Head Cross-Attentionは複数のヘッドで並列にCross-Attentionを行い、結果を結合して線形変換します。各ヘッドが異なる観点(文法的な対応、意味的な対応など)で別系列を参照できるため、1つのヘッドだけより豊かな対応関係を捉えられます。

なぜ複数のヘッドが必要なのか

1つのAttentionヘッドは、Query-Key間の1つの類似度パターンしか捉えられません。しかし、翻訳タスクを考えてみましょう。「私は猫が好きです」を生成するとき、デコーダは次のような複数の情報を同時に参照する必要があります。

  • 構文的な対応: 「好きです」→ “love”(動詞の翻訳関係)
  • 主語の参照: 「好きです」→ “I”(省略された主語の情報)
  • 局所的な文脈: 「好きです」→ “cats”(目的語との関係)

1つのヘッドでこれらを全て捉えるのは難しいため、複数のヘッドで異なる部分空間に射影し、それぞれ独立にCross-Attentionを計算してから結果を統合します。

定式化

$h$ 個のヘッドを用いたMulti-Head Cross-Attentionは次のように定義されます。

$$ \text{MultiHeadCrossAttn}(\bm{H}_\text{dec}, \bm{H}_\text{enc}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\bm{W}_O $$

各ヘッド $i$ は、ヘッドごとに独立な重み行列を使ってCross-Attentionを計算します。

$$ \text{head}_i = \text{CrossAttn}(\bm{H}_\text{dec}\bm{W}_Q^{(i)}, \bm{H}_\text{enc}\bm{W}_K^{(i)}, \bm{H}_\text{enc}\bm{W}_V^{(i)}) $$

ここで各ヘッドの射影行列のサイズは以下の通りです。

$$ \bm{W}_Q^{(i)} \in \mathbb{R}^{d \times d_k}, \quad \bm{W}_K^{(i)} \in \mathbb{R}^{d \times d_k}, \quad \bm{W}_V^{(i)} \in \mathbb{R}^{d \times d_v} $$

ヘッドごとの次元は $d_k = d_v = d / h$ とし、出力射影行列は $\bm{W}_O \in \mathbb{R}^{d \times d}$ です。

計算の流れを整理する

Multi-Head Cross-Attentionの計算の流れを段階的に追ってみましょう。

ステップ1: デコーダの隠れ状態 $\bm{H}_\text{dec} \in \mathbb{R}^{n \times d}$ を $h$ 個のQuery部分空間に射影します。同様に、エンコーダの出力 $\bm{H}_\text{enc} \in \mathbb{R}^{m \times d}$ を $h$ 個のKey部分空間と $h$ 個のValue部分空間に射影します。

ステップ2: 各ヘッド $i$ で独立にScaled Dot-Product Attentionを計算します。

$$ \text{head}_i = \text{softmax}\left(\frac{(\bm{H}_\text{dec}\bm{W}_Q^{(i)})(\bm{H}_\text{enc}\bm{W}_K^{(i)})^\top}{\sqrt{d_k}}\right)(\bm{H}_\text{enc}\bm{W}_V^{(i)}) $$

各ヘッドの出力は $\text{head}_i \in \mathbb{R}^{n \times d_v}$ です。

ステップ3: $h$ 個のヘッド出力を連結します。

$$ \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} = \mathbb{R}^{n \times d} $$

$d_v = d/h$ なので、連結後の次元は元の $d$ に戻ります。

ステップ4: 出力射影 $\bm{W}_O$ を掛けて最終出力を得ます。

$$ \text{MultiHeadCrossAttn} = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\bm{W}_O \in \mathbb{R}^{n \times d} $$

パラメータ数

Multi-Head Cross-Attentionのパラメータ数を確認しましょう。各ヘッドは $d \times (d/h)$ のQ, K, V射影行列を持ち、ヘッドが $h$ 個あるので、射影行列のパラメータ数は $3 \times h \times d \times (d/h) = 3d^2$ です。出力射影 $\bm{W}_O$ のパラメータ数は $d^2$ です。合計すると

$$ \text{パラメータ数} = 3d^2 + d^2 = 4d^2 $$

これはSelf-Attentionの場合と全く同じです。ヘッド数 $h$ を変えてもパラメータ数は変わりません。ヘッド数を増やすと各ヘッドの次元 $d/h$ が小さくなりますが、全体のパラメータ数は一定です。

Transformer原論文(Vaswani et al., 2017)では $d = 512$、$h = 8$ で $d_k = d_v = 64$ を使用しています。

Multi-Head Cross-Attentionの理論を理解したところで、次はCross-Attentionが実際にどのようなタスクで活躍しているか、具体的な応用例を見ていきましょう。

Cross-Attentionの応用例

Cross-Attentionの応用例

図のように、Cross-Attentionは「2つの異なる情報源を結びつける」汎用機構として、機械翻訳・テキストからの画像生成・マルチモーダルなど幅広く使われます。いずれも「Query側がKey/Value側の必要な情報を引き出す」橋渡しという点で共通しています。

機械翻訳(Transformer原論文)

Cross-Attentionの最も古典的かつ重要な応用は、Transformer原論文(”Attention Is All You Need”, Vaswani et al., 2017)における機械翻訳です。

Transformerのエンコーダはソース言語の文(例: 英語)を処理し、文脈を考慮した表現ベクトルの系列 $\bm{H}_\text{enc}$ を出力します。デコーダは自己回帰的にターゲット言語の文(例: ドイツ語)を1トークンずつ生成しますが、その際にCross-Attentionを通じてソース側の情報を参照します。

学習された注意重みを可視化すると、ソースとターゲットの単語アライメントに対応するパターンが観察されます。語順の異なる言語ペア(英語→日本語など)では、対角線から外れた注意パターンが現れ、Cross-Attentionが語順の入れ替えを自然にモデリングしていることがわかります。

Stable Diffusionのテキスト→画像条件付け

近年のテキストから画像を生成するモデル、特にStable Diffusionでは、Cross-Attentionが中心的な役割を果たしています。

Stable DiffusionのU-Netアーキテクチャでは、テキストプロンプトをCLIPテキストエンコーダで処理し、得られたテキスト埋め込みの系列をKey/Valueとして使います。U-Netの各解像度レベルで、画像の特徴マップ(Query)がテキスト埋め込み(Key/Value)に対してCross-Attentionを計算します。

$$ \begin{align} \bm{Q} &= \bm{W}_Q \cdot \text{flatten}(\bm{z}_\text{image}) \\ \bm{K} &= \bm{W}_K \cdot \bm{H}_\text{text} \\ \bm{V} &= \bm{W}_V \cdot \bm{H}_\text{text} \end{align} $$

ここで $\bm{z}_\text{image}$ はU-Net内部の画像特徴マップ、$\bm{H}_\text{text}$ はCLIPテキストエンコーダの出力です。

この構造により、「画像のどの空間位置がテキストのどの単語に対応すべきか」をモデルが学習します。たとえば「A red cat sitting on a blue chair」というプロンプトに対して、画像の猫の領域は “red” と “cat” に強く注目し、椅子の領域は “blue” と “chair” に強く注目する — というパターンが注意重みに現れます。

Cross-Attentionの注意重みを操作することで、画像編集(Prompt-to-Prompt editing)やアテンションマップの可視化といった応用も生まれています。

マルチモーダルモデル

Vision-Language Model(VLM)やマルチモーダルTransformerでは、異なるモダリティ間の情報融合にCross-Attentionが広く使われています。

Flamingo(DeepMind, 2022)では、言語モデルの各層にCross-Attention層を挿入し、視覚エンコーダの出力をKey/Valueとして使います。これにより、テキスト生成時に画像の情報を柔軟に参照できます。

音声認識(Whisper等)では、音声特徴のエンコーダ出力をKey/Valueとし、テキストデコーダがCross-Attentionを通じて音声情報を参照しながらテキストを生成します。

PerceiverIO(DeepMind, 2022)では、任意のモダリティの入力をKey/Valueとし、学習可能な潜在変数をQueryとしたCross-Attentionで情報を圧縮します。これにより、入力の長さや次元に依存しない統一的なアーキテクチャを実現しています。

これらの応用に共通するのは、Cross-Attentionが異なるモダリティや異なるソースの情報を、柔軟かつ動的に結びつける汎用的な仕組みとして機能しているということです。

理論と応用を一通り理解したところで、いよいよPyTorchを使ってCross-Attentionをスクラッチ実装し、実際に動作を確認していきましょう。

PyTorchでの実装

CrossAttentionクラスの実装

まず、シングルヘッドのCross-Attentionクラスを実装します。Self-Attentionの実装と比較しながら、Queryの出所が異なるという構造上の違いを確認しましょう。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class CrossAttention(nn.Module):
    """シングルヘッドのCross-Attention"""

    def __init__(self, d_model, d_k, d_v):
        """
        Args:
            d_model: 入力の埋め込み次元
            d_k: Query/Keyの次元
            d_v: Valueの次元
        """
        super().__init__()
        self.d_k = d_k

        # Queryはデコーダ入力から生成
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        # Key, Valueはエンコーダ出力から生成
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_v, bias=False)

    def forward(self, decoder_input, encoder_output, mask=None):
        """
        Args:
            decoder_input: デコーダの隠れ状態 (batch, n, d_model)
            encoder_output: エンコーダの出力 (batch, m, d_model)
            mask: パディングマスク (batch, n, m) — Trueの位置をマスク
        Returns:
            output: Cross-Attentionの出力 (batch, n, d_v)
            attn_weights: 注意重み (batch, n, m)
        """
        # Queryはデコーダ側から生成
        Q = self.W_Q(decoder_input)   # (batch, n, d_k)
        # Key, Valueはエンコーダ側から生成
        K = self.W_K(encoder_output)  # (batch, m, d_k)
        V = self.W_V(encoder_output)  # (batch, m, d_v)

        # スコア行列の計算: (batch, n, d_k) @ (batch, d_k, m) = (batch, n, m)
        scores = torch.bmm(Q, K.transpose(1, 2)) / (self.d_k ** 0.5)

        # マスクの適用(パディング位置を-infにする)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        # Softmaxで正規化して注意重みを得る
        attn_weights = F.softmax(scores, dim=-1)  # (batch, n, m)

        # Valueの重み付き和
        output = torch.bmm(attn_weights, V)  # (batch, n, d_v)

        return output, attn_weights

このコードのポイントは、forward メソッドの引数に decoder_inputencoder_output2つの入力を受け取る点です。Self-Attentionの場合は1つの入力から Q, K, V の全てを生成しますが、Cross-Attentionでは Q はデコーダ側、K と V はエンコーダ側と、出所が明確に分かれています。また、スコア行列の形状が (batch, n, m) と長方形になる点も Self-Attention((batch, n, n))との違いです。

MultiHeadCrossAttentionクラスの実装

次に、Multi-Head版を実装します。実際のTransformerでは、この Multi-Head 版が使われます。

class MultiHeadCrossAttention(nn.Module):
    """Multi-Head Cross-Attention"""

    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model: 入力の埋め込み次元
            num_heads: ヘッド数
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_modelはnum_headsで割り切れる必要があります"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # ヘッドあたりの次元

        # Q, K, V の射影行列(全ヘッド分をまとめて計算)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)

        # 出力射影
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, decoder_input, encoder_output, mask=None):
        """
        Args:
            decoder_input: (batch, n, d_model)
            encoder_output: (batch, m, d_model)
            mask: (batch, 1, 1, m) または (batch, 1, n, m)
        Returns:
            output: (batch, n, d_model)
            attn_weights: (batch, num_heads, n, m)
        """
        batch_size = decoder_input.size(0)
        n = decoder_input.size(1)
        m = encoder_output.size(1)

        # 全ヘッド分の Q, K, V をまとめて計算
        Q = self.W_Q(decoder_input)   # (batch, n, d_model)
        K = self.W_K(encoder_output)  # (batch, m, d_model)
        V = self.W_V(encoder_output)  # (batch, m, d_model)

        # ヘッドごとに分割: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, n, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, m, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, m, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        # scores: (batch, num_heads, n, m)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)  # (batch, num_heads, n, m)
        context = torch.matmul(attn_weights, V)    # (batch, num_heads, n, d_k)

        # ヘッドを連結: (batch, num_heads, n, d_k) -> (batch, n, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, n, self.d_model)

        # 出力射影
        output = self.W_O(context)  # (batch, n, d_model)

        return output, attn_weights

Multi-Head版では、全ヘッド分の Q, K, V をまとめて1回の行列積で計算し、その後 viewtranspose でヘッドごとに分割しています。これは計算効率のための実装上の工夫です。数学的には各ヘッドで独立に射影して Attention を計算するのと全く同じ結果になりますが、GPU 上で並列に計算できるため高速です。

合成データでの動作確認

実装が正しく動作するか、合成データを使って確認しましょう。疑似的な「翻訳」タスクを想定し、ソース系列(英語4トークン)に対してターゲット系列(日本語3トークン)がCross-Attentionを計算する状況をシミュレーションします。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# --- CrossAttention, MultiHeadCrossAttention は上で定義済みとする ---

# 再現性のためシードを固定
torch.manual_seed(42)
np.random.seed(42)

# ハイパーパラメータ
d_model = 64    # 埋め込み次元
num_heads = 4   # ヘッド数
batch_size = 1

# 疑似的なエンコーダ出力(ソース系列: 4トークン)
# "I", "love", "cats", "." に対応すると想定
m = 4
encoder_output = torch.randn(batch_size, m, d_model)

# 疑似的なデコーダ入力(ターゲット系列: 3トークン)
# "私は", "猫が", "好きです" に対応すると想定
n = 3
decoder_input = torch.randn(batch_size, n, d_model)

# シングルヘッドの動作確認
single_head = CrossAttention(d_model, d_k=64, d_v=64)
output_single, attn_single = single_head(decoder_input, encoder_output)

print("=== シングルヘッド Cross-Attention ===")
print(f"デコーダ入力の形状:     {decoder_input.shape}")
print(f"エンコーダ出力の形状:   {encoder_output.shape}")
print(f"出力の形状:             {output_single.shape}")
print(f"注意重み行列の形状:     {attn_single.shape}")
print(f"注意重み(行ごとの和): {attn_single[0].sum(dim=-1)}")
=== シングルヘッド Cross-Attention ===
デコーダ入力の形状:     torch.Size([1, 3, 64])
エンコーダ出力の形状:   torch.Size([1, 4, 64])
出力の形状:             torch.Size([1, 3, 64])
注意重み行列の形状:     torch.Size([1, 3, 4])
注意重み(行ごとの和): tensor([1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

出力結果を見ると、注意重み行列の形状が (1, 3, 4) であることが確認できます。これはデコーダの3トークンがエンコーダの4トークンに対して注目度を計算した結果です。各行の和が1.0になっているのは、softmaxによって行ごとに正規化されていることの確認です。また、出力の形状 (1, 3, 64) はデコーダの系列長を保ったまま、エンコーダの情報が組み込まれていることを示しています。

Multi-Headでの動作確認

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# --- MultiHeadCrossAttention は上で定義済みとする ---

torch.manual_seed(42)

d_model = 64
num_heads = 4
batch_size = 1
m = 4  # エンコーダ系列長
n = 3  # デコーダ系列長

encoder_output = torch.randn(batch_size, m, d_model)
decoder_input = torch.randn(batch_size, n, d_model)

# Multi-Head Cross-Attention
multi_head = MultiHeadCrossAttention(d_model, num_heads)
output_multi, attn_multi = multi_head(decoder_input, encoder_output)

print("=== Multi-Head Cross-Attention ===")
print(f"出力の形状:         {output_multi.shape}")
print(f"注意重みの形状:     {attn_multi.shape}")
print(f"ヘッド0の注意重み:\n{attn_multi[0, 0].detach().numpy().round(3)}")
print(f"ヘッド1の注意重み:\n{attn_multi[0, 1].detach().numpy().round(3)}")
=== Multi-Head Cross-Attention ===
出力の形状:         torch.Size([1, 3, 64])
注意重みの形状:     torch.Size([1, 4, 3, 4])
ヘッド0の注意重み:
[[0.218 0.333 0.183 0.266]
 [0.135 0.286 0.35  0.229]
 [0.217 0.176 0.374 0.233]]
ヘッド1の注意重み:
[[0.296 0.226 0.284 0.194]
 [0.213 0.292 0.174 0.321]
 [0.253 0.249 0.202 0.296]]

Multi-Head版の出力を見ると、注意重みの形状が (1, 4, 3, 4) です。これは (batch, num_heads, n, m) に対応し、4つのヘッドがそれぞれ独立に 3×4 の注意重み行列を持っていることを意味します。ヘッド0とヘッド1の注意パターンを比較すると、初期状態(ランダム重み)でもヘッドごとに異なる分布を示していることがわかります。学習が進むと、各ヘッドはより特化した注意パターンを獲得します。

注意重みの可視化

注意重みの可視化(どの入力語を参照したか)

図は注意重み行列を可視化したものです。各出力語(query)がどの入力語(key)を強く参照したかが色の濃さでわかります。たとえば「cat」は「猫」に、「sat」「down」は「座」「った」に強く注目しており、Cross-Attentionが入力と出力の対応(アラインメント)を自動的に学習することが見て取れます。

Cross-Attentionの注意重みをヒートマップとして可視化しましょう。ソース言語とターゲット言語のトークン間の対応関係を視覚的に確認します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# --- MultiHeadCrossAttention は上で定義済みとする ---

torch.manual_seed(42)

d_model = 64
num_heads = 4
batch_size = 1
m = 4
n = 3

encoder_output = torch.randn(batch_size, m, d_model)
decoder_input = torch.randn(batch_size, n, d_model)

multi_head = MultiHeadCrossAttention(d_model, num_heads)
_, attn_weights = multi_head(decoder_input, encoder_output)

# ソーストークン(エンコーダ側)とターゲットトークン(デコーダ側)
source_tokens = ["I", "love", "cats", "."]
target_tokens = ["私は", "猫が", "好きです"]

# 注意重みをnumpy配列に変換
attn_np = attn_weights[0].detach().numpy()  # (num_heads, n, m)

fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))
fig.suptitle("Multi-Head Cross-Attention Weights", fontsize=14)

for head_idx in range(num_heads):
    ax = axes[head_idx]
    im = ax.imshow(attn_np[head_idx], cmap="Blues", vmin=0, vmax=1, aspect="auto")

    ax.set_xticks(range(m))
    ax.set_xticklabels(source_tokens, fontsize=11)
    ax.set_yticks(range(n))
    ax.set_yticklabels(target_tokens, fontsize=11)
    ax.set_xlabel("Source (Encoder)", fontsize=10)
    ax.set_title(f"Head {head_idx + 1}", fontsize=12)

    # 各セルに数値を表示
    for i in range(n):
        for j in range(m):
            ax.text(j, i, f"{attn_np[head_idx][i, j]:.2f}",
                    ha="center", va="center", fontsize=9,
                    color="white" if attn_np[head_idx][i, j] > 0.5 else "black")

axes[0].set_ylabel("Target (Decoder)", fontsize=10)
plt.tight_layout()
plt.savefig("cross_attention_heads.png", dpi=150, bbox_inches="tight")
plt.show()

この可視化から、各ヘッドが異なる注意パターンを持っていることが確認できます。ランダムな初期化状態では注意重みはほぼ均一に分布していますが、実際の学習済みモデルでは、あるヘッドは「猫が→cats」のような直接的な翻訳対応を捉え、別のヘッドは「好きです→I」のように省略された主語への参照を捉えるなど、ヘッドごとに異なる言語的な関係を学習します。ヒートマップの色が濃い部分ほど注意重みが高く、そのソーストークンの情報がターゲットトークンの生成に強く影響していることを意味します。

Self-Attention と Cross-Attention の注意重み形状の比較

最後に、Self-AttentionとCross-Attentionの注意重み行列の形状の違いを視覚的に比較してみましょう。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)

d_model = 64
seq_len_dec = 3   # デコーダ系列長
seq_len_enc = 4   # エンコーダ系列長

# --- Self-Attention用の簡易クラス ---
class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)

    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        scores = torch.bmm(Q, K.transpose(1, 2)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        return torch.bmm(attn, V), attn

# --- CrossAttention は上で定義済みとする ---

# 入力データ
decoder_input = torch.randn(1, seq_len_dec, d_model)
encoder_output = torch.randn(1, seq_len_enc, d_model)

# Self-Attention(デコーダ系列に対して)
self_attn = SelfAttention(d_model, d_k=64)
_, self_attn_weights = self_attn(decoder_input)

# Cross-Attention
cross_attn = CrossAttention(d_model, d_k=64, d_v=64)
_, cross_attn_weights = cross_attn(decoder_input, encoder_output)

# 可視化
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Self-Attention: 正方行列 (n x n)
target_tokens = ["私は", "猫が", "好きです"]
source_tokens = ["I", "love", "cats", "."]

sa_np = self_attn_weights[0].detach().numpy()
ax0 = axes[0]
im0 = ax0.imshow(sa_np, cmap="Oranges", vmin=0, vmax=1, aspect="auto")
ax0.set_xticks(range(seq_len_dec))
ax0.set_xticklabels(target_tokens, fontsize=11)
ax0.set_yticks(range(seq_len_dec))
ax0.set_yticklabels(target_tokens, fontsize=11)
ax0.set_title(f"Self-Attention ({seq_len_dec}x{seq_len_dec})", fontsize=12)
ax0.set_xlabel("Key (same sequence)", fontsize=10)
ax0.set_ylabel("Query (same sequence)", fontsize=10)
for i in range(seq_len_dec):
    for j in range(seq_len_dec):
        ax0.text(j, i, f"{sa_np[i, j]:.2f}", ha="center", va="center", fontsize=9)

# Cross-Attention: 長方形行列 (n x m)
ca_np = cross_attn_weights[0].detach().numpy()
ax1 = axes[1]
im1 = ax1.imshow(ca_np, cmap="Blues", vmin=0, vmax=1, aspect="auto")
ax1.set_xticks(range(seq_len_enc))
ax1.set_xticklabels(source_tokens, fontsize=11)
ax1.set_yticks(range(seq_len_dec))
ax1.set_yticklabels(target_tokens, fontsize=11)
ax1.set_title(f"Cross-Attention ({seq_len_dec}x{seq_len_enc})", fontsize=12)
ax1.set_xlabel("Key (Encoder output)", fontsize=10)
ax1.set_ylabel("Query (Decoder input)", fontsize=10)
for i in range(seq_len_dec):
    for j in range(seq_len_enc):
        ax1.text(j, i, f"{ca_np[i, j]:.2f}", ha="center", va="center", fontsize=9)

plt.tight_layout()
plt.savefig("self_vs_cross_attention.png", dpi=150, bbox_inches="tight")
plt.show()

左のSelf-Attentionの注意重みは $3 \times 3$ の正方行列であり、デコーダ系列の各トークンが同じ系列内の他のトークンとの関係を計算しています。一方、右のCross-Attentionの注意重みは $3 \times 4$ の長方形行列であり、デコーダの3トークンがエンコーダの4トークンへの注目度をそれぞれ計算しています。この形状の違いこそが、Self-AttentionとCross-Attentionの本質的な違いを視覚的に示しています。Self-Attentionは「同じ空間の中での関係」を、Cross-Attentionは「異なる空間をまたぐ関係」を計算しているのです。

Transformerデコーダブロックの統合実装

最後に、Masked Self-Attention → Cross-Attention → FFN の3つのサブ層を統合したTransformerデコーダブロックを実装し、全体の動作を確認します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# --- MultiHeadCrossAttention は上で定義済みとする ---

class MultiHeadSelfAttention(nn.Module):
    """Multi-Head Self-Attention(マスク対応)"""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        Q = self.W_Q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_O(context), attn_weights


class TransformerDecoderBlock(nn.Module):
    """Transformerデコーダの1ブロック"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: 埋め込み次元
            num_heads: ヘッド数
            d_ff: FFNの中間層次元
            dropout: ドロップアウト率
        """
        super().__init__()

        # サブ層1: Masked Self-Attention
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # サブ層2: Cross-Attention
        self.cross_attn = MultiHeadCrossAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

        # サブ層3: FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, decoder_input, encoder_output, causal_mask=None):
        """
        Args:
            decoder_input: (batch, n, d_model)
            encoder_output: (batch, m, d_model)
            causal_mask: 因果マスク (1, 1, n, n)
        """
        # サブ層1: Masked Self-Attention + 残差接続 + LayerNorm
        self_attn_out, self_attn_weights = self.self_attn(decoder_input, causal_mask)
        h1 = self.norm1(decoder_input + self.dropout1(self_attn_out))

        # サブ層2: Cross-Attention + 残差接続 + LayerNorm
        cross_attn_out, cross_attn_weights = self.cross_attn(h1, encoder_output)
        h2 = self.norm2(h1 + self.dropout2(cross_attn_out))

        # サブ層3: FFN + 残差接続 + LayerNorm
        ffn_out = self.ffn(h2)
        h3 = self.norm3(h2 + self.dropout3(ffn_out))

        return h3, self_attn_weights, cross_attn_weights

統合テストを実行して、3つのサブ層が正しく連携することを確認します。

import torch
import torch.nn as nn
import numpy as np

# --- 上で定義した各クラスは既に定義済みとする ---

torch.manual_seed(42)

# ハイパーパラメータ
d_model = 64
num_heads = 4
d_ff = 256
batch_size = 1
m = 4  # ソース系列長
n = 3  # ターゲット系列長

# 入力データ
encoder_output = torch.randn(batch_size, m, d_model)
decoder_input = torch.randn(batch_size, n, d_model)

# 因果マスクの作成(上三角をTrue=マスクにする)
causal_mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, n, n)

# デコーダブロック
decoder_block = TransformerDecoderBlock(d_model, num_heads, d_ff, dropout=0.0)
decoder_block.eval()

with torch.no_grad():
    output, sa_weights, ca_weights = decoder_block(
        decoder_input, encoder_output, causal_mask
    )

print("=== Transformer Decoder Block ===")
print(f"デコーダ入力:           {decoder_input.shape}")
print(f"エンコーダ出力:         {encoder_output.shape}")
print(f"デコーダブロック出力:   {output.shape}")
print(f"Self-Attention重み:     {sa_weights.shape}")
print(f"Cross-Attention重み:    {ca_weights.shape}")
print(f"\n因果マスク:\n{causal_mask[0, 0].int().numpy()}")
print(f"\nSelf-Attention重み(ヘッド0):\n{sa_weights[0, 0].numpy().round(3)}")
print(f"\nCross-Attention重み(ヘッド0):\n{ca_weights[0, 0].numpy().round(3)}")
=== Transformer Decoder Block ===
デコーダ入力:           torch.Size([1, 3, 64])
エンコーダ出力:         torch.Size([1, 4, 64])
デコーダブロック出力:   torch.Size([1, 3, 64])
Self-Attention重み:     torch.Size([1, 4, 3, 3])
Cross-Attention重み:    torch.Size([1, 4, 3, 4])

因果マスク:
[[0 1 1]
 [0 0 1]
 [0 0 0]]

Self-Attention重み(ヘッド0):
[[1.    0.    0.   ]
 [0.537 0.463 0.   ]
 [0.289 0.352 0.359]]

Cross-Attention重み(ヘッド0):
[[0.218 0.333 0.183 0.266]
 [0.134 0.287 0.349 0.23 ]
 [0.219 0.177 0.373 0.231]]

この出力から3つの重要なことが確認できます。第一に、デコーダブロックの出力形状は (1, 3, 64) で入力と同じ形状を保っています。これはTransformerのデコーダ層をスタックできることを意味します。第二に、Self-Attentionの注意重み (1, 4, 3, 3) は正方行列で、因果マスクが正しく効いています。ヘッド0の重みを見ると、1行目は自分自身のみに注目(1.0)、2行目は位置0と位置1のみに注目(未来の位置2は0)、3行目は全位置に注目しています。第三に、Cross-Attentionの注意重み (1, 4, 3, 4) は $3 \times 4$ の長方形行列で、デコーダの各位置がエンコーダの全位置を自由に参照していることが確認できます。

条件注入の機構としてのCross-Attention

ここまで機械翻訳のエンコーダ・デコーダを軸に見てきましたが、Cross-Attentionはより一般的に「条件付けの機構」として捉えられます。本系列を Query、与えたい条件(指示文や系列状のメタdata)を Key・Value に置けば、本系列の各位置が「条件のどの部分を見るか」を動的に決められます。

条件注入としてのCross-Attention

連結(concat)やFiLMが「条件全体から固定の効き方」を作るのに対し、Cross-Attentionは位置ごと・可変長で参照先を変えられるのが強みです。Stable Diffusionがテキストプロンプトを画像生成に注入できるのも、この性質によります。条件が単一のベクトルなら連結やFiLMで十分ですが、条件が系列(文章・履歴)で「どこに注目するか」を選びたいときは、Cross-Attentionが最も適した条件付けになります。

まとめ

本記事では、Cross-Attention(クロスアテンション)の理論と実装について解説しました。

  • Cross-Attentionの本質: Self-Attentionと数式の骨格は同じだが、Queryがデコーダ側、Key/Valueがエンコーダ側から来る点が異なる。これにより「異なる系列間の対応関係」を動的に学習できる
  • Transformerデコーダ内の役割: Masked Self-Attention → Cross-Attention → FFN の3段構成で、Cross-Attentionはエンコーダ側の情報をデコーダに注入する「ゲート」として機能する
  • Multi-Head拡張: 複数のヘッドにより、構文的対応、意味的対応、参照関係など、複数の異なる側面の情報を同時に捉える
  • 幅広い応用: 機械翻訳だけでなく、Stable Diffusionのテキスト条件付けやマルチモーダルモデルの情報融合など、現代の深層学習の広範なタスクで活躍
  • 注意重み行列の形状: Self-Attentionは $n \times n$ の正方行列、Cross-Attentionは $n \times m$ の長方形行列。この違いが両者の役割の違いを象徴的に表している

Cross-Attentionは、2つの異なる情報源を結びつけるための基本的かつ強力な仕組みです。この理解を基盤として、Transformerのデコーダ全体の設計や、テキスト→画像生成の仕組みなど、さらに発展的なトピックに進むことができます。

次のステップとして、以下の記事も参考にしてください。

画像なし
Transformerのアーキテクチャ
Encoder-Decoderの全体構成と各コンポーネントの役割を理解する
画像なし
Attention Is All You Need 論文解説
Transformer原論文の詳細な解説と各コンポーネントの設計思想