アテンションマスクの種類と実装を完全理解する

アテンションマスク(Attention Mask)は、Transformerにおいて特定のトークンへの注意を制御するための仕組みです。主に2種類のマスクが使われます。Causalマスク(因果マスク)は未来の情報を遮断し、パディングマスクは可変長系列を処理する際にパディングトークンを無視します。

これらのマスクを正しく理解し実装することは、Transformerを使った開発において必須のスキルです。本記事では、各マスクの目的、数式、実装方法を詳しく解説します。

本記事の内容

  • Causalマスク(因果マスク)の仕組みと必要性
  • パディングマスクの仕組み
  • マスクの数学的定義
  • PyTorchでの実装
  • 複合マスクの作成方法

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

アテンションスコアの計算

復習:Scaled Dot-Product Attention

アテンションの計算を復習しましょう。Query $\bm{Q}$、Key $\bm{K}$、Value $\bm{V}$ に対して、

$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$

ここで、$\bm{Q}\bm{K}^\top$ はアテンションスコア行列 $\bm{S} \in \mathbb{R}^{n \times n}$ を形成します。$S_{ij}$ は位置 $i$ のQueryが位置 $j$ のKeyにどれだけ注目するかを表します。

マスクの役割

マスクは、softmax計算前にアテンションスコアを修正することで、特定の位置への注意を遮断します。

$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}, \bm{M}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}\right)\bm{V} $$

マスク行列 $\bm{M}$ は、注意を許可する位置には $0$、遮断する位置には $-\infty$ を持ちます。softmaxの性質により、$-\infty$ が入力された位置の出力は $0$ になります。

$$ \text{softmax}(-\infty) = \frac{e^{-\infty}}{\sum_j e^{s_j}} = \frac{0}{\sum_j e^{s_j}} = 0 $$

Causalマスク(因果マスク)

目的

Causalマスク(因果マスク)は、未来の情報を遮断するために使用されます。自己回帰生成モデル(GPTなど)では、時刻 $t$ のトークンを生成する際に、時刻 $t+1, t+2, \ldots$ のトークンを参照してはいけません。

これは因果律(causality)に基づきます。過去が未来に影響を与えることはあっても、未来が過去に影響を与えることはありません。

数学的定義

Causalマスク行列 $\bm{M}_{\text{causal}}$ は、下三角行列の形をとります。

$$ M_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases} $$

系列長 $n = 4$ の場合のマスク行列:

$$ \bm{M}_{\text{causal}} = \begin{pmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{pmatrix} $$

この行列を見ると: – 1行目(位置1): 自分自身のみ参照可能 – 2行目(位置2): 位置1と自分自身を参照可能 – 3行目(位置3): 位置1、2、3を参照可能 – 4行目(位置4): 全位置を参照可能

softmax適用後のアテンション重み

マスクを適用した後のアテンション重み行列は、下三角行列になります。

$$ \bm{A} = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}_{\text{causal}}\right) $$

例えば、スコア行列が以下の場合:

$$ \frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} = \begin{pmatrix} 1.0 & 0.5 & 0.3 & 0.2 \\ 0.8 & 1.2 & 0.6 & 0.4 \\ 0.5 & 0.7 & 1.1 & 0.9 \\ 0.6 & 0.4 & 0.8 & 1.0 \end{pmatrix} $$

マスク適用後:

$$ \frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}} + \bm{M}_{\text{causal}} = \begin{pmatrix} 1.0 & -\infty & -\infty & -\infty \\ 0.8 & 1.2 & -\infty & -\infty \\ 0.5 & 0.7 & 1.1 & -\infty \\ 0.6 & 0.4 & 0.8 & 1.0 \end{pmatrix} $$

1行目に softmax を適用すると: $$ \text{softmax}([1.0, -\infty, -\infty, -\infty]) = [1.0, 0, 0, 0] $$

パディングマスク

目的

