グラフ構造の表現学習(Representation Learning)を解説

表現学習(Representation Learning)は、データの本質的な特徴を低次元のベクトル表現(分散表現)として自動的に獲得する手法です。特にグラフ構造データにおいては、ノードやエッジの構造的な情報を密なベクトルに変換することで、ノード分類やリンク予測などの下流タスクに利用できます。

グラフ畳み込みネットワーク(GCN)はこの表現学習の代表的手法で、隣接ノードの特徴量を集約して各ノードの表現を更新するメッセージパッシングの仕組みに基づいています。

本記事の内容

  • 表現学習の基本概念
  • GCNのメッセージパッシング
  • 数学的定式化
  • PyTorch Geometric(PyG)での実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

表現学習とは

表現学習の目的は、入力データを有用な特徴ベクトル(分散表現、embedding)に変換することです。

例えば色はRGB値 $(r, g, b)$ のような3次元ベクトルで表現されます。同様に、グラフのノードもベクトルとして表現できれば、機械学習のアルゴリズムで扱いやすくなります。

重要なのは、このベクトル表現が手動で設計されるのではなく、データから自動的に学習される点です。

グラフにおける表現学習の課題

画像やテキストと異なり、グラフデータには以下の特殊性があります。

  • 不規則な構造: ノードの隣接数は可変
  • 順序がない: ノードに自然な順序がない(順列不変性)
  • 構造情報: 隣接関係が重要な情報を持つ

これらの特性に対応するため、GCN(Graph Convolutional Network)はメッセージパッシングの仕組みを用います。

GCNの数学的定式化

GCNの1層のメッセージパッシングは以下のように定式化されます。

$$ \bm{h}_i^{(l+1)} = \sigma\left( \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{c_{ij}} \bm{h}_j^{(l)} \bm{W}^{(l)} \right) $$

ここで、

  • $\bm{h}_i^{(l)}$: $l$ 層目のノード $i$ の特徴ベクトル
  • $\mathcal{N}(i)$: ノード $i$ の隣接ノード集合
  • $c_{ij} = \sqrt{|\mathcal{N}(i)||\mathcal{N}(j)|}$: 正規化定数
  • $\bm{W}^{(l)}$: $l$ 層目の学習可能な重み行列
  • $\sigma$: 活性化関数(ReLUなど)

行列形式では、

$$ \bm{H}^{(l+1)} = \sigma(\tilde{\bm{D}}^{-1/2}\tilde{\bm{A}}\tilde{\bm{D}}^{-1/2}\bm{H}^{(l)}\bm{W}^{(l)}) $$

  • $\tilde{\bm{A}} = \bm{A} + \bm{I}$: 自己ループを加えた隣接行列
  • $\tilde{\bm{D}}$: $\tilde{\bm{A}}$ の次数行列($\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}$)

メッセージパッシングの直感

メッセージパッシングを直感的に理解すると、各ノードが「隣接ノードの情報を集約して自身の表現を更新する」というプロセスを繰り返すことになります。

1層のGCNでは1ホップ先の隣接ノードの情報が集約されます。2層にすると2ホップ先まで、$L$ 層にすると $L$ ホップ先の情報まで取り込めます。

PyTorch Geometricでの実装

Zachary Karate Clubのノード分類をGCNで実装します。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx

# --- データの準備 ---
dataset = KarateClub()
data = dataset[0]

print(f"ノード数: {data.num_nodes}")
print(f"エッジ数: {data.num_edges}")
print(f"クラス数: {len(data.y.unique())}")

# --- GCNモデルの定義 ---
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # 第1層: メッセージパッシング + ReLU
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)

        # 第2層: メッセージパッシング + ReLU
        h = self.conv2(h, edge_index)
        h = F.relu(h)

        # 分類層
        out = self.classifier(h)
        return out, h

# --- 学習 ---
torch.manual_seed(42)
model = GCN(in_channels=data.num_node_features,
            hidden_channels=16,
            out_channels=len(data.y.unique()))

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 訓練マスク(一部のノードのみラベル付き)
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
train_mask[data.train_mask] = True

losses = []
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

# --- 評価 ---
model.eval()
with torch.no_grad():
    out, embeddings = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    acc = (pred == data.y).float().mean()
    print(f"\n全ノードの分類精度: {acc:.4f}")

# --- 可視化 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 学習曲線
ax1 = axes[0]
ax1.plot(losses, 'b-')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

# ノード埋め込みの可視化
ax2 = axes[1]
emb = embeddings.detach().numpy()
# t-SNEで2次元に射影
from sklearn.manifold import TSNE
emb_2d = TSNE(n_components=2, random_state=42).fit_transform(emb)
scatter = ax2.scatter(emb_2d[:, 0], emb_2d[:, 1],
                      c=data.y.numpy(), cmap='tab10', s=60, alpha=0.8)
ax2.set_xlabel('Dim 1')
ax2.set_ylabel('Dim 2')
ax2.set_title('Node Embeddings (t-SNE)')
plt.colorbar(scatter, ax=ax2, label='Class')
ax2.grid(True, alpha=0.3)

# 分類結果のグラフ可視化
ax3 = axes[2]
G = to_networkx(data, to_undirected=True)
pos = nx.spring_layout(G, seed=42)
node_colors = pred.numpy()
nx.draw(G, pos, node_color=node_colors, cmap=plt.cm.Set3,
        with_labels=True, node_size=300, font_size=8,
        edge_color='gray', alpha=0.8, ax=ax3)
ax3.set_title('Classification Result')

plt.tight_layout()
plt.show()

GCNによる表現学習の結果、同じコミュニティに属するノードは潜在空間上で近くに配置され、ノード分類タスクで高い精度が得られることが確認できます。

まとめ

本記事では、グラフ構造の表現学習について解説しました。

  • 表現学習はデータの本質的な特徴を低次元のベクトル表現として自動的に学習する手法
  • GCNはメッセージパッシングにより隣接ノードの情報を集約してノード表現を更新する
  • 行列形式では $\bm{H}^{(l+1)} = \sigma(\tilde{\bm{D}}^{-1/2}\tilde{\bm{A}}\tilde{\bm{D}}^{-1/2}\bm{H}^{(l)}\bm{W}^{(l)})$ で表される
  • 学習されたノード埋め込みは、ノード分類やリンク予測などの下流タスクに利用できる

次のステップとして、以下の記事も参考にしてください。