【機械学習】正規分布で可視化することでKL情報量を理解する

Posted: , Category: 数学 , 機械学習 , 確率 , 統計学

機械学習を理解し使いこなす上で、KL情報量について理解することは避けては通れないといえるほど、KL情報量は非常に重要な概念です。

そのため、さまざまな参考書等で、KL情報量について説明されています。

だいたい次のような定義とともに、KL情報量が紹介されていると思います。

\begin{equation}
\begin{split}
KL[q(x)|p(x)] &= - \int q(x) ln \frac{p(x)}{q(x)}dx  \\
&= \int q(x) ln \frac{q(x)}{p(x)}dx 
\end{split}
\end{equation}

ある程度KL情報量について詳しい人ならともかく、この式を見て、覚えたとして、KL情報量について十分理解できた気になる人などいないと思います。

一方で、機械学習ではこのKL情報量は本当に重要なので、十分に理解する必要性があります。そこで本記事では、KL情報量を理解するために、正規分布(1次元ガウス分布)を例にとり、Pythonでうまく可視化しながら、KL情報量について解説していきます。

また、同時に、なぜKL情報量が情報なのかについて機械学習の文脈から解説してきます。

本記事の内容
  • 機械学習でKL情報量が重要な理由について解説
  • 正規分布の例でKL情報量を可視化しながら理解する

機械学習でKL情報量が重要な理由

ここは本記事の内容とは直接関係ないので、次の章に飛ばしてしまっても全く問題ありません。

機械学習分野、特にベイズ理論ではよく、データが与えられた時のパラメータ$\bm{Z}$の 事後分布$p(\bm{Z} | \bm{X})$求めたいが、解析的に求めることができない状況がほとんどです。

そのような問題では、事後分布$p(\bm{Z} | \bm{X})$をMCMC等のサンプリングによって求めることもありますが、場合によっては変分推論(Variational Inference)等で求めることが多々あります。

変分推論は、現代の機械学習の手法で、非常に重要な理論の1つで、求めたい分布$p(\bm{Z} | \bm{X})$に対し、これと似ている確率分布$q(\bm{Z})$ (pではなくqになっていることに注意)を考えます。

そして、この$p(\bm{Z} | \bm{X})$と$q(\bm{Z})$のカルバックライブラー情報量を最小化することで、求めたい事後分布$p(\bm{Z} | \bm{X})$を近似する最適な$q(\bm{Z})$を考えます。

KL情報量についてより詳しく知りたい人は、こちらの記事でわかりやすく解説しているので、ぜひこちらの記事もご覧ください。

情報理論や確率変数のKLダイバージェンス(カルバック・ライブラー情報量)を解説
KLダイバージェンス(Kullback-Leibler divergence, KL情報量)は、2つの確率分布の距離を表す統計量として、統計学や機械学習分野で頻出の統計量となっています。 KLダイバージェンスを$D_{K […]

2つの正規分布間のKL情報量を可視化する

ここまで、前座として長々と、KL情報量の重要性について語ってきました。

続いて本題に入っていきましょう。ここでは、2つの正規分布(1次元ガウス分布)を考えます。

\begin{equation}
\begin{split}
p(x) &= \mathcal{N(\mu_1, \sigma_1)}  \\
&= 
\frac{1}{\sqrt{2\pi \sigma_1^2}} exp 
\biggl \{
- \frac{(x-\mu_1)^2}{2\sigma_1^2}
\biggr \}
\end{split}
\end{equation}
\begin{equation}
\begin{split}
q(x) &= \mathcal{N(\mu_2, \sigma_2)}  \\
&= \frac{1}{\sqrt{2\pi \sigma_2^2}} exp 
\biggl \{
- \frac{(x-\mu_2)^2}{2\sigma_2^2}
\biggr \}
\end{split}
\end{equation}

これらの2つの確率分布$p(x)$と$q(x)$のKL情報量を考えていきます。

(1)で示した定義より、2つの正規分布間のKL情報量は次のように表される。

\begin{equation}
\begin{split}
\operatorname{KL}[q(x)|p(x)] &=  \int q(x) ln \frac{q(x)}{p(x)}dx  \\
&=   \int  q(x)lnq(x)dx - \int q(x)lnp(x)dx\\
&=  \space \dots \\
&= ln \biggl (\frac{\sigma_2}{\sigma_1} \biggr ) +
\frac{\sigma_1^2  + (\mu1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2}

\end{split}
\end{equation}

ここで心苦しくも(4)の式変形は行数が膨大になるために、省略しました。

PRML上巻の1章に2つの正規分布のKL情報量を求める問題が掲載されているので、途中の式変形について知りたい人はそちらをご覧ください。

PythonでKL情報量を可視化

(4)式でKL情報量を表現できたので、これらを可視化して見ていきましょう。

(4)式を見ると残念ながら、変数が$\mu_1, \mu_2, \sigma_1, \sigma_2$の4つもあるので、変数を固定して1つだけ動かしていく形で可視化していきます。

$mu_1$以外全て固定して、平均$mu_1$を動かす場合

まず、$p(x)$の平均、$\mu_1$が変数で、それ以外が1であるとします。すると、(4)式は、

\begin{equation}
\begin{split}
\operatorname{KL}[q(x)|p(x)] &=  ln \biggl (\frac{\sigma_2}{\sigma_1} \biggr ) +
\frac{\sigma_1^2  + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2} \\
&= \frac{(\mu_1 -1)^2}{2}
\end{split}
\end{equation}

となります。これを描写すると次のようになります。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-3, 5,100)
k_dist = (x -1 )**2 / 2

fig, ax = plt.subplots(figsize=(4,4))
ax.plot(x, k_dist, lw=1)
ax.set_xlabel("$\mu_1$")
ax.set_ylabel("KL distance")

非常にシンプルな図で恐縮ですが、$\mu_1$を変えた時のKL情報量の値は次のようになります。$\mu_1=1$の時には、$p(x) =q(x)$となるので、KL情報量の値も0になっていることがわかります。

$sigma_1$以外全て固定して、標準偏差$sigma_1$を動かす場合

同様に、標準偏差$sigma_1$を動かしてみます。

\begin{equation}
\begin{split}
\operatorname{KL}[q(x)|p(x)] &=  ln \biggl (\frac{\sigma_2}{\sigma_1} \biggr ) +
\frac{\sigma_1^2  + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2} \\
&= ln \frac{1}{\sigma_1}  + \frac{\sigma_1^2 - 1}{2}
\end{split}
\end{equation}

これを可視化すると次のようになります。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0.01, 5,100)
k_dist = np.log(1 / x) + ( x ** 2 - 1 ) / 2

fig, ax = plt.subplots(figsize=(4,4))
ax.plot(x, k_dist, lw=1)
ax.set_xlabel("$\sigma_1$")
ax.set_ylabel("KL distance")

これも同様に、$sigma_1 = 1$の時は、$p(x) =q(x)$となり、KL情報量の値が0になっていることがわかります。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術