実際のシステムでは、長さの異なる系列をバッチ処理するため、短い系列にはパディングトークンが追加されます。しかし、パディングトークンは実際の内容を持たないため、アテンション計算で無視する必要があります。

2つの文をバッチ処理する場合: – 文A: “I love cats” (3トークン) – 文B: “Hello” (1トークン)

パディング後(最大長3): – 文A: [“I”, “love”, “cats”] – 文B: [“Hello”, ““, ““]

数学的定義

パディングマスクは、パディング位置を示すベクトルから生成されます。

パディングマスクベクトル $\bm{m}_{\text{pad}} \in \{0, 1\}^n$: $$ m_{\text{pad}, i} = \begin{cases} 0 & \text{if position } i \text{ is a real token} \\ 1 & \text{if position } i \text{ is a padding token} \end{cases} $$

アテンション計算では、Key側のパディング位置を無視します。マスク行列 $\bm{M}_{\text{pad}}$ は:

$$ M_{\text{pad}, ij} = \begin{cases} 0 & \text{if } m_{\text{pad}, j} = 0 \\ -\infty & \text{if } m_{\text{pad}, j} = 1 \end{cases} $$

これは、すべてのQuery位置に対して、パディングされたKey位置への注意を遮断します。

文B: [“Hello”, ““, ““] のパディングマスク:

$$ \bm{m}_{\text{pad}} = [0, 1, 1] $$

アテンションマスク行列:

$$ \bm{M}_{\text{pad}} = \begin{pmatrix} 0 & -\infty & -\infty \\ 0 & -\infty & -\infty \\ 0 & -\infty & -\infty \end{pmatrix} $$

すべての行で、パディング位置(列2、3)が $-\infty$ になっています。

PyTorchでの実装

Causalマスクの生成

import torch
import torch.nn.functional as F


def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
    """
    Causalマスク(因果マスク)を生成

    Args:
        seq_len: 系列長
        device: デバイス

    Returns:
        mask: (seq_len, seq_len) のマスク行列
              参照可能な位置は 0、参照不可の位置は -inf
    """
    # 上三角行列を作成(対角線を含まない)
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    # True/False を -inf/0 に変換
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask


# 使用例
seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causalマスク:")
print(causal_mask)

出力:

Causalマスク:
tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

パディングマスクの生成

def create_padding_mask(
    input_ids: torch.Tensor,
    pad_token_id: int,
) -> torch.Tensor:
    """
    パディングマスクを生成

    Args:
        input_ids: トークンID列 (batch_size, seq_len)
        pad_token_id: パディングトークンのID

    Returns:
        mask: (batch_size, 1, 1, seq_len) のマスク
              実トークンは 0、パディングは -inf
    """
    # パディング位置を特定
    padding_positions = (input_ids == pad_token_id)  # (batch_size, seq_len)

    # アテンション計算用に次元を拡張
    # (batch_size, 1, 1, seq_len) -> ブロードキャストで (batch_size, n_heads, seq_len, seq_len) に対応
    mask = padding_positions.unsqueeze(1).unsqueeze(2)

    # True/False を -inf/0 に変換
    mask = mask.float().masked_fill(mask == 1, float('-inf'))

    return mask


# 使用例
# バッチ2つ、系列長5、パディングID = 0
input_ids = torch.tensor([
    [1, 2, 3, 4, 5],     # 文A: パディングなし
    [1, 2, 0, 0, 0],     # 文B: 位置2,3,4がパディング
])
pad_token_id = 0

padding_mask = create_padding_mask(input_ids, pad_token_id)
print("パディングマスクの形状:", padding_mask.shape)
print("文Bのパディングマスク:")
print(padding_mask[1, 0, 0, :])

出力:

パディングマスクの形状: torch.Size([2, 1, 1, 5])
文Bのパディングマスク:
tensor([0., 0., -inf, -inf, -inf])

複合マスクの作成

Decoderでは、Causalマスクとパディングマスクを組み合わせて使用します。

