LLMを人間の好みに沿うように調整する手法として、DPO(Direct Preference Optimization)が注目されています。従来のRLHF(Reinforcement Learning from Human Feedback)と比べて、報酬モデルの学習が不要で、実装がシンプルという利点があります。
本記事では、DPOの数学的な導出から実装まで、詳しく解説します。
本記事の内容
- RLHFの課題とDPOの動機
- Bradley-Terryモデルによる選好の定式化
- DPO損失関数の導出
- Pythonでの実装例
前提知識
この記事を読む前に、以下の概念を理解しておくと役立ちます。
- Instruction Tuning
- 強化学習の基礎(方策、報酬)
- KL発散
RLHFの概要と課題
RLHFのパイプライン
RLHF(Reinforcement Learning from Human Feedback)は以下の3段階で構成されます。
- SFT(Supervised Fine-Tuning): 指示チューニング
- 報酬モデル学習: 人間の選好データから報酬関数を学習
- RL最適化: 報酬を最大化するようにPPOで方策を更新
報酬モデルの学習
選好データ $(x, y_w, y_l)$($y_w$ が好まれる応答、$y_l$ が好まれない応答)から報酬モデル $r_\phi$ を学習:
$$ \mathcal{L}_{\text{RM}}(\phi) = -\mathbb{E}_{(x, y_w, y_l)}[\log \sigma(r_\phi(x, y_w) – r_\phi(x, y_l))] $$
RL最適化
学習した報酬を最大化しつつ、参照モデル $\pi_{\text{ref}}$ から大きく逸脱しないように制約:
$$ \max_\pi \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi}[r(x, y)] – \beta \mathbb{D}_{\text{KL}}[\pi(y|x) \| \pi_{\text{ref}}(y|x)] $$
RLHFの課題
- 複雑なパイプライン: 3段階の学習が必要
- 報酬モデルの誤差: 報酬モデルの品質がボトルネック
- RLの不安定性: PPOのハイパーパラメータ調整が困難
- 計算コスト: 報酬モデルとポリシーの両方を保持する必要
DPOの発想
DPOの核心的なアイデアは、最適方策と報酬関数の間には閉形式の関係があることを利用し、報酬モデルを経由せずに直接方策を最適化することです。
数学的導出
KL制約付き報酬最大化問題
以下の最適化問題を考えます:
$$ \max_\pi \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi}[r(x, y)] – \beta \mathbb{D}_{\text{KL}}[\pi(y|x) \| \pi_{\text{ref}}(y|x)] $$
最適方策の導出
この問題の最適解は閉形式で求まります。
$$ \pi^*(y|x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) $$
ここで $Z(x)$ は正規化定数:
$$ Z(x) = \sum_y \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) $$
報酬関数の陽的表現
最適方策の式を報酬 $r$ について解くと:
$$ r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x) $$
Bradley-Terryモデル
人間の選好は、Bradley-Terryモデルで定式化されます。応答 $y_w$ が $y_l$ より好まれる確率:
$$ P(y_w \succ y_l | x) = \sigma(r(x, y_w) – r(x, y_l)) $$
ここで $\sigma$ はシグモイド関数です。
DPO損失関数の導出
報酬関数の陽的表現を代入すると、$Z(x)$ がキャンセルされます:
$$ \begin{align} r(x, y_w) – r(x, y_l) &= \beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi^*(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \\ &= \beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} \cdot \frac{\pi_{\text{ref}}(y_l|x)}{\pi^*(y_l|x)} \end{align} $$
これにより、報酬モデルなしで直接選好確率を表現できます:
$$ P(y_w \succ y_l | x) = \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right) $$
DPO損失関数
最終的なDPO損失関数は以下になります:
$$ \mathcal{L}_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)\right] $$
対数確率の比を使うと:
$$ \mathcal{L}_{\text{DPO}} = -\mathbb{E}\left[\log \sigma\left(\beta (\hat{r}_\theta(x, y_w) – \hat{r}_\theta(x, y_l))\right)\right] $$
ここで暗黙的報酬:
$$ \hat{r}_\theta(x, y) = \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)} $$
Pythonでの実装
DPO損失関数の実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class DPOLoss(nn.Module):
def __init__(self, beta=0.1):
"""DPO損失関数
Args:
beta: KL制約の強さ(温度パラメータ)
"""
super().__init__()
self.beta = beta
def forward(self, policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps):
"""
DPO損失を計算
Args:
policy_chosen_logps: 学習モデルの選好応答の対数確率
policy_rejected_logps: 学習モデルの非選好応答の対数確率
reference_chosen_logps: 参照モデルの選好応答の対数確率
reference_rejected_logps: 参照モデルの非選好応答の対数確率
Returns:
loss: DPO損失
metrics: 追加メトリクス
"""
# 対数確率比を計算
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps)
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps)
# DPO損失
logits = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(logits).mean()
# メトリクス
chosen_rewards_mean = chosen_rewards.mean().item()
rejected_rewards_mean = rejected_rewards.mean().item()
reward_margin = (chosen_rewards - rejected_rewards).mean().item()
accuracy = (logits > 0).float().mean().item()
metrics = {
'chosen_rewards': chosen_rewards_mean,
'rejected_rewards': rejected_rewards_mean,
'reward_margin': reward_margin,
'accuracy': accuracy
}
return loss, metrics
# 使用例
dpo_loss = DPOLoss(beta=0.1)
# サンプルデータ(バッチサイズ4)
batch_size = 4
policy_chosen_logps = torch.randn(batch_size)
policy_rejected_logps = torch.randn(batch_size) - 0.5 # 非選好は低い確率
reference_chosen_logps = torch.randn(batch_size)
reference_rejected_logps = torch.randn(batch_size)
loss, metrics = dpo_loss(
policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps
)
print(f"DPO Loss: {loss.item():.4f}")
print(f"Metrics: {metrics}")
対数確率の計算
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def compute_log_probs(model, tokenizer, prompt, response):
"""応答の対数確率を計算"""
# プロンプトと応答を連結
full_text = prompt + response
# トークン化
inputs = tokenizer(full_text, return_tensors='pt')
prompt_inputs = tokenizer(prompt, return_tensors='pt')
prompt_length = prompt_inputs['input_ids'].shape[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# 応答部分のトークンに対する対数確率を計算
labels = inputs['input_ids']
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
# クロスエントロピーを計算(負の対数確率)
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = torch.gather(
log_probs, 2, shift_labels.unsqueeze(-1)
).squeeze(-1)
# 応答部分のみを合計
response_log_probs = token_log_probs[:, prompt_length-1:].sum()
return response_log_probs
# 使用例(小さなモデルでデモ)
# model_name = "gpt2"
# model = AutoModelForCausalLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
#
# prompt = "What is machine learning?"
# response = " Machine learning is a subset of AI."
# log_prob = compute_log_probs(model, tokenizer, prompt, response)
# print(f"Log probability: {log_prob.item():.4f}")
DPOトレーナーの実装
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
class DPOTrainer:
def __init__(self, model, ref_model, tokenizer, beta=0.1, lr=1e-6):
"""DPOトレーナー
Args:
model: 学習対象のモデル
ref_model: 参照モデル(凍結)
tokenizer: トークナイザー
beta: KL制約の強さ
lr: 学習率
"""
self.model = model
self.ref_model = ref_model
self.tokenizer = tokenizer
self.beta = beta
# 参照モデルは凍結
for param in self.ref_model.parameters():
param.requires_grad = False
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
def compute_logps(self, model, input_ids, attention_mask, labels):
"""対数確率を計算"""
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# シフトして対応を合わせる
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
shift_mask = attention_mask[:, 1:]
# 対数確率
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = torch.gather(
log_probs, 2, shift_labels.unsqueeze(-1)
).squeeze(-1)
# マスクを適用して合計
masked_log_probs = token_log_probs * shift_mask
sequence_log_probs = masked_log_probs.sum(dim=1)
return sequence_log_probs
def dpo_loss(self, policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps):
"""DPO損失を計算"""
chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)
logits = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(logits).mean()
return loss
def train_step(self, batch):
"""1ステップの学習"""
self.model.train()
self.optimizer.zero_grad()
# 選好応答の対数確率
policy_chosen_logps = self.compute_logps(
self.model,
batch['chosen_input_ids'],
batch['chosen_attention_mask'],
batch['chosen_labels']
)
# 非選好応答の対数確率
policy_rejected_logps = self.compute_logps(
self.model,
batch['rejected_input_ids'],
batch['rejected_attention_mask'],
batch['rejected_labels']
)
# 参照モデルの対数確率(勾配不要)
with torch.no_grad():
ref_chosen_logps = self.compute_logps(
self.ref_model,
batch['chosen_input_ids'],
batch['chosen_attention_mask'],
batch['chosen_labels']
)
ref_rejected_logps = self.compute_logps(
self.ref_model,
batch['rejected_input_ids'],
batch['rejected_attention_mask'],
batch['rejected_labels']
)
# 損失計算
loss = self.dpo_loss(
policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps
)
# 逆伝播
loss.backward()
self.optimizer.step()
return loss.item()
def train(self, dataloader, epochs=3):
"""学習ループ"""
for epoch in range(epochs):
total_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
loss = self.train_step(batch)
total_loss += loss
pbar.set_postfix({'loss': f'{loss:.4f}'})
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
# 使用例
# trainer = DPOTrainer(model, ref_model, tokenizer, beta=0.1)
# trainer.train(dataloader, epochs=3)
DPOの派生手法
IPO(Identity Preference Optimization)
DPOの改良版で、過学習を防ぐための正則化を追加:
$$ \mathcal{L}_{\text{IPO}} = \mathbb{E}\left[\left(\log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} – \frac{1}{2\beta}\right)^2\right] $$
KTO(Kahneman-Tversky Optimization)
ペアでなく、単一の応答とそのラベル(良い/悪い)から学習:
$$ \mathcal{L}_{\text{KTO}} = \mathbb{E}_{(x,y,\text{label})}[w(y) \cdot \ell(\text{label}, \hat{r}_\theta(x, y))] $$
cDPO(Conservative DPO)
参照モデルからの乖離にペナルティを追加:
$$ \mathcal{L}_{\text{cDPO}} = \mathcal{L}_{\text{DPO}} + \alpha \cdot \mathbb{D}_{\text{KL}}[\pi_\theta \| \pi_{\text{ref}}] $$
RLHFとDPOの比較
| 観点 | RLHF | DPO |
|---|---|---|
| パイプライン | SFT → RM → RL | SFT → DPO |
| 報酬モデル | 必要 | 不要(暗黙的) |
| 安定性 | PPOの調整が困難 | 安定 |
| メモリ | 4モデル必要 | 2モデルで十分 |
| 実装難易度 | 高い | 低い |
ハイパーパラメータの選択
beta の選択
$\beta$ は参照モデルからの逸脱を制御します。
- 小さい $\beta$: 参照モデルから大きく逸脱可能
- 大きい $\beta$: 参照モデルに近い出力
推奨範囲: $0.1 \sim 0.5$
学習率
DPOは敏感なので、低い学習率を推奨します。
推奨範囲: $1 \times 10^{-7} \sim 5 \times 10^{-6}$
まとめ
本記事では、DPO(Direct Preference Optimization)の数学的原理と実装を解説しました。
- 核心アイデア: 最適方策と報酬関数の閉形式関係を利用
- 利点: 報酬モデル不要、安定した学習、シンプルな実装
- 損失関数: 選好データから直接方策を最適化
- 派生手法: IPO、KTO、cDPOなど多数
次のステップとして、以下の記事も参考にしてください。