pytorchは、ニューラルネットワークを実装するために非常に便利な関数やクラスがたくさんあります。
また、他の人が独自に作った深層学習モデルを読むのに、pytorchでよく使われるような関数やにクラスについては、ある程度理解しておく必要性があります。
今回は、初心者が必見のpytorchで理解しておきたい基本的な関数やクラスを紹介します。またpytorch 上でのデータの基本構造である、tensorを操作する関数についても便利なものをまとめていきます。
ニューラルネットワークの基本的なクラス
ご存じのとおり、ニューラルネットワークは下の画像のように多数の層によって構成されています。
pytorchでは、このニューラルネットワークの層を実装するための基本的なクラスとして、torch.nn.Module というクラスがあります。
torch.nn.Module
基本的にpytorchで新しい層を実装する際には、torch.nn.Moduleを継承したクラスを実装することになります。また、torch.nn.Moduleの中では、順伝搬(forward propagation)を記述するために、forward関数を実装することになります。
実際に学習をする際(backword時)には、torch.nn.Moduleを継承したクラスのforwardが呼ばれ、損失関数の値の計算じにその結果が利用されます。
torch.nn.Parameter
pytorchで学習するパラメータは、torch.nn.Parameterで定義します。
torch.nn.Parameterの初期値には、初期化したTensorを渡してあげることで、パラメータを準備することができます。
Tensor関連で覚えておきたい関数
pytorchでTensorを操作する関数として、代表的なものに、下記のものがあります。
- view
- transpose
- reshape
- permute
- squeeze
- unsqueeze
この辺りをしっかり押さえておきましょう。
view
viewは、Tensorを所定の形に変形する関数です。numpyにおけるreshapeとほぼ同じ働きをするので、使い方について苦戦することはないと思います。
transpose
transposeはTensorを転置する関数です。実際にモデルを組む際には、結構な頻度で使います。
reshape
viewと同じように、Tensorを所定の形式に変形する際に利用します。
permute
permuteは、任意のTensorの値を軸に沿って入れ替える関数です。
squeeze
squeeze関数は、引数に与えたれたTensorから次元が1の次元を削除する関数です。numpyにあるsqueeze関数と同じ役割を担います。
import torch
a = torch.tensor([[1, 2, 3]])
a.size()
# => torch.Size([1, 3])
a = a.squeeze()
# a => tensor([1, 2, 3])
a.size()
# => torch.Size([3])
unsqueeze
unsqueeze関数は、引数に与えられたTensorから次元を1つ関数です。numpyにあるexpand_dimsという関数と同じ役割をします。