DPO(Direct Preference Optimization)の数学的導出と実装

LLMを人間の好みに沿うように調整する手法として、DPO(Direct Preference Optimization)が注目されています。従来のRLHF(Reinforcement Learning from Human Feedback)と比べて、報酬モデルの学習が不要で、実装がシンプルという利点があります。

本記事では、DPOの数学的な導出から実装まで、詳しく解説します。

本記事の内容

  • RLHFの課題とDPOの動機
  • Bradley-Terryモデルによる選好の定式化
  • DPO損失関数の導出
  • Pythonでの実装例

前提知識

この記事を読む前に、以下の概念を理解しておくと役立ちます。

RLHFの概要と課題

RLHFのパイプライン

RLHF(Reinforcement Learning from Human Feedback)は以下の3段階で構成されます。

  1. SFT(Supervised Fine-Tuning): 指示チューニング
  2. 報酬モデル学習: 人間の選好データから報酬関数を学習
  3. RL最適化: 報酬を最大化するようにPPOで方策を更新

報酬モデルの学習

選好データ $(x, y_w, y_l)$($y_w$ が好まれる応答、$y_l$ が好まれない応答)から報酬モデル $r_\phi$ を学習:

$$ \mathcal{L}_{\text{RM}}(\phi) = -\mathbb{E}_{(x, y_w, y_l)}[\log \sigma(r_\phi(x, y_w) – r_\phi(x, y_l))] $$

RL最適化

学習した報酬を最大化しつつ、参照モデル $\pi_{\text{ref}}$ から大きく逸脱しないように制約:

$$ \max_\pi \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi}[r(x, y)] – \beta \mathbb{D}_{\text{KL}}[\pi(y|x) \| \pi_{\text{ref}}(y|x)] $$

RLHFの課題

  1. 複雑なパイプライン: 3段階の学習が必要
  2. 報酬モデルの誤差: 報酬モデルの品質がボトルネック
  3. RLの不安定性: PPOのハイパーパラメータ調整が困難
  4. 計算コスト: 報酬モデルとポリシーの両方を保持する必要

DPOの発想

DPOの核心的なアイデアは、最適方策と報酬関数の間には閉形式の関係があることを利用し、報酬モデルを経由せずに直接方策を最適化することです。

数学的導出

KL制約付き報酬最大化問題

以下の最適化問題を考えます:

$$ \max_\pi \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi}[r(x, y)] – \beta \mathbb{D}_{\text{KL}}[\pi(y|x) \| \pi_{\text{ref}}(y|x)] $$

最適方策の導出

この問題の最適解は閉形式で求まります。

$$ \pi^*(y|x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) $$

ここで $Z(x)$ は正規化定数:

$$ Z(x) = \sum_y \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) $$

報酬関数の陽的表現

最適方策の式を報酬 $r$ について解くと:

$$ r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x) $$

Bradley-Terryモデル

人間の選好は、Bradley-Terryモデルで定式化されます。応答 $y_w$ が $y_l$ より好まれる確率:

$$ P(y_w \succ y_l | x) = \sigma(r(x, y_w) – r(x, y_l)) $$

ここで $\sigma$ はシグモイド関数です。

DPO損失関数の導出

報酬関数の陽的表現を代入すると、$Z(x)$ がキャンセルされます:

$$ \begin{align} r(x, y_w) – r(x, y_l) &= \beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi^*(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \\ &= \beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} \cdot \frac{\pi_{\text{ref}}(y_l|x)}{\pi^*(y_l|x)} \end{align} $$

これにより、報酬モデルなしで直接選好確率を表現できます:

$$ P(y_w \succ y_l | x) = \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right) $$

DPO損失関数

最終的なDPO損失関数は以下になります:

$$ \mathcal{L}_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)\right] $$

対数確率の比を使うと:

$$ \mathcal{L}_{\text{DPO}} = -\mathbb{E}\left[\log \sigma\left(\beta (\hat{r}_\theta(x, y_w) – \hat{r}_\theta(x, y_l))\right)\right] $$

ここで暗黙的報酬:

$$ \hat{r}_\theta(x, y) = \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)} $$

Pythonでの実装

