機械学習の分野を勉強していると、一度は目にすることがあるディリクレ分布 (Dirichlet distribution)
カテゴリカル分布や多項分布のパラメータを表現する共役事前分布になる性質もあり、ベイズ推定の分野でも度々登場するディリクレ分布ですが、一度も扱ったことがない人には、どのような分布なのかイメージがつかない人も多いでしょう。一般的にベータ分布を多次元に拡張したもの、と言われることもありますが、このような説明を一度きいても理解できる人はまずいないでしょう。
しかし、ディリクレ分布自体は、ディリクレ過程混合モデルやLDAといった近年よく用いられているアルゴリズムの前提になっており、非常に重要な確率分布と言えるので、一度どのような分布になっているのか理解できると良いでしょう。
今回はディリクレ分布についてPythonを使って図示しながら、わかりやすく解説します。また統計量として期待値や分散などを解説していきます。
- ディリクレ分布 (Dirichlet distribution)の定義
- ディリクレ分布 を可視化する
ディリクレ分布の定義式
まず最初に、ディリクレ分布の定義を示します。
\begin{equation} \begin{split} Dir(\bm{\pi} | \bm{\alpha}) = C(\bm{\alpha}) \prod_{k=1}^{K} \pi_k^{\alpha_k-1} \end{split} \end{equation}
ここで、ディリクレ分布の確率変数$\bm{\pi}$はK次元のベクトルで、その各要素$\pi_k$は、$\pi_k \in (0, 1)$かつ、$\sum_{k=1}^{K} \pi_k = 1$を満たします。
ディリクレ分布自体のパラメータは、$\bm{\alpha}$で、この$\bm{\alpha}$もK次元のベクトルであり、各要素は$\bm{\alpha} = (\alpha_1, \alpha_2, … , \alpha_K)^T $となっています。
また、$\bm{\alpha}$のk番目の要素は$\alpha_k > 0$の範囲を取ります。
(1)式に登場する$C$は、ベータ分布やガンマ分布に登場したような正規化項で、この項によって(1)式の確率密度関数の総和が1になることが保証されます。
ディリクレ分布を扱う上で、定数部分になりあまり意識することはありませんが、Cは下記のような関数形になっています。
\begin{equation} \begin{split} C = \frac{\Gamma(\alpha_0)}{\Gamma(\alpha_1)\Gamma(\alpha_2) \dots \Gamma(\alpha_K)} \end{split} \end{equation}
ここで、$\Gamma(x)$は、ベータ分布やガンマ分布にも登場するガンマ関数です。
ガンマ関数については、こちらの記事で詳しく取り扱っているので、こちらをご覧ください。
(2)式に登場する$\ahpla_0$は下記の省略になっています。
\begin{equation} \begin{split} \alpha_0= \sum_{k=1}^{K} \alpha_k \end{split} \end{equation}
(2)式と(3)式を全て(1)式に組み込むと、ディリクレ分布は下記のような定義式になります。
\begin{equation} \begin{split} Dir(\bm{\pi} | \bm{\alpha}) = \frac{\Gamma(\sum_{k=1}^{K}\alpha_k)}{\Gamma(\alpha_1) \dots \Gamma(\alpha_K)} \prod_{k=1}^{K} \pi_k^{\alpha_k-1} \end{split} \end{equation}
ディリクレ分布のイメージ
ここまで天下り的に、ディリクレ分布の数式について書き下してきましたが、サイコロをイメージするとディリクレ分布の理解に役立ちます。
今、出目の出る確率が一律ではない、イカサマサイコロを考えます。出目の数は$K$個とすると、k番目の出目の出る確率は、$\pi_k$で表現できます。
ディリクレ分布は、この$\pi_k$を決めるための、分布ということになります。
ディリクレ分布とベータ分布の関係性
ディリクレ分布は、ベータ分布の関係を一般化したものです。
この関係は、コインの表と裏が出る確率分布を示す、ベルヌーイ分布を一般的に$K$次元に拡張したカテゴリカル分布との関係性に似ています。
実際、(4)式において、$K=2$にすることで、ベータ分布の確率密度関数に一致することがわかります。
ディリクレ分布を可視化してみる
ディリクレ分布を可視化してみましょう。可視化のコードはこちらのサイトを参考にさせていただきました。サイコロの例のように6次元で可視化してみたいところですが、6次元データをプロットするのは難しいので3次元でプロットします。
まずライブラリ等の準備を行います。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
AREA = 0.5 * 1 * 0.75**0.5
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=4)
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))
続いて、ディリクレ分布を描写する関数などを記述します。
def xy2bc(xy, tol=1.e-4):
coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
return np.clip(coords, tol, 1.0 - tol)
class Dirichlet(object):
def __init__(self, alpha):
from math import gamma
from operator import mul
self._alpha = np.array(alpha)
self._coef = gamma(np.sum(self._alpha)) / \
np.multiply.reduce([gamma(a) for a in self._alpha])
def pdf(self, x):
from operator import mul
return self._coef * np.multiply.reduce([xx ** (aa - 1)
for (xx, aa)in zip(x, self._alpha)])
def draw_pdf_contours(dist, nlevels=200, subdiv=8, **kwargs):
import math
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=subdiv)
pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs)
plt.axis('equal')
plt.xlim(0, 1)
plt.ylim(0, 0.75**0.5)
plt.axis('off')
コードは参考サイトからの引用になっています。
ディリクレ分布を描写する
ライブラリとコードのインポートが終わったら、ディリクレ分布を描写してみます。
今回確認するのは、$\alpha$を変えたときに、$\pi_1, \pi_2, \pi_3$の確率分布がどのように変わるかです。
$\pi_1, \pi_2, \pi_3$と三変数ありますが、実際には制約条件として、$\pi_1 + \pi_2 + \pi_3 = 1$を満たすので、2次元で表現することができます。
$\alpha_1 = \alpha_2 = \alpha_3 = 1$ を与えてみるとこのようになります。
draw_pdf_contours(Dirichlet([1, 1, 1]))
$\pi_1, \pi_2, \pi_3$の値が一律になることがわかります。
draw_pdf_contours(Dirichlet([3, 3, 3]))
中心付近($\pi_1, \pi_2, \pi_3$の確率が等しくなっている近傍)の確率が高くなることがわかります。
ディリクレ分布の期待値と分散
ディリクレ分布の期待値と分散はこのようになっています。
ディリクレ分布の期待値
\begin{equation} \begin{split} \mathop{E}[\pi_k] = \frac{\alpha_k}{\alpha_0} \end{split} \end{equation}
ディリクレ分布の分散
\begin{equation} \begin{split} Var [\pi_k] = \frac{\alpha_k(\alpha_0 - \alpha_k)}{\alpha_0^2(1+\alpha_0)} \end{split} \end{equation}