UniTS — タスクトークン化で時系列の予測・異常検知・分類を統一する【NeurIPS 2024】

衛星の運用現場では、テレメトリデータに対して複数のタスクを同時に処理する必要があります。「今後30分の温度推移を予測せよ」「このパターンは異常か?」「この振動パターンの種類(正常、共振、故障)を分類せよ」— これらは本質的に異なるタスクであり、従来は各タスクごとに専用のモデルを訓練していました。

しかし、タスクごとに別モデルを持つアプローチには深刻な非効率性があります。3つのタスクに3つのモデルを訓練する場合、訓練コストは3倍、デプロイのメモリも3倍、各モデルのハイパーパラメータ調整も独立に行う必要があります。さらに重要な問題として、各モデルが学習する「時系列の表現」は互いに共有されません。予測タスクで学んだ長期トレンドの知識が、異常検知に活用されることはないのです。

もし1つのモデルが全タスクを処理できたら? NLPの世界では、GPT-4がテキスト生成・要約・翻訳・質問応答を1つのモデルで処理しています。その鍵は「タスクの種類をプロンプト(テキスト指示)で指定する」というインターフェースです。

UniTS(Harvard, NeurIPS 2024)は、この考え方を時系列に適用しました。タスクトークンという特殊なトークンを入力に追加するだけで、同じTransformerが予測・異常検知・分類を切り替えて実行します。

UniTSを理解することは、以下のような場面で直接役立ちます。

  • テレメトリ運用の効率化: 1つのモデルで予測+異常検知+分類を同時処理でき、デプロイ・メンテナンスコストが大幅に削減されます
  • マルチタスク学習の相乗効果: 異なるタスクが共通の時系列表現を共有することで、各タスクの性能が単独学習よりも向上します

本記事の内容

  • なぜ統一モデルが必要か
  • タスクトークン化のアイデア(NLPのプロンプト学習との関係)
  • UniTSのアーキテクチャ詳細
  • マルチタスク学習の損失関数設計
  • Pythonによる簡易タスクトークン化モデルの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

なぜ統一モデルが必要か

タスクごとのモデルの非効率性

現在の時系列モデルの典型的な運用を考えます。

タスク モデル 出力形式 訓練データ
予測 PatchTST 連続値系列 $\hat{\bm{x}} \in \mathbb{R}^H$ 過去→未来のペア
異常検知 専用AE 異常スコア $s \in \mathbb{R}$ 正常データ
分類 ResNet-1D クラス確率 $\bm{p} \in \mathbb{R}^C$ ラベル付きデータ

3つのモデルはそれぞれ独立に時系列を処理し、それぞれが独自の「時系列の理解」を持ちます。しかし、3つのタスクは共通の基盤 — 時系列のパターン理解 — の上に成り立っています。

予測タスクで学んだ「周期的パターンの周期は90分である」という知識は、異常検知(この周期からの逸脱を検出)にも、分類(90分周期 = 低軌道衛星の温度パターン)にも有用です。

NLPにおける統一モデルの成功

NLPの世界では、T5(Text-to-Text Transfer Transformer)が全タスクを「テキスト→テキスト」の形式に統一しました。

  • 翻訳: “translate English to French: The cat sat on the mat” → “Le chat était assis sur le tapis”
  • 要約: “summarize: {長い文章}” → “{短い要約}”
  • 分類: “classify sentiment: This movie is great” → “positive”

タスクの種類は入力テキスト内のプレフィックス(プロンプト)で指定されます。UniTSはこのアイデアを時系列に適用しますが、テキストの代わりにタスクトークンを使います。

NLPとの対比で統一モデルの動機が明確になったところで、UniTSのタスクトークン化の具体的な仕組みを見ていきましょう。

タスクトークン化

タスクトークンの定義

UniTSでは、各タスクに対応する学習可能なタスクトークンを定義します。

$$ \bm{t}_{\text{forecast}} \in \mathbb{R}^d, \quad \bm{t}_{\text{anomaly}} \in \mathbb{R}^d, \quad \bm{t}_{\text{classify}} \in \mathbb{R}^d $$

これらはモデルのパラメータの一部として学習されます。NLPのプロンプトがテキストで「このタスクは翻訳です」と指定するのに対し、UniTSのタスクトークンは連続ベクトルで「このタスクは予測です」と指定します。

