Flash Attentionの仕組み — IO-Aware なアテンション高速化

Flash Attention は、Transformer のアテンション計算を大幅に高速化するアルゴリズムです。Dao et al. (2022) によって提案され、従来の実装と比較して2-4倍の高速化と、大幅なメモリ削減を実現します。現在、ほぼすべての最新LLMの推論・学習で使用されています。

Flash Attention の核心的アイデアは、IO-awareness(入出力を意識した最適化)です。計算量ではなく、GPUのメモリアクセスパターンを最適化することで、実際の実行速度を向上させます。本記事では、Flash Attention の背景から アルゴリズムまで解説します。

本記事の内容

  • 標準アテンションのボトルネック
  • GPUメモリ階層とIO-awareness
  • Flash Attentionのアルゴリズム(タイリング)
  • オンラインsoftmaxの仕組み
  • 速度とメモリ使用量の比較

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

標準アテンションのボトルネック

計算フロー

標準的なアテンション計算は以下の手順で行われます。

$$ \bm{S} = \bm{Q}\bm{K}^\top \quad \in \mathbb{R}^{N \times N} $$

$$ \bm{P} = \text{softmax}\left(\frac{\bm{S}}{\sqrt{d}}\right) \quad \in \mathbb{R}^{N \times N} $$

$$ \bm{O} = \bm{P}\bm{V} \quad \in \mathbb{R}^{N \times d} $$

ここで $N$ は系列長、$d$ はヘッド次元です。

問題点:メモリ使用量

中間行列 $\bm{S}$ と $\bm{P}$ は $N \times N$ のサイズを持ちます。系列長 $N = 8192$ の場合:

$$ \text{Memory} = 2 \times N^2 \times \text{sizeof(float16)} = 2 \times 8192^2 \times 2 \text{ bytes} = 256 \text{ MB} $$

これはバッチサイズ1、ヘッド1つの場合です。実際のモデルでは、バッチサイズとヘッド数を掛けた分だけメモリが必要になります。

問題点:メモリアクセス

GPUの計算速度は非常に高速ですが、メモリアクセスがボトルネックになることがあります。標準アテンションでは:

  1. HBM(High Bandwidth Memory)から $\bm{Q}, \bm{K}$ を読み込み
  2. $\bm{S} = \bm{Q}\bm{K}^\top$ を計算し、HBMに書き込み
  3. HBMから $\bm{S}$ を読み込み
  4. softmaxを計算し、$\bm{P}$ をHBMに書き込み
  5. HBMから $\bm{P}, \bm{V}$ を読み込み
  6. $\bm{O} = \bm{P}\bm{V}$ を計算

$N \times N$ の行列を何度も読み書きするため、メモリ帯域が制限要因になります。

GPUメモリ階層

メモリ階層の理解

GPUには複数レベルのメモリ階層があります。

メモリ種類 容量 帯域幅 用途
HBM(グローバルメモリ) 40-80 GB 1.5-3 TB/s 大規模データ
L2キャッシュ 40-50 MB 〜5 TB/s キャッシュ
SRAM(共有メモリ) 20-200 KB/SM 〜20 TB/s 高速アクセス
レジスタ 数百KB全体 最高速 演算

HBMは容量が大きいですが、アクセスが遅いです。SRAMは非常に高速ですが、容量が限られています。

IO-awareness

Flash Attention の key insight は、計算量ではなくメモリ転送量を最小化することです。

標準アテンションの計算量: $O(N^2 d)$ 標準アテンションのメモリ転送量: $O(N^2 + Nd)$

現代のGPUでは、メモリ転送が計算より遅いことが多いため、メモリ転送を減らすことが高速化につながります。

Flash Attentionのアルゴリズム

タイリング(Tiling)

Flash Attention はタイリングという手法を使います。$\bm{Q}, \bm{K}, \bm{V}$ を小さなブロック(タイル)に分割し、各ブロックをSRAMに載せて計算します。

