投機的デコーディングの数理と実装

投機的デコーディング(Speculative Decoding)は、小さなドラフトモデルを使ってLLMの推論を高速化する手法です。Leviathan et al. (2022) と Chen et al. (2023) によって独立に提案され、出力品質を完全に維持しながら2-3倍の高速化を実現できます。

本記事では、投機的デコーディングの原理、検証アルゴリズム、高速化の理論的背景、実装の考え方を解説します。

本記事の内容

  • 自己回帰生成のボトルネック
  • 投機的デコーディングの基本アイデア
  • 検証アルゴリズムの数学
  • 高速化の理論的分析
  • PyTorchでの実装の考え方

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

自己回帰生成のボトルネック

メモリバウンド問題

LLMの推論はメモリバウンドです。各トークンの生成で、モデルの全重みをGPUメモリから読み込む必要があります。

1トークン生成の計算: $$ \text{計算量} = O(d^2 \cdot L) \quad (\text{パラメータ数に比例}) $$

しかし、計算自体は非常に高速で、メモリからパラメータを読み込む時間がボトルネックになります。

逐次生成の問題

自己回帰生成では、トークンを1つずつ順番に生成します。

時間 →
[Token 1] → [Token 2] → [Token 3] → ... → [Token N]

各ステップで全パラメータを読み込むため、N トークンの生成には N 回の読み込みが必要です。

バッチ処理の限界

複数のリクエストをバッチ処理すれば、パラメータ読み込みを共有できます。しかし、単一リクエストのレイテンシは改善しません。

投機的デコーディングの基本アイデア

概要

投機的デコーディングは、2つのモデルを使用します:

  1. ドラフトモデル(Draft Model): 小さく高速なモデル
  2. ターゲットモデル(Target Model): 大きく高品質なモデル(元のLLM)

基本的な流れ:

  1. ドラフトモデルで $K$ 個のトークンを連続生成(投機的生成)
  2. ターゲットモデルで $K$ 個のトークンを並列に検証
  3. 検証に成功したトークンを採用、失敗した位置から再開

なぜ高速化できるのか

ターゲットモデルは、$K$ 個のトークンを1回の順伝播で検証できます。

通常の生成($K$ トークン): – ターゲットモデルの順伝播: $K$ 回

投機的デコーディング($K$ トークン): – ドラフトモデルの順伝播: $K$ 回(高速) – ターゲットモデルの順伝播: 1-2 回

ドラフトモデルが十分高速で、受理率が高ければ、全体として高速化できます。

重要な性質:出力分布の保存

投機的デコーディングは、ターゲットモデルと完全に同じ出力分布を維持します。つまり、品質の劣化は一切ありません。

これは、検証アルゴリズムが確率的に設計されているためです。

検証アルゴリズム

セットアップ

  • ドラフトモデルの確率分布: $q(x)$
  • ターゲットモデルの確率分布: $p(x)$
  • ドラフトモデルが生成したトークン: $\tilde{x}$

採択・棄却サンプリング

投機的デコーディングの検証は、採択・棄却サンプリング(Rejection Sampling)に基づいています。

アルゴリズム:

  1. ドラフトモデルからトークン $\tilde{x} \sim q(x)$ をサンプリング
  2. 一様乱数 $u \sim U(0, 1)$ をサンプリング
  3. 以下の条件で採択/棄却:

$$ \text{採択} \iff u < \min\left(1, \frac{p(\tilde{x})}{q(\tilde{x})}\right) $$

  1. 棄却された場合、修正分布からリサンプリング:

$$ p'(x) = \frac{\max(0, p(x) – q(x))}{\sum_y \max(0, p(y) – q(y))} $$

数学的証明:分布の保存

このアルゴリズムがターゲット分布 $p(x)$ を正確に再現することを証明します。

トークン $x$ が最終的に選ばれる確率:

$$ P(\text{output} = x) = P(\tilde{x} = x, \text{採択}) + P(\text{棄却}) \cdot P(\text{リサンプル} = x) $$

採択確率:

$$ P(\tilde{x} = x, \text{採択}) = q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right) $$

$q(x) \leq p(x)$ の場合: $q(x) \cdot 1 = q(x)$

$q(x) > p(x)$ の場合: $q(x) \cdot \frac{p(x)}{q(x)} = p(x)$

まとめると: $\min(p(x), q(x))$

棄却確率:

