PPO(Proximal Policy Optimization)の理論と実装

PPO(Proximal Policy Optimization)はOpenAIが2017年に発表した方策勾配法であり、TRPOの安定性を維持しつつ実装が大幅に簡単になったアルゴリズムです。現在の深層強化学習において最もよく使われるアルゴリズムの1つであり、ロボット制御からLLMのRLHFまで幅広く応用されています。

本記事の内容

  • 方策勾配法の課題とTRPOの着想
  • PPOのClipped Surrogate Objective
  • GAE(Generalized Advantage Estimation)
  • PPOアルゴリズムの全体像
  • Pythonでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

方策勾配法の課題

更新幅の問題

方策勾配法の更新則は

$$ \bm{\theta} \leftarrow \bm{\theta} + \alpha \nabla_{\bm{\theta}} J(\bm{\theta}) $$

学習率 $\alpha$ が大きすぎると方策が急激に変化し、パフォーマンスが崩壊します。小さすぎると学習が遅くなります。

方策は確率分布であるため、パラメータ空間での小さな変化が方策空間では大きな変化をもたらすことがあります。

重要度サンプリング

古い方策 $\pi_{\bm{\theta}_{\text{old}}}$ で収集したデータを使って、新しい方策 $\pi_{\bm{\theta}}$ を評価・更新したい場合、重要度サンプリングを用います。

$$ J(\bm{\theta}) = E_{a \sim \pi_{\bm{\theta}_{\text{old}}}} \left[\frac{\pi_{\bm{\theta}}(a|s)}{\pi_{\bm{\theta}_{\text{old}}}(a|s)} A^{\pi_{\bm{\theta}_{\text{old}}}}(s, a)\right] $$

ここで 確率比(probability ratio)を

$$ r_t(\bm{\theta}) = \frac{\pi_{\bm{\theta}}(a_t|s_t)}{\pi_{\bm{\theta}_{\text{old}}}(a_t|s_t)} $$

と定義すると

$$ J(\bm{\theta}) = E_t\left[r_t(\bm{\theta}) \hat{A}_t\right] $$

この $J(\bm{\theta})$ を サロゲート目的関数(surrogate objective)と呼びます。$\hat{A}_t$ はアドバンテージの推定値です。

TRPOの着想

TRPO(Trust Region Policy Optimization, Schulman et al., 2015)は、方策の更新に信頼領域制約を課すことで安定した学習を実現します。

$$ \begin{aligned} \max_{\bm{\theta}} \quad & E_t\left[r_t(\bm{\theta}) \hat{A}_t\right] \\ \text{subject to} \quad & E_t\left[D_{\text{KL}}\left(\pi_{\bm{\theta}_{\text{old}}}(\cdot|s_t) \| \pi_{\bm{\theta}}(\cdot|s_t)\right)\right] \leq \delta \end{aligned} $$

KLダイバージェンスで方策の変化量を制限し、パフォーマンスの崩壊を防ぎます。

TRPOの問題点:

  • 制約付き最適化が必要で、2次近似(フィッシャー情報行列)と共役勾配法を使う
  • 実装が複雑
  • ネットワーク間でのパラメータ共有が難しい

PPOのClipped Surrogate Objective

PPOはTRPOの制約を目的関数のクリッピングで近似的に実現します。

Clipped Objective

$$ \boxed{L^{\text{CLIP}}(\bm{\theta}) = E_t\left[\min\left(r_t(\bm{\theta}) \hat{A}_t, \; \text{clip}(r_t(\bm{\theta}), 1-\varepsilon, 1+\varepsilon) \hat{A}_t\right)\right]} $$

ここで $\varepsilon$ はハイパーパラメータ(通常 $\varepsilon = 0.2$)であり、

$$ \text{clip}(r, 1-\varepsilon, 1+\varepsilon) = \max(1-\varepsilon, \min(r, 1+\varepsilon)) $$

