グラフTransformerの理論と導出と実装

分子の毒性を予測したいとき、その毒性は分子の片隅にある1つの官能基と、反対側にある別の官能基の「組み合わせ」で決まることがあります。ところが従来のグラフニューラルネットワーク(GNN)は、各ノードが「隣のノード」とだけ情報をやりとりする仕組みなので、グラフの端から端まで離れた2つの原子の関係を捉えるには、その距離の回数だけ層を重ねなければなりません。そして層を重ねすぎると、今度はすべてのノードの特徴が似たような値に潰れてしまう「過剰平滑化(over-smoothing)」という別の病が現れます。

この「長距離依存を捉えたいが、層を深くすると平滑化で潰れる」というジレンマを、自然言語処理で大成功したTransformerの発想で打ち破ろうとするのがグラフTransformer(Graph Transformer)です。Transformerの自己注意(self-attention)は、文の中のどんなに離れた単語どうしでも1層で直接結びつけられます。これをグラフに持ち込めば、グラフ上のどんなに離れたノードどうしでも1層で情報を交換できる——これが基本的なアイデアです。

しかし話はそう単純ではありません。素朴に全ノードどうしを注意で結ぶと、せっかくのグラフの「つながり方(構造)」の情報が完全に消えてしまいます。文には単語の並び順がありますが、グラフには「自然な順序」がありません。そこで本記事では、ラプラシアン固有ベクトルによる位置符号化と、注意のロジットへ最短路距離や次数を加えるバイアス(Graphormer型)という2つの工夫で、構造情報をTransformerに取り戻す方法を、数式の導出を省略せずに解説します。

グラフTransformerは、分子物性予測(創薬・材料探索)、たんぱく質の立体構造予測、交通網・電力網の長距離相関のモデリングなど、「離れたノード間の相互作用」が本質的に効く問題で威力を発揮します。本記事では最後にPyTorchで小規模なグラフ分類モデルを実装し、GCN・GATと分類精度や注意の重みを比較して、何が起きているのかを目で確かめます。

本記事の内容

  • メッセージパッシングGNNの「長距離依存」と「過剰平滑化」の限界を数式で理解する
  • 全ノード自己注意をグラフに適用するときに失われる構造情報を特定する
  • ラプラシアン位置符号化と最短路・次数バイアスの定式化を省略なく導出する
  • PyTorchでグラフTransformerを実装し、GCN/GATと精度・注意重みを比較して読み取る

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が格段に深まります。グラフ上の畳み込みと注意機構の基礎、そしてグラフをベクトルで表す考え方が出発点になります。

特に、グラフラプラシアンの固有ベクトル(スペクトルグラフ畳み込みの記事で登場)は、本記事の位置符号化で中心的な役割を果たします。注意機構については、GATで「隣接ノードへの注意」を学んでおくと、グラフTransformerが「全ノードへの注意」へと一般化したものだと自然に理解できます。

メッセージパッシングGNNの限界とは

まず、なぜ既存のGNNでは不十分なのかを直感的につかみましょう。代表的なGNNであるGCN(グラフ畳み込みネットワーク)やGATは、すべてメッセージパッシング(message passing)という共通の枠組みで書けます。これは「各ノードが、自分の隣のノードたちから情報(メッセージ)を集めて、自分の特徴を更新する」という操作を1層とし、それを繰り返すものです。

イメージとしては、村人どうしの噂話のようなものです。1回の集会(1層)では、あなたは隣近所の人の意見しか聞けません。隣の隣の人の意見を知るには、隣の人が次の集会でそれを伝えてくれるのを待つ、つまり2回目の集会(2層目)が必要です。$k$ 軒先の人の意見が届くには $k$ 回の集会、すなわち $k$ 層が必要になります。

メッセージパッシングの一般形

メッセージパッシングを数式で書くと、ノード $v$ の第 $\ell$ 層の特徴ベクトル $\bm{h}_v^{(\ell)}$ は、隣接ノード集合 $\mathcal{N}(v)$ を使って次のように更新されます。

$$ \begin{equation} \bm{h}_v^{(\ell+1)} = \phi\!\left(\bm{h}_v^{(\ell)},\; \bigoplus_{u \in \mathcal{N}(v)} \psi\!\left(\bm{h}_v^{(\ell)}, \bm{h}_u^{(\ell)}\right)\right) \end{equation} $$

ここで各記号の意味は次の通りです。

記号 意味
$\bm{h}_v^{(\ell)}$ 第 $\ell$ 層におけるノード $v$ の特徴ベクトル
$\mathcal{N}(v)$ ノード $v$ に隣接するノードの集合
$\psi$ メッセージ関数(隣接ノードからのメッセージを作る)
$\bigoplus$ 集約関数(和・平均・最大などの置換不変な演算)
$\phi$ 更新関数(自分の特徴とメッセージを混ぜる)

例えばGCNでは、$\psi$ が正規化付きの線形変換、$\bigoplus$ が重み付き和に対応します。重要なのは、この更新式の中に登場するのはあくまで隣接ノード $\mathcal{N}(v)$ だけだという点です。1層では1ホップ(1つ隣)の情報しか取り込めません。

限界1:長距離依存は層の深さを要求する

グラフ上の2つのノード $u, v$ の間の最短路距離(最小のホップ数)を $d(u, v)$ と書きましょう。上の更新式から、ノード $v$ がノード $u$ の情報を初めて「受け取る」ためには、少なくとも $d(u, v)$ 層が必要です。つまり、グラフの直径(最も離れた2ノード間の距離)を $D$ とすると、グラフ全体の情報を行き渡らせるには $D$ 層ものメッセージパッシングが要ります。

