Flash Attention は、Transformer のアテンション計算を大幅に高速化するアルゴリズムです。Dao et al. (2022) によって提案され、従来の実装と比較して2-4倍の高速化と、大幅なメモリ削減を実現します。現在、ほぼすべての最新LLMの推論・学習で使用されています。
Flash Attention の核心的アイデアは、IO-awareness(入出力を意識した最適化)です。計算量ではなく、GPUのメモリアクセスパターンを最適化することで、実際の実行速度を向上させます。本記事では、Flash Attention の背景から アルゴリズムまで解説します。
本記事の内容
- 標準アテンションのボトルネック
- GPUメモリ階層とIO-awareness
- Flash Attentionのアルゴリズム(タイリング)
- オンラインsoftmaxの仕組み
- 速度とメモリ使用量の比較
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
標準アテンションのボトルネック
計算フロー
標準的なアテンション計算は以下の手順で行われます。
$$ \bm{S} = \bm{Q}\bm{K}^\top \quad \in \mathbb{R}^{N \times N} $$
$$ \bm{P} = \text{softmax}\left(\frac{\bm{S}}{\sqrt{d}}\right) \quad \in \mathbb{R}^{N \times N} $$
$$ \bm{O} = \bm{P}\bm{V} \quad \in \mathbb{R}^{N \times d} $$
ここで $N$ は系列長、$d$ はヘッド次元です。
問題点:メモリ使用量
中間行列 $\bm{S}$ と $\bm{P}$ は $N \times N$ のサイズを持ちます。系列長 $N = 8192$ の場合:
$$ \text{Memory} = 2 \times N^2 \times \text{sizeof(float16)} = 2 \times 8192^2 \times 2 \text{ bytes} = 256 \text{ MB} $$
これはバッチサイズ1、ヘッド1つの場合です。実際のモデルでは、バッチサイズとヘッド数を掛けた分だけメモリが必要になります。
問題点:メモリアクセス
GPUの計算速度は非常に高速ですが、メモリアクセスがボトルネックになることがあります。標準アテンションでは:
- HBM(High Bandwidth Memory)から $\bm{Q}, \bm{K}$ を読み込み
- $\bm{S} = \bm{Q}\bm{K}^\top$ を計算し、HBMに書き込み
- HBMから $\bm{S}$ を読み込み
- softmaxを計算し、$\bm{P}$ をHBMに書き込み
- HBMから $\bm{P}, \bm{V}$ を読み込み
- $\bm{O} = \bm{P}\bm{V}$ を計算
$N \times N$ の行列を何度も読み書きするため、メモリ帯域が制限要因になります。
GPUメモリ階層
メモリ階層の理解
GPUには複数レベルのメモリ階層があります。
| メモリ種類 | 容量 | 帯域幅 | 用途 |
|---|---|---|---|
| HBM(グローバルメモリ) | 40-80 GB | 1.5-3 TB/s | 大規模データ |
| L2キャッシュ | 40-50 MB | 〜5 TB/s | キャッシュ |
| SRAM(共有メモリ) | 20-200 KB/SM | 〜20 TB/s | 高速アクセス |
| レジスタ | 数百KB全体 | 最高速 | 演算 |
HBMは容量が大きいですが、アクセスが遅いです。SRAMは非常に高速ですが、容量が限られています。
IO-awareness
Flash Attention の key insight は、計算量ではなくメモリ転送量を最小化することです。
標準アテンションの計算量: $O(N^2 d)$ 標準アテンションのメモリ転送量: $O(N^2 + Nd)$
現代のGPUでは、メモリ転送が計算より遅いことが多いため、メモリ転送を減らすことが高速化につながります。
Flash Attentionのアルゴリズム
タイリング(Tiling)
Flash Attention はタイリングという手法を使います。$\bm{Q}, \bm{K}, \bm{V}$ を小さなブロック(タイル)に分割し、各ブロックをSRAMに載せて計算します。
ブロックサイズを $B_r$(Query方向)と $B_c$(Key/Value方向)とします。
Q を (N/B_r) 個のブロックに分割: Q_1, Q_2, ..., Q_{N/B_r}
K を (N/B_c) 個のブロックに分割: K_1, K_2, ..., K_{N/B_c}
V を (N/B_c) 個のブロックに分割: V_1, V_2, ..., V_{N/B_c}
アルゴリズムの流れ
for i = 1 to N/B_r:
Q_i を HBM から SRAM にロード
O_i = 0, m_i = -inf, l_i = 0 を初期化
for j = 1 to N/B_c:
K_j, V_j を HBM から SRAM にロード
# ブロック単位でアテンション計算
S_ij = Q_i @ K_j^T / sqrt(d)
# オンラインsoftmaxの更新
m_new = max(m_i, rowmax(S_ij))
l_new = exp(m_i - m_new) * l_i + rowsum(exp(S_ij - m_new))
O_i = exp(m_i - m_new) * O_i + exp(S_ij - m_new) @ V_j
m_i, l_i = m_new, l_new
# 最終正規化
O_i = O_i / l_i
O_i を HBM に書き込み
重要なのは、$N \times N$ の中間行列をHBMに保存しないことです。各ブロックの計算はSRAM内で完結し、最終結果のみをHBMに書き込みます。
オンラインsoftmax
通常のsoftmaxは、全要素を見てからでないと計算できません。
$$ \text{softmax}(\bm{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}} $$
しかし、Flash Attention はブロックごとに処理するため、オンラインsoftmaxが必要です。
オンラインsoftmaxの導出
ベクトル $\bm{x}$ を2つのブロック $\bm{x}^{(1)}, \bm{x}^{(2)}$ に分割した場合を考えます。
各ブロックの最大値と正規化係数: $$ m^{(1)} = \max(\bm{x}^{(1)}), \quad l^{(1)} = \sum_j e^{x_j^{(1)} – m^{(1)}} $$
$$ m^{(2)} = \max(\bm{x}^{(2)}), \quad l^{(2)} = \sum_j e^{x_j^{(2)} – m^{(2)}} $$
全体の最大値と正規化係数: $$ m = \max(m^{(1)}, m^{(2)}) $$
$$ l = e^{m^{(1)} – m} l^{(1)} + e^{m^{(2)} – m} l^{(2)} $$
最大値を引くのは数値安定性のためです。$e^{x}$ は $x$ が大きいとオーバーフローするため、最大値を引いて安定化します。
出力の増分更新
Valueの加重和も同様に増分更新できます。
ブロック1の処理後: $$ \bm{o}^{(1)} = \text{diag}(l^{(1)})^{-1} \exp(\bm{S}^{(1)} – m^{(1)}) \bm{V}^{(1)} $$
ブロック2を処理するとき、ブロック1の結果を更新: $$ \bm{o} = \text{diag}(l)^{-1} \left( e^{m^{(1)} – m} l^{(1)} \bm{o}^{(1)} + \exp(\bm{S}^{(2)} – m) \bm{V}^{(2)} \right) $$
メモリ使用量の分析
標準アテンション
$$ O(N^2 + Nd) $$
中間行列 $\bm{S}, \bm{P}$ が $N^2$ を占めます。
Flash Attention
$$ O(Nd) $$
$N \times N$ の中間行列を保存しないため、メモリ使用量は入出力のサイズのみ。
具体例
系列長 $N = 8192$、ヘッド次元 $d = 64$、float16の場合:
| 手法 | メモリ使用量 |
|---|---|
| 標準アテンション | $\approx 256$ MB |
| Flash Attention | $\approx 2$ MB |
100倍以上のメモリ削減です。
速度の改善
IO複雑度の比較
標準アテンションのHBMアクセス: $$ O(N^2 + Nd) $$
Flash AttentionのHBMアクセス: $$ O\left(\frac{N^2 d}{M}\right) $$
ここで $M$ はSRAMのサイズです。SRAMが大きいほど、HBMアクセスが減ります。
実測値
Flash Attention 2 の論文によると、A100 GPUでの性能:
| 系列長 | 標準アテンション | Flash Attention 2 | 高速化率 |
|---|---|---|---|
| 512 | 39 TFLOPs | 122 TFLOPs | 3.1x |
| 1024 | 42 TFLOPs | 149 TFLOPs | 3.5x |
| 2048 | 43 TFLOPs | 170 TFLOPs | 4.0x |
| 4096 | 44 TFLOPs | 187 TFLOPs | 4.2x |
系列長が長いほど、Flash Attention の効果が大きくなります。
擬似コードによる理解
標準アテンション
def standard_attention(Q, K, V):
"""
標準的なアテンション計算
Q, K, V: (N, d)
"""
# S = QK^T を計算し、HBMに保存
S = Q @ K.T / sqrt(d) # (N, N) をHBMに保存
# softmaxを計算し、HBMに保存
P = softmax(S, dim=-1) # (N, N) をHBMに保存
# 出力を計算
O = P @ V # (N, d)
return O
Flash Attention(概念的な擬似コード)
def flash_attention(Q, K, V, B_r, B_c):
"""
Flash Attention(概念的な擬似コード)
Q, K, V: (N, d)
B_r, B_c: ブロックサイズ
"""
N, d = Q.shape
O = zeros(N, d)
for i in range(0, N, B_r):
# Queryブロック
Q_i = Q[i:i+B_r] # SRAMにロード
m_i = full(B_r, -inf)
l_i = zeros(B_r)
O_i = zeros(B_r, d)
for j in range(0, N, B_c):
# Key/Valueブロック
K_j = K[j:j+B_c] # SRAMにロード
V_j = V[j:j+B_c] # SRAMにロード
# ブロック内でアテンション計算(SRAMで完結)
S_ij = Q_i @ K_j.T / sqrt(d) # (B_r, B_c)
# オンラインsoftmax更新
m_new = maximum(m_i, S_ij.max(dim=-1))
P_ij = exp(S_ij - m_new.unsqueeze(-1))
l_new = exp(m_i - m_new) * l_i + P_ij.sum(dim=-1)
O_i = exp(m_i - m_new).unsqueeze(-1) * O_i + P_ij @ V_j
m_i = m_new
l_i = l_new
# 正規化してHBMに書き込み
O[i:i+B_r] = O_i / l_i.unsqueeze(-1)
return O
重要なのは、$(N, N)$ サイズの行列が一度もHBMに書き込まれないことです。
Flash Attention 2 の改善
Flash Attention 2 では、さらなる最適化が行われました。
並列化の改善
- Flash Attention 1: Keyブロックを順次処理
- Flash Attention 2: 複数のブロックを並列処理
ワークの分割
GPU上のワーカー(ワープ)間でより効率的にワークを分割し、idle時間を削減。
因果マスクの最適化
Causalマスクがある場合、計算が不要なブロックをスキップ。
PyTorchでの使用
torch.nn.functional.scaled_dot_product_attention
PyTorch 2.0以降、Flash Attention はネイティブでサポートされています。
import torch
import torch.nn.functional as F
# パラメータ
batch_size = 2
n_heads = 8
seq_len = 1024
head_dim = 64
# 入力
Q = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')
K = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')
V = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda')
# Flash Attentionが自動的に使用される(条件を満たす場合)
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"出力形状: {output.shape}") # (2, 8, 1024, 64)
使用条件
Flash Attention が使用されるための条件: – CUDAデバイス – 対応するGPU(Ampere以降推奨) – 適切なデータ型(float16, bfloat16)
flash-attn ライブラリ
より高度な機能(可変長バッチなど)には flash-attn ライブラリを使用できます。
# pip install flash-attn
from flash_attn import flash_attn_func
# 入力形状: (batch, seqlen, nheads, headdim)
q = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
output = flash_attn_func(q, k, v, causal=True)
可視化
メモリアクセスパターンの比較
import matplotlib.pyplot as plt
import numpy as np
# 概念的なメモリアクセスの可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 標準アテンション
ax = axes[0]
N = 8
access_pattern = np.ones((N, N))
im = ax.imshow(access_pattern, cmap='Reds', alpha=0.7)
ax.set_title('Standard Attention\n(Full N x N matrix in HBM)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_xticks(range(N))
ax.set_yticks(range(N))
# Flash Attention(タイリング)
ax = axes[1]
B = 2 # ブロックサイズ
access_pattern = np.zeros((N, N))
# 各ステップでアクセスされるブロックを示す
colors = plt.cm.Set3(np.linspace(0, 1, (N//B)**2))
for step, (i, j) in enumerate([(0, 0), (0, 2), (0, 4), (0, 6),
(2, 0), (2, 2), (2, 4), (2, 6)]):
for di in range(B):
for dj in range(B):
if i + di < N and j + dj < N:
access_pattern[i + di, j + dj] = step + 1
im = ax.imshow(access_pattern, cmap='tab20', alpha=0.7)
ax.set_title('Flash Attention\n(Blocks processed sequentially in SRAM)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_xticks(range(N))
ax.set_yticks(range(N))
# グリッドを追加
for i in range(0, N + 1, B):
ax.axhline(i - 0.5, color='black', linewidth=2)
ax.axvline(i - 0.5, color='black', linewidth=2)
plt.tight_layout()
plt.show()
速度比較(概念図)
import matplotlib.pyplot as plt
import numpy as np
seq_lengths = [512, 1024, 2048, 4096, 8192]
# 概念的な速度データ(実際の測定値を模した)
standard_time = [1.0, 4.2, 17.5, 72.0, 290.0] # O(N^2) に近い成長
flash_time = [0.5, 1.1, 2.5, 5.5, 12.0] # より緩やかな成長
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(seq_lengths, standard_time, 'o-', linewidth=2, markersize=8,
label='Standard Attention', color='coral')
ax.plot(seq_lengths, flash_time, 's-', linewidth=2, markersize=8,
label='Flash Attention', color='steelblue')
ax.set_xlabel('Sequence Length', fontsize=12)
ax.set_ylabel('Relative Time (lower is better)', fontsize=12)
ax.set_title('Speed Comparison: Standard vs Flash Attention', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
# 高速化率を注釈
for i, (s, f, sl) in enumerate(zip(standard_time, flash_time, seq_lengths)):
speedup = s / f
ax.annotate(f'{speedup:.1f}x', xy=(sl, f), xytext=(sl + 200, f * 0.7),
fontsize=9, ha='left')
plt.tight_layout()
plt.show()
まとめ
本記事では、Flash Attention のアルゴリズムと原理について解説しました。
- 問題: 標準アテンションは $N \times N$ の中間行列をHBMに保存するため、メモリとIOがボトルネック
- IO-awareness: 計算量ではなくメモリ転送量を最適化するという発想
- タイリング: Q, K, V を小さなブロックに分割し、SRAMで処理
- オンラインsoftmax: ブロックごとに処理しながら正しいsoftmaxを計算
- 効果: メモリ使用量を $O(N^2)$ から $O(N)$ に削減、速度2-4倍向上
Flash Attention は現代のLLMにおいて必須の技術です。PyTorch 2.0以降では scaled_dot_product_attention で自動的に使用されるため、直接実装する必要はありませんが、原理を理解しておくことで、より効果的なモデル設計が可能になります。
次のステップとして、以下の記事も参考にしてください。