$$ P(\text{棄却}) = 1 – \sum_x \min(p(x), q(x)) = \sum_x \max(0, p(x) – q(x)) $$

リサンプル確率:

$$ p'(x) = \frac{\max(0, p(x) – q(x))}{\sum_y \max(0, p(y) – q(y))} $$

合計:

$$ P(\text{output} = x) = \min(p(x), q(x)) + \sum_y \max(0, p(y) – q(y)) \cdot \frac{\max(0, p(x) – q(x))}{\sum_z \max(0, p(z) – q(z))} $$

$$ = \min(p(x), q(x)) + \max(0, p(x) – q(x)) = p(x) $$

したがって、出力は正確に $p(x)$ に従います。

複数トークンの検証

$K$ 個のドラフトトークン $\tilde{x}_1, \tilde{x}_2, \ldots, \tilde{x}_K$ を順番に検証します。

for i = 1 to K:
    if 採択(x_i):
        出力に追加
    else:
        リサンプルして出力に追加
        break  # 残りのドラフトは破棄

最初の棄却が発生した位置で打ち切り、リサンプルしたトークンを1つ追加します。

高速化の分析

受理率

受理率 $\alpha$ は、ドラフトトークンが採択される確率の期待値です。

$$ \alpha = \mathbb{E}_{\tilde{x} \sim q}\left[\min\left(1, \frac{p(\tilde{x})}{q(\tilde{x})}\right)\right] = \sum_x \min(p(x), q(x)) $$

これは $p$ と $q$ の重なり度合いを表します。ドラフトモデルがターゲットモデルをよく近似するほど、$\alpha$ は1に近づきます。

期待トークン数

$K$ 個のドラフトを生成した場合、期待される受理トークン数:

$$ \mathbb{E}[\text{accepted tokens}] = \sum_{i=1}^{K} \alpha^{i-1} (1 – \alpha) \cdot i + \alpha^K \cdot K = \frac{1 – \alpha^{K+1}}{1 – \alpha} $$

常に少なくとも1つは出力される(棄却されてもリサンプルするため)ので:

$$ \mathbb{E}[\text{output tokens}] = \frac{1 – \alpha^{K+1}}{1 – \alpha} $$

高速化率

ドラフトモデルの速度を $c$ 倍($c > 1$ でドラフトが速い)とすると:

$$ \text{Speedup} = \frac{K}{K/c + 1} \cdot \frac{1 – \alpha^{K+1}}{1 – \alpha} \cdot \frac{1}{K} $$

簡略化すると($c \gg 1$ の場合):

$$ \text{Speedup} \approx \frac{1 – \alpha^{K+1}}{1 – \alpha} $$

$\alpha = 0.7$, $K = 5$ の場合:

$$ \text{Speedup} \approx \frac{1 – 0.7^6}{0.3} \approx 2.8 $$

PyTorchでの実装の考え方

基本的な実装

import torch
import torch.nn.functional as F
from typing import Tuple


def speculative_decode(
    target_model,
    draft_model,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    num_speculative: int = 4,
) -> torch.Tensor:
    """
    投機的デコーディング

    Args:
        target_model: ターゲットモデル(大きなLLM)
        draft_model: ドラフトモデル(小さなLLM)
        input_ids: 入力トークン (1, seq_len)
        max_new_tokens: 生成する最大トークン数
        num_speculative: 投機的に生成するトークン数

    Returns:
        output_ids: 生成されたトークン列
    """
    device = input_ids.device
    output_ids = input_ids.clone()
    generated = 0

    while generated < max_new_tokens:
        # --- ドラフト生成フェーズ ---
        draft_tokens = []
        draft_probs = []
        current_ids = output_ids.clone()

        for _ in range(num_speculative):
            with torch.no_grad():
                logits = draft_model(current_ids)[:, -1, :]
                probs = F.softmax(logits, dim=-1)

            # サンプリング
            next_token = torch.multinomial(probs, num_samples=1)
            draft_tokens.append(next_token)
            draft_probs.append(probs)

            current_ids = torch.cat([current_ids, next_token], dim=1)

        # --- 検証フェーズ ---
        # ターゲットモデルで全ドラフトトークンを並列に検証
        verify_ids = torch.cat([output_ids] + draft_tokens, dim=1)

        with torch.no_grad():
            target_logits = target_model(verify_ids)

        # 各位置でターゲット確率を取得
        start_pos = output_ids.shape[1] - 1
        accepted_count = 0

        for i, (draft_token, draft_prob) in enumerate(zip(draft_tokens, draft_probs)):
            pos = start_pos + i
            target_probs = F.softmax(target_logits[:, pos, :], dim=-1)

            token_id = draft_token.item()
            p = target_probs[0, token_id].item()
            q = draft_prob[0, token_id].item()

            # 採択判定
            acceptance_prob = min(1.0, p / max(q, 1e-10))
            u = torch.rand(1).item()

            if u < acceptance_prob:
                # 採択
                output_ids = torch.cat([output_ids, draft_token], dim=1)
                accepted_count += 1
                generated += 1

                if generated >= max_new_tokens:
                    break
            else:
                # 棄却 - 修正分布からリサンプル
                adjusted_probs = torch.clamp(target_probs - draft_prob, min=0)
                adjusted_probs = adjusted_probs / adjusted_probs.sum()

                new_token = torch.multinomial(adjusted_probs, num_samples=1)
                output_ids = torch.cat([output_ids, new_token], dim=1)
                generated += 1
                break

        # 全て採択された場合、追加で1トークンサンプル
        if accepted_count == num_speculative and generated < max_new_tokens:
            target_probs = F.softmax(target_logits[:, -1, :], dim=-1)
            new_token = torch.multinomial(target_probs, num_samples=1)
            output_ids = torch.cat([output_ids, new_token], dim=1)
            generated += 1

    return output_ids


