MPNN(Message Passing Neural Network)は、GNN(Graph Neural Network, グラフニューラルネットワーク)を説明する汎用的なフレームワークです。
このMPNNはメッセージパッシングのフレームワークとも呼ばれており、2017年の論文 Neural Message Passing for Quantum Chemistry で登場しました。
MPNNのメッセージパッシングの考え方は、GNNの多様なアルゴリズムを汎用的に説明できることから、GNN系のpytorch geometricなどのフレームワークでも、メッセージパッシングをベースとした実装もあり、メッセージパッシングを理解することで、GNNに対してかなりすっきりと理解できるようになります。
例えば、GNNとして代表的なアルゴリズムである、GCN(Graph Convolution Network)やGAT(Graph Attention Network)も、このMPNN のメッセージパッシングの考え方で説明することができます。
今回の記事では、上記の論文で登場した、MPNNについて分かりやすく解説していきます。
MPNNのメッセージパッシング
MPNNは、グラフにおけるノード$i$へ向けた、ノード$j$からのメッセージ$m_{ij}$を次のように定義しています。
\begin{equation} m_{ij} = M(h_i, h_j, e_{ij}) \end{equation}
ここで、$h_i$はノード$i$の特徴量(または潜在変数)、$e_ij$は、ノード$i, j$間のエッジのお耳で、$M$は、非線形関数でMessage Functionと呼びます。
Message Functionはよく、MLP(多層パーセプトロン)などが利用されます。
まず、メッセージ$m_{ij}$ですが、次のようなグラフがあったときに、ノード間で渡されるメッセージになり、グラフ上にあるような向きで渡されます。
ここで、定義でも書きましたが、Message Functionは、あるノード$i$と$j$を考えたときに、ノードの$i, j$の特徴量$h_i, h_j$とエッジの重み$e_{ij}$を入力に、メッセージを返す、MLPのような非線形関数が用いられます。
MPNNでのノードの更新
MPNNでは、(1)で定義されるメッセージを用いて、次の層のノード$i$を更新していきます。
ノード$i$の更新式は次のようになります。
\begin{equation} h_i^{(l+1)} = U \biggl ( h_i^{(l)}, \sum_{j \in \mathcal{N}_i} m_{ij} \biggr ) \end{equation}
ここで、$h_i^{(l+1)}$は、$l+1$層における、ノード$i$の重みです。Uは何かしらの、MLPのような非線形関数を想定しています。
(1)と(2)で重要なのは、グラフ上のノードにおける更新は、(1)による周辺ノードからのメッセージと (2)そのメッセージを入力とした出力によって定義できるということです。
かなり汎用的な枠組みに思えますが、このMPNNの枠組みで、今流行りのGCNやGATを説明することができます。
MPNNやメッセージパッシングの仕組みでこれらのアルゴリズムや手法を解説するのは、また次回行いたいと思います。