【GCN】pytorchでグラフ畳み込みネットワーク(GCN)を実装する

Posted: , Category: pytorch , グラフニューラルネットワーク , 深層学習

GCN(Graph Convolutional Network、グラフ畳み込みネットワーク) は、2017年に深層学習のトップカンファレスであるICLRで発表されて以来、徐々に注目を集めており、2022年現在深層学習関連のホットトピックでもあります。

今回は、このGCNについて解説しながら、pytorchで実装をしていきます。今回はライブラリとして、GNNをpytorchから簡単に利用できる、pytorch geomatricを利用します。

主にpytorchを使ってGCNを実装することを目標にしているので、GCNの詳しい構造や理論的な側面はこちらの記事で解説をしています。

とりあえず、Pytorchを利用して、GCNを動かしたい人や、そもそもGCNでどのような問題を解くことができるのかについて解説をしていきます。

今回GCNで解くタスク

こちらの記事でGNNについて解説していますが、GCNを含むGNNで解けるグラフ上のタスクは次のようになっています。

GNNで解けるグラフ上のタスク
  • リンク予測
  • グラフ分類
  • ノード分類

またこれらの分類でも、教師あり、教師なし、半教師ありなど様々な手法があります。

これらのタスクの解説はこちらの記事ではしません。詳しいことが知りたい人は、下記の記事で絵や図を用いて解説しているので、ぜひ参考にしてください。

【深層学習】GCN(グラフ畳み込みネットワーク)をわかりやすく解説する
GCN(Graph Convolution Network)は、GNN(Graph Neural Network)の1種類で、2022年現在、機械学習やAIのトップカンファレンスであるICMLやICLRで、非常に多くGC […]

今回は論文の引用・被引用の関係性をグラフ構造で表現したCoraデータセットでGCNの、ノード予測を行なっていきます。

具体的には、引用・被引用の関係性でグラフ構造となっている論文に対してGCNをしてあげることで、あるノードがどのクラス(Coraデータセットの場合、その論文がどのカテゴリに所属しているか)に属しているかを予測します。

GCNで利用するデータセットを準備する

では、まず最初にデータセットを準備していきましょう。

先ほど書きましたが、今回はCoraデータセットを利用します。Coraデータセットを準備するには、シンプルにデータの配布元からダウンロードする方法もありますが、今回はPytorch geometricでGCNを実装するので、pytorch geomatricにバンドルされているデータAPIを利用します。

pytorch geometricのインストール方法については、こちらの記事をご覧ください。

【PyG】PyTorch Geometricのインストール方法から利用方法まで解説
PyG(PyTorch Geometric)は、PyTorchでグラフ構造を取り入れた、GCNやGATなどのニューラルネットワークを簡単に実装、学習、推論などができるライブラリです。 グラフ構造を有するオープンなデータセ […]

以降は、環境にPytorch geomatricがインストールされている前提で話を進めます。

まず最初に必要なライブラリをインポートします。今回はこの後にGCNを実装する際に必要なtorchやtorchvision などのライブラリもまとめてインポートしてしまいます。

import matplotlib.pyplot as plt
import networkx as nx

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
 
import torchvision
import torchvision.transforms as transforms

from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

続いて、Coraデータセットをダウンロードし、networkXで可視化します。


from torch_geometric.datasets import Planetoid

dataset = Planetoid(root="./Cora", name="Cora")
data = dataset[0]

data_nx = to_networkx(data)

plt.figure(figsize=(12, 12), dpi=120)
nx.draw(data_nx, node_color = data.y, node_size=5)
plt.show()

かなりごちゃごちゃしていますが、これでデータセットを準備できました。データはdataオブジェクトに格納されています。

続いて、実際にpytorchでGCNを実装し、coraデータセットのノード分類をやっていきましょう。

PytorchでGCNを実装する

早速、pytorchでGCNを実装していきましょう。今回は、2層のGCNを実装していきます。モデルの実装は次のようなコードになります。

よくあるpytorchを用いたニューラルネットの実装とほとんど変わりありませんが、GCNConvというモジュールを利用していることが特徴的です。GCNConvモジュールでは、グラフのたたみ込み演算を行なっています。

class GCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 50)
        self.relu = nn.ReLU()
        self.conv2 = GCNConv(50, dataset.num_classes)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index

        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)

        return x

gcn = GCN()
gcn.cuda()

続いて、構築できたGCNモデルの訓練を行なっていきましょう。GCNの訓練は次のようにして行います。

data = data.cuda()
lossfunc = nn.CrossEntropyLoss()
optimizer = optim.Adam(gcn.parameters())
n_epoc = 400

gcn.train()
for epoch in range(n_epoc):
    optimizer.zero_grad()
    out = gcn(data)
    loss = lossfunc(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

しばらく時間がかかりますが、これでモデルを訓練できました。あとは検証データを用いて、モデルの検証を行います。

gcn.eval()
pred = gcn(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = int(correct) / int(data.test_mask.sum())
print("Accuracy: {} %".format(str(accuracy*100)))
# => Accuracy: 77.9 %

正解率77.9 %で予測できました。なかなかの精度と言えるのではないでしょうか。これにて、GCNの実装は完了です。

参考文献

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術