【深層学習】Graph Attention Networks(GAT)を理解する

Posted: , Category: グラフニューラルネットワーク , 深層学習

深層学習やGNN(Graph Neural Network)関連の論文を読み漁っていると、Graph Attention Network(GAT, グラフアテンションネットワーク)に関する論文を目にすることが多くあると思います。

GAT(は、GNN(Graph Neural Network)やGCN(Graph Convolutional Network)の枠組みに、Attentionの機構を取り入れることで、深層学習の予測精度を上げたモデルとなっています。

この記事では、2018年に発表された元論文に基づいて、GATについて分かりやすく解説していきます。

Graph Attention Networkの概要

GATは、一言で言えば、深層学習でグラフ構造の分類などを行うことができるGCNと呼ばれる手法に、Attentionの機構を取り入れた手法です。

Attentionの機構とは、入力されたデータに対し、どこに注目するかを動的に決定する仕組みで、このAttentionの機構をGCN(グラフ畳み込みネットワーク)に取り入れることで、分類、識別、予測精度を上げています。

このため、GCNについて理解が浅い場合、GATについていまいち理解できない可能性があります。

GCNについては、下記の記事で分かりやすく解説しているので、ぜひ参考にしてみてください。

【深層学習】GCN(グラフ畳み込みネットワーク)をわかりやすく解説する
GCN(Graph Convolution Network)は、GNN(Graph Neural Network)の1種類で、2022年現在、機械学習やAIのトップカンファレンスであるICMLやICLRで、非常に多くGC […]

GCNとGATの大きな違いは、ノードの畳み込みの際の係数(これがいわゆるAttention係数です)が大きく違います。

通常のGCNでは、あるノードの次の層の特徴量(潜在変数)を求める際には、隣接するノードの特徴量に線形の重みをかけたものの和に、活性化関数(ReLUなど)を通すことで次の層に値を伝搬させていましたが(※この辺りがよくわからない人は、先にGCNの記事をご覧ください)

その隣接ノードの特徴量の線形和を求める際に、隣接ノードをどれも対等に扱っていたのに対し、GATでは、隣接するノードを台頭ではなく、重要度(Attention係数)の概念を取り入れたものになります。

イメージとしては次のようになります。

上図では、ノード1と隣接するノード2, 3, 4のうち、エッジが太線になっているように、ノード3が重要で、次に4のノード、そして2のノードの順に重要度をつけたイメージとなります。

GATは隣接するノードにこのような概念を入れた仕組みとなっています。

GATでは、この重要度を、attention scoreで、$\alpha$としています。

GATでは、各ノードに対して、他のノードからの重要度 $\alpha_{ij}$を計算します。上図では、ノード1に対する$\alpha_{11} ~ \alpha_{14}$を示しています。

ここで、他のノードだけでなく、自分のノードの重要度$\alpha_{11}$も考えることに注意してください。

GATでは、このAttention係数(Attention coefficient)を利用して、畳み込み計算を行います。

GCNの畳み込み計算は次の方に書くことができました。

GCNにおけるMessage Passing
\begin{equation}
\begin{split}
\bm{h}_{i}^{(l+1)} = \sigma
\biggl (
\sum_{j \in \mathcal{N}(i) \cup \{ i\} }
\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}}
(\bm{W}^T \bm{h_j}^{(l)})

\biggr )
\end{split}
\end{equation}

(1)式は、$l+1$番目の層における、ノード$i$の特徴量の更新式をnode-wiseの表現したものとなっています。(この表現についても、先述のGCNの記事で解説しています)

ここで、隣接ノードの特徴量は、$\bm{W}^T \bm{h_j}^{(l)}$で、そこに係数$\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}}$がついているのが分かりますね。

この係数$\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}}$は、GCNの場合だと、ノード$i$とその隣接ノード$j$の字数の平方根をとっているだけですが、GATではこの係数(Attention係数)をもう少し、賢く求めていきましょう、というのが発想となっています。

