【pytorch】ニューラルネットでMNISTの分類しながらpytorchの使い方を学ぶ

Posted: , Category: pytorch , 分類問題 , 深層学習

pytorchはディープラーニング用のライブラリです。ディープラーニング用のライブラリとして、他に、Kerasやtensorflowなどがありますが、pytorchは近年非常に利用され始めてきています。よく深層学習の最新論文等で、pytorchの実装が落ちていることもよくあるため、pytorchの利用方法について、理解しておくと良いでしょう。

今回は、pytorchでニューラルネットを実装しMNISTの分類を行う方法について解説していきます。

MNISTの分類は、深層学習のHello world的な題材なので、簡単な例題を通してpytorchに慣れていただければと思います。

なお本記事の対象としては、深層学習やニューラルネットワークの初心者を対象としています。

ある程度、Pytorchの利用方法に比較的重点を置いているので、ニューラルネットワークも1から学びたいんだけど、という人は、当サイトの別記事を参考にしてみて下さい。

例えば下の図を見たときに、何となく理解できる、という程度の人向けの解説記事となります。

pytorchでニューラルネットのMNIST分類器を実装する

それでは、早速、pythorchを用いたMNISTの画像分類していきましょう。

今回pytorchで実装するニューラルネットワークは、50次元の中間層が一層あって、入力層、出力層を含む全ての層が全結合(Fully Conntected)でつながっているNNを実装します。

ネットワークの全体像は次のようになっています。

また、実装する手順は次のようになっています。

ニューラルネットの実装手順
  • torchvisonでMNISTのデータセットを準備
  • ニューラルネットワークのモデルを実装
  • 損失関数とオプティマイザを実装
  • ニューラルネットモデルの学習
  • テストデータで、学習済みモデルで画像分類

早速、手順に沿ってニューラルネットの実装をしていきます。

各内容と共に、コードをかいつまみながら解説していきますが、本記事の最後に、コードの全文を掲載しますので、実装だけ知りたい人は、そちらをご覧ください。

torchvision を利用してデータセットを準備

まず最初に、torchvisionを利用してMNISTのデータセットを準備してきましょう。

MNISTは様々な機械学習アルゴリズムの検証で利用されており、MINSTデータを準備する方法はたくさんありますが、今回はpythorchの関連ライブラリであるtorchvisinを用いて準備します。

まず、MNISTのデータをダウンロードするコードを実行します。

from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor()
])

mnist_train = MNIST("./data", train=True, download=True, transform=transform)
mnist_test = MNIST("./data", train=False, download=True, transform=transform)

ここまででデータをダウンロードできました。

MNISTクラスの第一引数には、データの保存先のパスを指定しています。

また、オプション引数のtransformに渡すために、途中で、transformというオブジェクトを作っています。これは、torchvisonに含まれる前処理用のライブラリであるtransformsに含まれるToTensor関数を代入しています。

torchvisonから入手できるMNISTのデータはPIL(Python Image Library)形式になっているため、これを深層学習の入力に用いるためにTensor形式に変換しています。

ここで、mnist_trainとmnist_testは、torchvisionのdatasetsクラスのインスタンスとなっています。

type(mnist_train) # => torchvision.datasets.mnist.MNIST

続いて、準備したDataSetからDataLoaderを作成します。

DatLoaderは、バッチサイズを指定して与えることで作成することができ、DataLoaderを用いることで簡単にミニバッチ勾配降下法を実行することができるようになります。

DataLoaderを作成するためには、次にようにします。

batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

pytorchでニューラルネットワークモデルを実装

ここまででデータを準備できたので、ここからはニューラルネットモデルを実装していきます。

再掲になりますが、今回実装するニューラルネットは次にようになります。

50次元の中間層が1層あり、入力次元が784次元から 784→ 50→ 10 と出力が変化していき、最後は10次元の1-hot-encodingの形式で、出力ラベルを出力するニューラルネットです。

ニューラルネットの実装は次のようになります。

import torch
import torch.nn as nn
import torch.nn.functional as F

img_dim = 28 ** 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Net(nn.Module):
  def __init__(self, input_dim, output_dim):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(input_dim, 100)
    self.fc2 = nn.Linear(100, output_dim)

  def forward(self, x):
    x = x.view(-1, img_dim)
    x = F.relu(self.fc1(x))
    x = F.log_softmax(self.fc2(x), dim=1)
    
    return x

net = Net(img_dim, 10).to(device)

今回実装するモデルのクラスをNetクラスとして実装しました。Netクラスは、torch.nnモジュールに含まれている、nn.Moduleクラスを継承しています。

