【PyG】PyTorch Geometricのインストール方法から利用方法まで解説

Posted: , Category: pytorch , 深層学習

PyG(PyTorch Geometric)は、PyTorchでグラフ構造を取り入れた、GCNやGATなどのニューラルネットワークを簡単に実装、学習、推論などができるライブラリです。

グラフ構造を有するオープンなデータセットへのAPIも多数あり、PyTorchでGNNのアルゴリズムを試す際には、ほぼPyG一択になるかと思います。

しかしそんな状況にもかかわらず、PyGに関する日本語での情報はほとんどないため、この記事ではPyGのインストール方法から基本的な利用方法まで説明します。

pytorch geometricで扱われるグラフやノードなどのデータ構造、またこれらの操作方法から、データセットAPIを用いて、世の中のオープンなデータセットを簡単に準備する方法まで、丁寧にまとめていきたいと思います。

基本的に公式ドキュメントの内容を丁寧にまとめる内容となっています。

本記事の内容
  • PyGのインストール方法
  • PyGでグラフ構造を扱う
  • PyG組み込みのデータセットAPIを利用しグラフデータを簡単に準備する方法

PyG(PyTorch Geometric)のインストール方法

まず、PyGをインストールする方法について説明していきます。PyG公式ドキュメントを参考にPyGをインストールしましょう。

現在のところ、PyGはPython3.7以上のPythonを必要としています。

PyGはかなり依存関係が複雑なので、諸々依存関係を解消しながらインストールしてくれるAnaconda経由でインストールするのが推奨されています。

Anacondaでインストールする場合は、次のコマンドでインストールします。

conda install pyg -c pyg

pip経由でインストールする場合は、まず、PyTorchのバージョンが1.12.0以上ある必要性があります。

また、最初にPytorchとCUDAのバージョンを確認します。PytorchとCUDAのバージョンを確認するには次のコマンドを実行します。

python -c "import torch; print(torch.__version__)"
>>> 1.13.0
python -c "import torch; print(torch.version.cuda)"
>>> 11.6

PytorchとCUDAのバージョンがわかったら次のコマンドで、PyGを関連ライブラリと一緒にインストールします。

pip install --verbose --no-cache-dir torch-scatter
pip install --verbose --no-cache-dir torch-sparse
pip install --verbose --no-cache-dir torch-cluster
pip install --verbose --no-cache-dir torch-spline-conv (optional)
pip install torch-geometric

もしこの辺りでエラーが出るようでしたら、自身の環境のPyTorchとCUDAに対応したインストール方法がPyG公式ドキュメントのインストールガイドに記載されているので、そちらをご確認ください。

PyGでグラフ構造を扱う

PyGでは、torch_geometric.data.Dataクラスのインスタンスでグラフ構造を扱います。

以降では、実際にPyGでグラフ構造を作ります。この具体例を通して、PyGでグラフ構造を扱う、torch_geometric.data.Dataクラスについて理解できるようになると思います。

この辺りは、公式ドキュメントのチュートリアルでも同様の内容がありますが、その内容をさらに噛み砕いてわかりやすくまとめていきます。

具体例を通してPyGのグラフ表現に慣れる

ではまず、PyGで下記のような有向グラフ(Directed Graph) を実装し、可視化していきましょう。

今回実装する有向グラフは次のようなものを考えます。

$V$がNode(頂点)の集合であり、$E$はEdge(辺)の集合に対応しています。

$e_1$ $e_2$ $e_3$は各辺であり、中身は、$(src, dist)$となっており、src → dist 向きの有向リンクを示しています。

$f_0, f_1, f_2$ は、各Node が持っている特徴量です。今回の設定では、Nodeが何らかの特徴量を持っている設定でグラフを構築していきます。

また、エッジも特徴量を持たせることができますが、今回はエッジは特徴量を持っていない設定でやります。

グラフ構造に慣れていないと少し難しく感じてしまうかもしれませんが、下記の記事で有向グラフについても詳しく解説しているので参考にしてみて下さい。