なんとなくGATとGCNの違いについては分かりましたか?

多分、GCNについて理解していないとなかなか難しいと思うので、GATの前にGCNについて理解を深めることをおすすめします。

では、GATではattention 係数をどのように計算するのか、以降で解説していきます。

GATにおけるAttention係数

では、GATで、各ノード間の繋がりの重要度を示す、アテンション係数 $\alpha$をどのように定義するのかみていきましょう。

まず、特徴量の更新をしたい$i$番目のノードの特徴量を$\bm{h_i} \in \mathbb{R}^{F}$、ノード$i$に隣接するノード$j$の特徴量を$\bm{h_j} \in \mathbb{R}^{F}$とすると、GATではこのノード$i$とノード$j$の繋がりのAttention係数$e_{ij}$を次のように定義します。

Graph Attention NetworkにおけるAttention係数$e_{ij}$の定義
\begin{equation}
e_{ij} = \bm{a}(\bm{Wh_i}, \bm{Wh_j})
\end{equation}

この$e_{ij}$ノード$i$とノード$j$の関連度の重要性を示しています。

ここで、$\bm{W} \in \mathbb{R}^{F \times F’}$はGCNでも登場しているように、特徴量を線形変換する重みパラメータです。次元は、$F$から$F’$に変わっていることに注意してください。

また、関数$a$は、shared attention mechanism(attention メカニズム, 注意機構)とよばれており、さまざまな関数系が考えられます。

たとえば、内積のような計算だと、似ているベクトルの内積の演算をすると、値が大きくなりますね。このような関数系がAttention メカニズム $a$のイメージとなります。

ここで、$e_{ij}$は決定された後、次のsoftmax関数によって、正規化されます。正規化されたAttention係数が、上述で説明した$\alpha_{ij}$に対応します。

\begin{equation}
\alpha_{ij} = \operatorname{softmax_j}(e_{ij}) =\frac{exp(e_{ij})}{ \sum_{k \in \mathcal{N_i}} exp(e_{ik})} 
\end{equation}

ここで、$\alpha_{ij}$を求めることまでできましたが、(1)式で提示されている、Attention mechanism $a$はどのような関数系になるのでしょうか。

実際のところ、この辺りはモデルの設計者が決めるパラメータですが、GATの元論文では、このAttention mechanism $\bm{a}$に、シンプルな単層の順伝播型ニューラルネットワークを選定しています。

このため、$\bm{a}$の次元は、$\bm{a} \in \mathbb{R}^{F’ \times F’}$です。また、非線形の活性化関数として、LeakyReLUを利用しています。

GAT元論文におけるAttention係数の定義
\begin{split}
e_{ij} = \operatorname{LeakyReLU} (\bm{a}[W\bm{h_i} || W \bm{h_j}])
\end{split}

これで、Attention係数 $\alpha_{ij}$を得ることができました。

ここで、$||$という見かけない演算子がありますが、これは単純に2つのベクトルを連結(concatenationさせているだけです。

難しく考える必要性がありません。今回は、$F’$のベクトルを2つ連結しているので、$W\bm{h_i} || W \bm{h_j}$はサイズ$2F’$のベクトルとなります。

論文中では上のような絵が登場します。よく見るとすごく簡単です。

下の丸が全て同一のベクトルで、このベクトルを単層のニューラルネットに入れて、正規化されたAttention係数を得ているだけです。

Graph Attention Network の全体像

ここまで、なんとなくGATについて理解できてきたと思うので、最後にMultihead Attentionも含めたGATの全体像をまとめていきます。

そこちらがMultihead Attentionの概要図となっています。

GATの元論文では、Attention係数の安定化のために、Attentionを複数$K$回行い、その平均を利用しています。この辺りのMultihead Attentionの詳細については、元論文を参考にしてみてください。

【広告】
統計学的にあなたの悩みを解決します。
仕事やプライベートでお悩みの方は、ベテラン占い師 蓮若菜にご相談ください。

機械学習と情報技術