ビームサーチの理論とPython実装

ビームサーチ(Beam Search)は、機械翻訳やテキスト生成において広く使われる探索アルゴリズムです。言語モデルが次のトークンを予測する際、最も確率の高いトークンだけを選ぶ貪欲法(Greedy Decoding)では、全体として最適な系列を見つけられないことがあります。ビームサーチは、複数の候補を並列に探索することで、この問題を緩和します。

本記事では、ビームサーチの数学的な定義から実装まで、段階的に解説します。

本記事の内容

  • 貪欲法の限界とビームサーチの必要性
  • ビームサーチのアルゴリズム
  • 長さ正規化とカバレッジペナルティ
  • PyTorchでの実装
  • ビーム幅と生成品質のトレードオフ

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

テキスト生成の問題設定

自己回帰生成

言語モデルによるテキスト生成は、条件付き確率の連鎖として定式化されます。入力(プロンプト)$\bm{x}$ が与えられたとき、出力系列 $\bm{y} = (y_1, y_2, \ldots, y_T)$ の確率は次のように分解されます。

$$ P(\bm{y} \mid \bm{x}) = \prod_{t=1}^{T} P(y_t \mid y_1, y_2, \ldots, y_{t-1}, \bm{x}) $$

目標は、この確率を最大化する系列 $\bm{y}^*$ を見つけることです。

$$ \bm{y}^* = \arg\max_{\bm{y}} P(\bm{y} \mid \bm{x}) $$

探索空間の爆発

語彙サイズを $V$、最大系列長を $T$ とすると、可能な系列の数は $V^T$ 通りあります。例えば、$V = 50000$、$T = 100$ の場合、探索空間は $50000^{100}$ という天文学的な数になります。すべての系列を列挙して最適解を見つけることは不可能です。

貪欲法(Greedy Decoding)

アルゴリズム

最も単純な方法は、各ステップで最も確率の高いトークンを選ぶ貪欲法(Greedy Decoding)です。

$$ y_t = \arg\max_{y} P(y \mid y_1, \ldots, y_{t-1}, \bm{x}) $$

Step 1: y_1 = argmax P(y | x)
Step 2: y_2 = argmax P(y | y_1, x)
Step 3: y_3 = argmax P(y | y_1, y_2, x)
...

貪欲法の限界

貪欲法は計算効率が良いですが、局所最適に陥りやすいという問題があります。

例を考えましょう。以下のような確率分布があるとします。

P("I") = 0.6,   P("We") = 0.4
P("am" | "I") = 0.9,   P("go" | "I") = 0.1
P("are" | "We") = 0.8,  P("go" | "We") = 0.2

貪欲法では: – ステップ1: $P(\text{“I”}) = 0.6 > P(\text{“We”}) = 0.4$ なので “I” を選択 – ステップ2: $P(\text{“am”} \mid \text{“I”}) = 0.9$ なので “am” を選択 – 結果: “I am” で $P = 0.6 \times 0.9 = 0.54$

しかし、別の経路を見ると: – “We are” の確率: $P = 0.4 \times 0.8 = 0.32$

この例では貪欲法が最適ですが、もし以下のような場合はどうでしょう:

P("I") = 0.5,   P("We") = 0.5
P("am" | "I") = 0.3,   P("love" | "I") = 0.3
P("are" | "We") = 0.9

貪欲法(どちらかを選ぶ): – “I am”: $0.5 \times 0.3 = 0.15$ – “I love”: $0.5 \times 0.3 = 0.15$

最適解: – “We are”: $0.5 \times 0.9 = 0.45$

貪欲法で最初に “I” を選んでしまうと、全体として最適な “We are” に到達できません。

ビームサーチのアルゴリズム

基本アイデア

ビームサーチは、各ステップで上位 $k$ 個の候補(ビーム)を保持しながら探索を進めます。$k$ をビーム幅(beam width)と呼びます。

アルゴリズムの流れ

初期化: – ビーム $\mathcal{B}_0 = \{\text{}\}$(開始トークン)

各ステップ $t$ で:

  1. 現在のビーム内の各候補 $\bm{y}_{1:t-1}$ に対して、すべての可能な次トークン $y_t$ の確率を計算

$$ \text{score}(\bm{y}_{1:t}) = \log P(\bm{y}_{1:t} \mid \bm{x}) = \sum_{i=1}^{t} \log P(y_i \mid \bm{y}_{1:i-1}, \bm{x}) $$

  1. 全候補($k \times V$ 個)からスコア上位 $k$ 個を選択して新しいビーム $\mathcal{B}_t$ を形成

  2. 終了トークン を生成した候補は完成リストに追加

