今回は、サポートベクトルマシン(Support Vector Machine, SVM)を用いて、MNISTの画像分類に取り組んでみます。
サポートベクトルマシンは理論を勉強するとなかなか難易度が高いのですが、Pythonのライブラリであるscikit-learnを利用することで、実装自体は非常に簡単に行うことができます。
今回は、手書き文字のデータセットであるMNISTを用いて、手書き文字の分類をMNISTでやっていきましょう。
- サポートベクトルマシン(SVM)でMNIST画像を分類
- SVMの分類結果を可視化、混同行列で評価
分類するデータセットMNIST
早速始める前に、今回サポートベクトルマシンで分類するデータセットを準備してきましょう。
MNISTのデータセットの解説やダウンロード方法は、こちらの記事で詳しく解説しています。

まずデータセットを準備する前に、今回の実装で使うライブラリ群をインポートします。
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
続いて、データセットを準備していきます。
データセット自体は、scikit-learnのdatasetsモジュールを利用することで、非常に簡単に入手することができます。
mnist = fetch_openml('mnist_784', version= 1, as_frame= False)
x, y = mnist["data"], mnist["target"]
scikit-learn経由で入手できるMNISTデータは、サイズ 28×28 で合計784px で枚数は合計70,000枚程あります。
これらのデータを全てサポートベクトルマシンで学習させても良いのですが、今回はアルゴリズムを検証する目的なので、計算時間を短くするために少なめのデータ数で学習を行います。
また、訓練データに用いたデータを検証では使いたくないので、訓練用のデータと検証用のデータを分割します。
分割するために自分で実装しても簡単にできますが、よくsckit-learnのmodel_selection モジュールに含まれている train_test_split関数を利用することが多く、train_test_split関数は非常に便利なので、今回はtrain_test_split関数を用いて訓練データとテストデータの分割を行います。
seed = 1
train_size = 5000
test_size = 100
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=train_size, test_size=test_size, random_state=seed, stratify=y)
train_test_split関数は、指定した枚数の画像をランダムに分割してくるので、random_state引数にseed値を与えることで、毎回同じ分類をするようにしています。
stratifyは、データの偏りがないようにするために指定しています。(例えば、MNISTの場合だと数字0ばっかりのデータだけに偏ってしまうと、正しい学習や検証ができないので)
続いてデータの正規化を行います。正規化とは、全ての値が0-1の間に収まるようにすることです。
正規化は、今回のSVMだけでなく、ニューラルネットワークやロジスティック回帰といったアルゴリズムなど、機械学習で予測問題を解く際によく利用します。
今回は、MNISTのデータは0~255の間の値になっているので、全てのピクセル値を255で割ることで、正規化することができます。
x_train = x_train /255
x_test = x_test / 255
ここまでで、MNISTのデータの準備ができました。続いて、これらのデータを用いてSVMを用いて手書き文字の分類をやっていきます。
サポートベクトルマシン(SVM)でMNISTを分類
サポートベクトルマシンのモデルを学習してみましょう。scikit-learnを用いると非常に簡単で、次のようになります。
svm = SVC()
svm.fit(x_train, y_train)
これでモデルの学習ができました。次に検証データでテストしてみます。
predicted = svm.predict(x_test)
score = svm.score(x_test, y_test)
print("score is {}".format(score))
# => score is 0.95
このコードを実行すると、スコアが0.95となり、95%の確率で予測できていることがわかります。
なかなかの結果ですね。実際に検証データとその予測結果をみていきます。
cols = 6
rows = 5
num = rows * cols
fig, ax = plt.subplots(rows, cols, figsize=(10, 10), tight_layout=True)
for i in range(num):
r = i // cols
c = i % cols
ax[r][c].imshow(x_test[i ].reshape([28, 28]), cmap="Greys", interpolation="nearest")
ax[r][c].set_title("target: {}, predict: {}".format(y_test[i], predicted[i]))
plt.show()

ある程度、うまく分類できていますね。分類問題でよく利用する、混同行列で、どのあたりがうまく分類でいている箇所、でいていない箇所を確認してみましょう。
cm = confusion_matrix(y_test, predicted)
sns.heatmap(cm, square=True, cbar=True, cmap='Blues')

混同行列をヒートマップで可視化するようになりました。
縦軸が正解ラベルで、横軸が予測ラベルです。ざっくりですが、2と8の画像を間違うことや、3の画像を1や8と誤分類することが多少あることがわかりました。