DPO損失関数の実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class DPOLoss(nn.Module):
    def __init__(self, beta=0.1):
        """DPO損失関数

        Args:
            beta: KL制約の強さ(温度パラメータ)
        """
        super().__init__()
        self.beta = beta

    def forward(self, policy_chosen_logps, policy_rejected_logps,
                reference_chosen_logps, reference_rejected_logps):
        """
        DPO損失を計算

        Args:
            policy_chosen_logps: 学習モデルの選好応答の対数確率
            policy_rejected_logps: 学習モデルの非選好応答の対数確率
            reference_chosen_logps: 参照モデルの選好応答の対数確率
            reference_rejected_logps: 参照モデルの非選好応答の対数確率

        Returns:
            loss: DPO損失
            metrics: 追加メトリクス
        """
        # 対数確率比を計算
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps)

        # DPO損失
        logits = chosen_rewards - rejected_rewards
        loss = -F.logsigmoid(logits).mean()

        # メトリクス
        chosen_rewards_mean = chosen_rewards.mean().item()
        rejected_rewards_mean = rejected_rewards.mean().item()
        reward_margin = (chosen_rewards - rejected_rewards).mean().item()
        accuracy = (logits > 0).float().mean().item()

        metrics = {
            'chosen_rewards': chosen_rewards_mean,
            'rejected_rewards': rejected_rewards_mean,
            'reward_margin': reward_margin,
            'accuracy': accuracy
        }

        return loss, metrics

# 使用例
dpo_loss = DPOLoss(beta=0.1)

# サンプルデータ(バッチサイズ4)
batch_size = 4
policy_chosen_logps = torch.randn(batch_size)
policy_rejected_logps = torch.randn(batch_size) - 0.5  # 非選好は低い確率
reference_chosen_logps = torch.randn(batch_size)
reference_rejected_logps = torch.randn(batch_size)

loss, metrics = dpo_loss(
    policy_chosen_logps, policy_rejected_logps,
    reference_chosen_logps, reference_rejected_logps
)

print(f"DPO Loss: {loss.item():.4f}")
print(f"Metrics: {metrics}")

対数確率の計算

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def compute_log_probs(model, tokenizer, prompt, response):
    """応答の対数確率を計算"""
    # プロンプトと応答を連結
    full_text = prompt + response

    # トークン化
    inputs = tokenizer(full_text, return_tensors='pt')
    prompt_inputs = tokenizer(prompt, return_tensors='pt')
    prompt_length = prompt_inputs['input_ids'].shape[1]

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # 応答部分のトークンに対する対数確率を計算
    labels = inputs['input_ids']
    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]

    # クロスエントロピーを計算(負の対数確率)
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = torch.gather(
        log_probs, 2, shift_labels.unsqueeze(-1)
    ).squeeze(-1)

    # 応答部分のみを合計
    response_log_probs = token_log_probs[:, prompt_length-1:].sum()

    return response_log_probs

# 使用例(小さなモデルでデモ)
# model_name = "gpt2"
# model = AutoModelForCausalLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
#
# prompt = "What is machine learning?"
# response = " Machine learning is a subset of AI."
# log_prob = compute_log_probs(model, tokenizer, prompt, response)
# print(f"Log probability: {log_prob.item():.4f}")

DPOトレーナーの実装

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