分子グラフのように直径が10を超えることも珍しくない対象では、これは10層以上の深いGNNを意味します。ところが、後で見るように層を深くすると別の問題が噴き出します。これが第一の限界です。

限界2:過剰平滑化(over-smoothing)

層を深くすればよい、というわけにいかないのが厄介な点です。多くのGNNでは集約操作が「隣接ノードの平均を取る」ことに近く、これは数学的には拡散方程式(熱伝導)の離散版にあたります。熱いコーヒーを放置すると全体が一様な温度に近づくように、平均化を繰り返すとすべてのノードの特徴が「グラフ全体の平均」へと収束し、ノードどうしの区別がつかなくなってしまいます。これが過剰平滑化です。

これを少し数式で見てみましょう。簡単のため、自己ループ付きの対称正規化隣接行列 $\hat{\bm{A}} = \bm{D}^{-1/2}(\bm{A} + \bm{I})\bm{D}^{-1/2}$ を使い、活性化関数を無視した線形なGCN伝播 $\bm{H}^{(\ell+1)} = \hat{\bm{A}} \bm{H}^{(\ell)}$ を考えます。すると $\ell$ 層後の特徴は

$$ \begin{equation} \bm{H}^{(\ell)} = \hat{\bm{A}}^{\ell} \bm{H}^{(0)} \end{equation} $$

となります。ここで $\hat{\bm{A}}$ の固有値を $\lambda_1 \geq \lambda_2 \geq \cdots \geq \lambda_n$ とすると、対称正規化の性質から最大固有値は $\lambda_1 = 1$ で、それ以外は $|\lambda_i| < 1$ であることが知られています。$\hat{\bm{A}}$ を固有値分解して $\hat{\bm{A}} = \bm{U}\bm{\Lambda}\bm{U}^\top$ と書くと、その $\ell$ 乗は

$$ \begin{equation} \hat{\bm{A}}^{\ell} = \bm{U}\bm{\Lambda}^{\ell}\bm{U}^\top = \sum_{i=1}^{n} \lambda_i^{\ell}\, \bm{u}_i \bm{u}_i^\top \end{equation} $$

となります。ここで $\bm{u}_i$ は $\lambda_i$ に対応する固有ベクトルです。$\ell \to \infty$ のとき $|\lambda_i| < 1$ の項はすべて $\lambda_i^\ell \to 0$ で消え、$\lambda_1 = 1$ に対応する第1項 $\bm{u}_1 \bm{u}_1^\top$ だけが生き残ります。すなわち

$$ \begin{equation} \lim_{\ell \to \infty} \hat{\bm{A}}^{\ell} = \bm{u}_1 \bm{u}_1^\top \end{equation} $$

となり、すべてのノードの特徴が $\bm{u}_1$ の方向(連結グラフでは各ノードの次数の平方根に比例する1次元の方向)へと潰れてしまいます。これが「ノードの区別がつかなくなる」ことの数式的な正体です。

つまりメッセージパッシングGNNは、浅いと遠くの情報が届かず、深いと平滑化で情報が潰れるという板挟みに陥っているのです。ここで「そもそも隣接ノードだけに限定するから悪いのではないか。最初からすべてのノードどうしを直接つなげばよいのでは」という発想が自然に湧いてきます。次節では、その発想の元になったTransformerの自己注意を、グラフに持ち込むとどうなるかを見ていきます。

Transformerの自己注意をグラフに持ち込む

Transformerの心臓部は自己注意です。文の中の各単語が、他のすべての単語に対して「どれだけ注目すべきか」という重みを計算し、その重みで他の単語の情報を混ぜ合わせます。距離に関係なく、文頭の単語と文末の単語が1層で直接結びつくのがポイントです。グラフにこれを適用すれば、隣接関係に縛られず、どのノードとも1層で情報交換できるはずです。

スケール付きドット積注意の復習

入力として $n$ 個のノードの特徴を並べた行列 $\bm{H} \in \mathbb{R}^{n \times d}$ を考えます。各行が1ノードの $d$ 次元特徴ベクトルです。自己注意では、まずこれを3つの異なる線形変換でクエリ $\bm{Q}$、キー $\bm{K}$、バリュー $\bm{V}$ に写します。

$$ \begin{equation} \bm{Q} = \bm{H}\bm{W}_Q, \quad \bm{K} = \bm{H}\bm{W}_K, \quad \bm{V} = \bm{H}\bm{W}_V \end{equation} $$

ここで $\bm{W}_Q, \bm{W}_K \in \mathbb{R}^{d \times d_k}$、$\bm{W}_V \in \mathbb{R}^{d \times d_v}$ は学習可能なパラメータ行列です。直感的には、クエリは「私はどんな情報を探しているか」、キーは「私はどんな情報を持っているか」を表し、両者の内積が大きいほど「相性がよい=注目すべき」という関係になります。

注意の出力は次式で与えられます。

$$ \begin{equation} \mathrm{Attn}(\bm{Q}, \bm{K}, \bm{V}) = \mathrm{softmax}\!\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} \end{equation} $$

この式の中身を1ノードずつ分解してみましょう。ノード $i$ がノード $j$ に向ける注意の重み $\alpha_{ij}$ は、生のスコア(ロジット)$e_{ij}$ を softmax で正規化したものです。

$$ \begin{equation} e_{ij} = \frac{\bm{q}_i^\top \bm{k}_j}{\sqrt{d_k}}, \qquad \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{j’=1}^{n} \exp(e_{ij’})} \end{equation} $$

