PyroはUber AI(現在はオープンソースコミュニティ)が開発した確率的プログラミングフレームワークです。PyTorchの上に構築されており、柔軟なベイズモデリングと変分推論を実装することができます。
本記事では、Pyroの基本的な使い方と、変分推論の実装方法を解説します。
本記事の内容
- 確率的プログラミングと変分推論の基本
- Pyroの基本概念(model, guide, plate)
- ベイズ線形回帰の変分推論
- ガウス混合モデルの変分推論
変分推論の基本
ベイズ推論では、事後分布 $p(\bm{\theta} \mid \bm{y})$ を計算したいですが、多くの場合これは解析的に計算できません。
$$ p(\bm{\theta} \mid \bm{y}) = \frac{p(\bm{y} \mid \bm{\theta}) p(\bm{\theta})}{p(\bm{y})} $$
分母の周辺尤度 $p(\bm{y}) = \int p(\bm{y} \mid \bm{\theta}) p(\bm{\theta}) d\bm{\theta}$ が計算困難なためです。
変分推論(Variational Inference, VI)では、事後分布 $p(\bm{\theta} \mid \bm{y})$ を、扱いやすい分布族 $q_\phi(\bm{\theta})$ で近似します。近似の良さはKLダイバージェンスで測ります。
$$ \text{KL}(q_\phi(\bm{\theta}) \| p(\bm{\theta} \mid \bm{y})) = \int q_\phi(\bm{\theta}) \log \frac{q_\phi(\bm{\theta})}{p(\bm{\theta} \mid \bm{y})} d\bm{\theta} $$
KLダイバージェンスを直接最小化するのは難しいため、代わりにELBO(Evidence Lower Bound)を最大化します。
$$ \text{ELBO}(\phi) = E_{q_\phi}[\log p(\bm{y}, \bm{\theta})] – E_{q_\phi}[\log q_\phi(\bm{\theta})] $$
ELBOの最大化はKLダイバージェンスの最小化と等価です。
Pyroの基本概念
Pyroでは、確率モデルを model 関数、変分分布を guide 関数として定義します。
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# 確率モデルの定義
def model(data):
# 事前分布
mu = pyro.sample("mu", dist.Normal(0., 10.))
sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
# 尤度
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
# 変分分布(guide)の定義
def guide(data):
# 変分パラメータ
mu_loc = pyro.param("mu_loc", torch.tensor(0.))
mu_scale = pyro.param("mu_scale", torch.tensor(1.),
constraint=dist.constraints.positive)
sigma_loc = pyro.param("sigma_loc", torch.tensor(0.))
sigma_scale = pyro.param("sigma_scale", torch.tensor(1.),
constraint=dist.constraints.positive)
pyro.sample("mu", dist.Normal(mu_loc, mu_scale))
pyro.sample("sigma", dist.LogNormal(sigma_loc, sigma_scale))
Pyroの重要な概念は以下のとおりです。
| 概念 | 説明 |
|---|---|
pyro.sample |
確率変数のサンプリング |
pyro.param |
最適化対象のパラメータ |
pyro.plate |
独立な繰り返し(バッチ処理) |
model |
同時分布 $p(\bm{y}, \bm{\theta})$ を定義 |
guide |
変分分布 $q_\phi(\bm{\theta})$ を定義 |
正規分布の推定
最も基本的な例として、データから正規分布のパラメータを推定してみましょう。
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
# パラメータストアをクリア
pyro.clear_param_store()
# データ生成
torch.manual_seed(42)
true_mu = 3.0
true_sigma = 1.5
data = torch.normal(true_mu, true_sigma, size=(200,))
def model(data):
mu = pyro.sample("mu", dist.Normal(0., 10.))
sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
def guide(data):
mu_loc = pyro.param("mu_loc", torch.tensor(0.))
mu_scale = pyro.param("mu_scale", torch.tensor(1.),
constraint=dist.constraints.positive)
sigma_loc = pyro.param("sigma_loc", torch.tensor(0.))
sigma_scale = pyro.param("sigma_scale", torch.tensor(0.5),
constraint=dist.constraints.positive)
pyro.sample("mu", dist.Normal(mu_loc, mu_scale))
pyro.sample("sigma", dist.LogNormal(sigma_loc, sigma_scale))
# SVIの設定
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# 学習
n_steps = 2000
losses = []
for step in range(n_steps):
loss = svi.step(data)
losses.append(loss)
if (step + 1) % 500 == 0:
print(f"Step {step+1}: ELBO = {-loss:.2f}")
# 推定結果
mu_loc = pyro.param("mu_loc").item()
mu_scale = pyro.param("mu_scale").item()
sigma_loc = pyro.param("sigma_loc").item()
print(f"\n推定結果:")
print(f" mu: {mu_loc:.3f} +/- {mu_scale:.3f} (真値: {true_mu})")
print(f" sigma: {np.exp(sigma_loc):.3f} (真値: {true_sigma})")
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# ELBO の収束
axes[0].plot(losses, linewidth=0.5)
axes[0].set_title("ELBO Convergence")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Loss (-ELBO)")
axes[0].grid(True, alpha=0.3)
# 推定された分布
x = np.linspace(-2, 8, 200)
true_pdf = (1/np.sqrt(2*np.pi*true_sigma**2)) * np.exp(-(x-true_mu)**2/(2*true_sigma**2))
est_sigma = np.exp(sigma_loc)
est_pdf = (1/np.sqrt(2*np.pi*est_sigma**2)) * np.exp(-(x-mu_loc)**2/(2*est_sigma**2))
axes[1].hist(data.numpy(), bins=30, density=True, alpha=0.5, color='steelblue', label='Data')
axes[1].plot(x, true_pdf, 'g-', linewidth=2, label=f'True (mu={true_mu}, sigma={true_sigma})')
axes[1].plot(x, est_pdf, 'r--', linewidth=2, label=f'Estimated (mu={mu_loc:.2f}, sigma={est_sigma:.2f})')
axes[1].set_title("Estimated Distribution")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
ベイズ線形回帰
より実用的な例として、ベイズ線形回帰を変分推論で実装します。
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
pyro.clear_param_store()
torch.manual_seed(42)
# データ生成
n = 100
X = torch.linspace(-3, 3, n)
w_true = 2.0
b_true = -1.0
sigma_true = 0.5
y = w_true * X + b_true + torch.normal(0, sigma_true, size=(n,))
def model(X, y=None):
# パラメータの事前分布
w = pyro.sample("w", dist.Normal(0., 5.))
b = pyro.sample("b", dist.Normal(0., 5.))
sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
mean = w * X + b
with pyro.plate("data", len(X)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
def guide(X, y=None):
w_loc = pyro.param("w_loc", torch.tensor(0.))
w_scale = pyro.param("w_scale", torch.tensor(1.),
constraint=dist.constraints.positive)
b_loc = pyro.param("b_loc", torch.tensor(0.))
b_scale = pyro.param("b_scale", torch.tensor(1.),
constraint=dist.constraints.positive)
sigma_loc = pyro.param("sigma_loc", torch.tensor(0.))
sigma_scale = pyro.param("sigma_scale", torch.tensor(0.5),
constraint=dist.constraints.positive)
pyro.sample("w", dist.Normal(w_loc, w_scale))
pyro.sample("b", dist.Normal(b_loc, b_scale))
pyro.sample("sigma", dist.LogNormal(sigma_loc, sigma_scale))
# 学習
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_steps = 3000
for step in range(n_steps):
svi.step(X, y)
# 結果
w_est = pyro.param("w_loc").item()
b_est = pyro.param("b_loc").item()
print(f"w: {w_est:.3f} (真値: {w_true})")
print(f"b: {b_est:.3f} (真値: {b_true})")
# 予測の不確実性を可視化
predictive = Predictive(model, guide=guide, num_samples=200)
X_test = torch.linspace(-4, 4, 100)
preds = predictive(X_test)["obs"]
pred_mean = preds.mean(dim=0).detach().numpy()
pred_std = preds.std(dim=0).detach().numpy()
plt.figure(figsize=(10, 6))
plt.scatter(X.numpy(), y.numpy(), c='steelblue', s=15, alpha=0.5, label='Data')
plt.plot(X_test.numpy(), pred_mean, 'r-', linewidth=2, label='Posterior mean')
plt.fill_between(X_test.numpy(), pred_mean - 2*pred_std, pred_mean + 2*pred_std,
alpha=0.2, color='red', label='2-sigma')
plt.plot(X_test.numpy(), w_true * X_test.numpy() + b_true, 'g--', linewidth=2, label='True')
plt.title("Bayesian Linear Regression with Pyro")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
ベイズ線形回帰では、点推定だけでなく予測の不確実性(赤い帯)も得られます。データが少ない領域では不確実性が大きく、データが密な領域では小さくなっていることが確認できます。
AutoGuide の利用
Pyroは変分分布を自動で構築するAutoGuideを提供しています。手動でguide関数を書く必要がなくなります。
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
pyro.clear_param_store()
torch.manual_seed(42)
n = 100
X = torch.linspace(-3, 3, n)
y = 2.0 * X - 1.0 + torch.normal(0, 0.5, size=(n,))
def model(X, y=None):
w = pyro.sample("w", dist.Normal(0., 5.))
b = pyro.sample("b", dist.Normal(0., 5.))
sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
mean = w * X + b
with pyro.plate("data", len(X)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
# AutoGuide: 対角正規分布で自動構築
guide = AutoDiagonalNormal(model)
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
for step in range(2000):
svi.step(X, y)
# AutoGuideから推定値を取得
for name, value in pyro.get_param_store().items():
print(f"{name}: {value.detach().numpy()}")
AutoDiagonalNormal は全ての潜在変数を独立な正規分布で近似します。他にも AutoMultivariateNormal(相関も考慮)や AutoNormal などが利用可能です。
まとめ
本記事では、Pyroフレームワークを使った変分推論の実装方法を解説しました。
- 変分推論は事後分布を扱いやすい分布族で近似し、ELBOを最大化する手法
- PyroはPyTorchベースの確率的プログラミングフレームワーク
modelで同時分布、guideで変分分布を定義し、SVIで最適化するAutoGuideを使えば変分分布を自動構築でき、実装が簡潔になる- ベイズ推論により予測の不確実性を自然に推定できる