入力の構成

時系列パッチ列 $(\bm{h}_1, \bm{h}_2, \ldots, \bm{h}_N)$ にタスクトークンを先頭に追加して、Transformerに入力します。

$$ \text{入力} = [\bm{t}_{\text{task}}; \bm{h}_1, \bm{h}_2, \ldots, \bm{h}_N] $$

ここで $[\cdot; \cdot]$ はトークン列の結合です。Transformer は $N + 1$ 個のトークン(タスクトークン1個 + パッチトークン $N$ 個)を処理します。

タスクトークンはSelf-Attentionを通じて全パッチトークンと相互作用し、タスクに依存した特徴抽出が行われます。予測タスクのタスクトークンは長期トレンドに注目するAttentionパターンを誘導し、異常検知のタスクトークンは短期的な逸脱に注目するパターンを誘導します。

プロンプト学習との関係

タスクトークン化は、NLPのプロンプトチューニング(Prompt Tuning)と数学的に同等です。

プロンプトチューニングでは、事前学習済みのTransformerの入力に学習可能なソフトプロンプト $\bm{p}_1, \ldots, \bm{p}_m$ を追加します。

$$ \text{入力} = [\bm{p}_1, \ldots, \bm{p}_m; \bm{x}_1, \ldots, \bm{x}_n] $$

UniTSのタスクトークンはソフトプロンプトの特殊なケース($m = 1$)と解釈できます。1つのプロンプトベクトルだけで、Transformerの振る舞いを予測モード/異常検知モード/分類モードに切り替えます。

タスクトークン化のアイデアが明確になったところで、UniTSのアーキテクチャ全体を見ていきましょう。

UniTSのアーキテクチャ

全体構造

UniTSの全体構造は以下の5つのコンポーネントで構成されます。

1. パッチ埋め込み: PatchTSTと同様に、時系列をパッチに分割し、線形層でトークン埋め込みに変換します。

$$ \bm{h}_i = \bm{p}_i \bm{W}_{\text{patch}} + \bm{b}_{\text{patch}} + \bm{e}_i^{\text{pos}} $$

2. タスクトークンの追加: タスクに応じたタスクトークンを先頭に追加します。

3. 共有Transformer Encoder: 全タスクで共有される $L$ 層のTransformer Encoder。タスクトークンとパッチトークンが自由に相互参照します。

$$ [\bm{t}’^{(l+1)}; \bm{h}’^{(l+1)}] = \text{TransformerBlock}([\bm{t}’^{(l)}; \bm{h}^{(l)}]) $$

4. タスク固有ヘッド: Encoderの出力に対して、タスクごとに異なる出力ヘッドを適用します。

  • 予測ヘッド: パッチトークンの出力を平坦化し、線形層で予測系列を出力

$$ \hat{\bm{x}}_{\text{future}} = \text{Linear}(\text{Flatten}(\bm{h}_1^{(L)}, \ldots, \bm{h}_N^{(L)})) $$

  • 異常検知ヘッド: タスクトークンの出力から異常スコアを出力

$$ s_{\text{anomaly}} = \sigma(\text{Linear}(\bm{t}’^{(L)})) $$

  • 分類ヘッド: タスクトークンの出力からクラス確率を出力

$$ \bm{p}_{\text{class}} = \text{softmax}(\text{Linear}(\bm{t}’^{(L)})) $$

5. マルチタスク損失: 各タスクの損失を重み付き和で統合します。

タスクトークンのAttentionの挙動

タスクトークンがどのようにTransformerの挙動を変えるかを考えます。Self-Attentionにおいて、タスクトークン $\bm{t}$ はクエリとしてもキーとしても機能します。

タスクトークンがクエリの場合: タスクトークンは全パッチトークンに注目し、タスクに必要な情報を集約します。予測タスクトークンはトレンド情報を持つパッチに高いAttention重みを割り当て、異常検知タスクトークンは異常が疑われるパッチに注目します。

タスクトークンがキーの場合: パッチトークンがタスクトークンに注目することで、「今処理すべきタスクは何か」という文脈情報がパッチトークンの表現に注入されます。これにより、同じパッチでも予測タスクでは長期トレンド成分が強調され、異常検知タスクでは短期変動成分が強調される、というタスク依存の表現変換が実現されます。