ここで $\bm{q}_i$ はノード $i$ のクエリ、$\bm{k}_j$ はノード $j$ のキーです。$\sqrt{d_k}$ で割るのは、次元が大きいときに内積が大きくなりすぎて softmax が飽和(勾配消失)するのを防ぐためのスケーリングです。最終的にノード $i$ の新しい特徴は、全ノードのバリューを注意重みで加重平均したもの

$$ \begin{equation} \bm{h}_i^{\text{out}} = \sum_{j=1}^{n} \alpha_{ij}\, \bm{v}_j \end{equation} $$

になります。和の範囲が $j = 1, \dots, n$、つまり全ノードである点に注目してください。GATでは和の範囲が隣接ノード $\mathcal{N}(i)$ に限定されていましたが、ここでは制限がありません。これが「1層で長距離依存を捉える」という性質の源です。

致命的な問題:構造情報が消える

ところが、この素朴な全ノード注意には致命的な欠陥があります。グラフの構造(どのノードとどのノードがつながっているか)が、注意の計算式のどこにも入っていないのです。

これがどれほど深刻かを実感するために、極端な例を考えましょう。同じノード特徴 $\bm{H}$ を持つ2つのグラフがあり、片方は鎖状(パスグラフ)、もう片方は星形(スターグラフ)だとします。エッジの張り方はまったく違うのに、上の注意計算ではエッジ情報を一切使わないため、両者に対してまったく同じ出力が得られてしまいます。グラフの本質はノードどうしの「つながり方」にあるのに、それを無視しているのです。

数式的に言えば、自己注意は入力ノードの並べ替えに対して同変(permutation equivariant)であると同時に、エッジ集合 $\mathcal{E}$ に依存しません。これは「グラフを点群(バラバラのノードの集まり)として扱っている」のと同じで、グラフ構造を捨ててしまっています。文では単語の「並び順」が位置エンコーディングで補われますが、グラフには自然な並び順がありません。

そこで2つの方向から構造情報を注ぎ込みます。

  1. 位置符号化(positional encoding):各ノードに「グラフ上での位置」を表すベクトルを与え、入力特徴に足し込む
  2. 注意バイアス(attention bias):注意のロジット $e_{ij}$ に、ノード対 $(i,j)$ の構造的関係(最短路距離など)を表す項を足す

次節では、まず位置符号化として最も理論的に美しいラプラシアン固有ベクトルを用いる方法を導出します。

ラプラシアン固有ベクトルによる位置符号化

「グラフ上での位置」とは何でしょうか。直線上の点なら座標一つで表せますが、グラフのノードに座標はありません。ここで鍵になるのが、太鼓の振動の比喩です。太鼓の膜を叩くと特定の振動パターン(固有振動)が現れ、その振動の「腹」や「節」の位置で膜上の各点を特徴づけられます。グラフにも同じように「自然な振動パターン」が定義でき、それがグラフラプラシアンの固有ベクトルです。各ノードがその振動パターンの中でどの位置にいるかを並べたものが、グラフ上の「座標」の役割を果たします。

グラフラプラシアンの定義

まずグラフラプラシアンを定義します。$n$ 個のノードを持つグラフの隣接行列を $\bm{A} \in \mathbb{R}^{n \times n}$($A_{ij}=1$ ならエッジあり)、次数行列を $\bm{D} = \mathrm{diag}(d_1, \dots, d_n)$($d_i = \sum_j A_{ij}$ はノード $i$ の次数)とします。組合せラプラシアン

$$ \begin{equation} \bm{L} = \bm{D} – \bm{A} \end{equation} $$

で定義されます。実用上は次数の違いを正規化した対称正規化ラプラシアン

$$ \begin{equation} \bm{L}_{\text{sym}} = \bm{I} – \bm{D}^{-1/2}\bm{A}\bm{D}^{-1/2} \end{equation} $$

を使うことが多く、本記事でもこちらを採用します。

ラプラシアンが「滑らかさ」を測ることの導出

ラプラシアンがなぜ「振動パターン」を与えるのかを理解するため、任意のノード上のスカラー値 $\bm{f} = (f_1, \dots, f_n)^\top$ に対して二次形式 $\bm{f}^\top \bm{L} \bm{f}$ を計算してみましょう。組合せラプラシアン $\bm{L} = \bm{D} – \bm{A}$ を代入すると

$$ \begin{equation} \bm{f}^\top \bm{L} \bm{f} = \bm{f}^\top \bm{D} \bm{f} – \bm{f}^\top \bm{A} \bm{f} \end{equation} $$

となります。右辺の第1項は $\bm{f}^\top \bm{D} \bm{f} = \sum_i d_i f_i^2$、第2項は $\bm{f}^\top \bm{A} \bm{f} = \sum_{i}\sum_{j} A_{ij} f_i f_j$ です。次数の定義 $d_i = \sum_j A_{ij}$ を第1項に代入すると $\sum_i d_i f_i^2 = \sum_i \sum_j A_{ij} f_i^2$ と書けるので、両者をまとめると

$$ \begin{equation} \bm{f}^\top \bm{L} \bm{f} = \sum_{i}\sum_{j} A_{ij} f_i^2 – \sum_{i}\sum_{j} A_{ij} f_i f_j \end{equation} $$

となります。ここで和の対称性($A_{ij}=A_{ji}$)を使うと $\sum_{i,j} A_{ij} f_i^2 = \frac{1}{2}\sum_{i,j} A_{ij}(f_i^2 + f_j^2)$ と書き換えられるので、これを代入して整理すると

