Bahdanau AttentionとLuong Attentionの違いを理論から実装まで完全解説

機械翻訳モデルに「この単語を訳すとき、原文のどこを見ればいいか」を教える仕組み — それがAttention機構です。2014年にBahdanauらが最初の実用的なAttentionを提案して以来、わずか3年でTransformerが登場するまでの間に、Attention機構には複数の重要なバリエーションが生まれました。中でもBahdanau Attention(加法的注意)とLuong Attention(乗法的注意)は、現代のAttention機構を理解する上で避けて通れない2つの原型です。

しかし、両者の違いを「スコア関数の形が違う」という一言で片付けてしまうと、設計思想の本質的な差異を見逃してしまいます。実は両者は、スコア関数の形だけでなく、Attentionを計算するタイミングデコーダの隠れ状態の使い方エンコーダの構造計算コストのトレードオフなど、アーキテクチャ全体にわたって異なる設計判断を行っています。この違いを理解することで、Transformerの Scaled Dot-Product Attentionがなぜあの形になったのか、という歴史的必然も見えてきます。

Bahdanau AttentionとLuong Attentionの違いを深く理解すると、以下のような場面で直接役立ちます。

  • 機械翻訳システムの設計: どのAttention機構をベースにするか判断できる
  • 文書要約: エンコーダ・デコーダ型モデルでAttentionのバリエーションを使い分けられる
  • 音声認識: CTC/Attentionハイブリッドモデルでのスコア関数の選択根拠を理解できる
  • Transformerの理解: Scaled Dot-Product Attentionに至る歴史的経緯がわかる

本記事の内容

  • Seq2Seqの固定長ボトルネック問題の復習
  • Bahdanau Attention(加法的注意)のスコア関数と計算フローの導出
  • Luong Attention(乗法的注意)の3種のスコア関数とGlobal/Localの区別
  • 両者の設計思想を7つの観点から比較
  • PyTorchによる両Attentionクラスの実装
  • 合成データでの注意重みヒートマップの比較実験

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

画像なし
Seq2Seqモデルの理論と実装
エンコーダ・デコーダアーキテクチャの基本をゼロから解説
画像なし
Attention機構の基礎 — なぜAttentionが必要なのか
Attention機構の基本概念とSeq2Seqへの導入をわかりやすく解説

Seq2Seqの固定長ボトルネック問題を復習する

Bahdanau AttentionとLuong Attentionはどちらも、Seq2Seqモデルの「ある欠陥」を解決するために提案されました。まずはその欠陥を明確にしておきましょう。

Seq2Seq(Sequence-to-Sequence)モデルは、入力系列をエンコーダで処理し、最終隠れ状態をコンテキストベクトル $\bm{c}$ としてデコーダに渡す構造です。エンコーダは入力系列 $(\bm{x}_1, \bm{x}_2, \ldots, \bm{x}_{T_x})$ を逐次的に処理します。

$$ \bm{h}_t = f_{\text{enc}}(\bm{x}_t, \bm{h}_{t-1}) $$

そしてコンテキストベクトルは最終隠れ状態そのものです。

$$ \bm{c} = \bm{h}_{T_x} $$

デコーダは、この固定されたコンテキストベクトル $\bm{c}$ を使って出力系列を生成します。

$$ \bm{s}_t = f_{\text{dec}}(\bm{y}_{t-1}, \bm{s}_{t-1}, \bm{c}) $$

ここで致命的な問題が生じます。入力が5単語でも50単語でも、同じ次元の1つのベクトル $\bm{c}$ に全情報を詰め込まなければなりません。たとえば隠れ状態の次元が256だとすると、256個の実数値で入力文全体の意味を表現することになります。短い文ならまだしも、長い文では情報が溢れ出してしまうのです。

Cho et al. (2014) の実験では、入力文の長さが20〜30単語を超えるとBLEUスコアが急激に低下することが報告されています。人間の翻訳者は長い文を訳すとき、原文の該当箇所を何度も見返します。モデルにも同じことができれば、この問題を解決できるはずです。

この発想を具体化する方法として、デコーダの各出力ステップにおいて、エンコーダの全ての隠れ状態を参照し、現在の出力に関連する部分に重みを付けて情報を取り出す仕組みが考案されました。これがAttention機構の核心的アイデアです。

では、具体的にどのように「関連度」を計算し、情報を取り出すのでしょうか。ここで2つの異なるアプローチが提案されました。Bahdanau et al. (2014) の加法的注意と、Luong et al. (2015) の乗法的注意です。まずはBahdanau Attentionから、その数理を丁寧に追っていきましょう。

Bahdanau Attention(加法的注意)の理論

基本的なアイデア — 翻訳者の「見返し」を数式にする

Bahdanau et al. (2014) の論文「Neural Machine Translation by Jointly Learning to Align and Translate」は、Attention機構を機械翻訳に初めて導入した画期的な研究です。

その基本的なアイデアを、日常的な場面で考えてみましょう。英語の長い文を日本語に翻訳しているとき、あなたは文全体を一度に記憶するのではなく、訳している箇所に応じて原文の該当部分を見返すはずです。「the cat sat on the mat」を訳すとき、「猫が」と書く瞬間は「the cat」の部分を見ていますし、「マットの上に」と書く瞬間は「on the mat」を見ています。

Bahdanau Attentionは、この「見返し」の行動を次のように数式化します。固定されたコンテキストベクトル $\bm{c}$ の代わりに、デコーダの各出力ステップ $i$ ごとに動的なコンテキストベクトル $\bm{c}_i$ を計算するのです。

$$ \bm{c}_i = \sum_{j=1}^{T_x} \alpha_{ij} \bm{h}_j $$

ここで $\bm{h}_j$ はエンコーダの時刻 $j$ における隠れ状態、$\alpha_{ij}$ は「デコーダの時刻 $i$ で、エンコーダの時刻 $j$ にどれだけ注目するか」を表す重みです。この重みが大きいほど、その位置の情報がコンテキストに多く反映されます。

アライメントスコアの計算

では、Attention重み $\alpha_{ij}$ をどう求めるのでしょうか。まず、デコーダの隠れ状態とエンコーダの各隠れ状態の「関連度」を測るアライメントスコア(整合性スコア)$e_{ij}$ を計算します。

Bahdanauらは、このスコアを以下のような加法的(additive)な形で定義しました。

$$ \begin{equation} e_{ij} = \bm{v}^T \tanh(\bm{W}_1 \bm{h}_j + \bm{W}_2 \bm{s}_{i-1}) \end{equation} $$

ここで重要なのは、使用するデコーダの隠れ状態が $\bm{s}_{i-1}$(前の時刻の状態)であることです。時刻 $i$ の出力を生成するためのAttentionを計算する時点では、まだ時刻 $i$ の隠れ状態 $\bm{s}_i$ は計算されていません。したがって、1つ前の隠れ状態 $\bm{s}_{i-1}$ を使います。

