GNNを利用した異常検知手法の全体像を解説

GNN(Graph Neural Network)を利用した異常検知は、近年急速に発展している研究分野です。2017年にICLRでGCNが発表されて以来、GNNを用いた多変量時系列の異常検知(Multivariate Time Series Anomaly Detection, MTSAD)の論文がICMLやAAAIなどのトップカンファレンスで次々と発表されています。

本記事では、GNNを利用した異常検知のサーベイ論文を基に、手法の全体像と主要なアプローチを整理します。

本記事の内容

  • グラフ異常検知の問題設定
  • ノード異常・エッジ異常・グラフ異常の分類
  • GNNベースの異常検知手法の概要
  • Graph AutoEncoderによる異常検知のPython実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

グラフ異常検知の問題設定

グラフ上の異常検知は、異常の種類によって3つのレベルに分類されます。

ノード異常検知(Node-level)

グラフ上の特定のノードが異常かどうかを判定。例: ソーシャルネットワークにおける不正アカウントの検出。

エッジ異常検知(Edge-level)

特定のエッジ(接続関係)が異常かどうかを判定。例: 金融取引ネットワークにおける不正取引の検出。

グラフ異常検知(Graph-level)

グラフ全体が異常かどうかを判定。例: 化学分子構造の異常検知。

GNNベースの異常検知アプローチ

GNNベースの異常検知は、モデリングのアプローチによっていくつかの方法に分類できます。

Graph AutoEncoder (GAE)

最も基本的なアプローチで、グラフの構造やノード特徴量を再構成するAutoEncoderを学習します。正常データで学習し、再構成誤差が大きいデータを異常と判定します。

エンコーダはGCNなどのGNNで構成し、ノード埋め込み $\bm{Z}$ を獲得します。

$$ \bm{Z} = \text{GCN}(\bm{X}, \bm{A}) $$

デコーダでは、ノード埋め込みから隣接行列を再構成します。

$$ \hat{\bm{A}} = \sigma(\bm{Z}\bm{Z}^T) $$

再構成誤差(損失関数):

$$ L = \|\bm{A} – \hat{\bm{A}}\|_F^2 $$

GNN + 構造的特徴量

グラフのトポロジー(次数、クラスタリング係数など)の違いを利用して異常を検知する手法。GNNで学習した表現に構造的特徴量を組み合わせます。

GNN + Attention

Attention機構を組み込んだGNN(例: GAT, Graph Attention Network)を用いて、異常に関連する重要なノードやエッジに注目する手法です。

Graph AutoEncoderのPython実装

PyTorch Geometricを用いて、Graph AutoEncoderによるノード異常検知を実装します。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_adj

torch.manual_seed(42)
np.random.seed(42)

# --- 合成グラフデータの生成 ---
def generate_graph_data(n_normal=50, n_anomaly=5):
    """コミュニティ構造を持つグラフに異常ノードを追加"""
    n_total = n_normal + n_anomaly

    # 正常ノード: 2つのコミュニティ
    edges = []
    for i in range(n_normal // 2):
        for j in range(i + 1, n_normal // 2):
            if np.random.random() < 0.3:
                edges.append([i, j])
                edges.append([j, i])

    for i in range(n_normal // 2, n_normal):
        for j in range(i + 1, n_normal):
            if np.random.random() < 0.3:
                edges.append([i, j])
                edges.append([j, i])

    # コミュニティ間のリンク
    for i in range(n_normal // 2):
        for j in range(n_normal // 2, n_normal):
            if np.random.random() < 0.05:
                edges.append([i, j])
                edges.append([j, i])

    # 異常ノード: ランダムに少数のリンク
    for i in range(n_normal, n_total):
        targets = np.random.choice(n_normal, 2, replace=False)
        for t in targets:
            edges.append([i, t])
            edges.append([t, i])

    edge_index = torch.tensor(edges, dtype=torch.long).t()

    # ノード特徴量
    x = torch.randn(n_total, 16)
    # 異常ノードは異なる分布
    x[n_normal:] = torch.randn(n_anomaly, 16) * 3 + 2

    labels = torch.zeros(n_total, dtype=torch.long)
    labels[n_normal:] = 1

    return x, edge_index, labels, n_total

x, edge_index, labels, n_total = generate_graph_data()

# --- Graph AutoEncoder ---
class GraphAutoEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_channels):
        super().__init__()
        # エンコーダ
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, latent_channels)

    def encode(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        z = self.conv2(h, edge_index)
        return z

    def decode(self, z):
        # 内積デコーダ
        return torch.sigmoid(z @ z.t())

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        adj_hat = self.decode(z)
        return adj_hat, z

# --- 学習 ---
model = GraphAutoEncoder(in_channels=16, hidden_channels=32, latent_channels=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 正解の隣接行列
adj_true = to_dense_adj(edge_index, max_num_nodes=n_total)[0]

losses = []
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    adj_hat, z = model(x, edge_index)
    loss = F.binary_cross_entropy(adj_hat, adj_true)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

# --- 異常スコアの計算 ---
model.eval()
with torch.no_grad():
    adj_hat, z = model(x, edge_index)
    # 各ノードの再構成誤差を異常スコアとする
    recon_error = ((adj_true - adj_hat) ** 2).mean(dim=1).numpy()

# --- 可視化 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 学習曲線
ax1 = axes[0]
ax1.plot(losses, 'b-')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

# 異常スコアの分布
ax2 = axes[1]
normal_scores = recon_error[labels == 0]
anomaly_scores = recon_error[labels == 1]
ax2.hist(normal_scores, bins=20, alpha=0.6, color='blue', label='Normal', density=True)
ax2.hist(anomaly_scores, bins=10, alpha=0.6, color='red', label='Anomaly', density=True)
ax2.set_xlabel('Reconstruction Error')
ax2.set_ylabel('Density')
ax2.set_title('Anomaly Score Distribution')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 潜在空間の可視化
embeddings = z.numpy()
from sklearn.manifold import TSNE
emb_2d = TSNE(n_components=2, random_state=42).fit_transform(embeddings)
ax3 = axes[2]
colors = ['blue' if l == 0 else 'red' for l in labels]
ax3.scatter(emb_2d[:, 0], emb_2d[:, 1], c=colors, s=50, alpha=0.7)
ax3.set_xlabel('Dim 1')
ax3.set_ylabel('Dim 2')
ax3.set_title('Latent Space (Blue=Normal, Red=Anomaly)')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"正常ノードの平均異常スコア: {normal_scores.mean():.4f}")
print(f"異常ノードの平均異常スコア: {anomaly_scores.mean():.4f}")

Graph AutoEncoderは正常なグラフ構造を再構成するように学習するため、異常なノードは再構成誤差が大きくなり、異常スコアとして利用できます。

まとめ

本記事では、GNNを利用した異常検知手法の全体像を解説しました。

  • グラフ異常検知はノード・エッジ・グラフの3つのレベルに分類される
  • Graph AutoEncoderはグラフ構造の再構成誤差を異常スコアとして利用する基本的なアプローチ
  • GCNのメッセージパッシングにより、ノードの局所的な構造情報を表現に反映できる
  • Attention機構や構造的特徴量との組み合わせにより、さらに高度な異常検知が可能

次のステップとして、以下の記事も参考にしてください。