終了条件: – すべての候補が を生成 – または最大長に到達

具体例

ビーム幅 $k = 2$、語彙 $\{A, B, C, \text{}\}$ の例を見てみましょう。

ステップ1:

<BOS> → A (log P = -0.5)
<BOS> → B (log P = -0.7)
<BOS> → C (log P = -1.2)

上位2つを保持: $\mathcal{B}_1 = \{A, B\}$

ステップ2:

A → AA (log P = -0.5 + (-0.3) = -0.8)
A → AB (log P = -0.5 + (-0.6) = -1.1)
A → AC (log P = -0.5 + (-0.9) = -1.4)
B → BA (log P = -0.7 + (-0.4) = -1.1)
B → BB (log P = -0.7 + (-0.5) = -1.2)
B → BC (log P = -0.7 + (-0.8) = -1.5)

上位2つを保持: $\mathcal{B}_2 = \{AA, AB\}$ または $\{AA, BA\}$(スコアによる)

この過程を続け、最終的にスコア最大の完成系列を出力します。

数式による定式化

時刻 $t$ でのビーム $\mathcal{B}_t$ は、以下のように更新されます。

$$ \mathcal{B}_t = \text{top-}k\left(\bigcup_{\bm{y} \in \mathcal{B}_{t-1}} \left\{(\bm{y}, y_t) : y_t \in \mathcal{V}\right\}, \text{score}\right) $$

ここで $\mathcal{V}$ は語彙集合、$\text{top-}k$ はスコア上位 $k$ 個を選択する操作です。

長さ正規化

問題:短い系列へのバイアス

対数確率の和をスコアとして使う場合、短い系列が有利になります。各項 $\log P(y_t \mid \cdot)$ は負の値(確率 < 1)なので、項が増えるほどスコアは下がります。

$$ \log P(\bm{y}_{1:3}) = \log P(y_1) + \log P(y_2 \mid y_1) + \log P(y_3 \mid y_1, y_2) < \log P(\bm{y}_{1:2}) $$

解決策:長さ正規化

系列長で正規化したスコアを使います。

$$ \text{score}(\bm{y}) = \frac{1}{|\bm{y}|^\alpha} \sum_{t=1}^{|\bm{y}|} \log P(y_t \mid \bm{y}_{1:t-1}, \bm{x}) $$

ここで $\alpha$ は長さペナルティのハイパーパラメータです。

  • $\alpha = 0$: 正規化なし(短い系列を優先)
  • $\alpha = 1$: 平均対数確率
  • $\alpha > 1$: 長い系列を優先

Google Neural Machine Translation (GNMT) では $\alpha = 0.6 \sim 0.7$ が使われています。

より洗練された正規化として、以下の形式もあります。

$$ \text{lp}(|\bm{y}|) = \frac{(5 + |\bm{y}|)^\alpha}{(5 + 1)^\alpha} $$

PyTorchでの実装

シンプルなビームサーチ

import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Optional


@dataclass
class BeamHypothesis:
    """ビームサーチの1つの仮説(候補系列)"""
    tokens: List[int]  # トークンID列
    score: float       # 累積対数確率