有向グラフ(Directed Graph)を隣接行列で表現する
有向グラフ(Directed Graph)は、計算機科学の分野では頻出のグラフのデータ構造です。 有向グラフを用いることで、実世界のさまざまなオブジェクトやオブジェクト間の関係性を表現することができます。 身近な例では、 […]

PyGでグラフを実装する

では、上記のようなグラフを実装してみましょう。グラフオブジェクトを作成する場合、torch_geometric.data.Dataクラスからインスタンスを作成します。

torch_geometric.data.Dataオブジェクトは、グラフ情報を管理するための次のような属性値と、グラフ操作をするための関数を持っています。Dataクラスの属性値としては以下のものを使います。

属性属性の説明
Data.x (Tensor)ノードの特徴量。ノードの数を$N$、特徴量の次元を$F$とすると、多変量の特徴量を保つので、$N \times F$ 
Data.edge_index (Tensor)$2 \times num\_edge $のサイズで、エッジの配列を保持。
Data.edge_attr (Tensor)オプション引数で、エッジが特徴量を保つ場合。エッジの特徴量を$H$ とすると、$ num\_edge \times H$
Data.y (Tensor)オプション引数で、ノードのラベルを保持。

では、実際にグラフを実装していきましょう。

import torch
from torch_geometric.data import Data

edge_from = [0, 0, 2]
edge_to = [2, 1, 1]
edge_index = torch.tensor([edge_from, edge_to], dtype=torch.long)
features = torch.tensor([[0, 1], [2, 3], [4, 5]])
labels = torch.tensor([0, 1, 2], dtype=torch.float)

data = Data(x=features, y=labels, edge_index=edge_index)
data # => Data(x=[3, 2], edge_index=[2, 3], y=[3])

torch_geometric.data.Dataクラスのインスタンスを実装することができました。

edge_index属性が、有効グラフのEdgeの情報を保持しており、yでNodeのラベル値、x が特徴量を保持しています。

意外と簡単に実装できましたね。

後はこのグラフ構造を可視化してみましょう。Pythonでグラフ構造の可視化をする場合は、networkXというライブラリが非常によく使われるので、今回もnetworkXを用いて可視化していきましょう。

PyGのDataクラスをnetworkXで可視化

PyGのDataクラスのオブジェクトをnetworkXで可視化するためには、PyGとnetworkX間で変換処理をする必要性があります。

この変換処理は、torch_geometric.utilsモジュールのto_networkx関数で変換することができます。

可視化のコードの全体は下記になります。

import networkx as nx
from torch_geometric.utils import to_networkx

nxg = to_networkx(data)
nx.draw(nxg, with_labels=True)

networkXで同様のグラフ構造を可視化することができました。

PyG組み込みのデータセットAPIを利用する

冒頭にも書きましたが、PyTorch Geometricには、データセット用のAPIがあり、簡単にグラフ構造のデータを準備することができます。

torchvision等を用いて画像系のデータセットを準備したことがある人なら、イメージしやすいかもしれませんが、torchvisonのdatasetsモジュールのような機能が、PyGにも搭載されているので利用してみましょう。

まずdatasetsモジュールを読み込み、どのようなデータがあるのか確認してみましょう。

from torch_geometric import datasets
dir(datasets)

こののような出力が表示されます。たくさんデータセットがありますね。今回は、定番のデータセットであるKarateClubを読み込んで、グラフ化してみましょう。

データセットを読み込んでからnetworkXで可視化するのは、上記で説明した通りです。

dataset = datasets.KarateClub()
data = dataset[0]

nxg = to_networkx(data)
nx.draw(nxg, with_labels=True)

無事可視化することができました。PyGのdatasetsモジュールの公式ドキュメントでは、他にも多くのデータセットを取り扱っているので、気になる人はぜひ閲覧してみることをお勧めします。

引用・参考文献

本記事の引用文献

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術