マルチタスク損失

全体の損失関数は、各タスクの損失の重み付き和です。

$$ \mathcal{L} = \lambda_{\text{fc}} \mathcal{L}_{\text{forecast}} + \lambda_{\text{ad}} \mathcal{L}_{\text{anomaly}} + \lambda_{\text{cl}} \mathcal{L}_{\text{classify}} $$

各タスクの損失は以下で定義されます。

予測損失: MSE(Mean Squared Error)

$$ \mathcal{L}_{\text{forecast}} = \frac{1}{H} \sum_{t=1}^H (x_{L+t} – \hat{x}_{L+t})^2 $$

異常検知損失: Binary Cross-Entropy

$$ \mathcal{L}_{\text{anomaly}} = -[y \log s + (1-y) \log(1-s)] $$

分類損失: Cross-Entropy

$$ \mathcal{L}_{\text{classify}} = -\sum_{c=1}^C y_c \log p_c $$

重み $\lambda$ の設定は重要で、タスク間のスケールの違いを補正する必要があります。UniTSでは不確実性ベースの自動重み付け(Kendall et al., 2018)が使われます。各タスクの損失の不確実性 $\sigma_{\text{task}}$ を学習パラメータとして持ち、

$$ \mathcal{L} = \frac{1}{2\sigma_{\text{fc}}^2} \mathcal{L}_{\text{forecast}} + \frac{1}{2\sigma_{\text{ad}}^2} \mathcal{L}_{\text{anomaly}} + \frac{1}{2\sigma_{\text{cl}}^2} \mathcal{L}_{\text{classify}} + \log \sigma_{\text{fc}} + \log \sigma_{\text{ad}} + \log \sigma_{\text{cl}} $$

$\log \sigma$ の正則化項が無限大への発散を防ぎます。損失が大きいタスクは自動的に $\sigma$ が大きくなり、重みが下がる(学習が容易なタスクを優先する)仕組みです。

アーキテクチャを理解したところで、Pythonでタスクトークン化の仕組みを実装してみましょう。