def beam_search(
    model,
    encoder_output: torch.Tensor,
    bos_token_id: int,
    eos_token_id: int,
    beam_width: int = 5,
    max_length: int = 50,
    length_penalty: float = 0.6,
):
    """
    ビームサーチによるテキスト生成

    Args:
        model: デコーダーモデル(入力トークンとエンコーダ出力から次トークン確率を出力)
        encoder_output: エンコーダの出力 (1, src_len, d_model)
        bos_token_id: 開始トークンID
        eos_token_id: 終了トークンID
        beam_width: ビーム幅
        max_length: 最大生成長
        length_penalty: 長さペナルティ係数

    Returns:
        best_sequence: 最良の生成系列
    """
    device = encoder_output.device

    # 初期仮説
    beams = [BeamHypothesis(tokens=[bos_token_id], score=0.0)]
    completed = []

    for step in range(max_length):
        all_candidates = []

        for beam in beams:
            # 終了済みの仮説はスキップ
            if beam.tokens[-1] == eos_token_id:
                completed.append(beam)
                continue

            # 現在の系列をテンソルに変換
            input_ids = torch.tensor([beam.tokens], device=device)

            # 次トークンの確率を計算
            with torch.no_grad():
                logits = model(input_ids, encoder_output)  # (1, seq_len, vocab_size)
                next_token_logits = logits[0, -1, :]  # (vocab_size,)
                log_probs = F.log_softmax(next_token_logits, dim=-1)

            # 上位 beam_width 個の候補を取得
            top_log_probs, top_indices = torch.topk(log_probs, beam_width)

            for log_prob, token_id in zip(top_log_probs.tolist(), top_indices.tolist()):
                new_tokens = beam.tokens + [token_id]
                new_score = beam.score + log_prob
                all_candidates.append(BeamHypothesis(tokens=new_tokens, score=new_score))

        # 候補がない場合は終了
        if not all_candidates:
            break

        # スコア上位 beam_width 個を選択
        all_candidates.sort(key=lambda x: x.score, reverse=True)
        beams = all_candidates[:beam_width]

        # 全て完了したら終了
        if all(beam.tokens[-1] == eos_token_id for beam in beams):
            completed.extend(beams)
            break

    # 完了した仮説がない場合は現在のビームから選択
    if not completed:
        completed = beams

    # 長さ正規化スコアで最良の仮説を選択
    def normalized_score(hypothesis):
        length = len(hypothesis.tokens)
        return hypothesis.score / (length ** length_penalty)

    best = max(completed, key=normalized_score)
    return best.tokens


# ダミーモデルでテスト
class DummyDecoder:
    """テスト用のダミーデコーダー"""
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size

    def __call__(self, input_ids, encoder_output):
        batch_size, seq_len = input_ids.shape
        # ランダムなロジットを返す
        torch.manual_seed(42 + seq_len)  # 再現性のため
        return torch.randn(batch_size, seq_len, self.vocab_size)


# 使用例
vocab_size = 100
model = DummyDecoder(vocab_size)
encoder_output = torch.randn(1, 10, 256)

result = beam_search(
    model=model,
    encoder_output=encoder_output,
    bos_token_id=1,
    eos_token_id=2,
    beam_width=3,
    max_length=20,
    length_penalty=0.6,
)
print(f"生成されたトークンID列: {result}")
print(f"系列長: {len(result)}")

効率的な実装(バッチ処理)

実際のシステムでは、ビーム内の候補をバッチとして並列処理します。

import torch
import torch.nn.functional as F


def beam_search_batch(
    model,
    encoder_output: torch.Tensor,
    bos_token_id: int,
    eos_token_id: int,
    pad_token_id: int,
    beam_width: int = 5,
    max_length: int = 50,
    length_penalty: float = 0.6,
):
    """
    バッチ処理による効率的なビームサーチ

    Args:
        model: デコーダーモデル
        encoder_output: エンコーダ出力 (1, src_len, d_model)
        bos_token_id: 開始トークンID
        eos_token_id: 終了トークンID
        pad_token_id: パディングトークンID
        beam_width: ビーム幅
        max_length: 最大生成長
        length_penalty: 長さペナルティ

    Returns:
        best_sequence: 最良の生成系列 (list of token ids)
    """
    device = encoder_output.device
    vocab_size = model.vocab_size if hasattr(model, 'vocab_size') else 100

    # エンコーダ出力をビーム数分複製
    encoder_output = encoder_output.repeat(beam_width, 1, 1)  # (beam_width, src_len, d_model)

    # 初期化
    input_ids = torch.full((beam_width, 1), bos_token_id, dtype=torch.long, device=device)
    beam_scores = torch.zeros(beam_width, device=device)
    beam_scores[1:] = float('-inf')  # 最初は1つのビームのみアクティブ

    # 完了したビーム
    done = torch.zeros(beam_width, dtype=torch.bool, device=device)

    for step in range(max_length):
        # 次トークンの確率を計算
        with torch.no_grad():
            logits = model(input_ids, encoder_output)  # (beam_width, seq_len, vocab_size)
            next_token_logits = logits[:, -1, :]  # (beam_width, vocab_size)
            log_probs = F.log_softmax(next_token_logits, dim=-1)

        # 完了したビームは pad のみ許可
        log_probs[done] = float('-inf')
        log_probs[done, pad_token_id] = 0

        # 累積スコア
        next_scores = beam_scores.unsqueeze(1) + log_probs  # (beam_width, vocab_size)

        # 全候補から上位 beam_width 個を選択
        next_scores = next_scores.view(-1)  # (beam_width * vocab_size,)
        top_scores, top_indices = torch.topk(next_scores, beam_width)

        # ビームインデックスとトークンIDを復元
        beam_indices = top_indices // vocab_size
        token_ids = top_indices % vocab_size

        # 更新
        input_ids = torch.cat([
            input_ids[beam_indices],
            token_ids.unsqueeze(1)
        ], dim=1)
        beam_scores = top_scores
        done = done[beam_indices] | (token_ids == eos_token_id)

        # 全て完了したら終了
        if done.all():
            break

    # 長さ正規化スコアで最良のビームを選択
    lengths = (input_ids != pad_token_id).sum(dim=1).float()
    normalized_scores = beam_scores / (lengths ** length_penalty)
    best_idx = normalized_scores.argmax()

    return input_ids[best_idx].tolist()