は確率比 $r$ を $[1-\varepsilon, 1+\varepsilon]$ の範囲に制限します。

クリッピングの直感的理解

$\min$ の中の2つの項を分析します。

場合1: $\hat{A}_t > 0$(良い行動)の場合

行動が良いので $r_t$ を大きくしたい(その行動をもっと取るようにしたい)。しかし $r_t$ が $1+\varepsilon$ を超えると

$$ \min(r_t \hat{A}_t, (1+\varepsilon) \hat{A}_t) = (1+\varepsilon) \hat{A}_t $$

となり、勾配がゼロになります。つまり、方策の変化が大きくなりすぎることを防ぎます。

場合2: $\hat{A}_t < 0$(悪い行動)の場合

行動が悪いので $r_t$ を小さくしたい(その行動を避けるようにしたい)。しかし $r_t$ が $1-\varepsilon$ を下回ると

$$ \min(r_t \hat{A}_t, (1-\varepsilon) \hat{A}_t) = (1-\varepsilon) \hat{A}_t $$

となり、やはり勾配がゼロになります。

いずれの場合も、方策が「古い方策から離れすぎる」更新が抑制されます。

GAE(Generalized Advantage Estimation)

アドバンテージ関数

アドバンテージ関数 $A^{\pi}(s, a)$ は、行動 $a$ が平均的な行動よりどれだけ良いかを表します。

$$ A^{\pi}(s, a) = Q^{\pi}(s, a) – V^{\pi}(s) $$

TD残差

1ステップのTD残差を

$$ \delta_t = r_{t+1} + \gamma V(s_{t+1}) – V(s_t) $$

と定義します。$\delta_t$ はアドバンテージの不偏推定量ですが分散が大きいです。

$n$ ステップアドバンテージ

$n$ ステップの収益を使ったアドバンテージ推定は

$$ \hat{A}_t^{(n)} = \sum_{l=0}^{n-1} \gamma^l \delta_{t+l} $$

$n = 1$ では分散は小さいがバイアスが大きく、$n = T-t$ では不偏だが分散が大きいです。

GAEの定義

GAE(Schulman et al., 2016)は、異なる $n$ の推定量を指数的に重み付け平均します。

$$ \boxed{\hat{A}_t^{\text{GAE}(\gamma, \lambda)} = \sum_{l=0}^{T-t-1} (\gamma\lambda)^l \delta_{t+l}} $$

ここで $\lambda \in [0, 1]$ はバイアスと分散のトレードオフを制御するパラメータです。

導出

$k$ ステップアドバンテージを

$$ \hat{A}_t^{(k)} = \sum_{l=0}^{k-1} \gamma^l \delta_{t+l} $$

とすると、GAEはこれらの指数的加重平均です。

$$ \begin{align} \hat{A}_t^{\text{GAE}} &= (1-\lambda)\left(\hat{A}_t^{(1)} + \lambda \hat{A}_t^{(2)} + \lambda^2 \hat{A}_t^{(3)} + \cdots \right) \\ &= (1-\lambda)\left(\delta_t + \lambda(\delta_t + \gamma\delta_{t+1}) + \lambda^2(\delta_t + \gamma\delta_{t+1} + \gamma^2\delta_{t+2}) + \cdots\right) \\ &= (1-\lambda)\left(\frac{\delta_t}{1-\lambda} + \frac{\gamma\lambda\delta_{t+1}}{1-\lambda} + \frac{(\gamma\lambda)^2\delta_{t+2}}{1-\lambda} + \cdots\right) \\ &= \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l} \end{align} $$

特殊なケース

  • $\lambda = 0$: $\hat{A}_t = \delta_t = r_{t+1} + \gamma V(s_{t+1}) – V(s_t)$(1ステップTD、低分散・高バイアス)
  • $\lambda = 1$: $\hat{A}_t = \sum_{l=0}^{T-t-1} \gamma^l \delta_{t+l} = G_t – V(s_t)$(モンテカルロ、不偏・高分散)