# ダミーモデルでのテスト
class DummyLM(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, temperature=1.0):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.fc = torch.nn.Linear(hidden_size, vocab_size)
        self.temperature = temperature

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = x.mean(dim=1, keepdim=True).expand(-1, input_ids.shape[1], -1)
        return self.fc(x) / self.temperature


vocab_size = 100
target = DummyLM(vocab_size, 256, temperature=1.0)
draft = DummyLM(vocab_size, 64, temperature=0.8)

input_ids = torch.tensor([[1, 2, 3]])
output = speculative_decode(target, draft, input_ids, max_new_tokens=20, num_speculative=4)
print(f"生成されたトークン数: {output.shape[1] - input_ids.shape[1]}")

受理率の計測

def measure_acceptance_rate(
    target_model,
    draft_model,
    input_ids: torch.Tensor,
    num_samples: int = 100,
) -> float:
    """受理率を計測"""
    accepted = 0
    total = 0

    for _ in range(num_samples):
        # ドラフトモデルでサンプリング
        with torch.no_grad():
            draft_logits = draft_model(input_ids)[:, -1, :]
            draft_probs = F.softmax(draft_logits, dim=-1)
            draft_token = torch.multinomial(draft_probs, num_samples=1)

            target_logits = target_model(input_ids)[:, -1, :]
            target_probs = F.softmax(target_logits, dim=-1)

        token_id = draft_token.item()
        p = target_probs[0, token_id].item()
        q = draft_probs[0, token_id].item()

        acceptance_prob = min(1.0, p / max(q, 1e-10))
        accepted += acceptance_prob
        total += 1

    return accepted / total


# 受理率を計測
acceptance_rate = measure_acceptance_rate(target, draft, input_ids)
print(f"推定受理率: {acceptance_rate:.3f}")

ドラフトモデルの選択

選択基準

  1. 速度: ターゲットモデルより十分に高速(10倍以上が理想)
  2. 品質: ターゲットモデルと似た分布を出力
  3. 互換性: 同じ語彙とトークナイザを使用

選択肢

  1. 同系列の小さなモデル: Llama 70B → Llama 7B
  2. 量子化されたターゲットモデル: FP16 → INT4
  3. 蒸留されたモデル: ターゲットから蒸留した専用モデル
  4. n-gramモデル: 非常に高速だが受理率は低い

可視化

受理率と高速化の関係

import matplotlib.pyplot as plt
import numpy as np

# 受理率と高速化率の関係
alphas = np.linspace(0.1, 0.95, 50)
Ks = [2, 4, 6, 8]

fig, ax = plt.subplots(figsize=(10, 6))

for K in Ks:
    speedups = (1 - alphas ** (K + 1)) / (1 - alphas)
    ax.plot(alphas, speedups, linewidth=2, label=f'K = {K}')

ax.set_xlabel('Acceptance Rate (alpha)', fontsize=12)
ax.set_ylabel('Expected Speedup', fontsize=12)
ax.set_title('Speculative Decoding: Speedup vs Acceptance Rate', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.1, 0.95)
ax.set_ylim(1, 6)