ブロックサイズを $B_r$(Query方向)と $B_c$(Key/Value方向)とします。

Q を (N/B_r) 個のブロックに分割: Q_1, Q_2, ..., Q_{N/B_r}
K を (N/B_c) 個のブロックに分割: K_1, K_2, ..., K_{N/B_c}
V を (N/B_c) 個のブロックに分割: V_1, V_2, ..., V_{N/B_c}

アルゴリズムの流れ

for i = 1 to N/B_r:
    Q_i を HBM から SRAM にロード
    O_i = 0, m_i = -inf, l_i = 0 を初期化

    for j = 1 to N/B_c:
        K_j, V_j を HBM から SRAM にロード

        # ブロック単位でアテンション計算
        S_ij = Q_i @ K_j^T / sqrt(d)

        # オンラインsoftmaxの更新
        m_new = max(m_i, rowmax(S_ij))
        l_new = exp(m_i - m_new) * l_i + rowsum(exp(S_ij - m_new))
        O_i = exp(m_i - m_new) * O_i + exp(S_ij - m_new) @ V_j
        m_i, l_i = m_new, l_new

    # 最終正規化
    O_i = O_i / l_i
    O_i を HBM に書き込み

重要なのは、$N \times N$ の中間行列をHBMに保存しないことです。各ブロックの計算はSRAM内で完結し、最終結果のみをHBMに書き込みます。

オンラインsoftmax

通常のsoftmaxは、全要素を見てからでないと計算できません。

$$ \text{softmax}(\bm{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}} $$

しかし、Flash Attention はブロックごとに処理するため、オンラインsoftmaxが必要です。

オンラインsoftmaxの導出

ベクトル $\bm{x}$ を2つのブロック $\bm{x}^{(1)}, \bm{x}^{(2)}$ に分割した場合を考えます。

各ブロックの最大値と正規化係数: $$ m^{(1)} = \max(\bm{x}^{(1)}), \quad l^{(1)} = \sum_j e^{x_j^{(1)} – m^{(1)}} $$

$$ m^{(2)} = \max(\bm{x}^{(2)}), \quad l^{(2)} = \sum_j e^{x_j^{(2)} – m^{(2)}} $$

全体の最大値と正規化係数: $$ m = \max(m^{(1)}, m^{(2)}) $$

$$ l = e^{m^{(1)} – m} l^{(1)} + e^{m^{(2)} – m} l^{(2)} $$

最大値を引くのは数値安定性のためです。$e^{x}$ は $x$ が大きいとオーバーフローするため、最大値を引いて安定化します。

出力の増分更新

Valueの加重和も同様に増分更新できます。

ブロック1の処理後: $$ \bm{o}^{(1)} = \text{diag}(l^{(1)})^{-1} \exp(\bm{S}^{(1)} – m^{(1)}) \bm{V}^{(1)} $$

ブロック2を処理するとき、ブロック1の結果を更新: $$ \bm{o} = \text{diag}(l)^{-1} \left( e^{m^{(1)} – m} l^{(1)} \bm{o}^{(1)} + \exp(\bm{S}^{(2)} – m) \bm{V}^{(2)} \right) $$

メモリ使用量の分析

標準アテンション

$$ O(N^2 + Nd) $$

中間行列 $\bm{S}, \bm{P}$ が $N^2$ を占めます。

Flash Attention

$$ O(Nd) $$

$N \times N$ の中間行列を保存しないため、メモリ使用量は入出力のサイズのみ。

具体例

系列長 $N = 8192$、ヘッド次元 $d = 64$、float16の場合:

手法 メモリ使用量
標準アテンション $\approx 256$ MB
Flash Attention $\approx 2$ MB

100倍以上のメモリ削減です。

速度の改善

IO複雑度の比較

標準アテンションのHBMアクセス: $$ O(N^2 + Nd) $$

