torchvisionはPytorchに含まれるライブラリの一種で、主に深層学習や機器学習で用いることができる画像や動画等のデータセットを手軽に準備したり、様々な形式に変換するための関数群などが含まれたツールセットです。
今回はこの、torchvisonを用いて、MNISTのデータセットを例に、torchvisionでデータセットを準備する方法について解説します。
- torchvisionで利用できるデータセット
- torchvisionでデータセットを準備する
torchvisionで利用できるデータセット
torchvisionでは、主に画像分類やセグメンテーションなどに利用できる、データセットが多く含まれています。torchvisionでデータセットを用意する場合、datasets モジュールを介してこれらをダウンロードするここになります。
まず最初に、torchvisionのdatasetsモジュールをインポートして、どのようなデータセットがあるか確認してみましょう。
次のようなコードを実行することで、含まれているデータセットを確認できます。
from torchvision import datasets
dir(datasets)
こちらがJupyter Notebookの実行結果ですが、このように表示されます。
深層学習や画像系のデータの扱いに慣れている人にとっては、CIFAR10やCoCoDetectionなど見慣れたデータセットが多数あると思います。
torchvisionでデータを利用する方法
さて、ここからメインの内容の解説に入っていきます。MNISTのデータセットを例に、torchvisionでデータを扱う方法について学んでいきます。
今までこのようなデータセットのダウンロード等をscikit-learn等で行なってきた人にとっては勝手が違うと戸惑うかもしれませんが、慣れればとても簡単なので気合を入れていきましょう。
まず、コードの全文を記載します。
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
mnist_train = MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST("./data", train=False, download=True, transform=transforms.ToTensor())
img_size = 28
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
ここまでで、torchvisionでMNISTのデータをローカルにダウンロードし、trainとtestまで振り分け、さらにMNISTの画像をテンソル形式に変換し、さらにバッチサイズを良い感じに設定することができました。