plt.tight_layout()
plt.show()

デコーディングプロセスの可視化

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

fig, axes = plt.subplots(2, 1, figsize=(14, 6))

# 通常のデコーディング
ax = axes[0]
for i in range(8):
    ax.add_patch(mpatches.Rectangle((i, 0), 0.9, 0.8, color='steelblue', alpha=0.7))
    ax.text(i + 0.45, 0.4, f'T{i+1}', ha='center', va='center', fontsize=10, color='white')
ax.set_xlim(-0.5, 8.5)
ax.set_ylim(-0.2, 1)
ax.set_title('Standard Autoregressive Decoding (8 target model calls)', fontsize=12)
ax.axis('off')

# 投機的デコーディング
ax = axes[1]
# ラウンド1: ドラフト + 検証
for i in range(4):
    ax.add_patch(mpatches.Rectangle((i * 0.4, 0), 0.35, 0.8, color='coral', alpha=0.7))
    ax.text(i * 0.4 + 0.175, 0.4, f'D{i+1}', ha='center', va='center', fontsize=8, color='white')
ax.add_patch(mpatches.Rectangle((1.8, 0), 0.9, 0.8, color='steelblue', alpha=0.7))
ax.text(2.25, 0.4, 'Verify', ha='center', va='center', fontsize=10, color='white')

# 採択結果
ax.add_patch(mpatches.Rectangle((3, 0), 0.35, 0.8, color='seagreen', alpha=0.7))
ax.add_patch(mpatches.Rectangle((3.4, 0), 0.35, 0.8, color='seagreen', alpha=0.7))
ax.add_patch(mpatches.Rectangle((3.8, 0), 0.35, 0.8, color='seagreen', alpha=0.7))
ax.add_patch(mpatches.Rectangle((4.2, 0), 0.35, 0.8, color='indianred', alpha=0.7))
ax.text(3.175, 0.4, 'A', ha='center', va='center', fontsize=8, color='white')
ax.text(3.575, 0.4, 'A', ha='center', va='center', fontsize=8, color='white')
ax.text(3.975, 0.4, 'A', ha='center', va='center', fontsize=8, color='white')
ax.text(4.375, 0.4, 'R', ha='center', va='center', fontsize=8, color='white')

# ラウンド2
for i in range(4):
    ax.add_patch(mpatches.Rectangle((5 + i * 0.4, 0), 0.35, 0.8, color='coral', alpha=0.7))
ax.add_patch(mpatches.Rectangle((6.8, 0), 0.9, 0.8, color='steelblue', alpha=0.7))
ax.text(7.25, 0.4, 'Verify', ha='center', va='center', fontsize=10, color='white')

ax.set_xlim(-0.5, 8.5)
ax.set_ylim(-0.2, 1)
ax.set_title('Speculative Decoding (2 target model calls, 3+? tokens)', fontsize=12)
ax.axis('off')

# 凡例
legend_elements = [
    mpatches.Patch(color='coral', alpha=0.7, label='Draft model'),
    mpatches.Patch(color='steelblue', alpha=0.7, label='Target model'),
    mpatches.Patch(color='seagreen', alpha=0.7, label='Accepted'),
    mpatches.Patch(color='indianred', alpha=0.7, label='Rejected'),
]
fig.legend(handles=legend_elements, loc='lower center', ncol=4, fontsize=10)

plt.tight_layout()
plt.subplots_adjust(bottom=0.15)
plt.show()

まとめ

本記事では、投機的デコーディングの仕組みについて解説しました。

  • 基本アイデア: 小さなドラフトモデルで投機的に複数トークンを生成し、大きなターゲットモデルで並列に検証
  • 採択・棄却サンプリング: 確率比に基づく採択判定により、ターゲット分布を正確に再現
  • 出力分布の保存: 品質の劣化なく高速化を実現
  • 高速化率: 受理率 $\alpha$ と投機トークン数 $K$ に依存。$\alpha = 0.7$, $K = 5$ で約2.8倍
  • ドラフトモデルの選択: 速度と品質(受理率)のバランスが重要

投機的デコーディングは、出力品質を完全に維持しながら推論を高速化できる強力な手法です。ドラフトモデルの選択と $K$ の調整により、様々な状況で効果を発揮できます。

次のステップとして、以下の記事も参考にしてください。