Flash AttentionのHBMアクセス: $$ O\left(\frac{N^2 d}{M}\right) $$

ここで $M$ はSRAMのサイズです。SRAMが大きいほど、HBMアクセスが減ります。

実測値

Flash Attention 2 の論文によると、A100 GPUでの性能:

系列長 標準アテンション Flash Attention 2 高速化率
512 39 TFLOPs 122 TFLOPs 3.1x
1024 42 TFLOPs 149 TFLOPs 3.5x
2048 43 TFLOPs 170 TFLOPs 4.0x
4096 44 TFLOPs 187 TFLOPs 4.2x

系列長が長いほど、Flash Attention の効果が大きくなります。

擬似コードによる理解

標準アテンション

def standard_attention(Q, K, V):
    """
    標準的なアテンション計算
    Q, K, V: (N, d)
    """
    # S = QK^T を計算し、HBMに保存
    S = Q @ K.T / sqrt(d)  # (N, N) をHBMに保存

    # softmaxを計算し、HBMに保存
    P = softmax(S, dim=-1)  # (N, N) をHBMに保存

    # 出力を計算
    O = P @ V  # (N, d)

    return O

Flash Attention(概念的な擬似コード)

def flash_attention(Q, K, V, B_r, B_c):
    """
    Flash Attention(概念的な擬似コード)
    Q, K, V: (N, d)
    B_r, B_c: ブロックサイズ
    """
    N, d = Q.shape
    O = zeros(N, d)

    for i in range(0, N, B_r):
        # Queryブロック
        Q_i = Q[i:i+B_r]  # SRAMにロード
        m_i = full(B_r, -inf)
        l_i = zeros(B_r)
        O_i = zeros(B_r, d)

        for j in range(0, N, B_c):
            # Key/Valueブロック
            K_j = K[j:j+B_c]  # SRAMにロード
            V_j = V[j:j+B_c]  # SRAMにロード

            # ブロック内でアテンション計算(SRAMで完結)
            S_ij = Q_i @ K_j.T / sqrt(d)  # (B_r, B_c)

            # オンラインsoftmax更新
            m_new = maximum(m_i, S_ij.max(dim=-1))
            P_ij = exp(S_ij - m_new.unsqueeze(-1))
            l_new = exp(m_i - m_new) * l_i + P_ij.sum(dim=-1)
            O_i = exp(m_i - m_new).unsqueeze(-1) * O_i + P_ij @ V_j

            m_i = m_new
            l_i = l_new

        # 正規化してHBMに書き込み
        O[i:i+B_r] = O_i / l_i.unsqueeze(-1)

    return O

重要なのは、$(N, N)$ サイズの行列が一度もHBMに書き込まれないことです。

Flash Attention 2 の改善

Flash Attention 2 では、さらなる最適化が行われました。

並列化の改善

  • Flash Attention 1: Keyブロックを順次処理
  • Flash Attention 2: 複数のブロックを並列処理

ワークの分割

GPU上のワーカー(ワープ)間でより効率的にワークを分割し、idle時間を削減。

因果マスクの最適化

Causalマスクがある場合、計算が不要なブロックをスキップ。

PyTorchでの使用

torch.nn.functional.scaled_dot_product_attention

PyTorch 2.0以降、Flash Attention はネイティブでサポートされています。

import torch
import torch.nn.functional as F

# パラメータ
batch_size = 2
n_heads = 8
seq_len = 1024
head_dim = 64

# 入力
Q = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')
K = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')
V = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')

# Flash Attentionが自動的に使用される(条件を満たす場合)
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

print(f"出力形状: {output.shape}")  # (2, 8, 1024, 64)

使用条件

Flash Attention が使用されるための条件: – CUDAデバイス – 対応するGPU(Ampere以降推奨) – 適切なデータ型(float16, bfloat16)

flash-attn ライブラリ