各パラメータの意味を整理しましょう。

  • $\bm{W}_1 \in \mathbb{R}^{d_a \times d_h}$: エンコーダの隠れ状態 $\bm{h}_j \in \mathbb{R}^{d_h}$ をアライメント空間 $\mathbb{R}^{d_a}$ に射影する重み行列
  • $\bm{W}_2 \in \mathbb{R}^{d_a \times d_s}$: デコーダの隠れ状態 $\bm{s}_{i-1} \in \mathbb{R}^{d_s}$ をアライメント空間に射影する重み行列
  • $\bm{v} \in \mathbb{R}^{d_a}$: アライメント空間のベクトルをスカラーに変換するパラメータ
  • $d_a$: アライメント空間の次元(ハイパーパラメータ)

この数式が何をしているのか、段階的に理解しましょう。

まず、$\bm{W}_1 \bm{h}_j$ と $\bm{W}_2 \bm{s}_{i-1}$ の2つの線形変換で、エンコーダとデコーダの隠れ状態を同じ次元のアライメント空間に写像します。これにより、もともと異なる次元を持つかもしれない2つのベクトルを比較可能にしています。

次に、$\tanh$ 非線形活性化関数を適用します。$\tanh$ を使うことで、加算しただけでは捉えられない2つのベクトル間の非線形な関係を表現できます。$\tanh$ の出力は $[-1, 1]$ の範囲に収まるため、数値的な安定性も確保されます。

最後に、$\bm{v}^T$ との内積でスカラー値 $e_{ij}$ を得ます。$d_a$ 次元のベクトルを1つの実数値に集約する操作です。このスカラーが、「デコーダの時刻 $i$ から見て、エンコーダの時刻 $j$ がどれだけ関連しているか」を表します。

Softmaxによる正規化

アライメントスコア $e_{ij}$ はそのままでは比較しにくいため、Softmax関数で確率分布に変換します。

$$ \begin{equation} \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})} \end{equation} $$

Softmax正規化により、以下の2つの性質が保証されます。

$$ \alpha_{ij} \geq 0, \quad \sum_{j=1}^{T_x} \alpha_{ij} = 1 $$

つまり、Attention重み $\alpha_{ij}$ は入力系列上の確率分布になります。「デコーダの時刻 $i$ で、入力のどの位置に注目する確率がどれだけか」を表現しているのです。

コンテキストベクトルの計算

Attention重みが得られたら、エンコーダの隠れ状態の重み付き和でコンテキストベクトルを計算します。

$$ \begin{equation} \bm{c}_i = \sum_{j=1}^{T_x} \alpha_{ij} \bm{h}_j \end{equation} $$

この操作は、「関連度が高い入力位置の情報を多く集め、関連度が低い位置の情報を少なく集める」ことに相当します。いわば「ソフトな選択」であり、微分可能であるため勾配ベースの学習が可能です。

デコーダの更新規則

Bahdanau Attentionでは、コンテキストベクトル $\bm{c}_i$ はデコーダのRNN(GRU)への入力の一部として使われます。具体的には以下の順序で計算が進みます。

ステップ1: 前の隠れ状態 $\bm{s}_{i-1}$ と全エンコーダ隠れ状態 $\bm{h}_1, \ldots, \bm{h}_{T_x}$ からAttention重み $\alpha_{ij}$ を計算する。

ステップ2: コンテキストベクトル $\bm{c}_i = \sum_j \alpha_{ij} \bm{h}_j$ を計算する。

ステップ3: コンテキストベクトル $\bm{c}_i$ と前の出力 $\bm{y}_{i-1}$ を入力として、デコーダの隠れ状態を更新する。

$$ \bm{s}_i = f_{\text{dec}}(\bm{y}_{i-1}, \bm{s}_{i-1}, \bm{c}_i) $$

ステップ4: 更新された隠れ状態 $\bm{s}_i$ とコンテキストベクトル $\bm{c}_i$ を連結し、出力確率分布を計算する。

