世界一分かりやすいnp.meshgridの使い方 (メッシュグリッド)

Posted: , Category: Numpy

2次元平面に何かを図示しようとしたとき、Pythonのnumpyのmeshgrid(メッシュグリッド)が頻繁に利用されます。一方で、meshgrid関数はなんとなく雰囲気で使っていて、実際に自分の思うように格子点を作りたい時に、あれどうなってるんだっけ?と思うことも多いとは思います。

今回は、numpyのmeshgrid()関数の動作について、分かりやすく解説していきます。

numpyのmeshgrid()は、格子点を作る関数

numpyのmeshgrid()は、格子点を作る関数です。ここで意識してほしいのは、格子点を作るだけで、値を保存する変数は、numpyのmeshgrid()とは別に用意する必要性があります。

まずは、numpyでmeshgrid()でどのような変数が返ってくるのかを確認しましょう。

meshgridで格子点を作る

meshgrid()では次のような格子点を作ることができます。

ここでは簡単のため、原点が(0, 0)で、X座標は5, Y座標は3までで、間隔が1の格子点を考えています。

import numpy as np

x = np.arange(0, 6, 1)
y = np.arange(0, 4, 1)
X, Y = np.meshgrid(x, y)

meshgridを利用すると、このようなコードで格子点のリストを作ることができます。

ここで、X, Yはそれぞれ次のようになっています。

自分の経験上ですが、6×4 = 24 点の格子点を作ったので、サイズが24のリストが返ってくるかと思いきや、X, Y と変数が2つに増えてるので、よくわからん…となってしまいます。

このmeshgrid()の返り値で返ってくる、X, Y とは何者なのか、ここからみていきましょう。

meshgridの返り値

先ほど、meshgridは2つの変数が返り値として返ってくるといいましたが、この2つの返り値の中身はこのようになっています。

まず最初の返り値は、meshgridで構築した格子点のX座標の集合、2番目の返り値は、Y座標の集合となっています。それぞれの中身はこのようになっています。

ビジュアライズすると分かりやすいですが、X, Yがそれぞれ作った格子点のx, y座標の値に対応していることがわかります。今回作った格子が、6×4のサイズであるため、X, Yも6×4のサイズになっています。

格子上の値を扱うには、別の変数が別途必要

meshgridでできるのは、格子点のリストだけなので、その格子状の値を扱うには、別の変数が必要になります。この値を入れる変数も、格子の点に対応した6×4のサイズの配列が必要になるわけです。

meshgridを利用して、2次元ガウス分布を描く

今回は例題として、meshgridを利用して、多次元(2次元)ガウス分布を、三次元空間で可視化します。

コードは下記のようになっています、。

matplotlibのplot_surface関数の引数に注目すると、meshgridで作成したx座標のリストX、y座標のリストY、そして、値の入ったリストZを引数に渡しているのがわかります。

# https://scipython.com/blog/visualizing-the-bivariate-gaussian-distribution/

import numpy as np
import matplotlib.pyplot as plt
from   matplotlib import cm
from   mpl_toolkits.mplot3d import Axes3D
from   scipy.stats import multivariate_normal

N = 60
X = np.linspace(-3, 3, N)
Y = np.linspace(-3, 4, N)
X, Y = np.meshgrid(X, Y)

mu = np.array([0., 1.])
Sigma = np.array([[ 1. , -0.5], [-0.5,  1.5]])

pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

F = multivariate_normal(mu, Sigma)
Z = F.pdf(pos)

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True, cmap=cm.viridis)

plt.show()

この例だと、scipyのmultivariate_normalのpdfを呼び出すのに、X, Yの座標を別の形に変形しているので注意してください。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術