Temperature・Top-k・Top-pサンプリングを比較して理解する

大規模言語モデル(LLM)がテキストを生成する際、次のトークンの選び方によって出力の性質が大きく変わります。毎回最も確率の高いトークンを選ぶ貪欲法では、決定論的で単調な出力になりがちです。一方、確率分布からサンプリングすることで、多様で創造的なテキストを生成できます。

本記事では、Temperature、Top-k、Top-p(Nucleus)サンプリングという3つの主要なサンプリング手法の数学的な定義から実装まで解説します。

本記事の内容

  • Temperatureサンプリングの仕組みと数式
  • Top-kサンプリング
  • Top-p(Nucleus)サンプリング
  • 各手法の組み合わせ
  • PyTorchでの実装と可視化

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

サンプリングの基本

なぜサンプリングが必要か

言語モデルは各ステップで、語彙 $\mathcal{V}$ 上の確率分布 $P(y_t \mid y_{

決定論的手法(Greedy / Beam Search): – 常に最も確率の高いトークン(または上位 $k$ 個)を選択 – 再現性がある – 出力が単調で、同じフレーズの繰り返しが発生しやすい

確率的手法(サンプリング): – 確率分布に従ってランダムにトークンを選択 – 多様な出力が得られる – 適切に制御しないと、不自然な出力になることがある

サンプリング手法は、対話システムや創作支援など、多様性が求められるタスクで特に有効です。

基本的なサンプリング

最も単純なサンプリングは、モデルが出力した確率分布をそのまま使うことです。

$$ y_t \sim P(y \mid y_{

しかし、この方法では低確率のトークンも選ばれる可能性があり、不自然な出力が生成されることがあります。そこで、分布を調整するさまざまな手法が考案されました。

Temperatureサンプリング

定義

Temperatureサンプリングは、softmax関数にTemperatureパラメータ $T$ を導入して確率分布の「鋭さ」を調整する手法です。

モデルが出力するロジット $z_i$(softmax前の値)に対して、Temperature $T$ を適用したsoftmaxは次のように定義されます。

$$ P(y = i) = \frac{\exp(z_i / T)}{\sum_{j=1}^{V} \exp(z_j / T)} $$

Temperatureの効果

$T = 1$(デフォルト): – 元の分布をそのまま使用 – モデルの学習時と同じ分布

$T < 1$(低温度): – 分布が鋭くなる(高確率トークンがより選ばれやすい) – $T \to 0$ で貪欲法に近づく – より確実で一貫性のある出力

$T > 1$(高温度): – 分布が平坦になる(低確率トークンも選ばれやすくなる) – $T \to \infty$ で一様分布に近づく – より多様で創造的(時に不自然)な出力

数学的な理解

Temperature $T$ による分布の変化を数学的に理解しましょう。

2つのトークン $i, j$ の確率比を考えます。

$$ \frac{P(y = i)}{P(y = j)} = \frac{\exp(z_i / T)}{\exp(z_j / T)} = \exp\left(\frac{z_i – z_j}{T}\right) $$

$z_i > z_j$(トークン $i$ の方がロジットが大きい)の場合:

  • $T < 1$: 指数の引数が大きくなり、確率比が拡大
  • $T > 1$: 指数の引数が小さくなり、確率比が縮小

エントロピーとの関係

Temperature を変えると、分布のエントロピーも変化します。

$$ H(P_T) = -\sum_{i=1}^{V} P_T(i) \log P_T(i) $$

  • 低Temperature: エントロピーが低い(不確実性が低い)
  • 高Temperature: エントロピーが高い(不確実性が高い)

Top-kサンプリング

定義

Top-kサンプリングは、確率上位 $k$ 個のトークンのみを候補として、それらの中からサンプリングする手法です。

  1. 確率上位 $k$ 個のトークン集合 $\mathcal{V}_k$ を選択
  2. $\mathcal{V}_k$ 内で確率を再正規化
  3. 再正規化した分布からサンプリング

$$ P'(y = i) = \begin{cases} \dfrac{P(y = i)}{\sum_{j \in \mathcal{V}_k} P(y = j)} & \text{if } i \in \mathcal{V}_k \\ 0 & \text{otherwise} \end{cases} $$

長所と短所

長所: – 低確率の不自然なトークンを確実に排除 – 実装が単純

短所: – 固定の $k$ がすべての文脈に適しているとは限らない – 確信度の高い場面では $k$ 個の候補は多すぎる – 不確実な場面では $k$ 個の候補は少なすぎる

例えば、「日本の首都は」の次は「東京」がほぼ確実ですが、「私の好きな色は」の次は多くの選択肢があります。固定の $k$ では、この違いに対応できません。

Top-p(Nucleus)サンプリング

定義

Top-p(Nucleus)サンプリングは、累積確率が $p$ を超える最小のトークン集合からサンプリングする手法です。Holtzman et al. (2020) で提案されました。

  1. トークンを確率の降順にソート
  2. 累積確率が $p$ を超える最小の集合 $\mathcal{V}_p$ を選択
  3. $\mathcal{V}_p$ 内で確率を再正規化
  4. 再正規化した分布からサンプリング

$$ \mathcal{V}_p = \min \left\{ \mathcal{V}’ \subseteq \mathcal{V} : \sum_{i \in \mathcal{V}’} P(y = i) \geq p \right\} $$

ここで「最小」は、集合の要素数が最小という意味です。

Top-pの利点

Top-pの大きな利点は、文脈に応じて候補数が動的に変化することです。

確信度が高い場合(分布が鋭い): – 少数のトークンで累積確率 $p$ に到達 – 候補が絞られ、一貫性のある出力

不確実な場合(分布が平坦): – 多くのトークンが必要 – 多様な候補から選択可能

語彙サイズ5、確率分布 $P = [0.5, 0.3, 0.1, 0.08, 0.02]$ の場合:

$p = 0.9$ の場合: – 累積確率: $0.5 + 0.3 + 0.1 = 0.9$ – $\mathcal{V}_p = \{1, 2, 3\}$(上位3トークン)

$p = 0.5$ の場合: – 累積確率: $0.5$ – $\mathcal{V}_p = \{1\}$(最上位トークンのみ)

PyTorchでの実装

各サンプリング手法の実装

import torch
import torch.nn.functional as F
from typing import Optional


def temperature_sampling(
    logits: torch.Tensor,
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    Temperatureサンプリング

    Args:
        logits: モデル出力のロジット (batch_size, vocab_size)
        temperature: Temperature値(0 < T)

    Returns:
        sampled_ids: サンプリングされたトークンID (batch_size,)
    """
    if temperature <= 0:
        raise ValueError("Temperature must be positive")

    # Temperature を適用
    scaled_logits = logits / temperature

    # 確率分布を計算
    probs = F.softmax(scaled_logits, dim=-1)

    # サンプリング
    sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)

    return sampled_ids


def top_k_sampling(
    logits: torch.Tensor,
    k: int,
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    Top-kサンプリング

    Args:
        logits: モデル出力のロジット (batch_size, vocab_size)
        k: 上位k個のトークンを候補とする
        temperature: Temperature値

    Returns:
        sampled_ids: サンプリングされたトークンID (batch_size,)
    """
    # Temperature を適用
    scaled_logits = logits / temperature

    # 上位k個以外を -inf に
    top_k_logits, top_k_indices = torch.topk(scaled_logits, k, dim=-1)
    filtered_logits = torch.full_like(scaled_logits, float('-inf'))
    filtered_logits.scatter_(-1, top_k_indices, top_k_logits)

    # 確率分布を計算(-inf は 0 になる)
    probs = F.softmax(filtered_logits, dim=-1)

    # サンプリング
    sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)

    return sampled_ids


def top_p_sampling(
    logits: torch.Tensor,
    p: float,
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    Top-p (Nucleus) サンプリング

    Args:
        logits: モデル出力のロジット (batch_size, vocab_size)
        p: 累積確率の閾値 (0 < p <= 1)
        temperature: Temperature値

    Returns:
        sampled_ids: サンプリングされたトークンID (batch_size,)
    """
    # Temperature を適用
    scaled_logits = logits / temperature

    # 確率を計算してソート
    probs = F.softmax(scaled_logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)

    # 累積確率を計算
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

    # 累積確率が p を超える位置を特定
    # 最初に p を超えるトークンまでを含める
    sorted_mask = cumsum_probs - sorted_probs > p  # p を超えたら True
    sorted_probs[sorted_mask] = 0

    # 再正規化
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

    # ソートされた確率からサンプリング
    sampled_sorted_idx = torch.multinomial(sorted_probs, num_samples=1).squeeze(-1)

    # 元のインデックスに変換
    sampled_ids = sorted_indices.gather(-1, sampled_sorted_idx.unsqueeze(-1)).squeeze(-1)

    return sampled_ids


def combined_sampling(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
) -> torch.Tensor:
    """
    Temperature、Top-k、Top-p を組み合わせたサンプリング

    Args:
        logits: モデル出力のロジット (batch_size, vocab_size)
        temperature: Temperature値
        top_k: Top-k のk値(None の場合は適用しない)
        top_p: Top-p のp値(None の場合は適用しない)

    Returns:
        sampled_ids: サンプリングされたトークンID (batch_size,)
    """
    # Temperature を適用
    scaled_logits = logits / temperature

    # Top-k フィルタリング
    if top_k is not None and top_k > 0:
        top_k_values, _ = torch.topk(scaled_logits, top_k, dim=-1)
        threshold = top_k_values[:, -1].unsqueeze(-1)
        scaled_logits = torch.where(
            scaled_logits >= threshold,
            scaled_logits,
            torch.full_like(scaled_logits, float('-inf'))
        )

    # Top-p フィルタリング
    if top_p is not None and top_p < 1.0:
        probs = F.softmax(scaled_logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

        # p を超えたトークンを除外
        sorted_mask = cumsum_probs - sorted_probs > top_p
        sorted_probs[sorted_mask] = 0

        # 元のインデックス順に戻す
        probs = torch.zeros_like(probs).scatter_(-1, sorted_indices, sorted_probs)

        # 再正規化
        probs = probs / probs.sum(dim=-1, keepdim=True)
    else:
        probs = F.softmax(scaled_logits, dim=-1)

    # サンプリング
    sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)

    return sampled_ids

使用例

# ランダムなロジットを生成
torch.manual_seed(42)
batch_size = 1
vocab_size = 100
logits = torch.randn(batch_size, vocab_size)

# 各手法でサンプリング
print("=== サンプリング結果 ===")

# Temperature サンプリング
for temp in [0.5, 1.0, 1.5]:
    samples = [temperature_sampling(logits, temperature=temp).item() for _ in range(5)]
    print(f"Temperature={temp}: {samples}")

print()

# Top-k サンプリング
for k in [5, 10, 50]:
    samples = [top_k_sampling(logits, k=k).item() for _ in range(5)]
    print(f"Top-k (k={k}): {samples}")

print()

# Top-p サンプリング
for p in [0.5, 0.9, 0.95]:
    samples = [top_p_sampling(logits, p=p).item() for _ in range(5)]
    print(f"Top-p (p={p}): {samples}")

可視化

Temperature による分布の変化

import matplotlib.pyplot as plt
import numpy as np

# サンプルのロジット分布を作成
logits = torch.tensor([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -1.5, -2.0, -2.5])
vocab = [f'Token {i}' for i in range(len(logits))]

temperatures = [0.5, 1.0, 2.0]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, temp in zip(axes, temperatures):
    probs = F.softmax(logits / temp, dim=-1).numpy()
    ax.bar(range(len(probs)), probs, color='steelblue', alpha=0.7)
    ax.set_xlabel('Token Index', fontsize=11)
    ax.set_ylabel('Probability', fontsize=11)
    ax.set_title(f'Temperature = {temp}', fontsize=12)
    ax.set_ylim(0, 0.6)
    ax.grid(True, alpha=0.3)

plt.suptitle('Effect of Temperature on Probability Distribution', fontsize=14)
plt.tight_layout()
plt.show()

Top-k と Top-p の違い

import matplotlib.pyplot as plt

# 2つの異なるロジット分布(確信度が異なる)
# 分布1: 確信度が高い(1つのトークンが圧倒的)
logits_confident = torch.tensor([3.0, 0.5, 0.0, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5])

# 分布2: 確信度が低い(複数のトークンが同程度)
logits_uncertain = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.1, 0.0, -0.1, -0.2, -0.3, -0.4])

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for row, (logits, label) in enumerate([(logits_confident, 'High Confidence'),
                                        (logits_uncertain, 'Low Confidence')]):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, _ = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=-1)

    # 元の確率分布
    ax = axes[row, 0]
    ax.bar(range(len(probs)), probs.numpy(), color='steelblue', alpha=0.7)
    ax.set_title(f'{label}: Original Distribution')
    ax.set_ylabel('Probability')
    ax.set_xlabel('Token Index')

    # Top-k (k=3) の候補
    ax = axes[row, 1]
    top_k_mask = torch.zeros_like(probs)
    _, top_k_idx = torch.topk(probs, 3)
    top_k_mask[top_k_idx] = probs[top_k_idx]
    ax.bar(range(len(probs)), top_k_mask.numpy(), color='coral', alpha=0.7)
    ax.set_title(f'{label}: Top-k (k=3)\n{int(top_k_mask.sum() * 100)}% probability mass')
    ax.set_ylabel('Probability')
    ax.set_xlabel('Token Index')

    # Top-p (p=0.9) の候補
    ax = axes[row, 2]
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=-1)
    nucleus_size = (cumsum <= 0.9).sum().item() + 1
    top_p_mask = torch.zeros_like(probs)
    top_p_mask[sorted_idx[:nucleus_size]] = probs[sorted_idx[:nucleus_size]]
    ax.bar(range(len(probs)), top_p_mask.numpy(), color='seagreen', alpha=0.7)
    ax.set_title(f'{label}: Top-p (p=0.9)\n{nucleus_size} tokens selected')
    ax.set_ylabel('Probability')
    ax.set_xlabel('Token Index')

plt.suptitle('Top-k vs Top-p: Adaptive Behavior', fontsize=14)
plt.tight_layout()
plt.show()

この図から、Top-pが文脈の確信度に応じて候補数を自動調整することがわかります。

パラメータ選択のガイドライン

タスクに応じた設定

タスク Temperature Top-k Top-p 特徴
事実質問応答 0.0-0.3 1-5 0.1-0.5 高精度、低多様性
機械翻訳 0.3-0.7 5-20 0.7-0.9 バランス型
対話 0.7-1.0 40-100 0.9-0.95 多様性重視
創作(詩、物語) 1.0-1.5 100+ 0.95-1.0 高創造性

実用的なヒント

  1. Temperature と Top-p の組み合わせ: 実際のシステムでは、Temperature と Top-p を組み合わせることが多い(例:Temperature=0.8、Top-p=0.95)

  2. Top-k の注意点: $k$ が小さすぎると多様性が失われ、大きすぎるとノイズが増える

  3. 反復生成でのパラメータ調整: 生成が長くなるにつれて Temperature を下げる手法もある

  4. 再現性: サンプリングは確率的なので、同じ結果を得るにはシードを固定する

まとめ

本記事では、LLMのテキスト生成を制御するサンプリング手法について解説しました。

  • Temperatureサンプリング: softmaxの温度を調整し、分布の鋭さを制御。低温度で確実な出力、高温度で多様な出力
  • Top-kサンプリング: 確率上位 $k$ 個のトークンのみを候補とする。シンプルだが $k$ が固定
  • Top-p(Nucleus)サンプリング: 累積確率が $p$ を超える最小の集合から選択。文脈に応じて候補数が動的に変化
  • 組み合わせ: 実用的には Temperature と Top-p を組み合わせることが多い
  • タスクに応じた選択: 精度重視なら低Temperature、多様性重視なら高Temperature・高Top-p

適切なサンプリング手法とパラメータの選択は、LLMの出力品質に大きく影響します。タスクの要件に応じて、これらのパラメータを調整しましょう。

次のステップとして、以下の記事も参考にしてください。