def create_decoder_mask(
    input_ids: torch.Tensor,
    pad_token_id: int,
) -> torch.Tensor:
    """
    Decoder用の複合マスクを生成(Causal + Padding)

    Args:
        input_ids: トークンID列 (batch_size, seq_len)
        pad_token_id: パディングトークンのID

    Returns:
        mask: (batch_size, 1, seq_len, seq_len) のマスク
    """
    batch_size, seq_len = input_ids.shape
    device = input_ids.device

    # Causalマスク (1, 1, seq_len, seq_len)
    causal_mask = create_causal_mask(seq_len, device=device)
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

    # パディングマスク (batch_size, 1, 1, seq_len)
    padding_mask = create_padding_mask(input_ids, pad_token_id)

    # 複合マスク(両方の条件を満たす位置のみ参照可能)
    # -inf + 0 = -inf, -inf + -inf = -inf, 0 + 0 = 0
    combined_mask = causal_mask + padding_mask

    return combined_mask


# 使用例
input_ids = torch.tensor([
    [1, 2, 3, 0, 0],  # 文A: 位置3,4がパディング
])
pad_token_id = 0

decoder_mask = create_decoder_mask(input_ids, pad_token_id)
print("Decoder用複合マスク:")
print(decoder_mask[0, 0])

出力:

Decoder用複合マスク:
tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., -inf, -inf]])

アテンション計算での使用

import math


class MaskedSelfAttention(torch.nn.Module):
    """マスク付きSelf-Attention"""

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model)
            mask: (batch_size, 1, seq_len, seq_len) または (1, 1, seq_len, seq_len)

        Returns:
            output: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape

        # Q, K, V を計算
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # マルチヘッドに分割
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        # (batch_size, n_heads, seq_len, d_k)

        # アテンションスコア
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # (batch_size, n_heads, seq_len, seq_len)

        # マスクを適用
        if mask is not None:
            scores = scores + mask

        # softmax
        attn_weights = F.softmax(scores, dim=-1)

        # Valueの加重和
        context = torch.matmul(attn_weights, V)
        # (batch_size, n_heads, seq_len, d_k)

        # ヘッドを結合
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        output = self.W_o(context)
        return output


# 使用例
d_model = 64
n_heads = 4
batch_size = 2
seq_len = 5
pad_token_id = 0

# モデル
attention = MaskedSelfAttention(d_model, n_heads)

# 入力
torch.manual_seed(42)
input_ids = torch.tensor([
    [1, 2, 3, 4, 5],
    [1, 2, 3, 0, 0],
])
x = torch.randn(batch_size, seq_len, d_model)

# マスク
mask = create_decoder_mask(input_ids, pad_token_id)

# 順伝播
output = attention(x, mask)
print(f"入力形状: {x.shape}")
print(f"マスク形状: {mask.shape}")
print(f"出力形状: {output.shape}")

可視化

Causalマスクの効果

import matplotlib.pyplot as plt
import numpy as np

# Causalマスクの可視化
seq_len = 8

# マスクなしのアテンション(仮想的な値)
np.random.seed(42)
attn_no_mask = np.random.rand(seq_len, seq_len)
attn_no_mask = attn_no_mask / attn_no_mask.sum(axis=1, keepdims=True)

