ガウス過程からランダムに関数をサンプリングする

Posted: , Category: 機械学習

ガウス過程(Gaussian Process)はよく、関数$\bm{f}$を出力する箱であると例えられることがあります。しかし、関数を出力する箱と言われても、ぶっちゃけ意味がわからないと思います。

ガウス過程のイメージ図としては、次のようになります。

データセット$\mathcal{D} = { (\bm{x_1}, y_1),(\bm{x_2}, y_2), \cdots, (\bm{x_N}, y_N) }$が存在するとき、ガウス過程$\mathcal{GP}$により、それに対応する回帰関数が確率的にいくつも生成されるというイメージです。

よくあるような単回帰や重回帰では、データセットからパラメータ$\bm{w}$を1つに決定するため、回帰直線も1つに定まります。一方、ガウス過程では、これらの回帰直線を1つに定めずに、確率的に考えるため、データセットに対応する関数を無限に生成することができます。

・・・・

と、ここまで説明しても頭の中にはてな?が並ぶだけだと思うので、この記事ではその理論について解説した後に、実装を通して、ガウス過程が関数を出力する箱であることのイメージを掴んでもらえたらと思います。

ガウス過程は無限次元のガウス分布

先ほど、ガウス過程は、関数$\bm{f}$を確率的に生成する箱と書きました。これは、つまるところ、無限次元のガウス分布ということになります。

厳密な定義とは違いますが、関数を言葉で表現してみると、このようになります。

一般的を関数を表現すると

ある一定の集合に属する、任意の$x$において、対応する値$y$を決める$f$を関数と呼び、次のように表現する。

\begin{equation}
y = f(x)
\end{equation}

ここで、先ほどの表現における無限次元とは、(1)の入力$x$に相当します。つまり、関数は、任意の$x$に対して値$y$を返すものですから、ガウス過程における入力は無限になるわけです。

ガウス過程から関数をランダムにサンプリングする

それでは、ガウス過程から関数をランダムにサンプリングしていきましょう。

ガウス過程は、カーネル関数によって、挙動が変わります。今回は次の4種類のカーネル関数を用いて、ガウス過程からランダムにサンプリングすることにします。

今回実装する4つのカーネル関数
  • ガウシアンカーネル
  • 線形カーネル
  • 指数カーネル
  • 周期カーネル

それぞれ、実装は次のようになります。

# 4種類のカーネルの実装
def gaussian_kernel(theta1, theta2, x1, x2):
    return theta1 * np.exp( -1 * np.abs(x2 - x1) ** 2 / theta2)

def linear_kernel(x1, x2):
    return x1 * x2

def exponential_kernel(theta, x1, x2):
    return np.exp( -1 * np.abs(x2 - x1) / theta)

def periodic_kernel(theta1, theta2, x_1, x_2):
    return np.exp(theta1 * np.cos(np.abs(x2 - x1) / theta2))

ガウス過程で関数をサンプリング

それでは、ガウス過程で関数をサンプリングしていきましょう。

まず、必要なライブラリをインポートします。

import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["image.cmap"] = "Blues"

続いて、ガウスカーネルを用いて、ガウス過程から関数をサンプリングしてみます。

サンプリングはそれぞれ5回行い、かつガウスカーネルのパラメータを変更して試してみます。コードはこのようになります。

x = np.linspace(-10, 10, 200)
x1, x2 = np.meshgrid(x, x)

mean =  np.zeros(len(x))

cols = 3
fig, ax = plt.subplots(nrows=2, ncols=cols, figsize=(10, 6), dpi=120)

params = [0.1, 1, 10]
for nc in range(cols):
    theta2 = params[nc]
    for k in range(5):
        gram_matrix = gaussian_kernel(1,  theta2,  x1,  x2)
        sample = np.random.multivariate_normal(mean, gram_matrix)
        ax[0, nc].plot(x, sample, label=f'Sample {k}')
        ax[0, nc].set_title(r"$\theta$ = {}".format(theta2))
    ax[1, nc].imshow(gram_matrix)

グラフの1行目は、ガウスカーネルのパラメータを3種類変更し、ガウス過程からサンプリングした関数群で、下のグラフが対応するグラム行列の値を可視化したものとなっています。

これを見ると、$x_1, x_2$の値が近いと、グラム行列の値、つまり、特徴空間上での内積の値が大きいことがわかります。また、ガウスカーネルにおけるパラメータを大きくすると、ガウス過程でサンプリングされる曲線が滑らかになることがわかります。

指数カーネルを用いてサンプリング

続いて、指数カーネルに変更してサンプリングし、サンプリングした関数とグラム行列をみてみましょう。

# 指数カーネルでサンプリング
x = np.linspace(-10, 10, 200)
x1, x2 = np.meshgrid(x, x)

mean =  np.zeros(len(x))

cols = 3
fig, ax = plt.subplots(nrows=2, ncols=cols, figsize=(10, 6), dpi=120)

params = [0.1, 1, 10]
for nc in range(cols):
    theta = params[nc]
    for k in range(5):
        gram_matrix = exponential_kernel(theta,  x1,  x2)
        sample = np.random.multivariate_normal(mean, gram_matrix)
        ax[0, nc].plot(x, sample, label=f'Sample {k}')
        ax[0, nc].set_title(r"$\theta$ = {}".format(theta))
    ax[1, nc].imshow(gram_matrix)

先ほどより、かなり角張った(滑らかではない)グラフが得られていることがわかります。また、ガウスカーネルを利用した時と同様に、パラメータの値を大きくすると、関数が滑らかになっていることがわかります。

周期カーネルでサンプリング

続いて周期カーネル(periodic kernel)を用いてみます。

# 周期カーネルでサンプリング
x = np.linspace(-10, 10, 200)
x1, x2 = np.meshgrid(x, x)

mean =  np.zeros(len(x))

cols = 3
fig, ax = plt.subplots(nrows=2, ncols=cols, figsize=(10, 6), dpi=120)

params = [0.1, 1, 10]
for nc in range(cols):
    theta = params[nc]
    for k in range(5):
        gram_matrix = periodic_kernel(1, theta,  x1,  x2)
        sample = np.random.multivariate_normal(mean, gram_matrix)
        ax[0, nc].plot(x, sample, label=f'Sample {k}')
        ax[0, nc].set_title(r"$\theta$ = {}".format(theta))
    ax[1, nc].imshow(gram_matrix)

周期カーネルを利用すると、今までのカーネルと違った傾向が見えてきます。

グラム行列に注目すると、$x_1, x_2$の入力が近くなくても、特徴空間での内積が大きくなっているのがわかります。

線形カーネル

最後に線形カーネルでのガウス過程のサンプリングを確かめてみます。直線が引けるだけのサンプリングとなりました。

# 線形カーネルでサンプリング
x = np.linspace(-10, 10, 200)
x1, x2 = np.meshgrid(x, x)

mean =  np.zeros(len(x))

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8, 3), dpi=120)

for k in range(5):
    gram_matrix = linear_kernel(x1,  x2)
    sample = np.random.multivariate_normal(mean, gram_matrix)
    ax[0].plot(x, sample, label=f'Sample {k}')
ax[1].imshow(gram_matrix)

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術