k-means法(k平均法)のアルゴリズムを実装して完全に理解する

Posted: , Category: クラスタリング , 機械学習

k-means法は機械学習の参考書に掲載されているような基本的なクラスタリングのアルゴリズムです。

よく、教師なし学習の代表例としても教科書等で扱われることがあり、またk-meansのアルゴリズム自体が機械学習分野で登場するEMアルゴリズムといったパラメータ学習方法に近しい手法を用いていることもあり、k-meansを学ぶことは機械学習の手法を学ぶ上で非常に重要と言えます。

また実際の課題を解こうとする際に、主成分分析(PCA)で低次元空間に次元変換した上で、k-meansでクラスタリングするなどをすることもあります。

今回は、k-meansのアルゴリズムを勉強することを目的とし、scikit-learn等のライブラリを実装せずにスクラッチでk-meansを実装して学んでいきます。本書の内容は、機械学習の著名な参考書であるPRMLの下巻を参考にしています。本記事を参考にPRMLを読むとより理解が深まると思います。

本記事の内容
  • クラスタリングのアルゴリズムであるk-meansの考え方とアルゴリズムを学ぶ
  • scikit-learn等のライブラリを用いず、k-meansをゼロから実装し、動かすことで理解を深める

k-means(k平均法)の 目的

k-meansの目的は、正解ラベルが与えられていないデータセットが与えられた時に、それをクラスタリングして似ているデータセットを分類するアルゴリズムです。

クラスタリング問題を解決するアルゴリズムだと考えてもらって良いでしょう。

k-means(k平均法)のアルゴリズム

では、k-meansのアルゴリズムについて説明をしていきます。

先ほど述べたように、k-meansは、$D$次元の$N$個のデータを、特定のクラス(今回はクラス数を$K$個とします)うまく分類して、それぞれのデータを$K$カテゴリに割り振るクラスタリングするアルゴリズムです。

いま、手元にあるN個のデータを、$\bm{x} = \{ \bm{x_1}, \bm{x_2}, \dots, \bm{x_N} \}$とします。

k-meansでは、以下の損失関数$J$を最小化するように動作します。

k-meansにおける損失関数
\begin{equation}
\begin{split}
J = \sum_{n=1}^{N} \sum_{k=1}^{K} r_{nk} {\|\bm{x_n} - \bm{\mu_k} \|}^2
\end{split}
\end{equation}

ここで、$\bm{\mu_k}$はクラスタ $k$の中心点(重心位置)を示すベクトルである。

この(1)式で表される損失関数を最小にするように、$r_{nk}$と$\mu_k$を求めているのが、k-meansのアルゴリズムです。

ここで、$r_nk$は、$r_nk \in {0, 1}$の値を取り、$n$番目のデータ$\bm{x_n}$がクラス$k$に属する場合は1を、属さない場合は、0をとる変数です。

今回は、データが$n$個あり、振り分けるクラスが$k$個あるので、$r_nk$は$n x k$個存在することになります。

k-meansの分類アルゴリズム

k-meansは次の2つのステップを繰り返し実行することによって、先ほど提示した損失関数$J$を最小化するようなアルゴリズムです。

k-meansにおける2ステップ

ステップ1

重心 $\mu_k$を固定し、損失関数$J$を最小化するような$r_nk$を決定する

ステップ2

$r_nk$を固定し、損失関数$J$を最小化するような$\mu_k$を決定する

ステップ1,2ともに損失関数$J$を最小化する方向に$r_nk$と$\mu_k$を選んでいるので、$J$が最小化するイメージはつくかなと思います。(ただし、実際最小値になるものの、局所解に陥る可能性はあります)

k-means(k平均法)を実際に実装してみる

Pythonを用いて、実際にk-meansを実装してみましょう。

アルゴリズム自体がかなり簡潔な上、numpyの関数を駆使することでかなり行数が少なく実装することができます。

まず最初に必要なライブラリをインストールします。

コメントアウトしているnbaggは、Jupyter notebook上でアニメーションを動かすためのライブラリです。Jupyter上で動作させる人は有効にしても問題ありません。

import numpy as np
import scipy
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# %matplotlib nbagg 

import seaborn as sns
sns.set(style="whitegrid", palette="muted", color_codes=True)
sns.set_style("whitegrid", {'grid.linestyle': '--'})

np.random.seed(2251) # データの再現性のため

続いて、今回クラスタリングするための人口のデータを生成します。

今回は、2次元のデータをクラスタリングすることにします。3種類の2次元ガウス分布を定義し、この3種類のガウス分布からデータを200点ずつサンプリングします。