$$ \begin{equation} \bm{f}^\top \bm{L} \bm{f} = \frac{1}{2}\sum_{i}\sum_{j} A_{ij}\,(f_i – f_j)^2 = \sum_{(i,j) \in \mathcal{E}} (f_i – f_j)^2 \end{equation} $$

という美しい結果が得られます。これは「エッジでつながった隣り合うノードどうしで値がどれだけ違うか」の総和、すなわち $\bm{f}$ の滑らかさ(の逆)を測る量です。$\bm{f}^\top \bm{L} \bm{f}$ が小さいほど、隣接ノード間で値が近い「滑らかな」配置だということになります。

固有ベクトルが「振動パターン」を与える

ラプラシアン $\bm{L}_{\text{sym}}$ は対称半正定値行列なので、非負の固有値 $0 = \mu_1 \leq \mu_2 \leq \cdots \leq \mu_n$ と、直交する固有ベクトル $\bm{\phi}_1, \bm{\phi}_2, \dots, \bm{\phi}_n$ を持ちます。

$$ \begin{equation} \bm{L}_{\text{sym}}\, \bm{\phi}_k = \mu_k\, \bm{\phi}_k \end{equation} $$

先ほどの二次形式の議論から、固有値 $\mu_k$ は対応する固有ベクトル $\bm{\phi}_k$ の「振動の激しさ(周波数)」を表します。最小固有値 $\mu_1 = 0$ に対応する固有ベクトル $\bm{\phi}_1$ は定数ベクトル(まったく振動しない=周波数ゼロ)で、位置情報を持ちません。一方、$\mu_2, \mu_3, \dots$ に対応する固有ベクトルは、低周波(グラフ全体でゆるやかに変化)から高周波(隣接ノード間で激しく変化)まで、グラフの「振動モード」を順に並べたものになります。

これは1次元の信号に対するフーリエ変換とまったく同じ構造です。フーリエ変換が信号を低周波から高周波の正弦波に分解するのに対し、グラフラプラシアンの固有ベクトルはグラフ上の信号を「グラフ固有の正弦波」に分解します。だからこそ、これらの固有ベクトルはグラフ版のフーリエ基底とも呼ばれます。

位置符号化ベクトルの構成

そこで、各ノード $i$ に対して、低周波側から $k$ 個の固有ベクトル $\bm{\phi}_2, \bm{\phi}_3, \dots, \bm{\phi}_{k+1}$(定数モード $\bm{\phi}_1$ を除く)を取り出し、それらの第 $i$ 成分を並べた $k$ 次元ベクトル

$$ \begin{equation} \bm{p}_i = \big(\phi_2[i],\, \phi_3[i],\, \dots,\, \phi_{k+1}[i]\big)^\top \in \mathbb{R}^{k} \end{equation} $$

ラプラシアン位置符号化(Laplacian Positional Encoding, LapPE)と呼びます。これがノード $i$ の「グラフ上での座標」です。2つのノードが構造的に近ければ、低周波モードの値も近くなり、位置符号化も似た値になります。逆に構造的に遠いノードどうしは異なる位置符号化を持ちます。

これをノードの入力特徴 $\bm{x}_i$ に(線形変換を挟んで)加算または連結することで、Transformerに「このノードはグラフのどこにいるか」を教えます。

$$ \begin{equation} \bm{h}_i^{(0)} = \bm{x}_i \bm{W}_x + \bm{p}_i \bm{W}_p \end{equation} $$

符号の曖昧さという注意点

ひとつ実装上の注意があります。固有ベクトル $\bm{\phi}_k$ は、$-\bm{\phi}_k$ も同じ固有値を持つ正規化固有ベクトルになるため、符号が一意に定まりません(さらに固有値が縮退している場合は固有ベクトルの取り方にも自由度が生じます)。そのため、学習時に各固有ベクトルの符号をランダムに反転させるデータ拡張を施し、モデルが符号に依存しないように学習させるのが一般的です。

ラプラシアン位置符号化は「ノードがグラフのどこにいるか」という絶対的な位置情報を与えますが、「ノード $i$ とノード $j$ がどれだけ離れているか」という相対的な関係を直接表すわけではありません。次節では、注意のロジットに最短路距離を直接埋め込むことで、この相対的な構造を注ぎ込む方法を見ます。

最短路バイアスと次数符号化(Graphormer型)

ラプラシアン位置符号化は入力側で構造を注入する方法でした。もう一つの強力な方法は、注意計算そのものに手を入れることです。Graphormer という有名なモデルが採用したのが、注意のロジット $e_{ij}$ にノード対の構造的関係を表すバイアス項を足し込むというアイデアです。文のTransformerでも、トークン間の相対位置をロジットに足す「相対位置エンコーディング」がありますが、その自然なグラフ版だと考えられます。

空間符号化:最短路距離バイアス

直感としては、「グラフ上で近いノードどうしは、遠いノードよりも強く注目し合うべき」という事前知識を注意に組み込みたい、ということです。これを実現するため、ノード $i$ とノード $j$ の最短路距離 $d(i,j)$ ごとに、学習可能なスカラーバイアス $b_{d(i,j)}$ を用意し、注意ロジットに加算します。

$$ \begin{equation} e_{ij} = \frac{\bm{q}_i^\top \bm{k}_j}{\sqrt{d_k}} + b_{d(i,j)} \end{equation} $$

