GCN(Graph Convolution Network)は、GNN(Graph Neural Network)の1種類で、2022年現在、機械学習やAIのトップカンファレンスであるICMLやICLRで、非常に多くGCNやGNN関連の論文が登場するなど、現在非常に盛んに研究が盛んに行われています。
GCN自体の論文は、2017年にICLRで、 SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS の論文として発表されました。
その後、GCNの派生系のアーキテクチャや手法は数多く登場していますが、一般的にGCNと呼ぶ際には、この論文の手法を指している場合が多く、非常に重要な論文となっています。なんと、2022年現在でも、14000件近い引用がされているほど、深層学習分野で非常に重要な論文となっていると思います。
今回はこの、論文の内容を解説する形で、できるだけ分かりやすくGCNの全体像やざっくりとした数式、GCNで何ができるのか?について、解説していきます。
- GCNの全体像やアーキテクチャを解説
- GCNでどのようなタスクを解くことができるか
- ICLRの論文に即してGCNの数式を解説
GCNの全体像やアーキテクチャを解説
まず最初に、GCNではどのようなアルゴリズムなのか全体像を理解していきましょう。GCNのよくある紹介記事や論文では、数式について詳細に立ち入るものもありますが、どのような操作をしているか、漫然と読んでいるだけでは中々理解が難しかったりします。
今回はまずGCNの全体像から最初に解説していきます。
まず、GCNの全体像の流れはこのようになっています。
画像の深層学習で有名なCNNは、深層学習モデルに画像の画素ベクトルを入力しますが、GCNで深層学習モデルに入力するのは、図のようなグラフです。
上図で右に向かって矢印が引かれていますが、この矢印によって畳み込み処理が行われています。
グラフの畳み込み処理の前後では、グラフの形状は変わりませんが、グラフのノードの有する特徴量が変わります。最初に入力したグラフの各ノードの特徴量$x_1, x_2, x_3, \dots, x_6$が畳み込みによって、$h_1^{(1)}, h_2^{(1)}, h_3^{(1)}, \dots, h_6^{(1)}$に変わっていますね。
ちなみに、GNNではこの畳み込みよって更新されたノードの更新量を潜在変数と呼ぶことがあります。
GCNを理解するポイントとして、グラフの畳み込みによって、グラフの構造自体は変わらず、グラフのノードの特徴量が変わっていくのがポイントです。
では、畳み込みによって、どのようにグラフ上の特徴量を更新していくかというと、あるノードの特徴量$x_i$は、そのノードが隣接するノードの特徴量$x_j ~~ (j \in \mathcal{N}(i))$を用いて更新していきます。
この更新用の関数を上図では、$g$としましたが、これは一般的にニューラルネットワークの関数で、MLP(多層パーセプトロン)や非線形活性化関数であるReLUやtanhなどに相当します。
この関数gのパラメータを一般的に$\bm{\theta}$や$\bm{W}$と表現すると、GCNの学習フェーズでは、この$\bm{\theta}$や$\bm{W}$を学習することになります。
ここで、$j \in \mathcal{N}(i)$という記号を使いましたが、これは$i$番目のノードに隣接している、ノードのインデックスの集合を表現しています。
グラフの畳み込みのイメージ
CNNとの比較でグラフの畳み込みを図解するとこのようになります。
左側は皆さんよくご存知の画像(CNN)の畳み込みの模式図で、右側は今回提示するグラフの畳み込み(GCN)の模式図です。
CNNは特定の画素値の値は、近隣のピクセルの値の線型結合でできるのに対し、グラフも同様に、リンクを有する(隣接する)ノードの値を足し合わせて、特定のノードの値が構成されていきます。
GCNの内容自体は本当にこれだけです。
この内容を頭に叩き込んだ上で、次の内容に入ると良いかと思います。
- GCNの入力、中間層、出力層でグラフの形状は変わらず、ノード上の特徴量が変わる
- GCNの学習では、畳み込み関数のパラメータ$\bm{\theta}$や$\bm{W}$が学習される
GCNで解きたい問題設定
続いて、GCNで解きたい問題設定について説明します。
これは本論文の内容とは逸脱しますが、非常に重要で頭に入れておいた方がいいので、軽く触れておきます。こちらの記事でも解説していますが、GCNを含む、GNNはグラフ上の次のようなタスクを解くために利用されます。
- ノード分類
- グラフ分類
- リンク予測
論文等でもよく登場するのがノード分類やグラフ分類です。
この3つのタスクについて、簡単に解説していきます。
ノード分類
ノード分類は、グラフ上のノードがどのカテゴリに属するかを予測するタスクです。具体的には、$i$番目のノードのカテゴリを予測したい場合、GCNの出力層における$i$番目の特徴量(潜在変数)から、そのノードのカテゴリを分類します。
学習時には、各層のパラメータや、最後の潜在変数からカテゴリを分類する$f$を学習します。
グラフ分類
グラフ分類はノード分類と似ていますが、GCNの最終層で出力された全ノードの潜在変数から、そのグラフが属するカテゴリを分類します。
リンク予測
リンク予測は、上のグラフとノード分類と少し違うような感じもしますが、ほとんど一緒です。
GCNの最後の出力をもとに、あるノードとノード間のエッジが存在するかどうかを予測するタスクです。
差し当たりは、この辺りのタスクがあることを覚えておければ大丈夫でしょう。
元論文では、GCNについて紹介した後、先ほど一番最初に紹介したタスクであるノード分類を実際にデータセットで実施し、精度を検証しています。
この論文中ではノード分類の中でも、部分的なノードだけ正解ラベルを与えた、半教師ありのノード分類をしています。
このタスクは、ある1つのグラフにおいて、その部分グラフのラベルがわかっているとき、残りのラベルが分かっていないノードのクラスを予測する、分類問題です。
イメージとしてはこのようになっています。グラフのノードのうち、色がついているのが正解ラベルが既知のラベルで、白いノードが正解ラベルを与えず、GCNで予測するノードです。
また、ノードの脇に書いてあるカラーバーが、各Node(頂点)が持っている特徴量を模式的に表現したものです。
先ほども書きましたが、グラフの構造も既知であり、かついくつかのNodeに関しては、正解のラベルがわかっています。
このような状況で、クラスがわからないNodeのクラスを当てるのが、今回の問題設定です。
- 全体のグラフの構造(Graph Topology)は既知
- 各ノードは、ラベルと特徴量(ベクトル)を持っている
- ラベルが分かっているノードと、分かっていないノードがある
- 各ノードの特徴量とグラフの構造から、ラベルが分かっていないノードのラベルを予測
ここまで、GCNでどのようなタスクが解けるのか、そして論文中で検証していた半教師あるのノード分類についても紹介しました。
と言っても、GCNやGNNを用いて解くことができるグラフ上の問題というのは他にも多数あると思います。
冒頭でも述べましたが、GNNやGCNの研究は現在ものすごいで行われているので、ぜひ最新の研究をぜひ調べてみてください。
GCNのアーキテクチャ
ここまでで、GCNの全体像と、どのようなタスクを解くことができるのかについて書いてきました。
続いては、少しの数式を交えながら、GCNの構造やアーキテクチャに入っていきます。
まず、これがGCNのネットワークの全体像となっています。
Input(入力)にグラフ構造を入力し、出力にも、入力時と同じ形式のグラフを出力しています。ここで意識すべき点として、先ほども記載しましたが、入力と出力が同じ形状のグラフになっているということです。
また、上手では中間層が2層のGCNですが、実際は層は何層でも実装することができます。
ここで、今回このGCNの入力とするグラフを$\mathcal{G}=(\mathcal{V}, \mathcal{E})$とします。ここで、GCNの入力は特徴行列$\bm{X} \in \mathbb{R}^{N \times D}$と隣接行列$\bm{A} \in \mathbb{R}^{N \times N}$です。ここで、$D$は各ノードの特徴ベクトルの大きさで、$N$はグラフ上のノードの数を示しています。
特徴行列は、N個のノード上の特徴量$\bm{X}$をまとめた行列で、隣接行列は、どのノード間が隣り合っているかを表現した行列となっています。
GCNやGNNの世界では、特徴行列や隣接行列は頻出ですので、どのような形式なのかイメージできるようにしておきましょう。
また、GCNの出力を$\bm{Z} \in \mathbb{R}^{N \times F}$と表現します。ここで、$F$は、出力層の各ノードの特徴ベクトルの次元を示します。
ここで、入力の特徴の次元は、$D$で、出力層の特徴ベクトルの次元は$F$となっていますが、一般的に$D \neq F$とすることができます。
GCNの数式
続いて、GCNの数式に入っていきます。
ここからは、少ししんどくなっていきます。またグラフラプラシアン などの用語が出てきますが、もし一度で理解できなくても、何度も見返すうちになんとなく理解できるようになっていくので、安心してください。
では、各畳み込み層におけるl番目の層の出力を$\bm{H}^{(l)}$とすると、非線形関数$f$を利用して、次のようにかけます。
\begin{equation} \begin{split} \bm{H}^{(l+1)} = f (\bm{H}^{(l)}, \bm{A} ) \end{split} \end{equation}
ここで、$\bm{H}$は、ニューラルネットワークの各層で、$\bm{H}^{(0)} = \bm{X}$で、$\bm{H}^{(L)} = \bm{Z}$とします。
ここで、$L$は全部の層の数で、$\bm{Z}$は、出力層の特徴ベクトルです。また、$\bm{A}$は隣接行列です。
(1)式で表現できるGNNを、論文の著者はGCN(Graph Convolutional Network)と定義しています。
そして、非線形関数$f$をどのように選定するかで、GCNの中でも手法として分岐してくるとしています。
ここでは、論文内で提示されている、最もシンプルかつ基本的なGCNについて取り上げます。
(1)で表現される非線形関数$f$を次のようにおいた場合を考えます。
\begin{equation} \begin{split} f(\bm{H}^{(l)}, \bm{A} ) = \sigma (\bm{A} \bm{H}^{(l)} \bm{W}^{(l)}) \end{split} \end{equation}
ここで、$\sigma$はReLUなどの非線形な活性関数であり、$\bm{H}^{(l)}$は$l$層目の出力であり、$\bm{W}^{(l)}$は$l$層目の重みです。
非常にシンプルな関数ですが、この(2)式では2つの問題があると論文で述べています。
1つ目は、これは隣接行列の行列の形を見るとわかりますが、畳み込みの演算の際に自身のノードの値が含まれません。そのため、隣接行列$\bm{A}$に単位行列$I$を足し合わせた、$\tilde{\bm{A}} = \bm{A} + I$を利用することで、この問題を解決します。
通常の隣接行列$\bm{A}$ではなく、この$\tilde{\bm{A}} $を利用することで、畳み込みの計算時に、自分のノードの特徴量を次の層に伝えることができます。
また2つ目の問題として、(2)では層の数だけ、$\bm{AW}$の計算が実行されていますが、これだと各ノードでスケールが異なってしまうことになります。
このため、字数行列$D$で対角化をすることでこれを解決します。今回は問題1の議論から、$\tilde{A} $を対角化して、$\tilde{D}^{- \frac{1}{2}} \tilde{A} \tilde{D}^{- \frac{1}{2}}$で正規化します。
この辺りの正規化の話は、グラフラプラシアンを理解していないと難しいかもしれません。。。
この辺りを突きつけていくと、元論文で提示されている、GCNにおけるpropagationの式が導出されます。
(この辺りは自分もよく理解していないの省略させてください、すみません。以下の結論だけ覚えておくと良いと思います。)
GCNにおける伝搬(propagration)(または、更新(update))は、下記のようになる。
\begin{equation} \begin{split} H^{(l+1)} = \sigma \biggl ( \tilde{D}^{- \frac{1}{2}} \tilde{A} \tilde{D}^{- \frac{1}{2}} H^{(l)}W^{(l)} \biggr ) \end{split} \end{equation}
また、$\bm{H}^{(l+1)}$の各要素$\bm{h}_{i}^{(l)}$については、このようにかける。
\begin{equation} \begin{split} \bm{h}_{i}^{(l+1)} = \sigma \biggl ( \sum_{j \in \mathcal{N}(i) \cup \{ i\} }\frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}} (\bm{W}^T \bm{h_j}^{(l)}) \biggr ) \end{split} \end{equation}
(4)の式は、(3)における$i$番目のノードの値に注目して取り出しただけです。
このような表現は、node-wiseと呼ばれています。GNN周りの新しい最新論文などを読んでいると、たまにnode-wiseの形式で記載している論文等もあるので、(4)のように書けるんだなーということくらいは、覚えておくと良いでしょう。
以上でGCNとその論文の解説を終わりにします。
少し省略してしまったところがあったのですが(グラフラプラシアンのところなど)、おおよそGCNの全体像について理解できたのではないでしょうか。
もう少しGCNについて詳しく知りたい人や、実装をもとに理解したい人は、こちらの記事でGCNで実際にグラフ上のノード分類を解説しているので、実装や手を動かして確認したい人はぜひご覧ください。