data1 = np.array(stats.multivariate_normal.rvs( mean=[1, 4], cov=np.asanyarray ([[0.1, 0.2], [0.2, 0.1]]), size=200 ))
data2 = np.array(stats.multivariate_normal.rvs(mean=[4, 4], cov=np.asanyarray ([[0.1, 0.4], [0.4, 0.1]]), size=200))
data3 = np.array(stats.multivariate_normal.rvs(mean=[3, 0], cov=np.asanyarray ([[0.4, -1.1], [-1.1, 0.4]]), size=200))

続いて、これらのデータを一旦可視化してみてみましょう。

colors = ['m', 'c', 'g', 'r', 'y']
plt.figure(figsize=(6, 4), dpi=120)

plt.scatter(data1[:, 0], data1[:, 1], s=20, c=colors[0], alpha=0.5)
plt.scatter(data2[:, 0], data2[:, 1], s=20, c=colors[1], alpha=0.5)
plt.scatter(data3[:, 0], data3[:, 1], s=20, c=colors[2], alpha=0.5)

各ガウス分布から生起されるデータがこのようになることがわかりました。

続いてこれらのデータを全部統合した上で、k-meansでクラスタリングをしてみます。

data = np.concatenate([data1, data2, data3])

plt.figure(figsize=(6, 4), dpi=120)
plt.scatter(data[:, 0], data[:, 1], s=10, c="gray", alpha=0.5)

先ほどの画像を全て統合して1つのデータセットにしました。この画像をクラスタリングしていきます。

今回は、$K=3$ カテゴリでクラスタリングをしていきます。カテゴリ数Kは任意で設定できますが、今回は3つのガウス分布に基づいたデータセットを用いたので、K=3でやってみますが、k-meansのアルゴリズム的にはクラスタ数は任意でできるので、Kの数を変えて結果がどう変わるか確認してみると面白いです。

# 初期値の選択
K = 3
centers = np.array([[np.random.uniform(0, 6), np.random.uniform(-4, 6)] for i in range(K)])

plt.figure(figsize=(6, 4), dpi=120)
plt.scatter(data[:, 0], data[:, 1], s=10, c="gray", alpha=0.5)

# 最初の平均値を描写
for i, mu in enumerate(centers):
    plt.scatter(mu[0], mu[1], s=30, c=colors[i])    

まず最初に、基準となる重心位置を決めました。この最初の重心位置は本当に適当です。今回のコードの場合は、[np.random.uniform(0, 6), np.random.uniform(-4, 6)] で、x軸の範囲が0~6、y軸の値が-4~6 の間に収まるようにしました。この初期値でさえ本来はどこに設定しても良いのですが、初期値の場所があまりに遠すぎると、収束までの回数がかかってしまうので、今回はこの範囲に設定しました。

実際にはどこに設定しても収束することになるので、収束具合が初期値にどの程度依存するか気になる人はこの制約を取っ払ってコードを動かしてみてください。

続いて、k-meanのアルゴリズムの中身を実装します。

関数の名前がanimateになっていますが、これは後のコードでアニメーションする都合でこのような名前になっていますが、k-meansのアルゴリズムとは本質的に何も関係がありません。

animate関数の中身が、1回のK-meansの実施に相当しています。

def animate(nframe):
    print("current is {} the frame".format(nframe))
    plt.clf()
    
    old_centers = centers.copy()
    
    # Step1: 各データ点を最も近い重心のクラスを割り当て
    label_nearest = []
    for _, d in enumerate(data):
        dist = [np.linalg.norm(d - mu) for _, mu in enumerate(centers)]
        label_nearest.append(np.argmin([dist]))
    
    # Step2: 重心を算出 (平均の逐次算出)
    for index, label in enumerate(label_nearest):
        centers[label] += ((data[index] - centers[label]) / (index + 1)) 
    
    # データをプロット
    for k in range(K):
        plt.scatter(data[np.where(np.array(label_nearest) == k)][:, 0], data[np.where(np.array(label_nearest) == k)][:, 1], s=10, c=colors[k], alpha=0.5)
    
    # 重心をプロット
    for index, center in enumerate(centers):
        plt.scatter(center[0], center[1], s=100, c=colors[index], marker='o', ec="black")
        
    diff = np.linalg.norm(old_centers - centers)
    if (diff <= 0.0001):
         plt.title('k-means is converged.')
    else:
        plt.title("iter:{}".format(nframe + 1))

コードにStepとしてコメントしていますが、k-meansでやっていることは、k-meansのアルゴリズムの説明時に提示した、次の2ステップを繰り返し実行していることになります。

k-meansの1回のイテレーションでやっていること
  • Step 1. 各データ点を、最も近い重心のクラスを割り当て
  • Step 2. Step1で割り当てたクラス毎に重心を計算
fig = plt.figure(figsize=(6, 4), dpi=120)
anim = animation.FuncAnimation(fig, animate, frames=100)
anim.save("k-means.gif", writer="imagemagick")

最後に、プロットしてみるとこのようになります。

最初に正解ラベルを作ったように、うまくクラスタリングができていますね。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術