実用上は $\lambda = 0.95$ 程度がよく使われます。

PPOアルゴリズムの全体像

全損失関数

PPOでは、Actor(方策)とCritic(価値関数)を同時に学習します。全損失関数は

$$ L(\bm{\theta}) = L^{\text{CLIP}}(\bm{\theta}) – c_1 L^{\text{VF}}(\bm{\theta}) + c_2 S[\pi_{\bm{\theta}}] $$

各項の意味:

  • $L^{\text{CLIP}}$: クリップされたサロゲート目的関数(最大化)
  • $L^{\text{VF}} = (V_{\bm{\theta}}(s_t) – V_t^{\text{target}})^2$: 価値関数の損失(最小化)
  • $S[\pi_{\bm{\theta}}] = -\sum_a \pi_{\bm{\theta}}(a|s) \log \pi_{\bm{\theta}}(a|s)$: エントロピーボーナス(探索促進)
  • $c_1, c_2$: 重み係数

アルゴリズムの流れ

  1. 現在の方策 $\pi_{\bm{\theta}_{\text{old}}}$ で $T$ ステップ分のデータを収集
  2. GAEでアドバンテージ $\hat{A}_t$ を計算
  3. 収集したデータに対して $K$ エポック(通常 $K = 3 \sim 10$)のミニバッチ更新を実行
  4. $\bm{\theta}_{\text{old}} \leftarrow \bm{\theta}$ として1に戻る

同じデータを複数エポック再利用できるのは、クリッピングにより方策が大きく変わらないためです。

Pythonでの実装

CartPole環境でのPPO

import numpy as np
import matplotlib.pyplot as plt