nn.Moduleは、pytorchでニューラルネットを実装する際にベースとなるクラスで、pytorchでニューラルネットを実装する場合は、全てのこのクラスを継承する必要性があります。

Netクラスでは、イニシャライザ(__init__メソッド)内で、層の構造を定義し、forwardメソッドの内部で、推論(順伝搬)の処理フローを実装していきます。

今回は、FC1層(Fully Connected の1層目)の出力の後に、活性化関数としてReLU関数を通し、その後log_softmaxで正解ラベルを出力します。

モデルの中身は、print関数でも見ることができます。

print(net)
# =>
Net(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=10, bias=True)
)

損失関数とオプティマイザを実装

続いて、ニューラルネットワークの学習に必要な、損失関数とオプティマイザを実装していきます。

実装と言っても、torchにはすでに簡単に利用できる損失関数やオプティマイザのAPIがあるので、それそ利用するだけです。今回は損失関数に交差エントロピー誤差を用い、オプティマイザには、SGD(確立的勾配降下法)を用います。

実装はこのようになります。

from torch import optim

loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

交差エントロピー誤差については、こちらの記事で詳しく解説しているので、合わせてご覧ください。

【損失関数】交差エントロピー誤差の概要と関数系を理解する
交差エントロピー誤差(Cross Entropy Error, Cross Entropy Loss)は、深層学習の分類問題で非常によく利用される損失関数です。 交差エントロピーと聞くと、初めて遭遇した人にとっては、もの […]

optim.SGDオブジェクト作成の際には、モデルのパラメータと学習率をlrオプション引数で渡しています。

モデルの訓練とテスト

ここまででようやくモデルの準備ができました。いよいよMNIST画像を用いて学習し、モデルの検証をしていきましょう。

訓練データセットで10回(エポック)学習し、その都度モデルの精度がどのように変わるか確認します。

num_epoch = 10

for i in range(num_epoch):

  # モデルの訓練
  net.train()
  loss_train = 0
  for j, (x, t) in enumerate(train_loader):
    x, t = x.cuda(), t.cuda()

    # optimizerを初期化
    optimizer.zero_grad()
    
    # モデルの出力値を計算
    y = net(x)

    # 損失を計算
    loss = loss_func(y, t)
    loss_train += loss.item()

    # 損失から勾配を計算
    loss.backward()

    # optimizerを更新
    optimizer.step()

  loss_train /= j + 1
  trace_loss_train.append(loss_train)

  # モデルの検証
  net.eval()
  loss_test = 0
  cnt_correct = 0

  for j, (x, t) in enumerate(test_loader):
    x, t = x.cuda(), t.cuda()

    # モデルの出力値を計算
    y = net(x) 

    # 損失関数を計算
    loss = loss_func(y, t)
    pred = y.argmax(1)

    # 損失を保存
    loss_test += loss.item()
    cnt_correct += pred.eq(t.view_as(pred)).sum().item()

  loss_test /= j+1
  trace_loss_test.append(loss_test)

  print("Epoch: {}, TrainLoss: {}, TestLoss:{}".format(i, loss_train, loss_test))
  print("Accurary: {}".format(float(100 * cnt_correct / len(mnist_test))))

出力はこのようになりました。

Epoch: 0, TrainLoss: 2.0389727409849776, TestLoss:1.689751812815666
Accurary: 75.47
Epoch: 1, TrainLoss: 1.3513769702708467, TestLoss:1.0231679156422615
Accurary: 80.89
Epoch: 2, TrainLoss: 0.8810299016059713, TestLoss:0.7230687990784646
Accurary: 84.5
Epoch: 3, TrainLoss: 0.674292935716345, TestLoss:0.584641245380044
Accurary: 86.39
Epoch: 4, TrainLoss: 0.5691809363821719, TestLoss:0.5070051297545433
Accurary: 87.71
Epoch: 5, TrainLoss: 0.506310485144879, TestLoss:0.4576300609856844
Accurary: 88.44
Epoch: 6, TrainLoss: 0.464679871214197, TestLoss:0.4246001744642854
Accurary: 88.99
Epoch: 7, TrainLoss: 0.4352785681156402, TestLoss:0.40008886009454725
Accurary: 89.49
Epoch: 8, TrainLoss: 0.4129511552922269, TestLoss:0.38186720553785564
Accurary: 89.77
Epoch: 9, TrainLoss: 0.39628829753145256, TestLoss:0.36700065322220327
Accurary: 90.01

エポックを増やすごとに、損失関数の値を減らすことができ、さらに検証データの正解率が上昇しているのがわかります。

参考文献

本記事の参考文献

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術