混合ガウスモデル(Gaussian Mixture Model, GMM)は、教師なしの分類アルゴリズムであり、クラスタリングの手法の1つです。他のクラスタリングのアルゴリズムとしてk-meansがありますが、k-meansと比較して、データの分布の傾向(分散)を取り入れたのが、混合ガウスモデルだとよく言われます。
イメージとしてはこちらの画像を見ると混合ガウスモデルのイメージが湧くと思います。
混合ガウスモデルのアルゴリズムを用いることで、左側のようなデータの分布を与えたときに、右側のようにデータのクラスタリングと、そのクラスタを生成していると仮定されるガウス分布を得ることができます。右側のグラフは、ガウス分布の確率の等高線を描いており、これによって、そのデータが生成される確率がどの程度かを定量的に評価することができます。
この記事では、まず混合ガウスモデルがどのような機械学習モデルであるかをわかりやすく解説し、その後、有名なアヤメデータセットであるIrisデータセットを用いて、混合ガウスモデルを実際に実装し、上記のようなクラスタリングを行います。
- 混合ガウスモデル(GMM)をわかりやすく解説
- Irisデータセットを用いてGMMを実装する
混合ガウスモデルをわかりやすく
混合ガウスモデルは混合モデルの1種
混合ガウスモデルは、重み付きの異なるガウス分布の線型結合からなる、混合モデル(mixture model)として表現できます。混合ガウスモデルを用いることで、多峰性(multimodal)の分布を表現することができます。
わかりやすく図解すると次のようになります。
左側は正規分布のグラフで、1つの山があるグラフとなっています。一方右側が、薄い線で描いている3つの正規分布を足し合わせたグラフになっています。赤い実線の分布に注目すると、山が2つあるような分布になっていることがわかります。
このように混合モデルを用いることで、よく使われるようなガウス分布やベータ、ガンマ分布のような確率分布では表現できないようなデータの分布を表現できるようになります。
一般的に、上記のような$K$個のガウス分布が複数重ね合わさった確率分布を混合ガウス分布(Mixture Gaussian Distribution)といい、次のような式で定義されます。
\begin{equation} p(x | \bm{\pi}, \bm{\mu}, \bm{\Sigma}) = \sum_{k=1}^{K} \pi_k \mathcal{ N}(x | \mu_k, \Sigma_k) \end{equation}
ここで、$\pi_k $はk番目のガウス分布の混合比率を表すパラメータであり、次の式を見たす。
\begin{split} \sum _{k=1}^{K}\pi_k = 1 \\ 0< \pi_k < 1 \end{split}
また、$\bm{\pi} = [ \pi_1, \pi_2, \cdots, \pi_K ]$であり、$\mu_k, \Sigma_k$はそれぞれk番目のガウス分布の平均と分散である。
(1)の定義式から、混合ガウス分布のパラメータは、$\bm{\pi}, \bm{\mu}, \bm{\Sigma}$であることがわかります。
通常だとこのパラメータの推定には、データから対数尤度関数を計算し、それを各パラメータで偏微分という形になりますが、(1)式のような確率分布の関数に$\Sigma$の形式が入っている場合は、通常の西遊推定はできず、EMアルゴリズムという手法を使うことになります。
EMアルゴリズムの詳細については、こちらの記事をご覧ください。
今回は、scikit-learnを用いてガウス混合モデルのフィッティングを行いますが、scikit-learnでもEMアルゴリズムによる最適化を行なっています。このあたりは、少し難しい概念なのですが、混合エキスパートモデルなどの機械学習でよく使われているモデルを実装する際などに必要なので、余裕がある人は理解して実装できるようになると良いでしょう。
今回は、GMMを使って実際のデータをクラスタリングすることを主眼としているため、この詳しいパラメータ決定のアルゴリズムの詳細までは立ち入りません。
Irisデータセットを混合ガウスモデルでクラスタリング
今回は、Irisデータセットを用いて、混合ガウスモデルのクラスタリングを行なっていきましょう。実際に理論を説明しましたが、GMMの実装自体はscikit-learnを用いることで非常に簡単に実行することができます。
Irisデータセットを準備して次元削減を行う
まず、今回はクラスタリングするデータセットとしてIrisデータを用います。Irisデータセットの説明はこちらをご覧ください。
データセットを準備します。
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
Irisデータセットはデータ点数が150点で、系列が4つある以下のようなデータになります。
混合ガウスモデルを利用する際は、データの次元はいつくでも構いませんが、今回はガウス分布の等高線を可視化してわかりやすくするために、2次元に落とします。そのために主成分分析を行います。主成分分析のコードは下記のようになります。
pca = PCA(n_components=2)
data = df[["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]].values
trans_pca = pca.fit_transform(data)
fig, ax = plt.subplots()
cmap = plt.get_cmap("tab10")
color = [cmap(label) for label in df.label]
ax.set_xlabel("PCA axis 1")
ax.set_ylabel("PCA axis 2")
for i in range(3):
ax.scatter(trans_pca[df[df.label == i].index, 0], trans_pca[df[df.label == i].index, 1], color=cmap(i), s=10, marker="o", label="Label: {}".format(i))
ax.legend()
主成分分析した結果を可視化するとこのようになります。この辺りは、本記事の主題ではないのでサクッといきます。主成分分析の詳しい解説はこちらに掲載しています。
IrisデータセットでGMMでクラスタリング
ここまで準備することができたら、実際に混合ガウスモデルを用いて、データのクラスタリングを行なっていきましょう。Irisデータセットは正解ラベルがあるので、先ほど掲載したように、データがうまくこのようにクラスタリングできていれば、GMMがうまく行っていることになります。
GMMはスクラッチでも実装できますが、最適化計算のところで、EMアルゴリズムや変分推論など数学的に少し難解なアルゴリズムを必要とするため、今回はscikit-learnで既に実装されているモジュールを利用します。
GMM自体は、sckit-learnを用いることで、学習までわずか数行の次のコードで行うことができます。trans_pcaは主成分分析によって2次元に次元削減したデータセットが格納されています。
from matplotlib.colors import LogNorm
from sklearn import mixture
gmm = mixture.GaussianMixture(n_components=3, covariance_type='full')
gmm.fit(trans_pca)
labels = clf.predict(trans_pca)
ここまでで、GMMを用いて推論をすることができました。GMMではハイパーパラメータとして、分類するクラスタ数を割り当てる必要性があります。あとは結果を可視化します。
x = np.linspace(-4, 4)
y = np.linspace(-2, 2)
X, Y = np.meshgrid(x, y)
XX = np.array([X.ravel(), Y.ravel()]).T
Z = - gmm.score_samples(XX)
Z = Z.reshape(X.shape)
fig,ax=plt.subplots(dpi=150,figsize=(5,4))
ax.scatter(trans_pca[:, 0], trans_pca[:, 1], s=0.5,c=labels)
cont = ax.contourf(X, Y, Z, norm=LogNorm(vmin=1.0, vmax=100.0), levels=np.logspace(-1, 3, 20), alpha=0.2, linestyles='dashed', linewidths=0.5)
ax.scatter(trans_pca[:, 0], trans_pca[:, 1], s=1, c=labels)
ax.set_title("GMMによるクラスタリングと等高線")
結果はこのようになりました。実際に正解ラベルと比較してみると、かなり高い精度で分類できていることがわかります。