Attention機構は、入力データに対して「注目すべき箇所」を動的に特定する仕組みです。2017年のTransformer論文「Attention Is All You Need」で注目を集め、現在では自然言語処理(NLP)、画像認識、音声処理など、深層学習のあらゆる分野で利用されています。
本記事では、Attention機構の基本であるQuery(Q)、Key(K)、Value(V)の概念から、Scaled Dot-Product AttentionとMulti-Head Attentionの数学的定式化とPyTorchでの実装を解説します。
本記事の内容
- Attention機構のQuery, Key, Valueの概念
- Scaled Dot-Product Attentionの数学的定義
- Multi-Head Attentionの構造
- Self-Attentionの仕組み
- PyTorchでの実装
前提知識
この記事を読む前に、以下の概念を押さえておくと理解が深まります。
- ニューラルネットワークの基礎
- 行列演算(行列積、softmax関数)
Attentionの基本概念
Query, Key, Valueの直感
Attentionは「検索」に例えるとわかりやすいです。
- Query (Q): 「何を探しているか」(検索クエリ)
- Key (K): 「各要素の見出し」(検索対象のインデックス)
- Value (V): 「各要素の内容」(検索対象の値)
QueryとKeyの類似度を計算し、類似度が高いKeyに対応するValueを重み付き平均で取得します。
Scaled Dot-Product Attention
最も基本的なAttentionの定式化です。
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^T}{\sqrt{d_k}}\right)\bm{V} $$
ここで、
- $\bm{Q} \in \mathbb{R}^{n \times d_k}$: Queryの行列($n$ はクエリ数)
- $\bm{K} \in \mathbb{R}^{m \times d_k}$: Keyの行列($m$ は要素数)
- $\bm{V} \in \mathbb{R}^{m \times d_v}$: Valueの行列
- $d_k$: Keyの次元数
なぜ $\sqrt{d_k}$ で割るのか
$\bm{Q}\bm{K}^T$ の各要素は、$d_k$ 個の要素の内積です。$Q$, $K$ の各要素が平均0、分散1に従うとき、内積の分散は $d_k$ になります。
$$ \text{Var}(\bm{q}^T\bm{k}) = d_k $$
$d_k$ が大きいと内積の値が大きくなり、softmaxが飽和して勾配が消失します。$\sqrt{d_k}$ で割ることで分散を1に正規化し、学習を安定させます。
Attention重みの計算過程
$$ \bm{A} = \frac{\bm{Q}\bm{K}^T}{\sqrt{d_k}} \in \mathbb{R}^{n \times m} $$
$$ \bm{W} = \text{softmax}(\bm{A}) \in \mathbb{R}^{n \times m} $$
$$ \text{Output} = \bm{W}\bm{V} \in \mathbb{R}^{n \times d_v} $$
$\bm{W}$ の各行は確率分布(和が1)になっており、各Queryが各Valueにどれだけ「注目」しているかを表します。
Multi-Head Attention
1つのAttentionだけでは、複数の異なる観点での「注目」を同時に学習できません。Multi-Head Attentionは、複数のAttentionヘッドを並列に計算し、結果を連結します。
$$ \text{MultiHead}(\bm{Q}, \bm{K}, \bm{V}) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\bm{W}^O $$
$$ \text{head}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q, \bm{K}\bm{W}_i^K, \bm{V}\bm{W}_i^V) $$
- $\bm{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$
- $\bm{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$
- $\bm{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$
- $\bm{W}^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$
各ヘッドは異なる射影を学習し、異なる関係性(構文的、意味的など)を捉えることができます。
Self-Attention
Self-Attentionは、Q, K, Vが全て同じ入力から生成される場合です。入力系列の各要素が、系列内の他の全ての要素との関連度を計算します。
入力 $\bm{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ に対して、
$$ \bm{Q} = \bm{X}\bm{W}^Q, \quad \bm{K} = \bm{X}\bm{W}^K, \quad \bm{V} = \bm{X}\bm{W}^V $$
これにより、「文中の各単語が他のどの単語に関連しているか」を学習できます。
PyTorchでの実装
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(42)
# --- Scaled Dot-Product Attention ---
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, Q, K, V, mask=None):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V
return output, attn_weights
# --- Multi-Head Attention ---
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention()
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 線形射影
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Attention
output, attn_weights = self.attention(Q, K, V, mask)
# ヘッドの連結
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(output)
return output, attn_weights
# --- 動作確認 ---
d_model = 64
n_heads = 8
seq_len = 10
batch_size = 2
mha = MultiHeadAttention(d_model, n_heads)
x = torch.randn(batch_size, seq_len, d_model)
# Self-Attention: Q=K=V=x
output, attn_weights = mha(x, x, x)
print(f"入力: {x.shape}")
print(f"出力: {output.shape}")
print(f"Attention重み: {attn_weights.shape}")
# --- Attention重みの可視化 ---
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(8):
ax = axes[i // 4, i % 4]
w = attn_weights[0, i].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title(f'Head {i+1}')
ax.set_xlabel('Key')
ax.set_ylabel('Query')
plt.suptitle('Multi-Head Attention Weights', fontsize=14)
plt.tight_layout()
plt.show()
各ヘッドが異なるAttentionパターンを学習していることが可視化で確認できます。あるヘッドは近くの要素に注目し、別のヘッドは遠くの要素に注目するなど、多様な関係性を捉えています。
まとめ
本記事では、Attention機構の理論と実装について解説しました。
- Attentionは Query-Key-Value の枠組みで、入力の重要な部分に動的に注目する機構
- Scaled Dot-Product Attentionは $\sqrt{d_k}$ でスケーリングすることでsoftmaxの飽和を防ぐ
- Multi-Head Attentionは複数の射影を並列に計算し、異なる観点からの関係性を捉える
- Self-Attentionは系列内の各要素間の関連度を学習し、Transformerの基盤となっている
次のステップとして、以下の記事も参考にしてください。