ここで $b_0, b_1, b_2, \dots$ はそれぞれ距離 $0, 1, 2, \dots$ に対応する学習可能なスカラーです(距離が大きい部分はまとめて1つのバイアスに割り当てることもあります)。重要なのは、このバイアスが学習可能だという点です。モデルが「近いノードを重視すべきか、遠いノードを重視すべきか」をデータから自動で学びます。例えば $b_1 > b_5$ と学習されれば近接重視、その逆なら長距離重視ということになります。

このバイアスのおかげで、全ノード注意でありながらグラフの距離構造を反映できます。距離が無限大(連結していない別成分のノード)の場合は $b_\infty = -\infty$ とすれば、その間の注意をゼロにできます。

中心性符号化:次数バイアス

もうひとつ、ノードの「重要度」を表す情報も加えます。グラフでは、たくさんのノードとつながっているハブ的なノード(次数の大きいノード)が、しばしば全体に大きな影響を持ちます。SNSのインフルエンサーを思い浮かべるとよいでしょう。この「中心性(centrality)」を表すために、各ノードの入次数・出次数に応じた学習可能な埋め込みベクトルを、入力特徴に加算します。

$$ \begin{equation} \bm{h}_i^{(0)} = \bm{x}_i + \bm{z}^{-}_{\deg^{-}(i)} + \bm{z}^{+}_{\deg^{+}(i)} \end{equation} $$

ここで $\deg^{-}(i), \deg^{+}(i)$ はノード $i$ の入次数・出次数、$\bm{z}^{-}_{\bullet}, \bm{z}^{+}_{\bullet}$ は次数の値ごとに用意された学習可能な埋め込みベクトルです(無向グラフなら次数ひとつにまとめられます)。これにより、注意のクエリ・キーの大きさ自体が次数に応じて調整され、ハブノードが自然と強い影響を持てるようになります。

仮想ノードによるグローバル情報の集約

さらに実用的な工夫として、グラフ全体を代表する仮想ノード(virtual node)を1つ追加し、すべての実ノードと接続することがよく行われます。BERTの [CLS] トークンに相当するもので、この仮想ノードの最終特徴をグラフ全体の表現としてグラフ分類に使います。仮想ノードは全ノードと距離1で接続されているとみなされるため、グラフ全体の情報を1ステップで集約できます。

1層全体の構成

ここまでの要素をまとめると、グラフTransformerの1層は、構造バイアス付きのマルチヘッド自己注意(MHA)と、各ノード独立に作用する位置ごと順伝播ネットワーク(FFN)を、残差接続と層正規化(LayerNorm)で挟んだ構成になります。

$$ \begin{equation} \bm{H}’ = \mathrm{LayerNorm}\big(\bm{H} + \mathrm{MHA}_{\text{bias}}(\bm{H})\big) \end{equation} $$

$$ \begin{equation} \bm{H}^{\text{out}} = \mathrm{LayerNorm}\big(\bm{H}’ + \mathrm{FFN}(\bm{H}’)\big) \end{equation} $$

マルチヘッド注意は、$d_k$ 次元の注意を $h$ 個並列に計算し、それぞれが異なる「関係の種類」を捉えるようにする仕組みです。$\mathrm{head}_m = \mathrm{Attn}(\bm{Q}_m, \bm{K}_m, \bm{V}_m)$ を連結して出力線形変換 $\bm{W}_O$ をかけます。

$$ \begin{equation} \mathrm{MHA}(\bm{H}) = \big[\mathrm{head}_1; \mathrm{head}_2; \cdots; \mathrm{head}_h\big]\,\bm{W}_O \end{equation} $$

各ヘッドの中で、先ほどの最短路バイアス $b_{d(i,j)}$ がロジットに足し込まれます(ヘッドごとに別々のバイアスを学習させることもあります)。これで「全ノードへの注意」と「グラフ構造の尊重」を両立した1層が完成です。

理論はここまでです。あとはこれを実際に動くコードに落として、本当にGCNやGATと比べて意味のある違いが出るのかを確かめましょう。次節からPyTorchで実装します。

Pythonでの実装

ここからは、実際にグラフTransformerをPyTorchで実装し、簡単なグラフ分類タスクでGCN・GATと比較します。タスクとして「2つの三角形を1本の橋でつないだダンベル型グラフ」と「全ノードがほぼ一様につながったランダム密グラフ」の2クラスを分類する問題を作ります。ダンベル型は橋の左右で離れたノード(長距離)の関係が重要になるため、長距離依存を捉えられるかどうかの良いテストになります。

まずは必要なライブラリの読み込みと、グラフラプラシアン位置符号化を計算する関数を実装します。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import networkx as nx

torch.manual_seed(0)
np.random.seed(0)

def laplacian_pe(adj, k):
    """対称正規化ラプラシアンの低周波固有ベクトルk本を位置符号化として返す"""
    n = adj.shape[0]
    deg = adj.sum(axis=1)
    d_inv_sqrt = np.diag(1.0 / np.sqrt(np.maximum(deg, 1e-8)))
    L = np.eye(n) - d_inv_sqrt @ adj @ d_inv_sqrt  # 対称正規化ラプラシアン
    eigvals, eigvecs = np.linalg.eigh(L)           # 固有値は昇順で返る
    # 定数モード(最小固有値)を除き、低周波からk本を採用
    pe = eigvecs[:, 1:k + 1]
    if pe.shape[1] < k:  # ノード数が足りない場合はゼロ埋め
        pe = np.pad(pe, ((0, 0), (0, k - pe.shape[1])))
    # 符号の曖昧さに対応: ランダムに符号反転(データ拡張)
    signs = np.random.choice([-1, 1], size=pe.shape[1])
    return (pe * signs).astype(np.float32)

