pytorch-geometricでは、グラフ状の様々な畳み込み操作をメッセージパッシングの枠組みで、柔軟に記述できるMessagePasssingクラスが用意されています。
MessagePasssingを利用することで、論文の検証や研究などで新しいアルゴリズムを考えたときにサクッと実装することができ、大変便利です。
今回は、pytorch-geometricに準備されているMessagePassingの解説や利用方法を分かりやすく説明していきます。
また、GCNの畳み込み処理をスクラッチで実装することを通して、利用方法を具体例を通して解説していきます。
- GNNのメッセージパッシングについて
- MessagePassingクラスの解説
- GCNをMessagePassingクラスを用いて実装する
GNNのメッセージパッシング
GNN(Graph Neural Network)の畳み込みを行う際は、近隣ノードを集約(Aggregate)と捉えるか、近隣ノードからメッセージパッシングされるとして、畳み込みを行うことができます。
ここで、$k-1$層目のノード$i$の特徴ベクトルを$\mathbf{x}_i^{(k-1)} \in \mathbb{R}^{F}$、ノード$j$から$i$のエッジがもつ特徴量を$\mathbf{e}_{j,i}$とすると、メッセージパッシングの枠組みでは、$k$層目のノード$i$の特徴ベクトルは次のように表現できます。
\begin{equation} \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) \end{equation}
ここで、$\gamma^{(k)}$と$\phi^{(k)}$はMLPなどの微分可能な関数で、$\square$は、和や平均、合計などの順列不変(Permutation invariant)な関数を想定しています。
pytorch-geometricでは、(1)式で表現できる畳み込み処理を記述でいる、MessagePassingと呼ばれる基底クラスを用意しています。
このMessagePassingクラスを利用して畳み込み処理をすることで、パラメータの学習などを完全にフレームワーク側に任せることができるので、とても便利です。
MessagePassingクラスの解説
MessagePassingクラスを継承するクラスを準備して、上記の(1)式の$\gamma^{(k)}$に対応する、update関数と、$\phi^{(k)}$に対応するmessage()関数を実装することで、よしなにやってくれます。
また、集約関数として(1)式の$\square$に相当するPermutation invariantを、add, mean, sum のいずれかで指定することで実装可能です。
とはいえ、この辺りは言葉で説明するより実装を見た方が早いと思うので、以降は、基本的なGCNを0台に解説をしていきます。
GCNをMessagePassingで実装
では、GNNでも非常に有名なGCNを実装してみましょう。ちなみにGCNをMessagePassingで実装する内容については、公式ドキュメントのチュートリアルにも掲載されています。
GCNの畳み込みのNode-Wise表現は、次のようになっています。 この辺りが自信がない人は、GCNの解説記事をご覧ください。
\begin{equation} \begin{split} \bm{h}_{i}^{(l+1)} = \sum_{j \in \mathcal{N}(i) \cup \{ i\} }\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}} (\bm{W}^T \bm{h_j}^{(l)}) + \bm{b} \end{split} \end{equation}
(2)の式を読み解くと、まず、隣接ノードの特徴量を重みベクトル$\bm{W}$で変換した後に、ノード$i$とノード$j$の度数の平方根で正規化しています。その値を全てのノードに足し合わせた後で、バイアス$\bm{b}$を加え、最後に活性化関数を通しています。
これを、MessagePassingで実装すると次のような手順となります。
- ノード$j$の特徴量を、重みベクトル$\bm{W}$で変換
- 正規化係数を計算する (平方根の部分)
- 隣接ノードに対する特徴量を集約(aggregate)する
- bias 項を追加する
実装自体は次のようになります。
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
# torch_geometric.nn.MessagePassingを継承したクラスを作成します.
# クラス名は任意に設定可能.
class GCNConv(MessagePassing):
# in_channels: ノードの入力特徴量の次元. # out_channels: ノードの出力特徴量の次元
def __init__(self, in_channels, out_channels):
# ここで集約の方法を指定.
super().__init__(aggr='add')
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
# edge_index has shape [2, E]
# x has shape [N, in_channels]
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
# row: 2, col: Edgesの数.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# propagateをcallすると、下で定義するmessage 関数が呼ばれる.
out = self.propagate(edge_index, x=x, norm=norm)
# biasを足し算する.
out += self.bias
return out
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j