深層学習モデルは多くのタスクで優れた性能を発揮しますが、その予測には重大な弱点があります。通常のニューラルネットワークは、訓練データからかけ離れた入力に対しても、何の躊躇もなく自信満々の予測を返すのです。
たとえば、犬と猫を分類するモデルに「自動車」の画像を入力すると、「犬: 95%」のような高い確信度で誤った分類を行うことがあります。人間であれば「これは犬でも猫でもない、わからない」と正直に答えるところですが、通常のニューラルネットワークにはこの「わからない」を表現する仕組みがありません。
この問題は、自動運転や医療診断など、誤った判断が深刻な結果をもたらす領域では致命的です。モデルが「自分の予測にどれだけ自信があるか」を正確に伝えられることが不可欠なのです。
ベイズニューラルネットワーク(Bayesian Neural Network, BNN)は、この問題を解決するために、ネットワークの重みを確定的な値ではなく確率分布として扱います。ベイズ推論の枠組みで重みの事後分布を推定することで、予測の不確実性を自然かつ理論的に正当な方法で定量化できるのです。
ベイズニューラルネットワークを理解すると、以下のような応用が開けます。
- 不確実性推定: 予測の信頼度を定量化し、「自信のない予測」を識別できる。医療AI、自動運転に不可欠
- OOD検出: 訓練データの分布外(Out-of-Distribution)の入力を検出できる
- 能動学習: 予測の不確実性が高いサンプルを優先的にラベル付けし、効率的にモデルを改善
- 正則化効果: 重みの事前分布が暗黙的な正則化として機能し、過学習を抑制
- 安全なAI: モデルの「知らないことを知っている」能力は、安全性が重要なシステムの基盤技術
本記事の内容
- 通常のNNの問題点(過信する予測)
- BNNの直感的理解(重みに分布を持たせる)
- 数学的定式化(重みの事後分布)
- 変分推論による近似(Bayes by Backprop)
- MCドロップアウトとの関係
- PyTorchでの実装
- 認識的不確実性と偶然的不確実性の分離
- OOD検出への応用
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
- ベイズ推定とは?仕組みについてわかりやすく解説 — 事前分布・事後分布の基本
- ベイズ予測分布とは?事後分布から未知データを予測する理論と実装 — 予測分布とパラメータの周辺化
- ガウス過程回帰を初めから分かりやすく解説 — BNNの無限幅極限との関係
通常のニューラルネットワークの問題
点推定の限界
通常のニューラルネットワークの学習では、損失関数(交差エントロピー損失やMSE)を最小化して重み $\bm{w}$ の最適値 $\hat{\bm{w}}$ を求めます。これは最尤推定(MLE)に相当し、正則化を加えるとMAP推定に相当します。
$$ \hat{\bm{w}}_{\text{MLE}} = \arg\max_{\bm{w}} \log p(\mathcal{D} | \bm{w}) $$
この点推定アプローチには、2つの根本的な問題があります。
問題1: 予測の不確実性が表現できない
分類問題では、ソフトマックスの出力を「確率」として使いますが、これはモデルが正しいことを前提とした条件付き確率 $p(y | \bm{x}, \hat{\bm{w}})$ に過ぎません。重み $\hat{\bm{w}}$ が最適解からずれている可能性を考慮していないため、実際の不確実性を過小評価します。
問題2: 分布外データに対する過信
ニューラルネットワークは、訓練データの分布(in-distribution)上ではよい予測を行いますが、分布外(out-of-distribution, OOD)のデータに対しても高い確信度で予測を行ってしまいます。これは、ネットワークが「自分が何を知らないか」を知る仕組みを持っていないためです。
具体例: 過信するソフトマックス
なぜソフトマックスの出力が不確実性の良い指標にならないのかを、直感的に理解しましょう。
2つのクラスの分類問題で、決定境界の近くに入力が来たとします。決定境界の近くでは、入力の微小な変化で予測が大きく変わります。理想的には、このような入力に対して「よくわからない(50%に近い確率)」と出力すべきです。
しかし、通常のネットワークは、決定境界付近でも何らかのクラスに偏った予測を行い、しかもその予測に高い確信度を持つことがあります。特に、ReLU活性化関数を使ったネットワークは、入力空間の大部分で線形に振る舞うため、訓練データから遠い領域でもソフトマックス出力が飽和(0か1に近づく)してしまいます。
このような過信の問題を根本的に解決するのがBNNです。重みを点推定ではなく分布として扱うことで、「この重みの設定が正しいかどうか」という不確実性を予測に反映させるのです。
BNNの直感的理解
重みに分布を持たせるということ
BNNの核心的なアイデアは、実はとてもシンプルです。通常のNNでは各重みが1つの数値ですが、BNNでは各重みが確率分布です。
たとえば、あるニューロンの重みが通常のNNでは $w = 2.5$ であるのに対し、BNNでは $w \sim \mathcal{N}(2.5, 0.3^2)$(平均2.5、標準偏差0.3の正規分布)のように表現されます。この標準偏差0.3は、データからこの重みが $2.5$ であるとどの程度確信できるかを表しています。
予測のアンサンブルとしてのBNN
BNNで予測を行うとき、重みの分布からサンプルを複数回引き、それぞれのサンプルで順伝播を行い、結果を集計します。
- 重みの事後分布からサンプル $\bm{w}^{(1)}, \bm{w}^{(2)}, \dots, \bm{w}^{(T)}$ を引く
- 各サンプルで予測 $y^{(t)} = f(\bm{x}; \bm{w}^{(t)})$ を計算する
- 予測の平均 $\bar{y} = \frac{1}{T}\sum_t y^{(t)}$ を最終的な予測とする
- 予測のばらつき(分散)を不確実性の指標とする
これは「無限個のネットワークのアンサンブル」と解釈できます。各ネットワークは少しずつ異なる重みを持ち、その予測が一致すれば確信度が高く、ばらつけば確信度が低いと判断します。
この直感を数学的に定式化しましょう。
BNNの数学的定式化
ベイズの定理による重みの事後分布
データ $\mathcal{D} = \{(\bm{x}_n, y_n)\}_{n=1}^{N}$ が与えられたとき、重み $\bm{w}$ の事後分布はベイズの定理により、
$$ \begin{equation} p(\bm{w} | \mathcal{D}) = \frac{p(\mathcal{D} | \bm{w}) \, p(\bm{w})}{p(\mathcal{D})} \end{equation} $$
ここで、
- $p(\mathcal{D} | \bm{w}) = \prod_{n=1}^{N} p(y_n | \bm{x}_n, \bm{w})$: 尤度関数。モデルの出力とデータの整合性
- $p(\bm{w})$: 重みの事前分布。通常は $p(\bm{w}) = \mathcal{N}(\bm{0}, \sigma_p^2 \bm{I})$(ガウス事前分布)
- $p(\mathcal{D}) = \int p(\mathcal{D} | \bm{w}) p(\bm{w}) d\bm{w}$: 周辺尤度(エビデンス)
予測分布
新しい入力 $\bm{x}_*$ に対する予測分布は、重みの事後分布で周辺化することで得られます。
$$ \begin{equation} p(y_* | \bm{x}_*, \mathcal{D}) = \int p(y_* | \bm{x}_*, \bm{w}) \, p(\bm{w} | \mathcal{D}) \, d\bm{w} \end{equation} $$
これはベイズ予測分布の一般的な形式であり、パラメータ(重み)の不確実性を予測に正しく反映します。
計算の困難さ
ニューラルネットワークの場合、事後分布 $p(\bm{w} | \mathcal{D})$ の計算には2つの根本的な困難があります。
- 周辺尤度の計算不能: $p(\mathcal{D}) = \int p(\mathcal{D} | \bm{w}) p(\bm{w}) d\bm{w}$ は重みの次元が数万〜数億に及ぶため、数値的に計算不可能です
- 非凸な尤度面: ニューラルネットワークの損失関数は非凸であり、事後分布は多峰性(複数のモード)を持つ可能性があります
これらの困難を回避するために、変分推論(variational inference)による近似が使われます。
変分推論による近似 — Bayes by Backprop
変分推論の基本アイデア
変分推論では、真の事後分布 $p(\bm{w} | \mathcal{D})$ を、パラメータ $\bm{\phi}$ で特徴づけられる扱いやすい分布族 $q_{\bm{\phi}}(\bm{w})$ で近似します。
日常的なアナロジーで考えてみましょう。複雑な地形(真の事後分布)を写真に収めたいとき、ある角度からの2D写真(近似分布)では完全には表現できませんが、最も情報を捉える角度を選べば良い近似が得られます。変分推論の「最適な角度を選ぶ」操作が、$\bm{\phi}$ の最適化に相当します。
ELBO(エビデンス下界)
$q_{\bm{\phi}}(\bm{w})$ と $p(\bm{w} | \mathcal{D})$ の近さをKLダイバージェンスで測り、これを最小化します。
$$ \text{KL}[q_{\bm{\phi}}(\bm{w}) \| p(\bm{w} | \mathcal{D})] = \int q_{\bm{\phi}}(\bm{w}) \log \frac{q_{\bm{\phi}}(\bm{w})}{p(\bm{w} | \mathcal{D})} d\bm{w} $$
しかし、$p(\bm{w} | \mathcal{D})$ が計算できないため、KLダイバージェンスを直接最小化することはできません。代わりに、以下のELBO(Evidence Lower Bound)を最大化します。
$$ \begin{equation} \mathcal{L}(\bm{\phi}) = \mathbb{E}_{q_{\bm{\phi}}(\bm{w})}[\log p(\mathcal{D} | \bm{w})] – \text{KL}[q_{\bm{\phi}}(\bm{w}) \| p(\bm{w})] \end{equation} $$
なぜELBOの最大化がKLダイバージェンスの最小化と等価なのかを見ましょう。対数周辺尤度は次のように分解できます。
$$ \log p(\mathcal{D}) = \mathcal{L}(\bm{\phi}) + \text{KL}[q_{\bm{\phi}}(\bm{w}) \| p(\bm{w} | \mathcal{D})] $$
$\log p(\mathcal{D})$ は $\bm{\phi}$ に依存しない定数なので、$\mathcal{L}(\bm{\phi})$ を最大化することは $\text{KL}[q_{\bm{\phi}} \| p(\cdot | \mathcal{D})]$ を最小化することと等価です。
ELBOの2つの項はそれぞれ明確な役割を持ちます。
- 第1項 $\mathbb{E}_{q}[\log p(\mathcal{D} | \bm{w})]$: データ適合項。近似分布の下でデータの対数尤度がどれだけ高いかを測ります。この項を大きくすると、データをよく説明する重みに分布が集中します
- 第2項 $-\text{KL}[q \| p(\bm{w})]$: 正則化項。近似分布が事前分布から離れすぎないようにするペナルティです。この項は過学習を防ぐ正則化として機能します
Bayes by Backprop
Blundell et al. (2015) が提案したBayes by Backpropは、ELBOをSGD(確率的勾配降下法)で最適化する実用的なアルゴリズムです。
近似分布として、各重みが独立な正規分布に従うと仮定します。
$$ q_{\bm{\phi}}(\bm{w}) = \prod_i \mathcal{N}(w_i | \mu_i, \sigma_i^2) $$
パラメータ $\bm{\phi} = \{\mu_i, \sigma_i\}$ を勾配法で最適化します。$\sigma_i$ は正でなければならないため、$\sigma_i = \log(1 + \exp(\rho_i))$(ソフトプラス変換)として $\rho_i$ を最適化パラメータとします。
再パラメータ化トリック
$\bm{w} \sim q_{\bm{\phi}}(\bm{w})$ からのサンプリングは微分可能ではないため、再パラメータ化トリック(reparameterization trick)を使います。
$$ w_i = \mu_i + \sigma_i \cdot \epsilon_i, \quad \epsilon_i \sim \mathcal{N}(0, 1) $$
これにより、$\bm{w}$ は $\bm{\phi}$ の決定的な関数と確率的なノイズ $\bm{\epsilon}$ に分解されます。$\bm{\epsilon}$ は $\bm{\phi}$ に依存しないため、ELBOの $\bm{\phi}$ に関する勾配を標準的なバックプロパゲーションで計算できます。
アルゴリズムの手順
Bayes by Backpropの1ステップは以下の通りです。
- ノイズをサンプリング: $\epsilon_i \sim \mathcal{N}(0, 1)$
- 重みを計算: $w_i = \mu_i + \log(1 + \exp(\rho_i)) \cdot \epsilon_i$
- ミニバッチに対するELBOを計算(モンテカルロ推定)
- $\mu_i$ と $\rho_i$ に関する勾配を計算してパラメータを更新
各ステップでサンプリングされた重みは異なるため、通常のSGDと比べて1つの入力に対する出力が確率的になります。これは訓練時のノイズとして機能し、暗黙的な正則化効果をもたらします。
変分推論による近似は理論的にエレガントですが、実装の複雑さとパラメータ数の増加(平均と分散で2倍)が課題です。より実用的な近似として、MCドロップアウトが提案されています。
MCドロップアウトとの関係
ドロップアウトの再解釈
Gal & Ghahramani (2016) は、驚くべき理論的結果を示しました。訓練時と推論時の両方でドロップアウトを適用し、推論時に複数回のフォワードパスを行って予測を集計することが、特定のベイズニューラルネットワークの変分推論近似と等価であるというのです。
通常、ドロップアウトは訓練時にのみ適用し、推論時は全ニューロンを使います。しかしMCドロップアウトでは、推論時にもドロップアウトをオンにしたまま、同じ入力に対して $T$ 回のフォワードパスを実行します。
各フォワードパスでは異なるニューロンがドロップされるため、異なる重みの組み合わせ(サブネットワーク)で予測が行われます。これは、重みの事後分布からのサンプリングを近似しています。
数学的な対応
ドロップアウト率 $p$ を持つネットワークは、以下の変分分布を暗黙的に使っていると解釈できます。
各重み行列 $\bm{W}_l$ の列ベクトル $\bm{w}_{l,j}$ に対して、
$$ q(\bm{w}_{l,j}) = (1 – p) \cdot \delta(\bm{w}_{l,j} – \bm{m}_{l,j}) + p \cdot \delta(\bm{w}_{l,j}) $$
つまり、確率 $(1-p)$ で学習された重み $\bm{m}_{l,j}$ を使い、確率 $p$ でゼロにします。これはベルヌーイ分布と点質量の混合分布であり、変分推論の近似分布として解釈できます。
MCドロップアウトの予測
$T$ 回のフォワードパスの結果を集計して予測を行います。
予測平均:
$$ \mathbb{E}[y_* | \bm{x}_*] \approx \frac{1}{T}\sum_{t=1}^{T} f(\bm{x}_*; \hat{\bm{w}}^{(t)}) $$
予測分散(回帰の場合):
$$ \text{Var}[y_* | \bm{x}_*] \approx \frac{1}{T}\sum_{t=1}^{T} f(\bm{x}_*; \hat{\bm{w}}^{(t)})^2 – \left(\frac{1}{T}\sum_{t=1}^{T} f(\bm{x}_*; \hat{\bm{w}}^{(t)})\right)^2 + \tau^{-1} $$
ここで $\tau^{-1}$ はデータノイズの分散(偶然的不確実性)の推定値です。
MCドロップアウトの利点
MCドロップアウトの最大の利点は、既存のドロップアウト付きネットワークをそのまま使えることです。追加の訓練は必要なく、推論時にドロップアウトをオンにして複数回フォワードパスを実行するだけです。
| 観点 | Bayes by Backprop | MCドロップアウト |
|---|---|---|
| 実装の手間 | 大(カスタム層が必要) | 小(既存モデルに適用可能) |
| パラメータ数 | 2倍(平均+分散) | 変化なし |
| 近似の質 | 比較的高い | やや粗い |
| 計算コスト(訓練) | 高い | 変化なし |
| 計算コスト(推論) | $T$ 回のフォワードパス | $T$ 回のフォワードパス |
実用上は、MCドロップアウトがその手軽さから広く使われています。一方、不確実性推定の精度がより重要な場面ではBayes by Backpropが選ばれます。
理論を十分に理解したところで、PyTorchでBNNを実装してみましょう。
PyTorch での実装
Bayes by Backprop のベイズ線形層
まず、Bayes by Backpropに基づくベイズ線形層(BayesianLinear)を実装します。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class BayesianLinear(nn.Module):
"""ベイズ線形層 (Bayes by Backprop)"""
def __init__(self, in_features, out_features, prior_sigma=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 重みの変分パラメータ
self.weight_mu = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
self.weight_rho = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-5, -4))
# バイアスの変分パラメータ
self.bias_mu = nn.Parameter(
torch.Tensor(out_features).uniform_(-0.2, 0.2))
self.bias_rho = nn.Parameter(
torch.Tensor(out_features).uniform_(-5, -4))
# 事前分布のパラメータ
self.prior_sigma = prior_sigma
self.log_prior = 0
self.log_variational_posterior = 0
def forward(self, x):
# ソフトプラス変換で正の標準偏差を保証
weight_sigma = torch.log1p(torch.exp(self.weight_rho))
bias_sigma = torch.log1p(torch.exp(self.bias_rho))
# 再パラメータ化トリック
weight_epsilon = torch.randn_like(weight_sigma)
bias_epsilon = torch.randn_like(bias_sigma)
weight = self.weight_mu + weight_sigma * weight_epsilon
bias = self.bias_mu + bias_sigma * bias_epsilon
# 対数事前確率
self.log_prior = (
-0.5 * torch.sum(weight**2) / self.prior_sigma**2
- 0.5 * torch.sum(bias**2) / self.prior_sigma**2
)
# 対数変分事後確率
self.log_variational_posterior = (
-0.5 * torch.sum(((weight - self.weight_mu) / weight_sigma)**2)
- torch.sum(torch.log(weight_sigma))
- 0.5 * torch.sum(((bias - self.bias_mu) / bias_sigma)**2)
- torch.sum(torch.log(bias_sigma))
)
return F.linear(x, weight, bias)
この実装の要点は以下の通りです。
weight_muとweight_rhoが変分パラメータです。$\sigma = \log(1 + \exp(\rho))$(ソフトプラス)により正の標準偏差を保証しますforwardメソッドの中で再パラメータ化トリックを使い、$w = \mu + \sigma \cdot \epsilon$ として重みをサンプリングしますlog_priorとlog_variational_posteriorはELBO計算のために記録されます
BNNモデルの構築
class BayesianNN(nn.Module):
"""ベイズニューラルネットワーク"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.blinear1 = BayesianLinear(input_dim, hidden_dim)
self.blinear2 = BayesianLinear(hidden_dim, hidden_dim)
self.blinear3 = BayesianLinear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.blinear1(x))
x = F.relu(self.blinear2(x))
x = self.blinear3(x)
return x
def log_prior(self):
return (self.blinear1.log_prior
+ self.blinear2.log_prior
+ self.blinear3.log_prior)
def log_variational_posterior(self):
return (self.blinear1.log_variational_posterior
+ self.blinear2.log_variational_posterior
+ self.blinear3.log_variational_posterior)
def elbo_loss(self, x, y, n_samples, n_data):
"""ELBO損失(最小化する)"""
total_log_lik = 0
total_log_prior = 0
total_log_var_post = 0
for _ in range(n_samples):
output = self.forward(x)
log_lik = -0.5 * torch.sum((y - output)**2)
total_log_lik += log_lik
total_log_prior += self.log_prior()
total_log_var_post += self.log_variational_posterior()
# モンテカルロ推定の平均
total_log_lik /= n_samples
total_log_prior /= n_samples
total_log_var_post /= n_samples
# ミニバッチスケーリング
kl = (total_log_var_post - total_log_prior) / n_data
nll = -total_log_lik / n_data
return nll + kl
elbo_loss メソッドは、ELBOの負(最小化するため符号を反転)を計算します。KLダイバージェンスの項はデータ数で割ることで、ミニバッチ学習に対応しています。
次に、このBNNを1次元回帰問題に適用して、不確実性の推定を確認しましょう。
1次元回帰での実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features, prior_sigma=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_mu = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
self.weight_rho = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-5, -4))
self.bias_mu = nn.Parameter(
torch.Tensor(out_features).uniform_(-0.2, 0.2))
self.bias_rho = nn.Parameter(
torch.Tensor(out_features).uniform_(-5, -4))
self.prior_sigma = prior_sigma
self.log_prior = 0
self.log_variational_posterior = 0
def forward(self, x):
weight_sigma = torch.log1p(torch.exp(self.weight_rho))
bias_sigma = torch.log1p(torch.exp(self.bias_rho))
weight = self.weight_mu + weight_sigma * torch.randn_like(weight_sigma)
bias = self.bias_mu + bias_sigma * torch.randn_like(bias_sigma)
self.log_prior = (-0.5 * torch.sum(weight**2) / self.prior_sigma**2
- 0.5 * torch.sum(bias**2) / self.prior_sigma**2)
self.log_variational_posterior = (
-0.5 * torch.sum(((weight - self.weight_mu) / weight_sigma)**2)
- torch.sum(torch.log(weight_sigma))
- 0.5 * torch.sum(((bias - self.bias_mu) / bias_sigma)**2)
- torch.sum(torch.log(bias_sigma)))
return F.linear(x, weight, bias)
class BayesianNN(nn.Module):
def __init__(self):
super().__init__()
self.l1 = BayesianLinear(1, 50)
self.l2 = BayesianLinear(50, 50)
self.l3 = BayesianLinear(50, 1)
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
return self.l3(x)
def log_prior(self):
return self.l1.log_prior + self.l2.log_prior + self.l3.log_prior
def log_variational_posterior(self):
return (self.l1.log_variational_posterior
+ self.l2.log_variational_posterior
+ self.l3.log_variational_posterior)
# データ生成
torch.manual_seed(42)
np.random.seed(42)
def true_function(x):
return np.sin(2 * x) + 0.3 * x
N = 40
x_train = np.sort(np.random.uniform(-2, 2, N))
y_train = true_function(x_train) + np.random.normal(0, 0.2, N)
x_tensor = torch.FloatTensor(x_train).reshape(-1, 1)
y_tensor = torch.FloatTensor(y_train).reshape(-1, 1)
# 学習
model = BayesianNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
n_epochs = 2000
losses = []
for epoch in range(n_epochs):
model.train()
optimizer.zero_grad()
# ELBO損失
n_mc = 3 # モンテカルロサンプル数
total_loss = 0
for _ in range(n_mc):
output = model(x_tensor)
nll = 0.5 * torch.sum((y_tensor - output)**2) / N
kl = (model.log_variational_posterior() - model.log_prior()) / N
total_loss += nll + 0.1 * kl # KLの重みを調整
loss = total_loss / n_mc
loss.backward()
optimizer.step()
losses.append(loss.item())
# 予測(複数回フォワードパスで不確実性を推定)
model.eval()
x_test = np.linspace(-4, 4, 200)
x_test_tensor = torch.FloatTensor(x_test).reshape(-1, 1)
T = 100 # サンプル数
predictions = np.zeros((T, len(x_test)))
with torch.no_grad():
for t in range(T):
predictions[t] = model(x_test_tensor).numpy().flatten()
pred_mean = predictions.mean(axis=0)
pred_std = predictions.std(axis=0)
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 左: BNNの予測と不確実性
ax = axes[0]
ax.plot(x_test, pred_mean, 'b-', linewidth=2, label='BNN予測平均')
ax.fill_between(x_test, pred_mean - 2*pred_std, pred_mean + 2*pred_std,
alpha=0.2, color='blue', label='95%信頼区間')
ax.plot(x_test, true_function(x_test), 'k--', linewidth=1.5,
label='真の関数')
ax.scatter(x_train, y_train, c='red', s=30, zorder=5, label='訓練データ')
# 訓練データの範囲を示す
ax.axvspan(-2, 2, alpha=0.05, color='green')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('ベイズニューラルネットワークの予測')
ax.legend(fontsize=9, loc='upper left')
ax.set_ylim(-4, 4)
# 右: 複数のサンプルによる予測
ax = axes[1]
for t in range(20):
ax.plot(x_test, predictions[t], 'b-', alpha=0.1, linewidth=0.8)
ax.plot(x_test, true_function(x_test), 'k--', linewidth=1.5,
label='真の関数')
ax.scatter(x_train, y_train, c='red', s=30, zorder=5, label='訓練データ')
ax.axvspan(-2, 2, alpha=0.05, color='green')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('重みサンプルごとの予測(20回)')
ax.legend(fontsize=10)
ax.set_ylim(-4, 4)
plt.tight_layout()
plt.savefig("bnn_regression.png", dpi=150, bbox_inches="tight")
plt.show()
BNNの予測結果から、以下の重要な特徴が読み取れます。
-
訓練データ領域($-2 \leq x \leq 2$)では不確実性が小さい: BNNはデータがある領域では予測に自信を持ち、狭い信頼区間を示しています。これはデータが重みの事後分布を強く制約しているためです
-
訓練データ外($x < -2$ や $x > 2$)では不確実性が急激に増大: データがない領域では、重みの事後分布の不確実性が予測に大きく影響し、信頼区間が急激に広がります。これがBNNの最も重要な性質であり、「知らないことを知っている」能力を表しています
-
右のグラフ(個別サンプル): 20回の異なる重みサンプルによる予測を重ねて描画すると、訓練データの範囲内では各予測がほぼ一致していますが、範囲外では大きくばらついています。このばらつきの大きさが、そのまま予測の不確実性に反映されています
通常のニューラルネットワークでは、$x = 4$ のような外挿点でも1つの確定的な予測を返すだけですが、BNNは「この領域にはデータがないため予測は不確実です」と正直に報告しているのです。
認識的不確実性と偶然的不確実性
2種類の不確実性
BNNの予測の不確実性は、2つの異なる源泉に分解できます。
認識的不確実性(epistemic uncertainty): データ不足から生じる不確実性です。データが増えれば減少します。BNNでは、重みの事後分布の広がりとして表現されます。「もっとデータがあれば解消できる不確実性」です。
偶然的不確実性(aleatoric uncertainty): データ自体に内在するノイズから生じる不確実性です。データを増やしても減りません。測定誤差や本質的なランダム性に由来します。「原理的に解消できない不確実性」です。
数学的な分離
回帰問題で、モデルの出力が平均 $\mu(\bm{x}; \bm{w})$ と分散 $\sigma^2(\bm{x}; \bm{w})$ の両方を出す場合、全不確実性は以下のように分解できます。
全分散の法則を適用すると、
$$ \underbrace{\text{Var}[y_* | \bm{x}_*]}_{\text{全不確実性}} = \underbrace{\mathbb{E}_{\bm{w}}[\sigma^2(\bm{x}_*; \bm{w})]}_{\text{偶然的不確実性}} + \underbrace{\text{Var}_{\bm{w}}[\mu(\bm{x}_*; \bm{w})]}_{\text{認識的不確実性}} $$
偶然的不確実性は各重みサンプルでの予測分散の平均であり、認識的不確実性は予測平均のばらつきです。
分離の実装と可視化
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features, prior_sigma=1.0):
super().__init__()
self.weight_mu = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
self.weight_rho = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-5, -4))
self.bias_mu = nn.Parameter(
torch.Tensor(out_features).uniform_(-0.2, 0.2))
self.bias_rho = nn.Parameter(
torch.Tensor(out_features).uniform_(-5, -4))
self.prior_sigma = prior_sigma
def forward(self, x):
weight_sigma = torch.log1p(torch.exp(self.weight_rho))
bias_sigma = torch.log1p(torch.exp(self.bias_rho))
weight = self.weight_mu + weight_sigma * torch.randn_like(weight_sigma)
bias = self.bias_mu + bias_sigma * torch.randn_like(bias_sigma)
return F.linear(x, weight, bias)
class HeteroscedasticBNN(nn.Module):
"""偶然的不確実性も学習するBNN"""
def __init__(self):
super().__init__()
self.shared1 = BayesianLinear(1, 50)
self.shared2 = BayesianLinear(50, 50)
self.mean_head = BayesianLinear(50, 1)
self.var_head = BayesianLinear(50, 1) # 対数分散を出力
def forward(self, x):
h = F.relu(self.shared1(x))
h = F.relu(self.shared2(h))
mean = self.mean_head(h)
log_var = self.var_head(h)
return mean, log_var
# 不均一ノイズを持つデータ
torch.manual_seed(42)
np.random.seed(42)
N = 80
x_train = np.sort(np.random.uniform(-3, 3, N))
# ノイズの大きさがxに依存(不均一分散)
noise_level = 0.1 + 0.3 * np.abs(x_train)
y_train = np.sin(x_train) + noise_level * np.random.randn(N)
x_tensor = torch.FloatTensor(x_train).reshape(-1, 1)
y_tensor = torch.FloatTensor(y_train).reshape(-1, 1)
# 学習
model = HeteroscedasticBNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
for epoch in range(3000):
model.train()
optimizer.zero_grad()
mean_pred, log_var_pred = model(x_tensor)
var_pred = torch.exp(log_var_pred)
# 異分散正規分布の負の対数尤度
nll = 0.5 * torch.mean(log_var_pred + (y_tensor - mean_pred)**2 / var_pred)
nll.backward()
optimizer.step()
# 予測
model.eval()
x_test = np.linspace(-5, 5, 200)
x_test_tensor = torch.FloatTensor(x_test).reshape(-1, 1)
T = 200
means = np.zeros((T, len(x_test)))
log_vars = np.zeros((T, len(x_test)))
with torch.no_grad():
for t in range(T):
m, lv = model(x_test_tensor)
means[t] = m.numpy().flatten()
log_vars[t] = lv.numpy().flatten()
# 不確実性の分解
pred_mean = means.mean(axis=0)
aleatoric = np.exp(log_vars).mean(axis=0) # E[sigma^2(x)]
epistemic = means.var(axis=0) # Var[mu(x)]
total = aleatoric + epistemic
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 左: 予測と2種類の不確実性
ax = axes[0]
ax.plot(x_test, pred_mean, 'b-', linewidth=2, label='予測平均')
ax.fill_between(x_test,
pred_mean - 2*np.sqrt(epistemic),
pred_mean + 2*np.sqrt(epistemic),
alpha=0.4, color='orange', label='認識的不確実性')
ax.fill_between(x_test,
pred_mean - 2*np.sqrt(total),
pred_mean + 2*np.sqrt(total),
alpha=0.2, color='blue', label='全不確実性')
ax.plot(x_test, np.sin(x_test), 'k--', linewidth=1.5, label='真の関数')
ax.scatter(x_train, y_train, c='gray', s=15, alpha=0.5)
ax.axvspan(-3, 3, alpha=0.05, color='green')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('予測と不確実性の分解')
ax.legend(fontsize=9)
ax.set_ylim(-4, 4)
# 右: 各不確実性のプロファイル
ax = axes[1]
ax.plot(x_test, np.sqrt(aleatoric), 'r-', linewidth=2, label='偶然的')
ax.plot(x_test, np.sqrt(epistemic), 'orange', linewidth=2, label='認識的')
ax.plot(x_test, np.sqrt(total), 'b-', linewidth=2, label='合計')
ax.axvspan(-3, 3, alpha=0.05, color='green', label='訓練データ範囲')
ax.set_xlabel('x')
ax.set_ylabel('標準偏差')
ax.set_title('不確実性のプロファイル')
ax.legend(fontsize=10)
plt.tight_layout()
plt.savefig("bnn_uncertainty_decomposition.png", dpi=150, bbox_inches="tight")
plt.show()
不確実性の分解から、以下の重要な洞察が得られます。
-
左のグラフ: オレンジの帯(認識的不確実性)は訓練データ範囲内では狭く、範囲外で急激に広がっています。一方、青の帯(全不確実性)は訓練データ範囲内でも一定の幅を持っています。この差が偶然的不確実性であり、データのノイズに由来する還元不能な不確実性です
-
右のグラフ: 偶然的不確実性(赤)は訓練データの範囲内で $|x|$ に比例して増加する傾向があります。これは、データ生成時にノイズレベルを $0.1 + 0.3|x|$ と設定したことを正確に捕捉しています。認識的不確実性(オレンジ)はデータの外側で急増しており、モデルがデータのない領域について不確実であることを示しています
-
2つの不確実性の使い分け: 認識的不確実性が大きい入力に対しては「もっとデータを集めるべき」と判断でき、能動学習に利用できます。偶然的不確実性が大きい入力に対しては「データを増やしても予測精度は改善しない」と判断でき、測定精度の改善やモデルの改良が必要です
最後に、BNNの重要な応用例であるOOD検出を実装しましょう。
OOD検出への応用
分布外データの検出
BNNの認識的不確実性は、訓練データの分布外(Out-of-Distribution, OOD)の入力を検出するための自然な指標になります。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight_mu = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
self.weight_rho = nn.Parameter(
torch.Tensor(out_features, in_features).uniform_(-5, -4))
self.bias_mu = nn.Parameter(
torch.Tensor(out_features).uniform_(-0.2, 0.2))
self.bias_rho = nn.Parameter(
torch.Tensor(out_features).uniform_(-5, -4))
def forward(self, x):
ws = torch.log1p(torch.exp(self.weight_rho))
bs = torch.log1p(torch.exp(self.bias_rho))
w = self.weight_mu + ws * torch.randn_like(ws)
b = self.bias_mu + bs * torch.randn_like(bs)
return F.linear(x, w, b)
class BNNClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, n_classes):
super().__init__()
self.l1 = BayesianLinear(input_dim, hidden_dim)
self.l2 = BayesianLinear(hidden_dim, hidden_dim)
self.l3 = BayesianLinear(hidden_dim, n_classes)
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
return self.l3(x)
# 2次元の2クラス分類データを生成
np.random.seed(42)
torch.manual_seed(42)
n_per_class = 100
# クラス0: 中心 (-1, -1)
x0 = np.random.randn(n_per_class, 2) * 0.5 + np.array([-1, -1])
# クラス1: 中心 (1, 1)
x1 = np.random.randn(n_per_class, 2) * 0.5 + np.array([1, 1])
X = np.vstack([x0, x1]).astype(np.float32)
y = np.concatenate([np.zeros(n_per_class), np.ones(n_per_class)]).astype(np.int64)
X_tensor = torch.FloatTensor(X)
y_tensor = torch.LongTensor(y)
# 学習
model = BNNClassifier(2, 50, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1000):
model.train()
optimizer.zero_grad()
n_mc = 3
loss = 0
for _ in range(n_mc):
logits = model(X_tensor)
loss += F.cross_entropy(logits, y_tensor)
loss = loss / n_mc
loss.backward()
optimizer.step()
# OOD検出のための予測不確実性マップ
model.eval()
grid_range = np.linspace(-4, 4, 100)
xx, yy = np.meshgrid(grid_range, grid_range)
grid_points = np.column_stack([xx.ravel(), yy.ravel()]).astype(np.float32)
grid_tensor = torch.FloatTensor(grid_points)
T = 50
all_probs = np.zeros((T, len(grid_points), 2))
with torch.no_grad():
for t in range(T):
logits = model(grid_tensor)
probs = F.softmax(logits, dim=1).numpy()
all_probs[t] = probs
# 不確実性指標: 予測エントロピー
mean_probs = all_probs.mean(axis=0)
predictive_entropy = -np.sum(mean_probs * np.log(mean_probs + 1e-10), axis=1)
# 認識的不確実性(相互情報量)
entropy_per_sample = -np.sum(all_probs * np.log(all_probs + 1e-10), axis=2)
expected_entropy = entropy_per_sample.mean(axis=0)
mutual_information = predictive_entropy - expected_entropy
# 可視化
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# 左: 予測クラス
ax = axes[0]
pred_class = mean_probs[:, 1].reshape(xx.shape)
c = ax.contourf(xx, yy, pred_class, levels=20, cmap='RdBu_r', alpha=0.8)
plt.colorbar(c, ax=ax, label='P(class=1)')
ax.scatter(x0[:, 0], x0[:, 1], c='blue', s=15, alpha=0.6, label='Class 0')
ax.scatter(x1[:, 0], x1[:, 1], c='red', s=15, alpha=0.6, label='Class 1')
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_title('予測クラス確率')
ax.legend()
# 中央: 予測エントロピー(全不確実性)
ax = axes[1]
entropy_map = predictive_entropy.reshape(xx.shape)
c = ax.contourf(xx, yy, entropy_map, levels=20, cmap='hot_r', alpha=0.8)
plt.colorbar(c, ax=ax, label='エントロピー')
ax.scatter(x0[:, 0], x0[:, 1], c='blue', s=15, alpha=0.6)
ax.scatter(x1[:, 0], x1[:, 1], c='red', s=15, alpha=0.6)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_title('予測エントロピー(全不確実性)')
# 右: 認識的不確実性(相互情報量)
ax = axes[2]
mi_map = mutual_information.reshape(xx.shape)
c = ax.contourf(xx, yy, mi_map, levels=20, cmap='hot_r', alpha=0.8)
plt.colorbar(c, ax=ax, label='相互情報量')
ax.scatter(x0[:, 0], x0[:, 1], c='blue', s=15, alpha=0.6)
ax.scatter(x1[:, 0], x1[:, 1], c='red', s=15, alpha=0.6)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_title('認識的不確実性(OOD検出)')
plt.tight_layout()
plt.savefig("bnn_ood_detection.png", dpi=150, bbox_inches="tight")
plt.show()
OOD検出のグラフから、以下の重要な区別が見て取れます。
-
左(予測クラス確率): 2つのクラスの間に決定境界があり、各クラスの中心付近では高い確信度の予測が得られています。しかし、この図だけでは「データから遠い領域」の問題が見えません
-
中央(予測エントロピー): 全不確実性を表す予測エントロピーは、決定境界付近とデータから遠い領域の両方で高い値を示しています。しかし、この2つの不確実性の源泉は異なります。決定境界付近のエントロピーは偶然的不確実性(どちらのクラスか本質的に曖昧)が支配的であり、遠い領域のエントロピーは認識的不確実性(データがないため判断不能)が支配的です
-
右(認識的不確実性 = 相互情報量): 認識的不確実性のみを抽出すると、訓練データから遠い領域で高い値を示し、訓練データの近くでは低い値を示しています。決定境界付近の不確実性は比較的低く抑えられています。これがOOD検出に最も適した指標であり、「データがあるが判断が難しい」領域と「データ自体がない」領域を明確に区別できます
認識的不確実性(相互情報量)を閾値として用いることで、OODデータを効果的に検出できます。この能力は、安全性が重要なAIシステムにおいて不可欠です。
まとめ
本記事では、ベイズニューラルネットワーク(BNN)の理論と実装について解説しました。
- 通常のNNの問題: 点推定された重みによる予測は不確実性を表現できず、OODデータに対して過信した予測を返す
- BNNの核心: 重みを確率分布として扱い、予測時にパラメータの不確実性を積分で織り込む。これにより「知らないことを知っている」モデルが実現される
- 変分推論(Bayes by Backprop): 真の事後分布を変分分布で近似し、ELBOを最大化する。再パラメータ化トリックにより標準的なバックプロパゲーションで学習可能
- MCドロップアウト: 推論時にドロップアウトを適用し複数回フォワードパスを実行するだけで、BNNの近似が得られる手軽な手法
- 不確実性の分解: 認識的不確実性(データ不足)と偶然的不確実性(本質的ノイズ)を分離でき、それぞれ異なるアクションに繋がる
- OOD検出: 認識的不確実性(相互情報量)がOODデータの検出に有効であり、安全なAIシステムの基盤技術となる
BNNは「深層学習にベイズの原理を持ち込む」という意味で、不確実性を伴う意思決定が必要なあらゆる分野で重要性が増しています。計算コストの高さという課題はありますが、MCドロップアウトのような実用的な近似手法や、ハードウェアの進歩によって、適用範囲は着実に広がっています。
次のステップとして、以下の記事も参考にしてください。
- ベイズ予測分布とは?事後分布から未知データを予測する理論と実装 — 予測分布の一般理論をより深く
- ガウス過程回帰を初めから分かりやすく解説 — BNNの無限幅極限としてのGP
- PyMCで学ぶ確率的プログラミング — ベイズモデリングの実践入門 — より手軽にベイズモデリングを実践
- ベイズ最適化の理論 — ガウス過程と獲得関数でブラックボックス関数を効率的に最適化する — 不確実性を活用した最適化