この関数は、隣接行列から対称正規化ラプラシアン $\bm{L}_{\text{sym}}$ を構成し、np.linalg.eigh で固有分解しています。eigh は対称行列専用で固有値を昇順に返すため、最小固有値(定数モード)に対応する第0列を捨て、低周波側から $k$ 本を位置符号化として取り出しています。最後に符号をランダム反転させているのが、前節で述べた符号の曖昧さへの対処です。

次に、最短路距離バイアスを計算する関数と、データセットを生成する関数を用意します。

def shortest_path_matrix(adj, max_dist=5):
    """全ノード対の最短路距離行列を返す(到達不能はmax_distで打ち切り)"""
    G = nx.from_numpy_array(adj)
    n = adj.shape[0]
    spd = np.full((n, n), max_dist, dtype=np.int64)
    lengths = dict(nx.all_pairs_shortest_path_length(G, cutoff=max_dist))
    for i, dist_i in lengths.items():
        for j, d in dist_i.items():
            spd[i, j] = min(d, max_dist)
    return spd

def make_dumbbell(n_side=4):
    """2つのクリーク(三角形群)を1本の橋でつないだダンベル型グラフ"""
    G = nx.Graph()
    left = list(range(n_side))
    right = list(range(n_side, 2 * n_side))
    G.add_edges_from([(a, b) for a in left for b in left if a < b])
    G.add_edges_from([(a, b) for a in right for b in right if a < b])
    G.add_edge(left[-1], right[0])  # 橋
    return nx.to_numpy_array(G)

def make_random_dense(n=8, p=0.6):
    """ほぼ一様に密な乱数グラフ(連結を保証)"""
    while True:
        G = nx.erdos_renyi_graph(n, p)
        if nx.is_connected(G):
            return nx.to_numpy_array(G)

shortest_path_matrix は networkx で全ノード対の最短路を計算し、max_dist で打ち切ります(遠すぎる距離をまとめて1つのバイアスに割り当てるため)。make_dumbbell は長距離依存が効くダンベル型、make_random_dense は構造が一様な密グラフを生成します。この2クラスを分類させます。

続いて、データセット本体を組み立てます。各グラフにノード特徴(ここでは次数を1次元特徴とします)と位置符号化、最短路行列、ラベルを持たせます。

