深層学習やGNN(Graph Neural Network)関連の論文を読み漁っていると、Graph Attention Network(GAT, グラフアテンションネットワーク)に関する論文を目にすることが多くあると思います。
GAT(は、GNN(Graph Neural Network)やGCN(Graph Convolutional Network)の枠組みに、Attentionの機構を取り入れることで、深層学習の予測精度を上げたモデルとなっています。
この記事では、2018年に発表された元論文に基づいて、GATについて分かりやすく解説していきます。
Graph Attention Networkの概要
GATは、一言で言えば、深層学習でグラフ構造の分類などを行うことができるGCNと呼ばれる手法に、Attentionの機構を取り入れた手法です。
Attentionの機構とは、入力されたデータに対し、どこに注目するかを動的に決定する仕組みで、このAttentionの機構をGCN(グラフ畳み込みネットワーク)に取り入れることで、分類、識別、予測精度を上げています。
このため、GCNについて理解が浅い場合、GATについていまいち理解できない可能性があります。
GCNについては、下記の記事で分かりやすく解説しているので、ぜひ参考にしてみてください。
GCNとGATの大きな違いは、ノードの畳み込みの際の係数(これがいわゆるAttention係数です)が大きく違います。
通常のGCNでは、あるノードの次の層の特徴量(潜在変数)を求める際には、隣接するノードの特徴量に線形の重みをかけたものの和に、活性化関数(ReLUなど)を通すことで次の層に値を伝搬させていましたが(※この辺りがよくわからない人は、先にGCNの記事をご覧ください)
その隣接ノードの特徴量の線形和を求める際に、隣接ノードをどれも対等に扱っていたのに対し、GATでは、隣接するノードを台頭ではなく、重要度(Attention係数)の概念を取り入れたものになります。
イメージとしては次のようになります。
上図では、ノード1と隣接するノード2, 3, 4のうち、エッジが太線になっているように、ノード3が重要で、次に4のノード、そして2のノードの順に重要度をつけたイメージとなります。
GATは隣接するノードにこのような概念を入れた仕組みとなっています。
GATでは、この重要度を、attention scoreで、$\alpha$としています。
GATでは、各ノードに対して、他のノードからの重要度 $\alpha_{ij}$を計算します。上図では、ノード1に対する$\alpha_{11} ~ \alpha_{14}$を示しています。
ここで、他のノードだけでなく、自分のノードの重要度$\alpha_{11}$も考えることに注意してください。
GATでは、このAttention係数(Attention coefficient)を利用して、畳み込み計算を行います。
GCNの畳み込み計算は次の方に書くことができました。
\begin{equation} \begin{split} \bm{h}_{i}^{(l+1)} = \sigma \biggl ( \sum_{j \in \mathcal{N}(i) \cup \{ i\} } \frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}} (\bm{W}^T \bm{h_j}^{(l)}) \biggr ) \end{split} \end{equation}
(1)式は、$l+1$番目の層における、ノード$i$の特徴量の更新式をnode-wiseの表現したものとなっています。(この表現についても、先述のGCNの記事で解説しています)
ここで、隣接ノードの特徴量は、$\bm{W}^T \bm{h_j}^{(l)}$で、そこに係数$\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}}$がついているのが分かりますね。
この係数$\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}}$は、GCNの場合だと、ノード$i$とその隣接ノード$j$の字数の平方根をとっているだけですが、GATではこの係数(Attention係数)をもう少し、賢く求めていきましょう、というのが発想となっています。
なんとなくGATとGCNの違いについては分かりましたか?
多分、GCNについて理解していないとなかなか難しいと思うので、GATの前にGCNについて理解を深めることをおすすめします。
では、GATではattention 係数をどのように計算するのか、以降で解説していきます。
GATにおけるAttention係数
では、GATで、各ノード間の繋がりの重要度を示す、アテンション係数 $\alpha$をどのように定義するのかみていきましょう。
まず、特徴量の更新をしたい$i$番目のノードの特徴量を$\bm{h_i} \in \mathbb{R}^{F}$、ノード$i$に隣接するノード$j$の特徴量を$\bm{h_j} \in \mathbb{R}^{F}$とすると、GATではこのノード$i$とノード$j$の繋がりのAttention係数$e_{ij}$を次のように定義します。
\begin{equation} e_{ij} = \bm{a}(\bm{Wh_i}, \bm{Wh_j}) \end{equation}
この$e_{ij}$ノード$i$とノード$j$の関連度の重要性を示しています。
ここで、$\bm{W} \in \mathbb{R}^{F \times F’}$はGCNでも登場しているように、特徴量を線形変換する重みパラメータです。次元は、$F$から$F’$に変わっていることに注意してください。
また、関数$a$は、shared attention mechanism(attention メカニズム, 注意機構)とよばれており、さまざまな関数系が考えられます。
たとえば、内積のような計算だと、似ているベクトルの内積の演算をすると、値が大きくなりますね。このような関数系がAttention メカニズム $a$のイメージとなります。
ここで、$e_{ij}$は決定された後、次のsoftmax関数によって、正規化されます。正規化されたAttention係数が、上述で説明した$\alpha_{ij}$に対応します。
\begin{equation} \alpha_{ij} = \operatorname{softmax_j}(e_{ij}) =\frac{exp(e_{ij})}{ \sum_{k \in \mathcal{N_i}} exp(e_{ik})} \end{equation}
ここで、$\alpha_{ij}$を求めることまでできましたが、(1)式で提示されている、Attention mechanism $a$はどのような関数系になるのでしょうか。
実際のところ、この辺りはモデルの設計者が決めるパラメータですが、GATの元論文では、このAttention mechanism $\bm{a}$に、シンプルな単層の順伝播型ニューラルネットワークを選定しています。
このため、$\bm{a}$の次元は、$\bm{a} \in \mathbb{R}^{F’ \times F’}$です。また、非線形の活性化関数として、LeakyReLUを利用しています。
\begin{split} e_{ij} = \operatorname{LeakyReLU} (\bm{a}[W\bm{h_i} || W \bm{h_j}]) \end{split}
これで、Attention係数 $\alpha_{ij}$を得ることができました。
ここで、$||$という見かけない演算子がありますが、これは単純に2つのベクトルを連結(concatenationさせているだけです。
難しく考える必要性がありません。今回は、$F’$のベクトルを2つ連結しているので、$W\bm{h_i} || W \bm{h_j}$はサイズ$2F’$のベクトルとなります。
論文中では上のような絵が登場します。よく見るとすごく簡単です。
下の丸が全て同一のベクトルで、このベクトルを単層のニューラルネットに入れて、正規化されたAttention係数を得ているだけです。
Graph Attention Network の全体像
ここまで、なんとなくGATについて理解できてきたと思うので、最後にMultihead Attentionも含めたGATの全体像をまとめていきます。
そこちらがMultihead Attentionの概要図となっています。
GATの元論文では、Attention係数の安定化のために、Attentionを複数$K$回行い、その平均を利用しています。この辺りのMultihead Attentionの詳細については、元論文を参考にしてみてください。