より高度な機能(可変長バッチなど)には flash-attn ライブラリを使用できます。

# pip install flash-attn
from flash_attn import flash_attn_func

# 入力形状: (batch, seqlen, nheads, headdim)
q = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)

output = flash_attn_func(q, k, v, causal=True)

可視化

メモリアクセスパターンの比較

import matplotlib.pyplot as plt
import numpy as np

# 概念的なメモリアクセスの可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 標準アテンション
ax = axes[0]
N = 8
access_pattern = np.ones((N, N))
im = ax.imshow(access_pattern, cmap='Reds', alpha=0.7)
ax.set_title('Standard Attention\n(Full N x N matrix in HBM)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_xticks(range(N))
ax.set_yticks(range(N))

# Flash Attention(タイリング)
ax = axes[1]
B = 2  # ブロックサイズ
access_pattern = np.zeros((N, N))

# 各ステップでアクセスされるブロックを示す
colors = plt.cm.Set3(np.linspace(0, 1, (N//B)**2))
for step, (i, j) in enumerate([(0, 0), (0, 2), (0, 4), (0, 6),
                                (2, 0), (2, 2), (2, 4), (2, 6)]):
    for di in range(B):
        for dj in range(B):
            if i + di < N and j + dj < N:
                access_pattern[i + di, j + dj] = step + 1

im = ax.imshow(access_pattern, cmap='tab20', alpha=0.7)
ax.set_title('Flash Attention\n(Blocks processed sequentially in SRAM)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_xticks(range(N))
ax.set_yticks(range(N))

# グリッドを追加
for i in range(0, N + 1, B):
    ax.axhline(i - 0.5, color='black', linewidth=2)
    ax.axvline(i - 0.5, color='black', linewidth=2)

plt.tight_layout()
plt.show()

速度比較(概念図)

import matplotlib.pyplot as plt
import numpy as np

seq_lengths = [512, 1024, 2048, 4096, 8192]

# 概念的な速度データ(実際の測定値を模した)
standard_time = [1.0, 4.2, 17.5, 72.0, 290.0]  # O(N^2) に近い成長
flash_time = [0.5, 1.1, 2.5, 5.5, 12.0]  # より緩やかな成長

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(seq_lengths, standard_time, 'o-', linewidth=2, markersize=8,
        label='Standard Attention', color='coral')
ax.plot(seq_lengths, flash_time, 's-', linewidth=2, markersize=8,
        label='Flash Attention', color='steelblue')

ax.set_xlabel('Sequence Length', fontsize=12)
ax.set_ylabel('Relative Time (lower is better)', fontsize=12)
ax.set_title('Speed Comparison: Standard vs Flash Attention', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# 高速化率を注釈
for i, (s, f, sl) in enumerate(zip(standard_time, flash_time, seq_lengths)):
    speedup = s / f
    ax.annotate(f'{speedup:.1f}x', xy=(sl, f), xytext=(sl + 200, f * 0.7),
                fontsize=9, ha='left')

plt.tight_layout()
plt.show()

まとめ

本記事では、Flash Attention のアルゴリズムと原理について解説しました。

  • 問題: 標準アテンションは $N \times N$ の中間行列をHBMに保存するため、メモリとIOがボトルネック
  • IO-awareness: 計算量ではなくメモリ転送量を最適化するという発想
  • タイリング: Q, K, V を小さなブロックに分割し、SRAMで処理
  • オンラインsoftmax: ブロックごとに処理しながら正しいsoftmaxを計算
  • 効果: メモリ使用量を $O(N^2)$ から $O(N)$ に削減、速度2-4倍向上

Flash Attention は現代のLLMにおいて必須の技術です。PyTorch 2.0以降では scaled_dot_product_attention で自動的に使用されるため、直接実装する必要はありませんが、原理を理解しておくことで、より効果的なモデル設計が可能になります。

次のステップとして、以下の記事も参考にしてください。