# --- CartPole環境 ---
class CartPoleEnv:
    """簡易CartPole環境"""

    def __init__(self):
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masscart + self.masspole
        self.length = 0.5
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02
        self.theta_threshold = 12 * np.pi / 180
        self.x_threshold = 2.4
        self.max_steps = 500

    def reset(self):
        self.state = np.random.uniform(-0.05, 0.05, size=4)
        self.steps = 0
        return self.state.copy()

    def step(self, action):
        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = np.cos(theta)
        sintheta = np.sin(theta)
        temp = (force + self.polemass_length * theta_dot**2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0/3.0 - self.masspole * costheta**2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        x += self.tau * x_dot
        x_dot += self.tau * xacc
        theta += self.tau * theta_dot
        theta_dot += self.tau * thetaacc

        self.state = np.array([x, x_dot, theta, theta_dot])
        self.steps += 1
        done = (abs(x) > self.x_threshold or
                abs(theta) > self.theta_threshold or
                self.steps >= self.max_steps)
        reward = 1.0 if not done or self.steps >= self.max_steps else 0.0
        return self.state.copy(), reward, done


# --- Actor-Critic ネットワーク ---
class ActorCritic:
    """Actor-Criticネットワーク(NumPyスクラッチ実装)"""

    def __init__(self, state_size, action_size, hidden_size=64, lr=3e-4):
        self.action_size = action_size
        self.lr = lr

        # 共有層
        self.W1 = np.random.randn(state_size, hidden_size) * np.sqrt(2.0 / state_size)
        self.b1 = np.zeros(hidden_size)

        # Actor(方策)ヘッド
        self.W_actor = np.random.randn(hidden_size, action_size) * 0.01
        self.b_actor = np.zeros(action_size)

        # Critic(価値関数)ヘッド
        self.W_critic = np.random.randn(hidden_size, 1) * 0.01
        self.b_critic = np.zeros(1)

    def forward(self, state):
        """順伝播: 方策と価値を出力"""
        self.h = np.maximum(0, state @ self.W1 + self.b1)  # ReLU
        logits = self.h @ self.W_actor + self.b_actor
        value = (self.h @ self.W_critic + self.b_critic).flatten()

        # ソフトマックス
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)

        return probs, value

    def get_action(self, state):
        """行動をサンプリング"""
        probs, value = self.forward(state.reshape(1, -1))
        probs = probs[0]
        action = np.random.choice(self.action_size, p=probs)
        return action, probs[action], value[0]

    def evaluate(self, states, actions):
        """バッチでの評価"""
        probs, values = self.forward(states)
        action_probs = probs[np.arange(len(actions)), actions]
        # エントロピー
        entropy = -np.sum(probs * np.log(probs + 1e-8), axis=1)
        return action_probs, values, entropy

    def update(self, states, actions, old_probs, advantages, returns,
               clip_epsilon=0.2, c1=0.5, c2=0.01):
        """PPO更新"""
        # 現在の方策での評価
        curr_probs, values, entropy = self.evaluate(states, actions)

        # 確率比
        ratio = curr_probs / (old_probs + 1e-8)

        # クリップされたサロゲート目的関数
        surr1 = ratio * advantages
        surr2 = np.clip(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
        policy_loss = -np.mean(np.minimum(surr1, surr2))

        # 価値関数の損失
        value_loss = np.mean((values - returns) ** 2)

        # エントロピーボーナス
        entropy_bonus = np.mean(entropy)

        # 合計損失
        total_loss = policy_loss + c1 * value_loss - c2 * entropy_bonus

        # --- 勾配計算と更新(簡略化した勾配降下) ---
        batch_size = len(states)

        # Actor勾配の近似計算
        probs_all, _ = self.forward(states)

        # 方策勾配
        for i in range(batch_size):
            h = np.maximum(0, states[i:i+1] @ self.W1 + self.b1)  # (1, hidden)
            p = probs_all[i]

            # ∂L/∂logits のクリップを考慮した勾配
            r = ratio[i]
            adv = advantages[i]

            if (r < 1 - clip_epsilon and adv > 0) or (r > 1 + clip_epsilon and adv < 0):
                continue  # クリップされた場合はスキップ

            grad_logit = np.zeros(self.action_size)
            a = actions[i]
            for j in range(self.action_size):
                if j == a:
                    grad_logit[j] = (1 - p[j]) * adv
                else:
                    grad_logit[j] = -p[j] * adv

            self.W_actor += self.lr * h.T @ grad_logit.reshape(1, -1) / batch_size
            self.b_actor += self.lr * grad_logit / batch_size

        # Critic勾配
        value_error = values - returns
        self.W_critic -= self.lr * c1 * self.h.T @ value_error.reshape(-1, 1) / batch_size
        self.b_critic -= self.lr * c1 * np.mean(value_error)

        return total_loss


def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """GAE(Generalized Advantage Estimation)の計算"""
    T = len(rewards)
    advantages = np.zeros(T)
    gae = 0

    for t in reversed(range(T)):
        if t == T - 1:
            next_value = 0
        else:
            next_value = values[t + 1]

        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae

    returns = advantages + values
    return advantages, returns


# --- PPO学習 ---
np.random.seed(42)

env = CartPoleEnv()
agent = ActorCritic(state_size=4, action_size=2, hidden_size=64, lr=3e-4)

n_updates = 200
steps_per_update = 256
n_epochs = 4
clip_epsilon = 0.2
gamma = 0.99
lam = 0.95

rewards_history = []
episode_reward = 0
state = env.reset()

all_episode_rewards = []
current_episode_reward = 0

for update in range(n_updates):
    # データ収集
    states_batch = []
    actions_batch = []
    rewards_batch = []
    values_batch = []
    probs_batch = []
    dones_batch = []

    for step in range(steps_per_update):
        action, prob, value = agent.get_action(state)

        next_state, reward, done = env.step(action)
        current_episode_reward += reward

        states_batch.append(state)
        actions_batch.append(action)
        rewards_batch.append(reward)
        values_batch.append(value)
        probs_batch.append(prob)
        dones_batch.append(float(done))

        state = next_state
        if done:
            all_episode_rewards.append(current_episode_reward)
            current_episode_reward = 0
            state = env.reset()

    # numpy配列に変換
    states_arr = np.array(states_batch)
    actions_arr = np.array(actions_batch)
    rewards_arr = np.array(rewards_batch)
    values_arr = np.array(values_batch)
    probs_arr = np.array(probs_batch)
    dones_arr = np.array(dones_batch)

    # GAEでアドバンテージと収益を計算
    advantages, returns = compute_gae(rewards_arr, values_arr, dones_arr, gamma, lam)

    # アドバンテージの正規化
    advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)

    # 複数エポックの更新
    for epoch in range(n_epochs):
        agent.update(states_arr, actions_arr, probs_arr, advantages, returns,
                     clip_epsilon=clip_epsilon)

