アテンションマスク(Attention Mask)は、Transformerにおいて特定のトークンへの注意を制御するための仕組みです。主に2種類のマスクが使われます。Causalマスク(因果マスク)は未来の情報を遮断し、パディングマスクは可変長系列を処理する際にパディングトークンを無視します。
これらのマスクを正しく理解し実装することは、Transformerを使った開発において必須のスキルです。本記事では、各マスクの目的、数式、実装方法を詳しく解説します。
本記事の内容
- Causalマスク(因果マスク)の仕組みと必要性
- パディングマスクの仕組み
- マスクの数学的定義
- PyTorchでの実装
- 複合マスクの作成方法
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
アテンションスコアの計算
復習:Scaled Dot-Product Attention
アテンションの計算を復習しましょう。Query $\bm{Q}$、Key $\bm{K}$、Value $\bm{V}$ に対して、
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$
ここで、$\bm{Q}\bm{K}^\top$ はアテンションスコア行列 $\bm{S} \in \mathbb{R}^{n \times n}$ を形成します。$S_{ij}$ は位置 $i$ のQueryが位置 $j$ のKeyにどれだけ注目するかを表します。
マスクの役割
マスクは、softmax計算前にアテンションスコアを修正することで、特定の位置への注意を遮断します。
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}, \bm{M}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}\right)\bm{V} $$
マスク行列 $\bm{M}$ は、注意を許可する位置には $0$、遮断する位置には $-\infty$ を持ちます。softmaxの性質により、$-\infty$ が入力された位置の出力は $0$ になります。
$$ \text{softmax}(-\infty) = \frac{e^{-\infty}}{\sum_j e^{s_j}} = \frac{0}{\sum_j e^{s_j}} = 0 $$
Causalマスク(因果マスク)
目的
Causalマスク(因果マスク)は、未来の情報を遮断するために使用されます。自己回帰生成モデル(GPTなど)では、時刻 $t$ のトークンを生成する際に、時刻 $t+1, t+2, \ldots$ のトークンを参照してはいけません。
これは因果律(causality)に基づきます。過去が未来に影響を与えることはあっても、未来が過去に影響を与えることはありません。
数学的定義
Causalマスク行列 $\bm{M}_{\text{causal}}$ は、下三角行列の形をとります。
$$ M_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases} $$
系列長 $n = 4$ の場合のマスク行列:
$$ \bm{M}_{\text{causal}} = \begin{pmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{pmatrix} $$
この行列を見ると: – 1行目(位置1): 自分自身のみ参照可能 – 2行目(位置2): 位置1と自分自身を参照可能 – 3行目(位置3): 位置1、2、3を参照可能 – 4行目(位置4): 全位置を参照可能
softmax適用後のアテンション重み
マスクを適用した後のアテンション重み行列は、下三角行列になります。
$$ \bm{A} = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}_{\text{causal}}\right) $$
例えば、スコア行列が以下の場合:
$$ \frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} = \begin{pmatrix} 1.0 & 0.5 & 0.3 & 0.2 \\ 0.8 & 1.2 & 0.6 & 0.4 \\ 0.5 & 0.7 & 1.1 & 0.9 \\ 0.6 & 0.4 & 0.8 & 1.0 \end{pmatrix} $$
マスク適用後:
$$ \frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}_{\text{causal}} = \begin{pmatrix} 1.0 & -\infty & -\infty & -\infty \\ 0.8 & 1.2 & -\infty & -\infty \\ 0.5 & 0.7 & 1.1 & -\infty \\ 0.6 & 0.4 & 0.8 & 1.0 \end{pmatrix} $$
1行目に softmax を適用すると: $$ \text{softmax}([1.0, -\infty, -\infty, -\infty]) = [1.0, 0, 0, 0] $$
パディングマスク
目的
実際のシステムでは、長さの異なる系列をバッチ処理するため、短い系列にはパディングトークンが追加されます。しかし、パディングトークンは実際の内容を持たないため、アテンション計算で無視する必要があります。
例
2つの文をバッチ処理する場合: – 文A: “I love cats” (3トークン) – 文B: “Hello” (1トークン)
パディング後(最大長3):
– 文A: [“I”, “love”, “cats”]
– 文B: [“Hello”, “
数学的定義
パディングマスクは、パディング位置を示すベクトルから生成されます。
パディングマスクベクトル $\bm{m}_{\text{pad}} \in \{0, 1\}^n$: $$ m_{\text{pad}, i} = \begin{cases} 0 & \text{if position } i \text{ is a real token} \\ 1 & \text{if position } i \text{ is a padding token} \end{cases} $$
アテンション計算では、Key側のパディング位置を無視します。マスク行列 $\bm{M}_{\text{pad}}$ は:
$$ M_{\text{pad}, ij} = \begin{cases} 0 & \text{if } m_{\text{pad}, j} = 0 \\ -\infty & \text{if } m_{\text{pad}, j} = 1 \end{cases} $$
これは、すべてのQuery位置に対して、パディングされたKey位置への注意を遮断します。
例
文B: [“Hello”, “
$$ \bm{m}_{\text{pad}} = [0, 1, 1] $$
アテンションマスク行列:
$$ \bm{M}_{\text{pad}} = \begin{pmatrix} 0 & -\infty & -\infty \\ 0 & -\infty & -\infty \\ 0 & -\infty & -\infty \end{pmatrix} $$
すべての行で、パディング位置(列2、3)が $-\infty$ になっています。
PyTorchでの実装
Causalマスクの生成
import torch
import torch.nn.functional as F
def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
"""
Causalマスク(因果マスク)を生成
Args:
seq_len: 系列長
device: デバイス
Returns:
mask: (seq_len, seq_len) のマスク行列
参照可能な位置は 0、参照不可の位置は -inf
"""
# 上三角行列を作成(対角線を含まない)
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
# True/False を -inf/0 に変換
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
# 使用例
seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causalマスク:")
print(causal_mask)
出力:
Causalマスク:
tensor([[0., -inf, -inf, -inf],
[0., 0., -inf, -inf],
[0., 0., 0., -inf],
[0., 0., 0., 0.]])
パディングマスクの生成
def create_padding_mask(
input_ids: torch.Tensor,
pad_token_id: int,
) -> torch.Tensor:
"""
パディングマスクを生成
Args:
input_ids: トークンID列 (batch_size, seq_len)
pad_token_id: パディングトークンのID
Returns:
mask: (batch_size, 1, 1, seq_len) のマスク
実トークンは 0、パディングは -inf
"""
# パディング位置を特定
padding_positions = (input_ids == pad_token_id) # (batch_size, seq_len)
# アテンション計算用に次元を拡張
# (batch_size, 1, 1, seq_len) -> ブロードキャストで (batch_size, n_heads, seq_len, seq_len) に対応
mask = padding_positions.unsqueeze(1).unsqueeze(2)
# True/False を -inf/0 に変換
mask = mask.float().masked_fill(mask == 1, float('-inf'))
return mask
# 使用例
# バッチ2つ、系列長5、パディングID = 0
input_ids = torch.tensor([
[1, 2, 3, 4, 5], # 文A: パディングなし
[1, 2, 0, 0, 0], # 文B: 位置2,3,4がパディング
])
pad_token_id = 0
padding_mask = create_padding_mask(input_ids, pad_token_id)
print("パディングマスクの形状:", padding_mask.shape)
print("文Bのパディングマスク:")
print(padding_mask[1, 0, 0, :])
出力:
パディングマスクの形状: torch.Size([2, 1, 1, 5])
文Bのパディングマスク:
tensor([0., 0., -inf, -inf, -inf])
複合マスクの作成
Decoderでは、Causalマスクとパディングマスクを組み合わせて使用します。
def create_decoder_mask(
input_ids: torch.Tensor,
pad_token_id: int,
) -> torch.Tensor:
"""
Decoder用の複合マスクを生成(Causal + Padding)
Args:
input_ids: トークンID列 (batch_size, seq_len)
pad_token_id: パディングトークンのID
Returns:
mask: (batch_size, 1, seq_len, seq_len) のマスク
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Causalマスク (1, 1, seq_len, seq_len)
causal_mask = create_causal_mask(seq_len, device=device)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# パディングマスク (batch_size, 1, 1, seq_len)
padding_mask = create_padding_mask(input_ids, pad_token_id)
# 複合マスク(両方の条件を満たす位置のみ参照可能)
# -inf + 0 = -inf, -inf + -inf = -inf, 0 + 0 = 0
combined_mask = causal_mask + padding_mask
return combined_mask
# 使用例
input_ids = torch.tensor([
[1, 2, 3, 0, 0], # 文A: 位置3,4がパディング
])
pad_token_id = 0
decoder_mask = create_decoder_mask(input_ids, pad_token_id)
print("Decoder用複合マスク:")
print(decoder_mask[0, 0])
出力:
Decoder用複合マスク:
tensor([[0., -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf],
[0., 0., 0., -inf, -inf],
[0., 0., 0., -inf, -inf],
[0., 0., 0., -inf, -inf]])
アテンション計算での使用
import math
class MaskedSelfAttention(torch.nn.Module):
"""マスク付きSelf-Attention"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, d_model)
mask: (batch_size, 1, seq_len, seq_len) または (1, 1, seq_len, seq_len)
Returns:
output: (batch_size, seq_len, d_model)
"""
batch_size, seq_len, _ = x.shape
# Q, K, V を計算
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# マルチヘッドに分割
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# (batch_size, n_heads, seq_len, d_k)
# アテンションスコア
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# (batch_size, n_heads, seq_len, seq_len)
# マスクを適用
if mask is not None:
scores = scores + mask
# softmax
attn_weights = F.softmax(scores, dim=-1)
# Valueの加重和
context = torch.matmul(attn_weights, V)
# (batch_size, n_heads, seq_len, d_k)
# ヘッドを結合
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
output = self.W_o(context)
return output
# 使用例
d_model = 64
n_heads = 4
batch_size = 2
seq_len = 5
pad_token_id = 0
# モデル
attention = MaskedSelfAttention(d_model, n_heads)
# 入力
torch.manual_seed(42)
input_ids = torch.tensor([
[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
])
x = torch.randn(batch_size, seq_len, d_model)
# マスク
mask = create_decoder_mask(input_ids, pad_token_id)
# 順伝播
output = attention(x, mask)
print(f"入力形状: {x.shape}")
print(f"マスク形状: {mask.shape}")
print(f"出力形状: {output.shape}")
可視化
Causalマスクの効果
import matplotlib.pyplot as plt
import numpy as np
# Causalマスクの可視化
seq_len = 8
# マスクなしのアテンション(仮想的な値)
np.random.seed(42)
attn_no_mask = np.random.rand(seq_len, seq_len)
attn_no_mask = attn_no_mask / attn_no_mask.sum(axis=1, keepdims=True)
# Causalマスク付きのアテンション
causal_mask = np.triu(np.ones((seq_len, seq_len)), k=1)
attn_with_mask = attn_no_mask.copy()
attn_with_mask[causal_mask == 1] = 0
attn_with_mask = attn_with_mask / attn_with_mask.sum(axis=1, keepdims=True)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Causalマスク
ax = axes[0]
im = ax.imshow(causal_mask, cmap='Reds', vmin=0, vmax=1)
ax.set_title('Causal Mask\n(1 = masked)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)
# マスクなしアテンション
ax = axes[1]
im = ax.imshow(attn_no_mask, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('Attention (No Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)
# マスク付きアテンション
ax = axes[2]
im = ax.imshow(attn_with_mask, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('Attention (With Causal Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
パディングマスクの効果
# パディングマスクの可視化
seq_len = 6
pad_start = 4 # 位置4,5がパディング
# パディングマスク
pad_mask = np.zeros((seq_len, seq_len))
pad_mask[:, pad_start:] = 1
# アテンション重み(パディングなし)
np.random.seed(123)
attn_no_pad = np.random.rand(seq_len, seq_len)
attn_no_pad = attn_no_pad / attn_no_pad.sum(axis=1, keepdims=True)
# アテンション重み(パディングマスク適用)
attn_with_pad = attn_no_pad.copy()
attn_with_pad[:, pad_start:] = 0
attn_with_pad = attn_with_pad / attn_with_pad.sum(axis=1, keepdims=True)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# パディングマスク
ax = axes[0]
im = ax.imshow(pad_mask, cmap='Oranges', vmin=0, vmax=1)
ax.set_title('Padding Mask\n(1 = padding)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)
# マスクなしアテンション
ax = axes[1]
im = ax.imshow(attn_no_pad, cmap='Greens', vmin=0, vmax=0.4)
ax.set_title('Attention (No Padding Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)
# マスク付きアテンション
ax = axes[2]
im = ax.imshow(attn_with_pad, cmap='Greens', vmin=0, vmax=0.4)
ax.set_title('Attention (With Padding Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
Encoder-Decoderのマスク
Encoderのマスク
Encoderでは、入力系列全体を双方向に参照できるため、Causalマスクは不要です。パディングマスクのみを適用します。
Cross-Attentionのマスク
Decoderの Cross-Attention(Encoder-Decoder Attention)では: – Query: Decoderの出力 – Key/Value: Encoderの出力
マスクは Encoderの入力に対するパディングマスクを使用します。Decoderの各位置から、Encoderのパディング位置への注意を遮断します。
def create_cross_attention_mask(
encoder_input_ids: torch.Tensor,
decoder_seq_len: int,
pad_token_id: int,
) -> torch.Tensor:
"""
Cross-Attention用のパディングマスクを生成
Args:
encoder_input_ids: Encoderの入力トークンID (batch_size, src_len)
decoder_seq_len: Decoderの系列長
pad_token_id: パディングトークンID
Returns:
mask: (batch_size, 1, decoder_seq_len, src_len)
"""
batch_size, src_len = encoder_input_ids.shape
# Encoderのパディング位置
padding_positions = (encoder_input_ids == pad_token_id) # (batch_size, src_len)
# Decoder の全位置から参照するので、decoder_seq_len 方向に拡張
mask = padding_positions.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, src_len)
mask = mask.expand(-1, -1, decoder_seq_len, -1) # (batch_size, 1, decoder_seq_len, src_len)
# -inf/0 に変換
mask = mask.float().masked_fill(mask == 1, float('-inf'))
return mask
まとめ
本記事では、Transformerのアテンションマスクについて解説しました。
- Causalマスク: 未来の情報を遮断し、自己回帰生成を可能にする。上三角部分が $-\infty$ の行列
- パディングマスク: パディングトークンへの注意を遮断し、可変長系列のバッチ処理を可能にする
- 複合マスク: DecoderではCausalマスクとパディングマスクを組み合わせて使用
- Cross-Attentionのマスク: Encoderのパディング位置への注意を遮断
マスクの正しい実装は、Transformerの動作において非常に重要です。特に、$-\infty$ の代わりに大きな負の値(例:$-10000$)を使う実装もありますが、数値的な安定性に注意が必要です。
次のステップとして、以下の記事も参考にしてください。