KVキャッシュ(Key-Value Cache)は、大規模言語モデル(LLM)の推論を高速化するための重要な技術です。GPT系のモデルがテキストを生成する際、すでに計算したKeyとValueを再利用することで、計算量を大幅に削減できます。
LLMの推論速度がなぜ遅いのか、そしてKVキャッシュがどのようにこの問題を解決するのかを理解することは、LLMの効率的な運用に欠かせません。本記事では、KVキャッシュの数学的な背景から実装まで、段階的に解説します。
本記事の内容
- 自己回帰生成における冗長な計算の問題
- KVキャッシュの基本原理と数式
- メモリ使用量の計算
- PyTorchでのKVキャッシュ実装
- KVキャッシュの最適化手法
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
自己回帰生成の問題点
GPTの生成プロセス
GPTなどのDecoder-onlyモデルは、テキストを自己回帰的(autoregressive)に生成します。つまり、1トークンずつ順番に生成し、各ステップで過去に生成したすべてのトークンを条件として次のトークンを予測します。
$$ P(x_t \mid x_1, x_2, \ldots, x_{t-1}) $$
長さ $T$ のテキストを生成するには、このプロセスを $T$ 回繰り返す必要があります。
素朴な実装の問題
素朴な実装では、各生成ステップで過去のすべてのトークンに対してSelf-Attentionを再計算します。時刻 $t$ での計算量は $O(t^2 d)$ であり、長さ $T$ のテキスト全体を生成する総計算量は次のようになります。
$$ \sum_{t=1}^{T} O(t^2 d) = O(T^3 d) $$
これは非常に非効率です。例えば、$t=100$ のとき、過去99トークン分のKeyとValueは前のステップで計算済みであるにもかかわらず、毎回再計算されています。
冗長な計算の可視化
時刻 $t$ と $t+1$ での計算を比較してみましょう。
時刻 $t$ での計算: – Query: $\bm{q}_1, \bm{q}_2, \ldots, \bm{q}_t$ – Key: $\bm{k}_1, \bm{k}_2, \ldots, \bm{k}_t$ – Value: $\bm{v}_1, \bm{v}_2, \ldots, \bm{v}_t$
時刻 $t+1$ での計算: – Query: $\bm{q}_1, \bm{q}_2, \ldots, \bm{q}_t, \bm{q}_{t+1}$ – Key: $\bm{k}_1, \bm{k}_2, \ldots, \bm{k}_t, \bm{k}_{t+1}$ – Value: $\bm{v}_1, \bm{v}_2, \ldots, \bm{v}_t, \bm{v}_{t+1}$
$\bm{k}_1, \ldots, \bm{k}_t$ と $\bm{v}_1, \ldots, \bm{v}_t$ は前のステップと全く同じです。この冗長な計算を省くのがKVキャッシュです。
KVキャッシュの原理
基本アイデア
KVキャッシュの基本アイデアは単純です。一度計算したKeyとValueをメモリに保存し、次のステップで再利用するというものです。
各生成ステップでは、新しいトークン $x_t$ に対応するKeyとValueのみを計算し、キャッシュに追加します。
$$ \bm{K}_{\text{cache}} \leftarrow \text{concat}(\bm{K}_{\text{cache}}, \bm{k}_t) $$
$$ \bm{V}_{\text{cache}} \leftarrow \text{concat}(\bm{V}_{\text{cache}}, \bm{v}_t) $$
Attention計算では、キャッシュされた全KeyとValueを使います。
$$ \text{Attention}(\bm{q}_t, \bm{K}_{\text{cache}}, \bm{V}_{\text{cache}}) = \text{softmax}\left(\frac{\bm{q}_t \bm{K}_{\text{cache}}^\top}{\sqrt{d_k}}\right) \bm{V}_{\text{cache}} $$
計算量の削減
KVキャッシュを使用した場合、時刻 $t$ での計算量は以下のようになります。
新しいトークンのKey/Value計算: $$ O(d^2) $$
Attention計算(1つのQueryに対して $t$ 個のKey/Value): $$ O(t \cdot d) $$
したがって、長さ $T$ のテキスト全体を生成する総計算量は次のようになります。
$$ \sum_{t=1}^{T} O(d^2 + t \cdot d) = O(T d^2 + T^2 d) $$
素朴な実装の $O(T^3 d)$ と比較すると、大幅な削減です。特に $T \gg d$ の場合(長いテキスト生成)で効果が顕著です。
数式による厳密な理解
Self-Attentionの計算を改めて確認しましょう。入力系列 $\bm{X} \in \mathbb{R}^{n \times d}$ に対して、
$$ \bm{Q} = \bm{X} \bm{W}^Q, \quad \bm{K} = \bm{X} \bm{W}^K, \quad \bm{V} = \bm{X} \bm{W}^V $$
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$
ここで重要なのは、Query・Key・Valueは入力の各トークンに対して独立に計算されるという点です。つまり、トークン $i$ のQuery/Key/Valueは、他のトークンの影響を受けません。
$$ \bm{q}_i = \bm{x}_i \bm{W}^Q, \quad \bm{k}_i = \bm{x}_i \bm{W}^K, \quad \bm{v}_i = \bm{x}_i \bm{W}^V $$
この独立性により、過去のトークンのKey/Valueを再計算する必要がなく、キャッシュして再利用できるのです。
メモリ使用量
KVキャッシュのメモリ要件
KVキャッシュのメモリ使用量は、モデルサイズと系列長に依存します。
各レイヤーのKVキャッシュサイズは以下のように計算されます。
$$ \text{Memory}_{\text{layer}} = 2 \times \text{batch\_size} \times \text{seq\_len} \times d_{\text{model}} \times \text{sizeof(dtype)} $$
ここで、$2$ はKeyとValueの両方を保存するためです。
全レイヤーのKVキャッシュサイズは以下のようになります。
$$ \text{Memory}_{\text{total}} = n_{\text{layers}} \times \text{Memory}_{\text{layer}} $$
具体例:Llama 2 7B
Llama 2 7Bの場合のKVキャッシュサイズを計算してみましょう。
- レイヤー数: $n_{\text{layers}} = 32$
- 隠れ層次元: $d_{\text{model}} = 4096$
- Key/Value次元(GQA後): $d_{kv} = 4096 / 8 = 512$(8 KV heads)
- バッチサイズ: $B = 1$
- 系列長: $T = 4096$
- データ型: float16 (2 bytes)
$$ \text{Memory} = 32 \times 2 \times 1 \times 4096 \times 512 \times 2 \text{ bytes} = 268 \text{ MB} $$
系列長が長くなると、KVキャッシュのメモリ使用量も線形に増加します。バッチサイズを増やすと更に増加するため、大規模サービスではKVキャッシュのメモリ管理が重要な課題となります。
PyTorchでの実装
シンプルなKVキャッシュの実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CausalSelfAttentionWithKVCache(nn.Module):
"""KVキャッシュ付きCausal Self-Attention"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Q, K, V の射影
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, kv_cache=None, use_cache=False):
"""
Args:
x: (batch_size, seq_len, d_model) - 新しいトークンの埋め込み
kv_cache: タプル (cached_k, cached_v) または None
use_cache: キャッシュを返すかどうか
Returns:
output: (batch_size, seq_len, d_model)
new_kv_cache: (new_k, new_v) if use_cache else None
"""
batch_size, seq_len, _ = x.size()
# Q, K, V を計算
Q = self.W_q(x) # (B, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# ヘッドに分割: (B, seq_len, d_model) -> (B, n_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# KVキャッシュを結合
if kv_cache is not None:
cached_k, cached_v = kv_cache
# キャッシュされたK, Vと新しいK, Vを結合
K = torch.cat([cached_k, K], dim=2)
V = torch.cat([cached_v, V], dim=2)
# Attention計算
# Q: (B, n_heads, seq_len, d_k)
# K: (B, n_heads, total_len, d_k)
total_len = K.size(2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (B, n_heads, seq_len, total_len)
# 因果マスク(新しいトークンが過去のトークンのみを参照)
if seq_len > 1:
# 学習時や複数トークン処理時
mask = torch.triu(
torch.ones(seq_len, total_len, device=x.device),
diagonal=total_len - seq_len + 1
).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Valueの加重和
output = torch.matmul(attn_weights, V) # (B, n_heads, seq_len, d_k)
# ヘッドを結合
output = output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
output = self.W_o(output)
if use_cache:
return output, (K, V)
return output, None
使用例:自己回帰生成
# パラメータ設定
d_model = 256
n_heads = 8
vocab_size = 1000
# モデルの初期化
attn = CausalSelfAttentionWithKVCache(d_model, n_heads)
embedding = nn.Embedding(vocab_size, d_model)
# 初期トークン(プロンプト)
torch.manual_seed(42)
prompt_tokens = torch.randint(0, vocab_size, (1, 5)) # (batch=1, seq=5)
print(f"プロンプト: {prompt_tokens.shape}")
# プロンプトの処理(キャッシュを構築)
x = embedding(prompt_tokens)
output, kv_cache = attn(x, kv_cache=None, use_cache=True)
print(f"初期KVキャッシュのKの形状: {kv_cache[0].shape}")
# (1, 8, 5, 32) = (batch, heads, seq_len, d_k)
# 自己回帰生成のシミュレーション
for step in range(10):
# 新しいトークンを生成(ここではランダム)
new_token = torch.randint(0, vocab_size, (1, 1))
# 新しいトークンのみを処理
new_x = embedding(new_token)
output, kv_cache = attn(new_x, kv_cache=kv_cache, use_cache=True)
print(f"ステップ {step+1}: KVキャッシュの長さ = {kv_cache[0].shape[2]}")
# 最終的なKVキャッシュの長さ
print(f"最終KVキャッシュのKの形状: {kv_cache[0].shape}")
# (1, 8, 15, 32) = 初期5 + 生成10 = 15トークン
出力例:
プロンプト: torch.Size([1, 5])
初期KVキャッシュのKの形状: torch.Size([1, 8, 5, 32])
ステップ 1: KVキャッシュの長さ = 6
ステップ 2: KVキャッシュの長さ = 7
...
ステップ 10: KVキャッシュの長さ = 15
最終KVキャッシュのKの形状: torch.Size([1, 8, 15, 32])
KVキャッシュの最適化手法
Grouped-Query Attention(GQA)
KVキャッシュのメモリ使用量を削減する手法の1つがGrouped-Query Attention(GQA)です。
通常のMulti-Head Attentionでは、すべてのヘッドが独自のK/Vを持ちます。
$$ \text{MHA: } h \text{ Query heads}, \quad h \text{ Key heads}, \quad h \text{ Value heads} $$
GQAでは、複数のQueryヘッドで1組のK/Vを共有します。
$$ \text{GQA: } h \text{ Query heads}, \quad g \text{ Key heads}, \quad g \text{ Value heads} $$
ここで $g < h$ であり、$h$ は $g$ で割り切れる必要があります。例えば、$h=32$、$g=8$ の場合、4つのQueryヘッドが1組のK/Vを共有します。
KVキャッシュのメモリ使用量は $\frac{g}{h}$ 倍に削減されます。
PagedAttention
PagedAttentionは、vLLMで導入されたメモリ管理手法です。KVキャッシュを固定サイズのブロック(ページ)に分割し、OSの仮想メモリのように管理します。
利点: – メモリの断片化を防止 – 動的なバッチ処理が可能 – メモリ使用率の向上
KVキャッシュの圧縮
長い系列でのKVキャッシュのメモリ爆発を防ぐため、以下の手法が研究されています。
- Sliding Window Attention: 最近の $w$ トークンのみをキャッシュ
- H2O (Heavy-Hitter Oracle): 重要なトークンのみを選択的にキャッシュ
- 量子化: KVキャッシュをINT8やINT4に量子化
可視化:キャッシュの成長
import matplotlib.pyplot as plt
import numpy as np
# KVキャッシュのメモリ使用量をシミュレート
def calculate_kv_cache_memory(
seq_len, d_model, n_layers, n_kv_heads, batch_size=1, dtype_bytes=2
):
"""KVキャッシュのメモリ使用量(MB)を計算"""
d_k = d_model // n_kv_heads if n_kv_heads else d_model // 8
memory = 2 * batch_size * seq_len * n_kv_heads * d_k * n_layers * dtype_bytes
return memory / (1024 ** 2) # MB
# パラメータ(Llama 2 7B相当)
d_model = 4096
n_layers = 32
n_kv_heads = 8 # GQA
seq_lengths = np.arange(0, 8192, 100)
# 異なるバッチサイズでのメモリ使用量
batch_sizes = [1, 4, 16, 64]
plt.figure(figsize=(10, 6))
for batch_size in batch_sizes:
memory = [
calculate_kv_cache_memory(seq, d_model, n_layers, n_kv_heads, batch_size)
for seq in seq_lengths
]
plt.plot(seq_lengths, memory, label=f'Batch size = {batch_size}', linewidth=2)
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('KV Cache Memory (MB)', fontsize=12)
plt.title('KV Cache Memory Usage vs Sequence Length', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
このグラフは、系列長とバッチサイズに応じてKVキャッシュのメモリ使用量が線形に増加することを示しています。
まとめ
本記事では、KVキャッシュの仕組みと実装について解説しました。
- 自己回帰生成の問題: 素朴な実装では過去トークンのKey/Valueを毎回再計算し、$O(T^3 d)$ の計算量がかかる
- KVキャッシュの原理: 一度計算したKey/Valueをメモリに保存し再利用することで、計算量を $O(T d^2 + T^2 d)$ に削減
- メモリ使用量: 系列長とバッチサイズに比例して増加し、大規模サービスでは重要な課題となる
- 最適化手法: Grouped-Query Attention、PagedAttention、量子化などでメモリ使用量を削減可能
KVキャッシュは、LLMの推論を実用的な速度で行うための必須技術です。次のステップとして、Flash Attentionやモデル並列化など、さらなる高速化手法を学ぶとよいでしょう。
次のステップとして、以下の記事も参考にしてください。