【初心者】覚えておきたいpytorchの基本(関数やクラス)

Posted: , Category: pytorch , 深層学習

pytorchは、ニューラルネットワークを実装するために非常に便利な関数やクラスがたくさんあります。

また、他の人が独自に作った深層学習モデルを読むのに、pytorchでよく使われるような関数やにクラスについては、ある程度理解しておく必要性があります。

今回は、初心者が必見のpytorchで理解しておきたい基本的な関数やクラスを紹介します。またpytorch 上でのデータの基本構造である、tensorを操作する関数についても便利なものをまとめていきます。

ニューラルネットワークの基本的なクラス

ご存じのとおり、ニューラルネットワークは下の画像のように多数の層によって構成されています。

pytorchでは、このニューラルネットワークの層を実装するための基本的なクラスとして、torch.nn.Module というクラスがあります。

torch.nn.Module

基本的にpytorchで新しい層を実装する際には、torch.nn.Moduleを継承したクラスを実装することになります。また、torch.nn.Moduleの中では、順伝搬(forward propagation)を記述するために、forward関数を実装することになります。

実際に学習をする際(backword時)には、torch.nn.Moduleを継承したクラスのforwardが呼ばれ、損失関数の値の計算じにその結果が利用されます。

torch.nn.Parameter

pytorchで学習するパラメータは、torch.nn.Parameterで定義します。

torch.nn.Parameterの初期値には、初期化したTensorを渡してあげることで、パラメータを準備することができます。

Tensor関連で覚えておきたい関数

pytorchでTensorを操作する関数として、代表的なものに、下記のものがあります。

  • view
  • transpose
  • reshape
  • permute
  • squeeze
  • unsqueeze

この辺りをしっかり押さえておきましょう。

view

viewは、Tensorを所定の形に変形する関数です。numpyにおけるreshapeとほぼ同じ働きをするので、使い方について苦戦することはないと思います。

transpose

transposeはTensorを転置する関数です。実際にモデルを組む際には、結構な頻度で使います。

reshape

viewと同じように、Tensorを所定の形式に変形する際に利用します。

permute

permuteは、任意のTensorの値を軸に沿って入れ替える関数です。

squeeze

squeeze関数は、引数に与えたれたTensorから次元が1の次元を削除する関数です。numpyにあるsqueeze関数と同じ役割を担います。

import torch
a = torch.tensor([[1, 2, 3]])
a.size()
# => torch.Size([1, 3])
a = a.squeeze()
# a => tensor([1, 2, 3])
a.size()
# => torch.Size([3])

unsqueeze

unsqueeze関数は、引数に与えられたTensorから次元を1つ関数です。numpyにあるexpand_dimsという関数と同じ役割をします。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術