class DPOTrainer:
    def __init__(self, model, ref_model, tokenizer, beta=0.1, lr=1e-6):
        """DPOトレーナー

        Args:
            model: 学習対象のモデル
            ref_model: 参照モデル(凍結)
            tokenizer: トークナイザー
            beta: KL制約の強さ
            lr: 学習率
        """
        self.model = model
        self.ref_model = ref_model
        self.tokenizer = tokenizer
        self.beta = beta

        # 参照モデルは凍結
        for param in self.ref_model.parameters():
            param.requires_grad = False

        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    def compute_logps(self, model, input_ids, attention_mask, labels):
        """対数確率を計算"""
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # シフトして対応を合わせる
        shift_logits = logits[:, :-1, :]
        shift_labels = labels[:, 1:]
        shift_mask = attention_mask[:, 1:]

        # 対数確率
        log_probs = F.log_softmax(shift_logits, dim=-1)
        token_log_probs = torch.gather(
            log_probs, 2, shift_labels.unsqueeze(-1)
        ).squeeze(-1)

        # マスクを適用して合計
        masked_log_probs = token_log_probs * shift_mask
        sequence_log_probs = masked_log_probs.sum(dim=1)

        return sequence_log_probs

    def dpo_loss(self, policy_chosen_logps, policy_rejected_logps,
                 ref_chosen_logps, ref_rejected_logps):
        """DPO損失を計算"""
        chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)

        logits = chosen_rewards - rejected_rewards
        loss = -F.logsigmoid(logits).mean()

        return loss

    def train_step(self, batch):
        """1ステップの学習"""
        self.model.train()
        self.optimizer.zero_grad()

        # 選好応答の対数確率
        policy_chosen_logps = self.compute_logps(
            self.model,
            batch['chosen_input_ids'],
            batch['chosen_attention_mask'],
            batch['chosen_labels']
        )

        # 非選好応答の対数確率
        policy_rejected_logps = self.compute_logps(
            self.model,
            batch['rejected_input_ids'],
            batch['rejected_attention_mask'],
            batch['rejected_labels']
        )

        # 参照モデルの対数確率(勾配不要)
        with torch.no_grad():
            ref_chosen_logps = self.compute_logps(
                self.ref_model,
                batch['chosen_input_ids'],
                batch['chosen_attention_mask'],
                batch['chosen_labels']
            )
            ref_rejected_logps = self.compute_logps(
                self.ref_model,
                batch['rejected_input_ids'],
                batch['rejected_attention_mask'],
                batch['rejected_labels']
            )

        # 損失計算
        loss = self.dpo_loss(
            policy_chosen_logps, policy_rejected_logps,
            ref_chosen_logps, ref_rejected_logps
        )

        # 逆伝播
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, dataloader, epochs=3):
        """学習ループ"""
        for epoch in range(epochs):
            total_loss = 0
            pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

            for batch in pbar:
                loss = self.train_step(batch)
                total_loss += loss
                pbar.set_postfix({'loss': f'{loss:.4f}'})

            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")

# 使用例
# trainer = DPOTrainer(model, ref_model, tokenizer, beta=0.1)
# trainer.train(dataloader, epochs=3)

DPOの派生手法

IPO(Identity Preference Optimization)

DPOの改良版で、過学習を防ぐための正則化を追加:

$$ \mathcal{L}_{\text{IPO}} = \mathbb{E}\left[\left(\log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} – \frac{1}{2\beta}\right)^2\right] $$

KTO(Kahneman-Tversky Optimization)

ペアでなく、単一の応答とそのラベル(良い/悪い)から学習:

$$ \mathcal{L}_{\text{KTO}} = \mathbb{E}_{(x,y,\text{label})}[w(y) \cdot \ell(\text{label}, \hat{r}_\theta(x, y))] $$

cDPO(Conservative DPO)

参照モデルからの乖離にペナルティを追加:

$$ \mathcal{L}_{\text{cDPO}} = \mathcal{L}_{\text{DPO}} + \alpha \cdot \mathbb{D}_{\text{KL}}[\pi_\theta \| \pi_{\text{ref}}] $$

RLHFとDPOの比較

観点 RLHF DPO
パイプライン SFT → RM → RL SFT → DPO
報酬モデル 必要 不要(暗黙的)
安定性 PPOの調整が困難 安定
メモリ 4モデル必要 2モデルで十分
実装難易度 高い 低い

ハイパーパラメータの選択

beta の選択

$\beta$ は参照モデルからの逸脱を制御します。

  • 小さい $\beta$: 参照モデルから大きく逸脱可能
  • 大きい $\beta$: 参照モデルに近い出力

推奨範囲: $0.1 \sim 0.5$

学習率

DPOは敏感なので、低い学習率を推奨します。

推奨範囲: $1 \times 10^{-7} \sim 5 \times 10^{-6}$

まとめ

本記事では、DPO(Direct Preference Optimization)の数学的原理と実装を解説しました。

  • 核心アイデア: 最適方策と報酬関数の閉形式関係を利用
  • 利点: 報酬モデル不要、安定した学習、シンプルな実装
  • 損失関数: 選好データから直接方策を最適化
  • 派生手法: IPO、KTO、cDPOなど多数

次のステップとして、以下の記事も参考にしてください。