RLHF(Reinforcement Learning from Human Feedback)は、大規模言語モデル(LLM)を人間の意図に沿うように調整する手法です。ChatGPT、Claude、Geminiなど、現代の対話型AIの多くがRLHFを用いて学習されています。
本記事では、RLHFの概念、学習プロセス、数学的な定式化、そして実装の概要について解説します。
本記事の内容
- RLHFの概念と必要性
- 3段階の学習プロセス
- 報酬モデルの学習
- PPOによる方策最適化
- 数学的定式化
- 課題と発展
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
RLHFとは
定義
RLHF(Reinforcement Learning from Human Feedback)は、人間の評価(フィードバック)を報酬信号として、言語モデルを強化学習で最適化する手法です。
なぜRLHFが必要か
事前学習の限界:
事前学習(次トークン予測)だけでは、モデルは「もっともらしい」テキストを生成しますが、必ずしも「有益」「安全」「正確」ではありません。
問題のある応答の例:
- 有害なコンテンツの生成
- 事実と異なる情報の主張
- 指示に従わない応答
- 冗長で焦点のぼやけた回答
RLHFの目的:
人間の選好(preference)を学習し、モデルを「アライン」(align)させます。
$$ \text{目標: } P(\text{人間が好む応答} \mid \text{プロンプト}) \text{ を最大化} $$
従来手法との比較
| 手法 | 最適化対象 | 信号 |
|---|---|---|
| 事前学習 | 次トークン確率 | テキストコーパス |
| SFT | 条件付き生成 | デモンストレーション |
| RLHF | 人間の選好 | 比較フィードバック |
RLHFの3段階プロセス
RLHFは、以下の3段階で構成されます。
┌─────────────────────────────────────────────┐
│ Stage 1: 教師ありファインチューニング (SFT) │
│ - 高品質なデモデータで言語モデルを調整 │
│ - 望ましい応答のスタイルを学習 │
└─────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────┐
│ Stage 2: 報酬モデルの学習 │
│ - 人間の比較データを収集 │
│ - 「どちらの応答が良いか」を学習 │
└─────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────┐
│ Stage 3: PPOによる方策最適化 │
│ - 報酬モデルのスコアを最大化 │
│ - 元のモデルから離れすぎないよう制約 │
└─────────────────────────────────────────────┘
Stage 1: 教師ありファインチューニング (SFT)
目的
事前学習済みモデルを、対話形式や指示追従のスタイルに適応させます。
データ
人間が作成した高品質な(プロンプト, 応答)ペア:
プロンプト: "機械学習を簡単に説明してください"
応答: "機械学習は、コンピュータがデータから
パターンを学習し、予測や判断を行う
技術です。明示的にプログラムせずに、
経験から学ぶことができます。"
学習
標準的な言語モデルのファインチューニング:
$$
\mathcal{L}_{\text{SFT}} = -\mathbb{E}_{(x, y) \sim D} \left[ \sum_{t=1}^{T} \log P_\theta(y_t \mid x, y_{ 人間の選好を数値化する報酬モデル $r_\phi(x, y)$ を学習します。 同じプロンプトに対して複数の応答を生成し、人間が順位付けします。 2つの応答 $y_w$(勝者)と $y_l$(敗者)が与えられたとき、選好確率をBradley-Terryモデルでモデル化します。 $$
P(y_w \succ y_l \mid x) = \sigma(r_\phi(x, y_w) – r_\phi(x, y_l))
$$ ここで $\sigma$ はシグモイド関数: $$
\sigma(z) = \frac{1}{1 + e^{-z}}
$$ 比較データセット $D = \{(x^{(i)}, y_w^{(i)}, y_l^{(i)})\}$ に対して: $$
\mathcal{L}_{\text{RM}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma(r_\phi(x, y_w) – r_\phi(x, y_l)) \right]
$$ これは二値分類の交差エントロピー損失と同じ形式です。 報酬モデルのスコアを最大化しつつ、元のモデル(SFT後)から大きく乖離しないようにモデルを最適化します。 $$
\mathcal{J}(\theta) = \mathbb{E}_{x \sim D, y \sim \pi_\theta(\cdot|x)} \left[ r_\phi(x, y) – \beta \cdot D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}}) \right]
$$ ここで:
– $r_\phi(x, y)$: 報酬モデルのスコア
– $D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}})$: 現在のモデルと参照モデル(SFT後)のKLダイバージェンス
– $\beta$: KLペナルティの係数 KLペナルティがないと、モデルは報酬モデルを「ハック」して高いスコアを得ようとします(reward hacking)。 KLペナルティにより、元のモデルに近い範囲で最適化されます。 PPO(Proximal Policy Optimization)は、安定した方策勾配法です。 クリップ目的関数: $$
L^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]
$$ ここで:
– $r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$: 確率比
– $\hat{A}_t$: アドバンテージ推定値
– $\epsilon$: クリッピング範囲(例: 0.2) 言語モデルでは、報酬は系列全体の生成後に得られます。各トークン位置でのアドバンテージを計算するために、Generalized Advantage Estimation (GAE) を使用します。 $$
\hat{A}_t = \sum_{l=0}^{T-t} (\gamma \lambda)^l \delta_{t+l}
$$ $$
\delta_t = r_t + \gamma V(s_{t+1}) – V(s_t)
$$ $$
\max_\theta \mathbb{E}_{x \sim D} \left[ \mathbb{E}_{y \sim \pi_\theta(\cdot|x)} [r_\phi(x, y)] – \beta \cdot D_{\text{KL}}(\pi_\theta(\cdot|x) \| \pi_{\text{ref}}(\cdot|x)) \right]
$$ トークンごとのKLダイバージェンス: $$
D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}}) = \mathbb{E}_{y \sim \pi_\theta} \left[ \sum_{t=1}^{T} \log \frac{\pi_\theta(y_t | x, y_{ 実際には、KLペナルティを報酬に組み込んで: $$
r'(x, y) = r_\phi(x, y) – \beta \sum_{t=1}^{T} \log \frac{\pi_\theta(y_t | x, y_{ 1. 報酬ハッキング 報酬モデルの欠陥を悪用して、実際には低品質な応答で高いスコアを得る問題。 2. 報酬モデルのバイアス 人間の評価者のバイアスが報酬モデルに反映される。 3. 学習の不安定性 強化学習の学習は不安定になりやすい。 4. 計算コスト 複数のモデル(方策、報酬、価値、参照)を同時に管理する必要がある。 1. DPO (Direct Preference Optimization) 報酬モデルを明示的に学習せず、選好データから直接方策を最適化。 $$
\mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} – \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]
$$ 2. RLAIF (RL from AI Feedback) 人間の代わりにAIがフィードバックを提供。 3. Constitutional AI AIが自身の出力を評価・修正するフレームワーク。 本記事では、RLHF(人間のフィードバックからの強化学習)について解説しました。 RLHFは、現代のLLMを「有益で、害がなく、正直」にするための重要な技術です。今後もDPOなどの代替手法と共に発展が続くでしょう。 次のステップとして、以下の記事も参考にしてください。Stage 2: 報酬モデルの学習
目的
比較データの収集
プロンプト: "量子力学を説明してください"
応答A: "量子力学は、原子・分子スケールの
物理現象を記述する理論です..."
応答B: "量子力学は難しいです。"
人間の評価: A > B (Aの方が良い)
Bradley-Terryモデル
損失関数
報酬モデルのアーキテクチャ
import torch
import torch.nn as nn
class RewardModel(nn.Module):
"""報酬モデル"""
def __init__(self, base_model, hidden_size):
super().__init__()
self.base_model = base_model # 事前学習済みTransformer
self.reward_head = nn.Linear(hidden_size, 1)
def forward(self, input_ids, attention_mask):
# 基盤モデルで特徴抽出
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask
)
# 最後のトークンの表現を使用
last_hidden = outputs.last_hidden_state[:, -1, :]
# スカラーの報酬値を出力
reward = self.reward_head(last_hidden).squeeze(-1)
return reward
Stage 3: PPOによる方策最適化
目的
強化学習としての定式化
目的関数
KLペナルティの役割
問題例:
- 同じフレーズを繰り返す
- 異常に長い応答を生成
- 報酬モデルのバイアスを悪用
PPOアルゴリズム
言語モデルへの適用
数学的定式化のまとめ
全体の最適化問題
KLダイバージェンスの計算
報酬の修正
実装の概要
トレーニングループ
import torch
from torch.optim import Adam
def rlhf_training_step(policy_model, ref_model, reward_model,
prompts, optimizer, beta=0.1):
"""
RLHFの1ステップ
Args:
policy_model: 最適化対象のモデル
ref_model: 参照モデル(固定)
reward_model: 報酬モデル(固定)
prompts: プロンプトのバッチ
optimizer: オプティマイザ
beta: KLペナルティ係数
"""
# 応答を生成
responses, log_probs = generate_responses(policy_model, prompts)
# 報酬を計算
rewards = reward_model(prompts, responses)
# 参照モデルの対数確率
with torch.no_grad():
ref_log_probs = compute_log_probs(ref_model, prompts, responses)
# KLペナルティを計算
kl_penalty = (log_probs - ref_log_probs).sum(dim=-1)
# 修正済み報酬
modified_rewards = rewards - beta * kl_penalty
# PPOの損失を計算
loss = compute_ppo_loss(log_probs, modified_rewards)
# バックプロパゲーション
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), rewards.mean().item()
応答生成
def generate_responses(model, prompts, max_length=100):
"""
応答を生成し、対数確率を記録
Returns:
responses: 生成されたトークンID
log_probs: 各トークンの対数確率
"""
responses = []
all_log_probs = []
for prompt in prompts:
input_ids = prompt
log_probs = []
for _ in range(max_length):
outputs = model(input_ids)
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
# サンプリング
next_token = torch.multinomial(probs, num_samples=1)
log_prob = torch.log(probs.gather(-1, next_token))
log_probs.append(log_prob)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == eos_token_id:
break
responses.append(input_ids)
all_log_probs.append(torch.cat(log_probs))
return responses, all_log_probs
価値関数の学習
class ValueHead(nn.Module):
"""価値関数ヘッド"""
def __init__(self, hidden_size):
super().__init__()
self.value_head = nn.Linear(hidden_size, 1)
def forward(self, hidden_states):
return self.value_head(hidden_states).squeeze(-1)
def compute_value_loss(values, returns):
"""価値関数の損失"""
return ((values - returns) ** 2).mean()
課題と発展
課題
発展手法
まとめ