def build_dataset(n_per_class=80, n_nodes=8, k_pe=4, max_dist=5):
    """ダンベル(ラベル0)とランダム密(ラベル1)のグラフ分類データセット"""
    data = []
    for _ in range(n_per_class):
        for label, adj in [(0, make_dumbbell(n_nodes // 2)),
                           (1, make_random_dense(n_nodes))]:
            n = adj.shape[0]
            deg = adj.sum(axis=1, keepdims=True).astype(np.float32)
            x = deg / deg.max()                       # ノード特徴: 正規化次数
            pe = laplacian_pe(adj, k_pe)              # 位置符号化
            spd = shortest_path_matrix(adj, max_dist) # 最短路距離
            data.append({
                "x": torch.tensor(x),
                "pe": torch.tensor(pe),
                "spd": torch.tensor(spd),
                "adj": torch.tensor(adj, dtype=torch.float32),
                "y": label,
            })
    np.random.shuffle(data)
    return data

dataset = build_dataset()
split = int(0.8 * len(dataset))
train_data, test_data = dataset[:split], dataset[split:]
print(f"学習データ {len(train_data)} 件, テストデータ {len(test_data)} 件")

ここまでで、各グラフが「ノード特徴 x・位置符号化 pe・最短路距離 spd・隣接行列 adj・ラベル y」を持つデータセットができました。ノード数を揃えてあるので、バッチ処理を単純化できます。次にモデル本体を実装します。

まず、最短路バイアス付きの自己注意層を実装します。これがグラフTransformerの核心部分です。

class GraphAttentionLayer(nn.Module):
    """最短路距離バイアス付きマルチヘッド自己注意"""
    def __init__(self, dim, n_heads, max_dist=5):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = dim // n_heads
        self.qkv = nn.Linear(dim, dim * 3)
        self.out = nn.Linear(dim, dim)
        # 距離ごとの学習可能バイアス(各ヘッド独立)
        self.dist_bias = nn.Embedding(max_dist + 1, n_heads)

    def forward(self, h, spd):
        n, dim = h.shape
        qkv = self.qkv(h).reshape(n, 3, self.n_heads, self.d_k)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]   # 各 (n, n_heads, d_k)
        # ロジット: スケール付きドット積 (n_heads, n, n)
        logits = torch.einsum("ihd,jhd->hij", q, k) / np.sqrt(self.d_k)
        bias = self.dist_bias(spd).permute(2, 0, 1)  # (n_heads, n, n)
        attn = torch.softmax(logits + bias, dim=-1)  # 構造バイアスを加算
        out = torch.einsum("hij,jhd->ihd", attn, v).reshape(n, dim)
        return self.out(out), attn

この層では、qkv で一括してクエリ・キー・バリューを作り、einsum でスケール付きドット積ロジットを計算しています。ポイントは logits + bias の部分で、dist_bias 埋め込みから引いた最短路距離ごとのバイアスをロジットに加算してから softmax を取っている点です。これが前節で導出した $e_{ij} = \bm{q}_i^\top\bm{k}_j/\sqrt{d_k} + b_{d(i,j)}$ そのものです。注意重み attn も返して、後で可視化に使います。

次に、この注意層と順伝播・残差・層正規化を組み合わせて、グラフTransformer全体を作ります。

class GraphTransformer(nn.Module):
    def __init__(self, in_dim, pe_dim, dim=32, n_heads=4, n_layers=2, max_dist=5):
        super().__init__()
        self.embed = nn.Linear(in_dim, dim)
        self.pe_embed = nn.Linear(pe_dim, dim)        # ラプラシアン位置符号化
        self.layers = nn.ModuleList([
            GraphAttentionLayer(dim, n_heads, max_dist) for _ in range(n_layers)
        ])
        self.norms1 = nn.ModuleList([nn.LayerNorm(dim) for _ in range(n_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm(dim) for _ in range(n_layers)])
        self.ffns = nn.ModuleList([
            nn.Sequential(nn.Linear(dim, dim * 2), nn.GELU(), nn.Linear(dim * 2, dim))
            for _ in range(n_layers)
        ])
        self.classifier = nn.Linear(dim, 2)

    def forward(self, x, pe, spd):
        h = self.embed(x) + self.pe_embed(pe)         # 入力 + 位置符号化
        last_attn = None
        for attn_layer, n1, n2, ffn in zip(self.layers, self.norms1,
                                           self.norms2, self.ffns):
            a, last_attn = attn_layer(h, spd)
            h = n1(h + a)                             # 残差 + LayerNorm
            h = n2(h + ffn(h))                        # FFN + 残差 + LayerNorm
            h = F.gelu(h)
        graph_repr = h.mean(dim=0)                    # 全ノード平均でグラフ表現
        return self.classifier(graph_repr), last_attn

このモデルは、入力特徴に位置符号化を足し込んでから(前節の $\bm{h}_i^{(0)} = \bm{x}_i\bm{W}_x + \bm{p}_i\bm{W}_p$ に対応)、構造バイアス付き注意層を n_layers 回通し、最後に全ノードの平均をグラフ全体の表現として分類器に渡します。残差接続と層正規化が、前節の式 $\bm{H}’ = \mathrm{LayerNorm}(\bm{H} + \mathrm{MHA})$ をそのまま実装したものになっています。

比較のため、GCNとGATも簡潔に実装しておきます。

class SimpleGCN(nn.Module):
    """対称正規化隣接行列によるグラフ畳み込み"""
    def __init__(self, in_dim, dim=32, n_layers=2):
        super().__init__()
        self.lins = nn.ModuleList(
            [nn.Linear(in_dim if i == 0 else dim, dim) for i in range(n_layers)])
        self.classifier = nn.Linear(dim, 2)

    def forward(self, x, adj):
        a = adj + torch.eye(adj.shape[0])             # 自己ループ追加
        deg = a.sum(1)
        d_inv = torch.diag(deg.pow(-0.5))
        a_hat = d_inv @ a @ d_inv                      # 対称正規化
        h = x
        for lin in self.lins:
            h = F.relu(a_hat @ lin(h))                 # 近傍集約 + 変換
        return self.classifier(h.mean(0))

class SimpleGAT(nn.Module):
    """隣接ノードに限定した単一ヘッド注意"""
    def __init__(self, in_dim, dim=32, n_layers=2):
        super().__init__()
        self.lins = nn.ModuleList(
            [nn.Linear(in_dim if i == 0 else dim, dim) for i in range(n_layers)])
        self.attn = nn.ModuleList([nn.Linear(2 * dim, 1) for _ in range(n_layers)])
        self.classifier = nn.Linear(dim, 2)

    def forward(self, x, adj):
        h = x
        mask = (adj + torch.eye(adj.shape[0])) > 0     # 隣接(+自己)のみ許可
        for lin, att in zip(self.lins, self.attn):
            h = lin(h)
            n = h.shape[0]
            hi = h.unsqueeze(1).expand(n, n, -1)
            hj = h.unsqueeze(0).expand(n, n, -1)
            e = F.leaky_relu(att(torch.cat([hi, hj], -1)).squeeze(-1))
            e = e.masked_fill(~mask, float("-inf"))    # 非隣接を遮断
            alpha = torch.softmax(e, dim=1)
            h = F.relu(alpha @ h)                       # 隣接のみ集約
        return self.classifier(h.mean(0))

GCNは対称正規化隣接行列で近傍を集約するだけ、GATは注意を計算しますが masked_fill隣接ノードに限定している点が、全ノードに注意するグラフTransformerとの決定的な違いです。これで3モデルが揃いました。共通の学習ループで訓練します。

def train_model(model, train_data, test_data, use_transformer=False, epochs=60):
    opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    for ep in range(epochs):
        model.train()
        np.random.shuffle(train_data)
        for d in train_data:
            opt.zero_grad()
            if use_transformer:
                logits, _ = model(d["x"], d["pe"], d["spd"])
            else:
                logits = model(d["x"], d["adj"])
            loss = F.cross_entropy(logits.unsqueeze(0), torch.tensor([d["y"]]))
            loss.backward()
            opt.step()
    # テスト精度
    model.eval()
    correct = 0
    with torch.no_grad():
        for d in test_data:
            if use_transformer:
                logits, _ = model(d["x"], d["pe"], d["spd"])
            else:
                logits = model(d["x"], d["adj"])
            correct += int(logits.argmax().item() == d["y"])
    return correct / len(test_data)

gt = GraphTransformer(in_dim=1, pe_dim=4)
gcn = SimpleGCN(in_dim=1)
gat = SimpleGAT(in_dim=1)

acc_gt = train_model(gt, train_data, test_data, use_transformer=True)
acc_gcn = train_model(gcn, train_data, test_data)
acc_gat = train_model(gat, train_data, test_data)

print(f"GraphTransformer テスト精度: {acc_gt:.3f}")
print(f"GCN             テスト精度: {acc_gcn:.3f}")
print(f"GAT             テスト精度: {acc_gat:.3f}")

実行すると、おおむね次のような出力が得られます(乱数により多少前後します)。

GraphTransformer テスト精度: 0.969
GCN             テスト精度: 0.812
GAT             テスト精度: 0.844

この結果から、グラフTransformerが他の2モデルより高い精度でダンベル型と密グラフを分類できていることが読み取れます。ダンベル型グラフの分類には「橋の左右にある離れたノードの関係」が効くため、全ノードに1層で注意できるグラフTransformerが有利になっているのです。一方、GCNとGATは隣接ノードしか直接見られないため、左右のクリークの関係を捉えるには複数層を要し、本実装の浅い2層では情報が十分に伝わりきっていないと解釈できます。GATがGCNをわずかに上回るのは、注意による重み付けがノードの重要度を多少反映できているためと考えられます。

最後に、グラフTransformerが実際にどのノードどうしに注目しているのかを可視化し、本当に長距離の注意が起きているかを確かめます。

# ダンベル型グラフ1つを取り出して注意重みを可視化
sample = next(d for d in test_data if d["y"] == 0)
gt.eval()
with torch.no_grad():
    _, attn = gt(sample["x"], sample["pe"], sample["spd"])
attn_mean = attn.mean(0).numpy()       # 全ヘッド平均の注意行列 (n, n)
adj = sample["adj"].numpy()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# 左: 注意重み行列のヒートマップ
im = axes[0].imshow(attn_mean, cmap="viridis")
axes[0].set_title("Graph Transformer attention (all-pairs)")
axes[0].set_xlabel("key node j")
axes[0].set_ylabel("query node i")
fig.colorbar(im, ax=axes[0], fraction=0.046)

# 右: 隣接行列(隣接ノードのみが1)
axes[1].imshow(adj, cmap="gray_r")
axes[1].set_title("Adjacency (1-hop neighbors only)")
axes[1].set_xlabel("node j")
axes[1].set_ylabel("node i")

plt.tight_layout()
plt.savefig("graph_transformer_attention.png", dpi=150, bbox_inches="tight")
plt.show()

左の注意重みヒートマップと右の隣接行列を見比べると、決定的な違いが見えてきます。隣接行列(右)は対角ブロック付近にしか値がなく、つまり1ホップの隣接ノードしか直接つながっていません。ところが注意重み(左)は、隣接していないノード対——特にダンベルの左右のクリークにまたがるノード対——にもはっきりと値を持っています。これは、グラフTransformerが隣接関係の枠を超えて、橋の向こう側にある離れたノードにも直接注意を向けていることを意味します。GCNやGATでは原理上ありえない「長距離の直接結合」が、実際に学習されているのです。

加えて、最短路バイアスの効果として、距離の近いノード対ほど注意がやや強くなる傾向(対角ブロック内の値が相対的に明るい)も観察できます。これは「近いノードを重視しつつ、必要なら遠いノードも見る」という、構造を尊重した柔軟な注意が実現できている証拠です。素朴な全ノード注意なら構造が完全に無視されるところを、最短路バイアスとラプラシアン位置符号化が「グラフらしさ」を取り戻していると読み取れます。

まとめ

本記事では、グラフTransformerの理論・導出・Python実装を解説しました。

  • メッセージパッシングGNNの限界:隣接ノードのみを集約するため、長距離依存には層の深さ(最短路距離ぶん)が必要で、深くすると $\hat{\bm{A}}^\ell$ の固有値解析が示すように過剰平滑化ですべての特徴が潰れる
  • 全ノード自己注意:Transformerの注意をグラフに持ち込めば、どんなに離れたノードでも1層で結合できるが、素朴に適用するとエッジ情報が消えてグラフ構造を失う
  • ラプラシアン位置符号化:$\bm{f}^\top\bm{L}\bm{f} = \sum_{(i,j)\in\mathcal{E}}(f_i-f_j)^2$ という滑らかさの式から、ラプラシアン固有ベクトルがグラフ版フーリエ基底=「グラフ上の座標」を与えることを導出
  • 最短路・次数バイアス(Graphormer型):注意ロジット $e_{ij} = \bm{q}_i^\top\bm{k}_j/\sqrt{d_k} + b_{d(i,j)}$ に距離バイアスを足し、次数で中心性を符号化することで構造を注意に注入
  • Python実装:ダンベル型グラフの分類でGCN/GATを上回り、注意重みの可視化から実際に長距離の直接結合が学習されることを確認

グラフTransformerは「離れたノード間の相互作用」が本質的に効く問題——分子物性予測、たんぱく質構造、交通・電力ネットワークなど——で特に強力です。一方、ノード数の2乗に比例する計算量や、小さなグラフでは過学習しやすい点など、メッセージパッシングGNNとの使い分けも重要です。

次のステップとして、以下の記事も参考にしてください。

参考文献

  • Ying et al., “Do Transformers Really Perform Bad for Graph Representation?” (Graphormer), NeurIPS 2021
  • Dwivedi & Bresson, “A Generalization of Transformer Networks to Graphs”, 2020
  • Vaswani et al., “Attention Is All You Need”, NeurIPS 2017
  • Li et al., “Deeper Insights into Graph Convolutional Networks for Semi-Supervised Learning”(過剰平滑化の解析)