$$ P(y_i \mid y_{

この計算順序のポイントは、Attention計算がRNNの更新の前に行われることです。したがって、コンテキストベクトルがRNNの状態更新に直接影響を与えます。

双方向RNNエンコーダ

Bahdanauらは、エンコーダに双方向RNN(Bidirectional RNN)を採用しました。順方向RNNは $\bm{x}_1$ から $\bm{x}_{T_x}$ へ、逆方向RNNは $\bm{x}_{T_x}$ から $\bm{x}_1$ へ系列を処理します。

$$ \overrightarrow{\bm{h}}_j = \overrightarrow{f}_{\text{enc}}(\bm{x}_j, \overrightarrow{\bm{h}}_{j-1}) $$

$$ \overleftarrow{\bm{h}}_j = \overleftarrow{f}_{\text{enc}}(\bm{x}_j, \overleftarrow{\bm{h}}_{j+1}) $$

両方向の隠れ状態を連結して、エンコーダの隠れ状態とします。

$$ \bm{h}_j = [\overrightarrow{\bm{h}}_j ; \overleftarrow{\bm{h}}_j] $$

双方向RNNを使う理由は明確です。通常の(単方向)RNNでは、$\bm{h}_j$ は $\bm{x}_1, \ldots, \bm{x}_j$ の情報しか含みません。つまり、位置 $j$ より後ろの文脈情報が欠落しています。双方向RNNにすることで、$\bm{h}_j$ は入力系列全体の文脈を反映したベクトルになります。

たとえば「I saw her duck」という文を考えましょう。「duck」がアヒル(名詞)なのか、かがむ(動詞)なのかは、前後の文脈から判断する必要があります。双方向エンコーダなら、各位置の隠れ状態が全文脈を考慮できるため、より正確なAttentionの計算が可能になります。

ここまでで、Bahdanau Attentionの数理的な構造を把握しました。このAttentionは非線形活性化関数($\tanh$)を含む加法的なスコア関数を用い、デコーダのRNN更新のにAttentionを計算するのが特徴です。では、Luong Attentionはどのような異なるアプローチを取ったのでしょうか。

Luong Attention(乗法的注意)の理論

Bahdanauの問題意識を受け継ぎつつ再設計する

Luong et al. (2015) は「Effective Approaches to Attention-based Neural Machine Translation」において、Bahdanau Attentionのアイデアを受け継ぎつつ、いくつかの重要な設計変更を行いました。

Bahdanau Attentionは画期的でしたが、加法的スコア関数には $\bm{W}_1, \bm{W}_2, \bm{v}$ の3つのパラメータ行列が必要で、各ペアについて $\tanh$ の計算も行うため、計算コストがやや高いという課題がありました。Luongらは、よりシンプルで計算効率の高いスコア関数を提案するとともに、Attentionの計算タイミングを変更してアーキテクチャ全体を再設計しました。

3種のスコア関数

Luong Attentionの最も特徴的な点は、3種類のスコア関数を提案し、比較実験を行ったことです。

Dot Product(内積スコア)

最もシンプルな形で、追加のパラメータが一切不要です。

$$ \begin{equation} \text{score}(\bm{h}_t, \bar{\bm{h}}_s) = \bm{h}_t^T \bar{\bm{h}}_s \end{equation} $$

ここで $\bm{h}_t$ はデコーダの現在の隠れ状態、$\bar{\bm{h}}_s$ はエンコーダの位置 $s$ における隠れ状態です。内積は2つのベクトルの方向の類似度を測るため、「デコーダの現在の状態と似た方向を持つエンコーダの隠れ状態」に高いスコアが付きます。

ただし、内積スコアを使うには $\bm{h}_t$ と $\bar{\bm{h}}_s$ の次元が同じである必要があります。

General(一般化内積スコア)

パラメータ行列 $\bm{W}_a$ を1つ導入し、次元が異なるベクトル間の関連度も計算できるようにします。

$$ \begin{equation} \text{score}(\bm{h}_t, \bar{\bm{h}}_s) = \bm{h}_t^T \bm{W}_a \bar{\bm{h}}_s \end{equation} $$

$\bm{W}_a \in \mathbb{R}^{d_h \times d_{\bar{h}}}$ は学習可能な重み行列です。$\bm{W}_a \bar{\bm{h}}_s$ をまず計算すれば、エンコーダの隠れ状態を線形変換してからデコーダの隠れ状態との内積を取ることになります。$\bm{W}_a$ が恒等行列のとき、Dot Productスコアに退化します。

Concat(連結スコア)

Bahdanau Attentionのスコア関数に近い形です。

$$ \begin{equation} \text{score}(\bm{h}_t, \bar{\bm{h}}_s) = \bm{v}_a^T \tanh(\bm{W}_a [\bm{h}_t ; \bar{\bm{h}}_s]) \end{equation} $$

$[\bm{h}_t ; \bar{\bm{h}}_s]$ は2つのベクトルの連結(concatenation)です。Bahdanau Attentionとの違いは、2つの射影行列 $\bm{W}_1, \bm{W}_2$ を使って別々に変換する代わりに、連結ベクトルを1つの行列 $\bm{W}_a$ で変換する点です。

Luongらの実験では、Generalスコアが全体的に最も良い性能を示し、Dot Productは計算コストの低さに対して十分な性能を発揮しました。

Attentionの計算タイミング — Bahdanauとの決定的な違い

Luong Attentionの最も重要な設計変更は、Attentionの計算タイミングです。

Bahdanau Attentionでは、Attention計算にデコーダの前の時刻の隠れ状態 $\bm{s}_{i-1}$ を使いました。これは、コンテキストベクトルをRNNの更新に使うためです。

Luong Attentionでは、デコーダの現在の時刻の隠れ状態 $\bm{h}_t$ を使います。つまり、まずRNNの更新を行ってから、Attentionを計算するのです。

具体的な計算フローは以下の通りです。

ステップ1: デコーダRNNを通常通り更新して $\bm{h}_t$ を得る。

$$ \bm{h}_t = f_{\text{dec}}(\bm{y}_{t-1}, \bm{h}_{t-1}) $$

ステップ2: $\bm{h}_t$ と全エンコーダ隠れ状態からアライメントスコアを計算する。

$$ a_t(s) = \text{align}(\bm{h}_t, \bar{\bm{h}}_s) = \frac{\exp(\text{score}(\bm{h}_t, \bar{\bm{h}}_s))}{\sum_{s’} \exp(\text{score}(\bm{h}_t, \bar{\bm{h}}_{s’}))} $$

ステップ3: コンテキストベクトルを計算する。

$$ \bm{c}_t = \sum_s a_t(s) \bar{\bm{h}}_s $$

ステップ4: デコーダの隠れ状態とコンテキストベクトルを結合して、Attentional Hidden State $\tilde{\bm{h}}_t$ を生成する。

$$ \begin{equation} \tilde{\bm{h}}_t = \tanh(\bm{W}_c [\bm{c}_t ; \bm{h}_t]) \end{equation} $$

ステップ5: $\tilde{\bm{h}}_t$ から出力確率分布を計算する。

$$ P(y_t \mid y_{

この設計の利点は、RNNの更新とAttentionの計算を分離できることです。Bahdanau Attentionでは、Attention → RNN更新 → 出力 という流れで、Attentionの計算がRNNの内部に組み込まれていました。Luong Attentionでは、RNN更新 → Attention → 出力 という流れで、Attentionは独立したモジュールとして後付けされます。これにより、既存のSeq2Seqモデルにも容易にAttentionを追加できるようになりました。

Global Attention vs Local Attention

Luongらは、Attentionの範囲に関して2つの戦略を提案しました。

Global Attention は、エンコーダの全ての隠れ状態に対してAttentionを計算します。これはBahdanau Attentionと同じスコープです。全入力位置を考慮するため情報の取りこぼしはありませんが、入力が長い場合の計算コストは $O(T_x)$ です。

$$ \bm{c}_t = \sum_{s=1}^{T_x} a_t(s) \bar{\bm{h}}_s $$

Local Attention は、デコーダの各時刻 $t$ に対して、エンコーダ上の特定の位置周辺のみにAttentionを計算します。まず、注目する中心位置 $p_t$ を決定し、ウィンドウ $[p_t – D, p_t + D]$ の範囲のみで重み付き和を計算します。

中心位置 $p_t$ の決め方には2つの方法があります。

Monotonic Alignment(単調整列): 単純に $p_t = t$ とする。入力と出力が概ね同じ順序で対応している場合(例: 語順が近い言語間の翻訳)に適しています。

Predictive Alignment(予測整列): ニューラルネットワークで $p_t$ を予測する。

$$ p_t = T_x \cdot \sigma(\bm{v}_p^T \tanh(\bm{W}_p \bm{h}_t)) $$

ここで $\sigma$ はシグモイド関数で、出力を $[0, T_x]$ の範囲に制限します。さらに、中心位置からの距離に応じたガウシアン重みを掛けます。

$$ a_t(s) = \text{align}(\bm{h}_t, \bar{\bm{h}}_s) \cdot \exp\left(-\frac{(s – p_t)^2}{2\sigma^2}\right) $$

ここで $\sigma = D / 2$ とすることが多いです。この設計により、ウィンドウの中心に近い位置ほど高い重みを持ち、端に行くほど重みが減衰します。

Local Attentionの利点は、計算量が入力長 $T_x$ に依存せず $O(D)$ で済むことです。非常に長い入力系列を扱う場合にスケーラビリティの面で有利になります。ただし、ウィンドウの外にある重要な情報を見落とすリスクがあります。

ここまでで、Luong Attentionの3種のスコア関数、独特の計算タイミング、Global/Localの2つの戦略を理解しました。次に、Bahdanau AttentionとLuong Attentionの違いを体系的に整理し、それぞれの設計判断の意味を深掘りしましょう。

Bahdanau AttentionとLuong Attentionの体系的比較

7つの観点から比較する

ここまでの議論を踏まえ、両者の違いを7つの観点から整理します。

観点 Bahdanau Attention Luong Attention
論文 Bahdanau et al. (2014) Luong et al. (2015)
スコア関数 $\bm{v}^T \tanh(\bm{W}_1 \bm{h}_j + \bm{W}_2 \bm{s}_{i-1})$ dot / general / concat の3種
スコアの種別 加法的(Additive) 乗法的(Multiplicative)※dotとgeneral
使用するデコーダ状態 $\bm{s}_{i-1}$(前の時刻) $\bm{h}_t$(現在の時刻)
計算タイミング RNN更新の前 RNN更新の後
エンコーダ 双方向RNN(BiRNN) 単方向RNN(片方向)※論文のデフォルト
Attentionの範囲 Global のみ Global / Local を選択可

スコア関数の設計思想の違い

Bahdanau Attentionの加法的スコア関数は、2つのベクトルをそれぞれ線形変換してから加算し、非線形関数を通すという「ニューラルネットワーク的」な設計です。2層のフィードフォワードネットワークとみなせます。入力層が $[\bm{h}_j; \bm{s}_{i-1}]$、隠れ層が $\tanh(\bm{W}_1 \bm{h}_j + \bm{W}_2 \bm{s}_{i-1})$、出力層が $\bm{v}^T \cdot (\text{隠れ層})$ です。

一方、Luong Attentionの乗法的スコア関数(dot, general)は、2つのベクトルの内積を直接計算するという「幾何学的」な設計です。内積はベクトルの方向の類似度を測る最も基本的な操作であり、行列積として効率的に計算できます。

計算量の観点では、加法的スコアは $O(d_a \times (d_h + d_s))$ のパラメータと $O(d_a)$ の非線形演算が必要です。乗法的スコア(dot)は追加パラメータがゼロで、$O(d_h)$ の内積演算のみです。系列が長くなりエンコーダの位置数 $T_x$ が増えると、この計算量の差は無視できなくなります。

計算タイミングの違いが持つ意味

Bahdanau Attentionでは、$\bm{s}_{i-1}$ からAttentionを計算し、得られたコンテキストベクトル $\bm{c}_i$ をRNNの入力に含めて $\bm{s}_i$ を計算します。つまり、Attentionの結果がRNNの状態遷移に直接影響します。

$$ \bm{s}_i = f_{\text{dec}}(\bm{y}_{i-1}, \bm{s}_{i-1}, \bm{c}_i) $$

Luong Attentionでは、まずRNNを更新して $\bm{h}_t$ を得てから、$\bm{h}_t$ を使ってAttentionを計算します。コンテキストベクトルはRNNの状態遷移には影響せず、出力の計算にのみ使われます

$$ \bm{h}_t = f_{\text{dec}}(\bm{y}_{t-1}, \bm{h}_{t-1}), \quad \tilde{\bm{h}}_t = \tanh(\bm{W}_c [\bm{c}_t ; \bm{h}_t]) $$

Bahdanauの方式は、Attentionの情報がRNNの隠れ状態に蓄積されるため、原理的にはより豊かな表現が可能です。一方、Luongの方式は、RNNとAttentionが分離されているため、実装がシンプルで、モジュール性が高いという利点があります。

エンコーダの構造の違い

Bahdanau Attentionは双方向RNNエンコーダを前提としています。各位置の隠れ状態が文全体の文脈を反映するため、Attentionの計算がより正確になります。ただし、計算量はおよそ2倍になります。

Luongらの論文では、実験に多層の単方向LSTMを使っています。ただし、これはLuong Attentionが双方向エンコーダを使えないという意味ではありません。あくまでデフォルトの実験設定としての違いです。実用上は、どちらのAttention機構でも双方向エンコーダを組み合わせることが一般的です。

歴史的な文脈と後世への影響

Bahdanau AttentionとLuong Attentionの比較は、その後のAttention機構の発展に大きな影響を与えました。

Luong Attentionのdotスコアは、パラメータ不要で計算が高速という特性から、Vaswani et al. (2017) のTransformerにおけるScaled Dot-Product Attentionの直接的な先祖と言えます。Transformerのスコア関数 $\frac{\bm{Q}\bm{K}^T}{\sqrt{d_k}}$ は、Luongのdotスコアにスケーリングを加えたものです。

一方、Bahdanau Attentionの加法的スコアは、Attentionをフィードフォワードネットワークで計算するという発想を広めました。この考え方は、後のAttention機構のバリエーション(例: Locationベースの Attention)にも影響を与えています。

ここまでの理論的な比較で、両者の設計思想の違いが明確になりました。次は、これらの違いをPyTorchのコードで実装し、実際に動かして確認していきましょう。

PyTorchによる実装

実装の方針

ここでは、Bahdanau AttentionとLuong AttentionのAttentionモジュール単体をPyTorchで実装します。完全なSeq2Seqモデル(エンコーダ・デコーダ全体)の実装ではなく、Attention機構の核心部分に焦点を当てます。

実装するクラスは以下の3つです。

  1. BahdanauAttention: 加法的スコア関数によるAttention
  2. LuongAttention: 3種のスコア関数(dot, general, concat)を切り替え可能なAttention
  3. 合成データで両者の注意重みを可視化するスクリプト

Bahdanau Attentionの実装

まず、Bahdanau Attentionクラスを実装します。スコア関数 $e_{ij} = \bm{v}^T \tanh(\bm{W}_1 \bm{h}_j + \bm{W}_2 \bm{s}_{i-1})$ を忠実にコードに落とし込みます。

import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    """Bahdanau Attention(加法的注意機構)

    スコア関数: e_ij = v^T tanh(W1 h_j + W2 s_{i-1})

    Args:
        encoder_dim: エンコーダ隠れ状態の次元
        decoder_dim: デコーダ隠れ状態の次元
        attention_dim: アライメント空間の次元
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        # W1: エンコーダ隠れ状態をアライメント空間に射影
        self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
        # W2: デコーダ隠れ状態をアライメント空間に射影
        self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
        # v: アライメント空間からスカラーへ
        self.v = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim) エンコーダの全隠れ状態
            decoder_hidden: (batch, decoder_dim) デコーダの前時刻の隠れ状態

        Returns:
            context: (batch, encoder_dim) コンテキストベクトル
            attn_weights: (batch, src_len) Attention重み
        """
        # encoder_outputs を射影: (batch, src_len, attention_dim)
        encoder_proj = self.W1(encoder_outputs)

        # decoder_hidden を射影: (batch, attention_dim)
        # unsqueeze で (batch, 1, attention_dim) にしてブロードキャスト
        decoder_proj = self.W2(decoder_hidden).unsqueeze(1)

        # 加法的スコアの計算: tanh(W1*h_j + W2*s_{i-1})
        # (batch, src_len, attention_dim)
        combined = torch.tanh(encoder_proj + decoder_proj)

        # v^T で各位置のスコアをスカラーに: (batch, src_len, 1) → (batch, src_len)
        scores = self.v(combined).squeeze(-1)

        # Softmaxで正規化して確率分布に
        attn_weights = F.softmax(scores, dim=-1)

        # コンテキストベクトル: 重み付き和
        # (batch, 1, src_len) × (batch, src_len, encoder_dim) → (batch, 1, encoder_dim)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attn_weights

このコードのポイントを説明します。encoder_proj + decoder_proj の部分では、PyTorchのブロードキャスト機能を活用しています。encoder_proj(batch, src_len, attention_dim) で、decoder_proj(batch, 1, attention_dim) です。ブロードキャストにより、デコーダの射影が全てのエンコーダ位置に対して加算されます。これにより、forループを使わずに全位置のスコアを一度に計算できます。

torch.bmm(バッチ行列積)を使ったコンテキストベクトルの計算も重要です。Attention重みを行ベクトルとして扱い、エンコーダ出力行列との行列積で重み付き和を計算しています。これはエレガントで効率的な実装パターンです。

Luong Attentionの実装

次に、3種のスコア関数を切り替えられるLuong Attentionクラスを実装します。

import torch
import torch.nn as nn
import torch.nn.functional as F

class LuongAttention(nn.Module):
    """Luong Attention(乗法的注意機構)

    3種のスコア関数:
        - dot:     score = h_t^T * h_s
        - general: score = h_t^T * W_a * h_s
        - concat:  score = v_a^T * tanh(W_a * [h_t; h_s])

    Args:
        encoder_dim: エンコーダ隠れ状態の次元
        decoder_dim: デコーダ隠れ状態の次元
        score_fn: スコア関数の種類 ("dot", "general", "concat")
    """
    def __init__(self, encoder_dim, decoder_dim, score_fn="dot"):
        super().__init__()
        self.score_fn = score_fn
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

        if score_fn == "general":
            # W_a: デコーダ次元 → エンコーダ次元 への線形変換
            self.W_a = nn.Linear(encoder_dim, decoder_dim, bias=False)
        elif score_fn == "concat":
            # 連結ベクトルを変換する行列とスカラー化するベクトル
            self.W_a = nn.Linear(encoder_dim + decoder_dim, decoder_dim, bias=False)
            self.v_a = nn.Linear(decoder_dim, 1, bias=False)
        # dot の場合は追加パラメータなし

    def _compute_score(self, decoder_hidden, encoder_outputs):
        """スコア関数を計算する

        Args:
            decoder_hidden: (batch, decoder_dim) 現在のデコーダ隠れ状態
            encoder_outputs: (batch, src_len, encoder_dim) エンコーダの全隠れ状態

        Returns:
            scores: (batch, src_len)
        """
        if self.score_fn == "dot":
            # h_t^T * h_s: 内積
            # (batch, decoder_dim) → (batch, decoder_dim, 1)
            # (batch, src_len, encoder_dim) × (batch, encoder_dim, 1) → (batch, src_len, 1)
            scores = torch.bmm(
                encoder_outputs, decoder_hidden.unsqueeze(-1)
            ).squeeze(-1)

        elif self.score_fn == "general":
            # h_t^T * W_a * h_s
            # まず W_a * h_s を計算: (batch, src_len, decoder_dim)
            encoder_proj = self.W_a(encoder_outputs)
            scores = torch.bmm(
                encoder_proj, decoder_hidden.unsqueeze(-1)
            ).squeeze(-1)

        elif self.score_fn == "concat":
            # v_a^T * tanh(W_a * [h_t; h_s])
            src_len = encoder_outputs.size(1)
            # decoder_hidden を src_len 分繰り返す
            decoder_expanded = decoder_hidden.unsqueeze(1).expand(-1, src_len, -1)
            # [h_t; h_s] を連結: (batch, src_len, encoder_dim + decoder_dim)
            concat = torch.cat([decoder_expanded, encoder_outputs], dim=-1)
            scores = self.v_a(torch.tanh(self.W_a(concat))).squeeze(-1)

        return scores

    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim) エンコーダの全隠れ状態
            decoder_hidden: (batch, decoder_dim) デコーダの現在の隠れ状態

        Returns:
            context: (batch, encoder_dim) コンテキストベクトル
            attn_weights: (batch, src_len) Attention重み
        """
        # スコアを計算
        scores = self._compute_score(decoder_hidden, encoder_outputs)

        # Softmaxで正規化
        attn_weights = F.softmax(scores, dim=-1)

        # コンテキストベクトル: 重み付き和
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attn_weights

Luong Attentionの実装では、_compute_score メソッドで3種のスコア関数を切り替えています。dotスコアでは追加のパラメータなしで内積を計算し、generalスコアでは1つの線形層を通してから内積を取り、concatスコアでは連結・変換・非線形化の3段階を経ます。

Bahdanau Attentionとの実装上の重要な違いは、forward メソッドの引数にあります。Bahdanau Attentionでは decoder_hidden は「前の時刻」の隠れ状態 $\bm{s}_{i-1}$ を想定していますが、Luong Attentionでは「現在の時刻」の隠れ状態 $\bm{h}_t$ を想定しています。この違いは、呼び出し元のデコーダの実装に影響します。

動作確認 — 合成データでの注意重みの可視化

両方のAttentionモジュールが正しく動作することを、合成データで確認します。模擬的な翻訳シナリオを作り、Attention重みをヒートマップとして可視化します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# --- 再現性のためのシード固定 ---
torch.manual_seed(42)
np.random.seed(42)

# --- BahdanauAttention クラス定義 ---
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
        self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
        self.v = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        encoder_proj = self.W1(encoder_outputs)
        decoder_proj = self.W2(decoder_hidden).unsqueeze(1)
        combined = torch.tanh(encoder_proj + decoder_proj)
        scores = self.v(combined).squeeze(-1)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attn_weights

# --- LuongAttention クラス定義 ---
class LuongAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, score_fn="dot"):
        super().__init__()
        self.score_fn = score_fn
        if score_fn == "general":
            self.W_a = nn.Linear(encoder_dim, decoder_dim, bias=False)
        elif score_fn == "concat":
            self.W_a = nn.Linear(encoder_dim + decoder_dim, decoder_dim, bias=False)
            self.v_a = nn.Linear(decoder_dim, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        if self.score_fn == "dot":
            scores = torch.bmm(
                encoder_outputs, decoder_hidden.unsqueeze(-1)
            ).squeeze(-1)
        elif self.score_fn == "general":
            encoder_proj = self.W_a(encoder_outputs)
            scores = torch.bmm(
                encoder_proj, decoder_hidden.unsqueeze(-1)
            ).squeeze(-1)
        elif self.score_fn == "concat":
            src_len = encoder_outputs.size(1)
            dec_exp = decoder_hidden.unsqueeze(1).expand(-1, src_len, -1)
            concat = torch.cat([dec_exp, encoder_outputs], dim=-1)
            scores = self.v_a(torch.tanh(self.W_a(concat))).squeeze(-1)

        attn_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attn_weights

# --- 合成データの生成 ---
batch_size = 1
src_len = 6       # 入力系列の長さ(英語の単語数)
tgt_len = 5       # 出力系列の長さ(日本語の単語数)
hidden_dim = 32   # 隠れ状態の次元
attn_dim = 16     # アライメント空間の次元

# エンコーダの全隠れ状態(ランダム生成)
encoder_outputs = torch.randn(batch_size, src_len, hidden_dim)
# デコーダの各ステップの隠れ状態(ランダム生成)
decoder_hiddens = torch.randn(tgt_len, batch_size, hidden_dim)

# --- 各Attentionのインスタンス生成 ---
bahdanau = BahdanauAttention(hidden_dim, hidden_dim, attn_dim)
luong_dot = LuongAttention(hidden_dim, hidden_dim, score_fn="dot")
luong_general = LuongAttention(hidden_dim, hidden_dim, score_fn="general")
luong_concat = LuongAttention(hidden_dim, hidden_dim, score_fn="concat")

# 全モデルを評価モードに
for m in [bahdanau, luong_dot, luong_general, luong_concat]:
    m.eval()

# --- 各ステップの Attention重みを収集 ---
attention_maps = {}
models = {
    "Bahdanau\n(Additive)": bahdanau,
    "Luong\n(Dot)": luong_dot,
    "Luong\n(General)": luong_general,
    "Luong\n(Concat)": luong_concat,
}

with torch.no_grad():
    for name, model in models.items():
        weights_list = []
        for t in range(tgt_len):
            _, attn_w = model(encoder_outputs, decoder_hiddens[t])
            weights_list.append(attn_w.squeeze(0).numpy())
        attention_maps[name] = np.stack(weights_list)  # (tgt_len, src_len)

# --- ヒートマップの可視化 ---
src_tokens = ["The", "cat", "sat", "on", "the", "mat"]
tgt_tokens = ["猫が", "マットの", "上に", "座って", "いた"]

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for ax, (name, attn_map) in zip(axes, attention_maps.items()):
    im = ax.imshow(attn_map, cmap="YlOrRd", vmin=0, vmax=1, aspect="auto")
    ax.set_xticks(range(src_len))
    ax.set_xticklabels(src_tokens, fontsize=10)
    ax.set_yticks(range(tgt_len))
    ax.set_yticklabels(tgt_tokens, fontsize=10)
    ax.set_xlabel("Source (English)", fontsize=11)
    ax.set_ylabel("Target (Japanese)", fontsize=11)
    ax.set_title(name, fontsize=12, fontweight="bold")

    # 各セルに数値を表示
    for i in range(tgt_len):
        for j in range(src_len):
            val = attn_map[i, j]
            color = "white" if val > 0.4 else "black"
            ax.text(j, i, f"{val:.2f}", ha="center", va="center",
                    fontsize=8, color=color)

plt.colorbar(im, ax=axes, fraction=0.02, pad=0.04, label="Attention Weight")
plt.suptitle("Attention Weight Comparison (Random Initialization)", fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("bahdanau_luong_comparison.png", dpi=150, bbox_inches="tight")
plt.show()

このコードでは、4種類のAttention機構(Bahdanau加法的、Luong dot、Luong general、Luong concat)のAttention重みを並べて比較しています。ランダム初期化された重みパラメータを使っているため、実際の翻訳タスクでの対応関係を反映しているわけではありません。しかし、各Attention機構のスコア分布の特性の違いを観察できます。

Bahdanau Attention(加法的)のヒートマップでは、$\tanh$ 非線形関数を通じてスコアが計算されるため、重みの分布が比較的滑らかになる傾向があります。$\tanh$ の出力が $[-1, 1]$ に制限されるため、極端に大きなスコアが生じにくいのです。

Luong Dot(内積)のヒートマップでは、スケーリングなしの内積を使っているため、隠れ状態の次元 $d$ が大きくなるとスコアの分散が大きくなり、Softmax後の重みが一部の位置に集中しやすくなります。これはTransformerの論文で指摘された問題と同じであり、$\sqrt{d_k}$ でのスケーリングが必要な理由を示しています。

Luong General(一般化内積)では、学習可能な行列 $\bm{W}_a$ によって内積の前にエンコーダ隠れ状態が変換されるため、dot スコアよりも柔軟な対応関係を表現できます。ランダム初期化の時点では分布の違いが微妙ですが、学習が進むにつれて差異が顕著になります。

Luong Concat(連結)はBahdanau Attentionに近い構造を持つため、重みの分布も類似した特性を示します。

ここまでで、両Attentionの実装が正しく動作することを確認しました。次に、もう少し踏み込んだ実験として、対角構造を持つ理想的な対応関係に近い合成データを作り、Attentionが意味のある重みを学習できるかを確認しましょう。

実験: 学習による注意重みの変化

実験の設計

ランダムな重みでの可視化では、各Attention機構の初期状態しか確認できません。ここでは、簡単なコピータスク(入力系列をそのまま出力する)を通じて、Attention重みが対角パターン(各出力位置が対応する入力位置に注目する)を学習できるかを検証します。

コピータスクは、理想的なAttention重みが明確に定義される(対角行列に近い)ため、Attention機構の学習能力を評価するのに適しています。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

# --- 簡易 Seq2Seq + Attention モデル ---
class BahdanauAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, attn_dim):
        super().__init__()
        self.W1 = nn.Linear(enc_dim, attn_dim, bias=False)
        self.W2 = nn.Linear(dec_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, enc_out, dec_h):
        ep = self.W1(enc_out)
        dp = self.W2(dec_h).unsqueeze(1)
        scores = self.v(torch.tanh(ep + dp)).squeeze(-1)
        weights = F.softmax(scores, dim=-1)
        ctx = torch.bmm(weights.unsqueeze(1), enc_out).squeeze(1)
        return ctx, weights

class LuongAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, score_fn="dot"):
        super().__init__()
        self.score_fn = score_fn
        if score_fn == "general":
            self.W_a = nn.Linear(enc_dim, dec_dim, bias=False)

    def forward(self, enc_out, dec_h):
        if self.score_fn == "dot":
            scores = torch.bmm(enc_out, dec_h.unsqueeze(-1)).squeeze(-1)
        elif self.score_fn == "general":
            scores = torch.bmm(
                self.W_a(enc_out), dec_h.unsqueeze(-1)
            ).squeeze(-1)
        weights = F.softmax(scores, dim=-1)
        ctx = torch.bmm(weights.unsqueeze(1), enc_out).squeeze(1)
        return ctx, weights

class Seq2SeqWithAttention(nn.Module):
    """簡易Seq2Seq + Attention(コピータスク用)"""
    def __init__(self, vocab_size, embed_dim, hidden_dim, attn_type="bahdanau"):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder_cell = nn.GRUCell(embed_dim + hidden_dim * 2, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, vocab_size)

        enc_dim = hidden_dim * 2  # 双方向
        if attn_type == "bahdanau":
            self.attention = BahdanauAttention(enc_dim, hidden_dim, hidden_dim)
        elif attn_type == "luong_general":
            self.attention = LuongAttention(enc_dim, hidden_dim, score_fn="general")
        self.attn_type = attn_type

    def forward(self, src, tgt):
        """
        Args:
            src: (batch, src_len) 入力トークンID
            tgt: (batch, tgt_len) ターゲットトークンID(教師強制用)
        Returns:
            logits: (batch, tgt_len, vocab_size)
            all_weights: (batch, tgt_len, src_len)
        """
        # エンコード
        src_emb = self.embedding(src)
        enc_out, enc_h = self.encoder(src_emb)

        # デコーダ初期状態: 双方向の最終隠れ状態を加算
        dec_h = enc_h[0] + enc_h[1]  # (batch, hidden_dim)

        batch_size, tgt_len = tgt.size()
        logits_list = []
        weights_list = []

        for t in range(tgt_len):
            tgt_emb = self.embedding(tgt[:, t])  # (batch, embed_dim)

            # Attention計算
            ctx, attn_w = self.attention(enc_out, dec_h)
            weights_list.append(attn_w)

            # デコーダ更新
            dec_input = torch.cat([tgt_emb, ctx], dim=-1)
            dec_h = self.decoder_cell(dec_input, dec_h)

            logits_list.append(self.output_proj(dec_h))

        logits = torch.stack(logits_list, dim=1)
        all_weights = torch.stack(weights_list, dim=1)
        return logits, all_weights

# --- コピータスク用データ生成 ---
vocab_size = 20
seq_len = 8
num_samples = 2000
embed_dim = 32
hidden_dim = 64
num_epochs = 50
batch_size = 64

# ランダムなトークン列を生成(0はパディング用に予約)
data = torch.randint(1, vocab_size, (num_samples, seq_len))

# --- 2つのモデルを学習 ---
results = {}
for attn_type, label in [("bahdanau", "Bahdanau"), ("luong_general", "Luong (General)")]:
    model = Seq2SeqWithAttention(vocab_size, embed_dim, hidden_dim, attn_type)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    losses = []
    for epoch in range(num_epochs):
        # ミニバッチ学習
        perm = torch.randperm(num_samples)
        epoch_loss = 0.0
        for i in range(0, num_samples, batch_size):
            batch = data[perm[i:i+batch_size]]
            # コピータスク: 入力 = ターゲット
            logits, _ = model(batch, batch)
            loss = criterion(logits.reshape(-1, vocab_size), batch.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / (num_samples // batch_size))

    # 学習後の Attention 重みを取得
    model.eval()
    with torch.no_grad():
        test_seq = data[:1]  # テスト用に1サンプル
        _, attn_weights = model(test_seq, test_seq)
        results[label] = {
            "losses": losses,
            "weights": attn_weights.squeeze(0).numpy()
        }

print("学習完了")

上記のコードでは、Bahdanau AttentionとLuong Attention(General)を組み込んだ簡易Seq2Seqモデルを、コピータスクで学習させています。コピータスクでは入力系列をそのまま出力するだけなので、理想的なAttention重みは対角行列(各出力位置が対応する入力位置に注目する)に近くなるはずです。学習後のAttention重みが本当にそうなるか、次のコードで可視化します。

import numpy as np
import matplotlib.pyplot as plt

# --- 学習曲線と Attention 重みの可視化 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# (1) 学習曲線
for label, res in results.items():
    axes[0].plot(res["losses"], label=label, linewidth=2)
axes[0].set_xlabel("Epoch", fontsize=12)
axes[0].set_ylabel("Cross-Entropy Loss", fontsize=12)
axes[0].set_title("Training Loss (Copy Task)", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# (2) Bahdanau の Attention 重みヒートマップ
attn_b = results["Bahdanau"]["weights"]
im1 = axes[1].imshow(attn_b, cmap="YlOrRd", vmin=0, vmax=1, aspect="auto")
for i in range(attn_b.shape[0]):
    for j in range(attn_b.shape[1]):
        val = attn_b[i, j]
        axes[1].text(j, i, f"{val:.2f}", ha="center", va="center",
                     fontsize=8, color="white" if val > 0.4 else "black")
axes[1].set_xlabel("Source Position", fontsize=12)
axes[1].set_ylabel("Target Position", fontsize=12)
axes[1].set_title("Bahdanau Attention Weights", fontsize=13, fontweight="bold")
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

# (3) Luong の Attention 重みヒートマップ
attn_l = results["Luong (General)"]["weights"]
im2 = axes[2].imshow(attn_l, cmap="YlOrRd", vmin=0, vmax=1, aspect="auto")
for i in range(attn_l.shape[0]):
    for j in range(attn_l.shape[1]):
        val = attn_l[i, j]
        axes[2].text(j, i, f"{val:.2f}", ha="center", va="center",
                     fontsize=8, color="white" if val > 0.4 else "black")
axes[2].set_xlabel("Source Position", fontsize=12)
axes[2].set_ylabel("Target Position", fontsize=12)
axes[2].set_title("Luong (General) Attention Weights", fontsize=13, fontweight="bold")
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("attention_copy_task.png", dpi=150, bbox_inches="tight")
plt.show()

この実験結果からは、いくつかの重要な知見が得られます。

まず、学習曲線(左図)を見ると、両方のモデルがコピータスクの損失を着実に減少させていることが確認できます。Bahdanau AttentionとLuong Attention(General)の収束速度に大きな差はありませんが、タスクの難易度やモデルの設定によっては差が出ることもあります。

Attention重みのヒートマップ(中央・右図)を見ると、学習後のAttention重みが対角線上に集中しているパターンが観察できるはずです。コピータスクでは、出力の位置 $t$ は入力の位置 $t$ に対応するのが正解であり、Attention機構がこの対応関係を学習できたことを示しています。

Bahdanau AttentionとLuong Attention(General)のヒートマップを比べると、対角パターンの「鮮明さ」に違いが見られることがあります。加法的スコアの方がスコアの範囲が制限されるため、重みの分布がやや滑らかになる傾向があります。一方、乗法的スコアではスコアの分散が大きくなりやすいため、よりシャープなAttentionパターンが得られることがあります。

パラメータ数の比較

両Attentionのパラメータ数を具体的に比較してみましょう。隠れ状態の次元を $d = 256$、アライメント空間の次元を $d_a = 256$ とした場合を計算します。

import torch
import torch.nn as nn

# --- パラメータ数の比較 ---
d_enc = 512   # 双方向エンコーダの場合 256 * 2
d_dec = 256
d_attn = 256

# Bahdanau Attention のパラメータ数
W1_params = d_enc * d_attn       # W1: (d_attn, d_enc)
W2_params = d_dec * d_attn       # W2: (d_attn, d_dec)
v_params = d_attn * 1            # v: (1, d_attn)
bahdanau_total = W1_params + W2_params + v_params

# Luong Attention のパラメータ数
dot_params = 0                            # dot: パラメータなし
general_params = d_enc * d_dec            # general: W_a: (d_dec, d_enc)
concat_W = (d_enc + d_dec) * d_dec        # concat: W_a
concat_v = d_dec * 1                      # concat: v_a
concat_total = concat_W + concat_v

print("=== パラメータ数の比較 ===")
print(f"Bahdanau Attention:      {bahdanau_total:>10,} パラメータ")
print(f"  W1: {W1_params:,}, W2: {W2_params:,}, v: {v_params:,}")
print(f"Luong (dot):             {dot_params:>10,} パラメータ")
print(f"Luong (general):         {general_params:>10,} パラメータ")
print(f"Luong (concat):          {concat_total:>10,} パラメータ")
print(f"  W_a: {concat_W:,}, v_a: {concat_v:,}")

# PyTorchモデルで検証
class BahdanauAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, attn_dim):
        super().__init__()
        self.W1 = nn.Linear(enc_dim, attn_dim, bias=False)
        self.W2 = nn.Linear(dec_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

class LuongAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, score_fn="dot"):
        super().__init__()
        if score_fn == "general":
            self.W_a = nn.Linear(enc_dim, dec_dim, bias=False)

b = BahdanauAttention(d_enc, d_dec, d_attn)
l_g = LuongAttention(d_enc, d_dec, "general")

print(f"\n=== PyTorchで検証 ===")
print(f"Bahdanau: {sum(p.numel() for p in b.parameters()):,} パラメータ")
print(f"Luong (general): {sum(p.numel() for p in l_g.parameters()):,} パラメータ")

この比較から、Attention機構のパラメータ数の違いが明確になります。Bahdanau Attentionは $d_{\text{enc}} \times d_a + d_{\text{dec}} \times d_a + d_a$ 個のパラメータを持ちます。典型的な設定($d_{\text{enc}} = 512$, $d_{\text{dec}} = 256$, $d_a = 256$)では約197,000パラメータです。

一方、Luong Attention(dot)はパラメータがゼロです。内積の計算だけなので追加の学習パラメータは不要です。Luong Attention(general)は $d_{\text{enc}} \times d_{\text{dec}}$ 個で約131,000パラメータです。

パラメータ数が少ないということは、学習すべきものが少なく収束が速い可能性がある一方で、表現力が制限される可能性もあります。タスクの複雑さとモデルの容量のバランスが重要です。

この定量的な比較を踏まえて、最後に両Attentionの設計選択がどのような場面で有利になるかをまとめましょう。

実用的なガイドライン — どちらを使うべきか

タスクとモデル規模に応じた選択

理論と実装を踏まえて、実用的な場面でどちらのAttentionを選ぶべきかを考えます。

Bahdanau Attention(加法的)が適する場面

  • エンコーダとデコーダの隠れ状態の次元が異なる場合。加法的スコアは射影行列で次元を揃えるため、柔軟に対応できます
  • 非線形な対応関係を捉える必要がある場合。$\tanh$ による非線形変換が表現力を高めます
  • 比較的小規模なモデルで、Attention機構自体に表現力を持たせたい場合

Luong Attention(乗法的)が適する場面

  • 計算効率を重視する場合。特にdotスコアは追加パラメータなしで高速に計算できます
  • 大規模なモデルで、Attention以外の部分(エンコーダ・デコーダ本体)に十分な表現力がある場合
  • Transformerへの移行を見据えている場合。Scaled Dot-Product Attentionとの親和性が高いです

現代的な観点

2017年のTransformerの登場以降、実用的な多くのモデル(BERT、GPT、T5など)はScaled Dot-Product Attention(Luongのdotスコアにスケーリングを加えたもの)を採用しています。したがって、現代のNLPにおいてはLuong系のAttentionが主流と言えます。

しかし、Bahdanau Attentionの設計思想 — 非線形変換を通じてスコアを計算するというアイデア — は、音声認識のLocation-Sensitive Attentionや、画像キャプショニングのAdaptive Attentionなど、特殊なドメインでは依然として活用されています。

両者の違いを理解していることは、これらの応用分野でAttention機構を設計・改良する際の重要な基盤になります。

まとめ

本記事では、Bahdanau Attention(加法的注意)とLuong Attention(乗法的注意)の理論的な違いを7つの観点から整理し、PyTorchで実装して注意重みを比較しました。

  • Seq2Seqの固定長ボトルネック: 入力全体を1つのベクトルに圧縮する限界がAttention機構の動機
  • Bahdanau Attention: $\bm{v}^T \tanh(\bm{W}_1 \bm{h}_j + \bm{W}_2 \bm{s}_{i-1})$ の加法的スコア。非線形変換による高い表現力を持つが、パラメータ数と計算コストがやや大きい
  • Luong Attention: dot / general / concat の3種のスコア関数。特にdotスコアは追加パラメータなしで高速に計算でき、Transformerの先駆けとなった
  • 計算タイミングの違い: BahdanauはRNN更新の前、LuongはRNN更新の後にAttentionを計算する。これによりアーキテクチャ全体の情報フローが異なる
  • Global vs Local: LuongはAttentionの範囲を限定するLocal Attentionも提案し、長い系列での計算効率を改善した
  • 実装上の違い: パラメータ数、モジュール性、既存モデルへの組み込みやすさに差がある

次のステップとして、これらのAttention機構を発展させたSelf-AttentionTransformerアーキテクチャの記事も参考にしてください。

画像なし
Self-Attentionの理論と実装 — Query・Key・Valueの計算
Self-Attentionの理論をQuery・Key・Valueの線形射影から導出し、Scaled Dot-Product AttentionとMulti-Head Attentionを実装
画像なし
Transformerアーキテクチャの全体像
Self-Attention、Multi-Head Attention、Position Encodingを組み合わせたTransformerの構造を完全解説