Pythonによる実装

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# --- 合成データの生成(3タスク分)---
def generate_multitask_data(n_samples=500, length=128):
    """予測・異常検知・分類の3タスク用データ"""
    t = np.linspace(0, 4 * np.pi, length)
    data = {'forecast': [], 'anomaly': [], 'classify': []}
    labels = {'anomaly': [], 'classify': []}

    for _ in range(n_samples):
        # 基本パターン生成
        pattern_type = np.random.choice([0, 1, 2])
        noise = np.random.randn(length) * 0.1

        if pattern_type == 0:  # 正弦波
            x = np.sin(t) + noise
        elif pattern_type == 1:  # 三角波
            x = 2 * np.abs(2 * (t / (4*np.pi) - np.floor(t / (4*np.pi) + 0.5))) - 1 + noise
        else:  # ステップ関数
            x = np.sign(np.sin(t)) + noise

        # 予測タスク: 前半→後半
        data['forecast'].append(x)

        # 異常検知タスク: 50%の確率で異常を挿入
        x_anom = x.copy()
        is_anomaly = np.random.rand() > 0.5
        if is_anomaly:
            anom_pos = np.random.randint(length//4, 3*length//4)
            x_anom[anom_pos:anom_pos+10] += np.random.choice([-3, 3])
        data['anomaly'].append(x_anom)
        labels['anomaly'].append(int(is_anomaly))

        # 分類タスク
        data['classify'].append(x)
        labels['classify'].append(pattern_type)

    return {k: np.array(v) for k, v in data.items()}, {k: np.array(v) for k, v in labels.items()}

data, labels = generate_multitask_data(n_samples=600, length=128)

# --- 簡易UniTSモデル ---
class SimpleUniTS:
    def __init__(self, input_dim=128, patch_size=16, d_model=32, n_classes=3, lr=0.005):
        self.patch_size = patch_size
        self.d_model = d_model
        self.n_patches = input_dim // patch_size
        self.lr = lr

        # パッチ埋め込み
        self.W_patch = np.random.randn(patch_size, d_model) * 0.1
        self.b_patch = np.zeros(d_model)

        # タスクトークン(3種類)
        self.task_tokens = {
            'forecast': np.random.randn(d_model) * 0.1,
            'anomaly': np.random.randn(d_model) * 0.1,
            'classify': np.random.randn(d_model) * 0.1,
        }

        # Self-Attention
        self.W_q = np.random.randn(d_model, d_model) * 0.05
        self.W_k = np.random.randn(d_model, d_model) * 0.05
        self.W_v = np.random.randn(d_model, d_model) * 0.05

        # タスク固有ヘッド
        self.W_forecast = np.random.randn(d_model * self.n_patches, input_dim // 2) * 0.01
        self.W_anomaly = np.random.randn(d_model, 1) * 0.1
        self.W_classify = np.random.randn(d_model, n_classes) * 0.1

    def patchify(self, x):
        """時系列をパッチに分割"""
        patches = x.reshape(-1, self.n_patches, self.patch_size)
        return patches

    def forward(self, x, task):
        """タスクに応じた処理"""
        batch_size = len(x)
        patches = self.patchify(x)
        # パッチ埋め込み
        h = np.tanh(patches @ self.W_patch + self.b_patch)  # (B, N, d)

        # タスクトークンを追加
        task_tok = self.task_tokens[task][None, None, :].repeat(batch_size, axis=0)  # (B, 1, d)
        h_with_task = np.concatenate([task_tok, h], axis=1)  # (B, N+1, d)

        # Self-Attention(バッチ処理)
        Q = h_with_task @ self.W_q
        K = h_with_task @ self.W_k
        V = h_with_task @ self.W_v
        scale = np.sqrt(self.d_model)

        attn_weights_all = []
        output = np.zeros_like(h_with_task)
        for b in range(batch_size):
            scores = Q[b] @ K[b].T / scale
            attn = np.exp(scores - scores.max(axis=-1, keepdims=True))
            attn /= attn.sum(axis=-1, keepdims=True)
            output[b] = attn @ V[b]
            attn_weights_all.append(attn)

        h_out = output + h_with_task  # 残差接続

        # タスク固有ヘッド
        task_out = h_out[:, 0, :]  # タスクトークンの出力 (B, d)
        patch_out = h_out[:, 1:, :]  # パッチトークンの出力 (B, N, d)

        if task == 'forecast':
            flat = patch_out.reshape(batch_size, -1)
            pred = flat @ self.W_forecast
            return pred, np.array(attn_weights_all)
        elif task == 'anomaly':
            score = 1 / (1 + np.exp(-task_out @ self.W_anomaly))
            return score.squeeze(), np.array(attn_weights_all)
        else:  # classify
            logits = task_out @ self.W_classify
            exp_logits = np.exp(logits - logits.max(axis=1, keepdims=True))
            probs = exp_logits / exp_logits.sum(axis=1, keepdims=True)
            return probs, np.array(attn_weights_all)

model = SimpleUniTS(input_dim=128, patch_size=16, d_model=32, n_classes=3)

# --- 各タスクのAttentionパターンを可視化 ---
sample_x = data['forecast'][:5]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
tasks = ['forecast', 'anomaly', 'classify']
task_titles = ['Forecast', 'Anomaly Detection', 'Classification']

for idx, task in enumerate(tasks):
    _, attn = model.forward(sample_x, task)
    avg_attn = attn.mean(axis=0)  # サンプル平均

    im = axes[idx].imshow(avg_attn, cmap='viridis', aspect='auto')
    axes[idx].set_title(f'Attention Pattern: {task_titles[idx]}')
    axes[idx].set_xlabel('Key position (0=task token)')
    axes[idx].set_ylabel('Query position (0=task token)')
    # タスクトークンの位置を強調
    axes[idx].axhline(y=0.5, color='red', linewidth=1, linestyle='--', alpha=0.5)
    axes[idx].axvline(x=0.5, color='red', linewidth=1, linestyle='--', alpha=0.5)
    plt.colorbar(im, ax=axes[idx], fraction=0.046)

plt.suptitle('Task-dependent Attention Patterns (row 0 / col 0 = task token)', fontsize=12)
plt.tight_layout()
plt.savefig('units_attention.png', dpi=150, bbox_inches='tight')
plt.show()

3つのタスクのAttentionパターンから、タスクトークンがTransformerの挙動をどう制御しているかが可視化されています。赤い破線で区切られた行0/列0がタスクトークンの位置です。各タスクでAttentionの分布パターンが異なっており、タスクトークンが「どのパッチに注目すべきか」をTransformerに指示していることが確認できます。これは学習前の初期状態(ランダム重み)でのAttentionなので、パターンの差はタスクトークンの初期値の違いに起因します。学習が進むにつれて、予測タスクでは長期トレンド情報を持つパッチに、異常検知タスクでは異常候補のパッチに、分類タスクでは波形の特徴的な部分にそれぞれ集中したAttentionパターンが形成されます。

マルチタスク学習の効果を確認

# --- タスクトークンの埋め込み空間の可視化 ---
# 3つのタスクトークンがどのような方向を向いているか
task_vecs = np.array([model.task_tokens[t] for t in tasks])
# コサイン類似度行列
norms = np.linalg.norm(task_vecs, axis=1, keepdims=True)
cos_sim = (task_vecs @ task_vecs.T) / (norms @ norms.T + 1e-8)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# (a) タスクトークン間のコサイン類似度
im = axes[0].imshow(cos_sim, cmap='RdYlBu_r', vmin=-1, vmax=1, aspect='auto')
axes[0].set_xticks(range(3))
axes[0].set_xticklabels(['Forecast', 'Anomaly', 'Classify'])
axes[0].set_yticks(range(3))
axes[0].set_yticklabels(['Forecast', 'Anomaly', 'Classify'])
axes[0].set_title('Cosine Similarity between Task Tokens')
for i in range(3):
    for j in range(3):
        axes[0].text(j, i, f'{cos_sim[i,j]:.2f}', ha='center', va='center',
                     fontsize=12, color='white' if abs(cos_sim[i,j]) > 0.5 else 'black')
plt.colorbar(im, ax=axes[0])

# (b) タスク別の出力分布
# 各タスクの出力をヒストグラムで表示
forecast_out, _ = model.forward(data['forecast'][:100], 'forecast')
anomaly_out, _ = model.forward(data['anomaly'][:100], 'anomaly')
classify_out, _ = model.forward(data['classify'][:100], 'classify')

axes[1].hist(forecast_out.flatten(), bins=30, alpha=0.5, label='Forecast values',
             color='#00d4ff', density=True)
axes[1].hist(anomaly_out.flatten(), bins=30, alpha=0.5, label='Anomaly scores',
             color='#ff6b6b', density=True)
axes[1].hist(classify_out.max(axis=1), bins=30, alpha=0.5, label='Max class prob',
             color='#ffd93d', density=True)
axes[1].set_xlabel('Output value')
axes[1].set_ylabel('Density')
axes[1].set_title('Output Distributions by Task')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('units_tasks.png', dpi=150, bbox_inches='tight')
plt.show()

タスクトークンの分析から2つの重要な知見が得られます。左のコサイン類似度行列は、3つのタスクトークンが互いに異なる方向を向いていることを示しています(初期状態はランダム初期化なので自然な結果ですが、学習後にはタスクの性質に応じてさらに分化します)。右の出力分布は、同じモデルが各タスクで異なる範囲の出力を生成していることを示しています。予測値は連続的に広がり、異常スコアは0〜1に集中し、分類確率はカテゴリカルな分布を持ちます。1つのTransformerがタスクトークンの切り替えだけで、これら本質的に異なる出力形式を生成できることがUniTSの核心です。

まとめ

本記事では、UniTS(NeurIPS 2024)が提案するタスクトークン化による統一時系列モデルについて解説しました。

  • タスクトークン化: 学習可能なタスクトークンを入力に追加するだけで、同一Transformerが予測・異常検知・分類を切り替えて実行する
  • プロンプト学習との関係: タスクトークンはNLPのソフトプロンプトの特殊ケースであり、最小限の追加パラメータでタスク適応が可能
  • マルチタスク学習の相乗効果: 異なるタスクが共通表現を共有することで、単独学習よりも各タスクの性能が向上する
  • 不確実性ベースの自動重み付け: タスク間の損失スケールの違いを自動的に補正し、安定したマルチタスク学習を実現

次のステップとして、以下の記事も参考にしてください。