ビーム幅の影響

ビーム幅と品質のトレードオフ

import matplotlib.pyplot as plt
import numpy as np

# ビーム幅と品質の関係(概念的な図)
beam_widths = [1, 2, 4, 8, 16, 32, 64]

# BLEU スコア(典型的な傾向)
bleu_scores = [25.0, 27.5, 29.0, 29.8, 30.2, 30.3, 30.4]

# 計算時間(ビーム幅に比例)
computation_time = [1.0, 1.8, 3.2, 6.0, 11.5, 22.0, 43.0]

fig, ax1 = plt.subplots(figsize=(10, 6))

# BLEU スコア
color1 = 'tab:blue'
ax1.set_xlabel('Beam Width', fontsize=12)
ax1.set_ylabel('BLEU Score', color=color1, fontsize=12)
ax1.plot(beam_widths, bleu_scores, 'o-', color=color1, linewidth=2, markersize=8, label='BLEU Score')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_xscale('log', base=2)

# 計算時間
ax2 = ax1.twinx()
color2 = 'tab:red'
ax2.set_ylabel('Relative Computation Time', color=color2, fontsize=12)
ax2.plot(beam_widths, computation_time, 's--', color=color2, linewidth=2, markersize=8, label='Computation Time')
ax2.tick_params(axis='y', labelcolor=color2)

plt.title('Trade-off between Beam Width, Quality, and Computation', fontsize=14)
fig.tight_layout()
plt.grid(True, alpha=0.3)
plt.show()

観察される傾向

  1. ビーム幅 1(= 貪欲法): 最も高速だが品質は最低
  2. ビーム幅 4-8: 品質と速度のバランスが良い(実用的な選択)
  3. ビーム幅 > 16: 品質向上は頭打ち、計算コストのみ増加

ビームサーチの課題

  1. 多様性の欠如: 上位ビームが似た系列になりやすい
  2. 退化(Degeneration): 同じフレーズの繰り返しが発生することがある
  3. 最適性の非保証: ビームサーチは近似アルゴリズムであり、真の最適解を保証しない

これらの課題に対処するため、Diverse Beam Search、Nucleus Sampling(Top-p)などの手法が提案されています。

他のデコーディング手法との比較

手法 計算量 多様性 品質 用途
Greedy $O(TV)$ 高速推論
Beam Search $O(TkV)$ 翻訳、要約
Sampling $O(TV)$ 対話、創作
Nucleus (Top-p) $O(TV)$ 中〜高 中〜高 対話、創作

ここで $T$ は系列長、$V$ は語彙サイズ、$k$ はビーム幅です。

まとめ

本記事では、ビームサーチのアルゴリズムと実装について解説しました。

  • 貪欲法の限界: 各ステップで局所最適を選ぶため、全体最適を逃す可能性がある
  • ビームサーチの原理: 上位 $k$ 個の候補を並列に探索し、より良い系列を見つける確率を高める
  • 長さ正規化: 短い系列へのバイアスを補正するため、系列長で正規化したスコアを使用
  • ビーム幅のトレードオフ: 大きいほど品質は向上するが、計算コストも増加。実用的には 4-8 が多い
  • 課題: 多様性の欠如、退化現象があり、サンプリング手法との組み合わせも検討される

ビームサーチは、特に機械翻訳や要約など、正確性が重要なタスクで広く使われています。一方、対話や創作など多様性が求められるタスクでは、Temperature サンプリングや Nucleus サンプリングが適しています。

次のステップとして、以下の記事も参考にしてください。