# Causalマスク付きのアテンション
causal_mask = np.triu(np.ones((seq_len, seq_len)), k=1)
attn_with_mask = attn_no_mask.copy()
attn_with_mask[causal_mask == 1] = 0
attn_with_mask = attn_with_mask / attn_with_mask.sum(axis=1, keepdims=True)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Causalマスク
ax = axes[0]
im = ax.imshow(causal_mask, cmap='Reds', vmin=0, vmax=1)
ax.set_title('Causal Mask\n(1 = masked)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)

# マスクなしアテンション
ax = axes[1]
im = ax.imshow(attn_no_mask, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('Attention (No Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)

# マスク付きアテンション
ax = axes[2]
im = ax.imshow(attn_with_mask, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('Attention (With Causal Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

パディングマスクの効果

# パディングマスクの可視化
seq_len = 6
pad_start = 4  # 位置4,5がパディング

# パディングマスク
pad_mask = np.zeros((seq_len, seq_len))
pad_mask[:, pad_start:] = 1

# アテンション重み(パディングなし)
np.random.seed(123)
attn_no_pad = np.random.rand(seq_len, seq_len)
attn_no_pad = attn_no_pad / attn_no_pad.sum(axis=1, keepdims=True)

# アテンション重み(パディングマスク適用)
attn_with_pad = attn_no_pad.copy()
attn_with_pad[:, pad_start:] = 0
attn_with_pad = attn_with_pad / attn_with_pad.sum(axis=1, keepdims=True)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# パディングマスク
ax = axes[0]
im = ax.imshow(pad_mask, cmap='Oranges', vmin=0, vmax=1)
ax.set_title('Padding Mask\n(1 = padding)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)

# マスクなしアテンション
ax = axes[1]
im = ax.imshow(attn_no_pad, cmap='Greens', vmin=0, vmax=0.4)
ax.set_title('Attention (No Padding Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)

# マスク付きアテンション
ax = axes[2]
im = ax.imshow(attn_with_pad, cmap='Greens', vmin=0, vmax=0.4)
ax.set_title('Attention (With Padding Mask)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.axvline(x=pad_start - 0.5, color='red', linestyle='--', linewidth=2)
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

Encoder-Decoderのマスク

Encoderのマスク

Encoderでは、入力系列全体を双方向に参照できるため、Causalマスクは不要です。パディングマスクのみを適用します。

Cross-Attentionのマスク

Decoderの Cross-Attention(Encoder-Decoder Attention)では: – Query: Decoderの出力 – Key/Value: Encoderの出力

マスクは Encoderの入力に対するパディングマスクを使用します。Decoderの各位置から、Encoderのパディング位置への注意を遮断します。

def create_cross_attention_mask(
    encoder_input_ids: torch.Tensor,
    decoder_seq_len: int,
    pad_token_id: int,
) -> torch.Tensor:
    """
    Cross-Attention用のパディングマスクを生成

    Args:
        encoder_input_ids: Encoderの入力トークンID (batch_size, src_len)
        decoder_seq_len: Decoderの系列長
        pad_token_id: パディングトークンID

    Returns:
        mask: (batch_size, 1, decoder_seq_len, src_len)
    """
    batch_size, src_len = encoder_input_ids.shape

    # Encoderのパディング位置
    padding_positions = (encoder_input_ids == pad_token_id)  # (batch_size, src_len)

    # Decoder の全位置から参照するので、decoder_seq_len 方向に拡張
    mask = padding_positions.unsqueeze(1).unsqueeze(1)  # (batch_size, 1, 1, src_len)
    mask = mask.expand(-1, -1, decoder_seq_len, -1)  # (batch_size, 1, decoder_seq_len, src_len)

    # -inf/0 に変換
    mask = mask.float().masked_fill(mask == 1, float('-inf'))

    return mask

まとめ

本記事では、Transformerのアテンションマスクについて解説しました。

  • Causalマスク: 未来の情報を遮断し、自己回帰生成を可能にする。上三角部分が $-\infty$ の行列
  • パディングマスク: パディングトークンへの注意を遮断し、可変長系列のバッチ処理を可能にする
  • 複合マスク: DecoderではCausalマスクとパディングマスクを組み合わせて使用
  • Cross-Attentionのマスク: Encoderのパディング位置への注意を遮断

マスクの正しい実装は、Transformerの動作において非常に重要です。特に、$-\infty$ の代わりに大きな負の値(例:$-10000$)を使う実装もありますが、数値的な安定性に注意が必要です。

次のステップとして、以下の記事も参考にしてください。