# --- 可視化 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 学習曲線
window = 20
if len(all_episode_rewards) > window:
    smooth = np.convolve(all_episode_rewards, np.ones(window)/window, mode='valid')
    axes[0].plot(all_episode_rewards, alpha=0.3, color='blue')
    axes[0].plot(np.arange(window-1, len(all_episode_rewards)), smooth,
                 'b-', linewidth=2, label=f'Moving avg (window={window})')
else:
    axes[0].plot(all_episode_rewards, 'b-', linewidth=1.5)
axes[0].set_xlabel('Episode', fontsize=11)
axes[0].set_ylabel('Total Reward', fontsize=11)
axes[0].set_title('PPO Learning Curve (CartPole)', fontsize=13)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# 報酬分布
if len(all_episode_rewards) > 40:
    half = len(all_episode_rewards) // 2
    axes[1].hist(all_episode_rewards[:half], bins=20, alpha=0.5, color='red',
                 label='First half')
    axes[1].hist(all_episode_rewards[half:], bins=20, alpha=0.5, color='blue',
                 label='Second half')
    axes[1].legend(fontsize=10)
axes[1].set_xlabel('Total Reward', fontsize=11)
axes[1].set_ylabel('Count', fontsize=11)
axes[1].set_title('Reward Distribution', fontsize=13)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

if len(all_episode_rewards) >= 50:
    print(f"最後の50エピソードの平均報酬: {np.mean(all_episode_rewards[-50:]):.1f}")

PPOの実装上のポイント

PPOを実装する際に重要なポイントを整理します。

  1. アドバンテージの正規化: ミニバッチ内でアドバンテージを正規化(平均0、分散1)すると学習が安定する
  2. 学習率のスケジュール: 線形に減衰させることが多い
  3. クリップ範囲 $\varepsilon$: 通常0.1から0.3の範囲。$\varepsilon = 0.2$ が最もよく使われる
  4. GAEの $\lambda$: $\lambda = 0.95$ が標準的。バイアスと分散のトレードオフを制御
  5. エントロピーボーナス: 係数 $c_2$ は0.01程度。早期の方策の収束を防ぐ
  6. ミニバッチ: 収集データをシャッフルしてミニバッチに分割

まとめ

本記事では、PPO(Proximal Policy Optimization)の理論と実装について解説しました。

  • サロゲート目的関数: 重要度サンプリングにより $r_t(\bm{\theta}) \hat{A}_t$ で方策を評価
  • Clipped Objective: 確率比を $[1-\varepsilon, 1+\varepsilon]$ にクリップし、方策の急激な変化を防止
  • GAE: $\hat{A}_t = \sum_{l=0}^{T-t-1}(\gamma\lambda)^l \delta_{t+l}$ でバイアスと分散を制御
  • 複数エポック更新: 同じデータを $K$ 回再利用でき、データ効率が良い
  • TRPOの安定性を保ちつつ、実装が大幅に簡潔

次のステップとして、以下の記事も参考にしてください。