SoftCLT — ソフト連続値の対照学習で時系列表現を学習する【ICLR 2024】

衛星テレメトリの異常検知では「この温度パターンは正常か? 過去のどのパターンに近いか?」という問いに答える必要があります。ここで重要なのは、時系列パターンの類似度が0か1かの二値ではないということです。軌道上の温度変動パターンは、「ほぼ同じ」「似ているが少しずれている」「形は似ているが振幅が異なる」「まったく異なる」というように、連続的な類似度のグラデーションを持ちます。

従来の対照学習(SimCLR, MoCo等)は、「正例(同じクラス or 同じインスタンスの拡張)」と「負例(異なるクラス or 異なるインスタンス)」をハードに区別し、正例を引き寄せ負例を引き離すように表現空間を学習します。しかし、このハードな二値区分は時系列の連続的な類似構造を無視しています。「少し似ている」ペアを負例として強く引き離すと、類似度の順序関係が壊れてしまいます。

SoftCLT(Lee et al., ICLR 2024 Spotlight)は、この根本的な問題を解決しました。正例/負例のハードなラベルの代わりに、連続値のソフト類似度で対照学習を行います。さらに、インスタンスレベル(系列全体の類似度)とテンポラルレベル(時刻間の類似度)の2つの粒度でソフト対照学習を同時に行う点が革新的です。

SoftCLTを理解することは、以下のような場面で直接役立ちます。

  • テレメトリ検索の精度向上: 「似ている度合い」の連続的なランキングが直接得られるため、検索結果の順位付けの質が向上します。ハード対照学習では「似ている or 似ていない」しか区別できませんが、SoftCLTは「どの程度似ているか」を学習します
  • 既存手法への即座の適用: SoftCLTはplug-and-play型の手法であり、既存の対照学習パイプライン(TS2Vec, TNC等)に対して、損失関数を差し替えるだけで性能が向上します

本記事の内容

  • 対照学習の基礎(SimCLR、InfoNCE損失の復習)
  • ハード対照学習の問題点 — 時系列の連続的類似構造との不整合
  • SoftCLTの定式化(インスタンスレベル + テンポラルレベル)
  • ソフトInfoNCE損失の導出
  • DTWベースのソフトラベル生成
  • Pythonによるソフト対照学習の実装と実験

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

画像なし
時系列基盤モデルの全体像
時系列の表現学習の位置づけを解説しています
画像なし
TRACE — 時系列データのマルチモーダル埋め込みと検索
対照学習ベースの時系列埋め込みを解説しています

対照学習の基礎

SimCLRとInfoNCE損失

対照学習(Contrastive Learning)の基本アイデアは、「似たもの同士を近づけ、異なるもの同士を引き離す」表現空間を学習することです。SimCLR(Chen et al., 2020)の枠組みでは、以下のステップで学習が進みます。

  1. 入力 $\bm{x}$ にランダムな拡張(augmentation)を2回適用し、正例ペア $(\tilde{\bm{x}}_i, \tilde{\bm{x}}_j)$ を生成
  2. エンコーダ $f$ で表現ベクトル $\bm{h}_i = f(\tilde{\bm{x}}_i)$, $\bm{h}_j = f(\tilde{\bm{x}}_j)$ を得る
  3. 射影ヘッド $g$ で $\bm{z}_i = g(\bm{h}_i)$, $\bm{z}_j = g(\bm{h}_j)$ に変換
  4. InfoNCE損失を最小化

InfoNCE損失は以下で定義されます。ミニバッチ内の $2N$ 個のサンプルに対して、

$$ \mathcal{L}_i = -\log \frac{\exp(\text{sim}(\bm{z}_i, \bm{z}_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}[k \neq i] \exp(\text{sim}(\bm{z}_i, \bm{z}_k) / \tau)} $$

ここで $\text{sim}(\bm{u}, \bm{v}) = \bm{u}^\top \bm{v} / (\|\bm{u}\| \|\bm{v}\|)$ はコサイン類似度、$\tau$ は温度パラメータです。

直感的には、分子は「正例ペアの類似度を上げたい」、分母は「全ペアの中で正例が際立つようにしたい」という2つの力を表現しています。温度 $\tau$ が低いほど、分布が鋭くなり(最も似たペアに集中し)、高いほど平坦になります。

時系列への適用

時系列に対照学習を適用する場合、拡張手法が画像とは異なります。TS2Vec(Yue et al., 2022)では、タイムスタンプのマスキングや部分系列のクロッピングが拡張として使われます。同一系列の2つの拡張が正例、異なる系列の拡張が負例です。

TNC(Tonekaboni et al., 2021)では、時間的に近い窓を正例、遠い窓を負例とする「時間的コントラスト」を用います。

これらの手法では、正例/負例の区分がハードです。同一系列の拡張 → 正例(ラベル1)、異なる系列 → 負例(ラベル0)。この二値ラベルが時系列においてどのような問題を引き起こすか、次のセクションで詳しく見ていきます。

ハード対照学習の問題点

「偽の負例」問題

ミニバッチ内で「異なるインスタンス」を負例として扱うと、実際には似たパターンなのに負例として引き離されてしまうケースが多発します。

例えば、衛星テレメトリで以下のような状況を考えます。

  • インスタンスA: 軌道の昼側での太陽電池出力パターン(周期的上昇)
  • インスタンスB: 別の軌道の昼側での太陽電池出力パターン(ほぼ同じ周期的上昇)
  • インスタンスC: 食(地球の影)区間での出力パターン(急激な低下)

ハード対照学習では、AとBは異なるインスタンスなので負例として扱われ、表現空間で引き離されます。しかし、AとBは物理的にほぼ同じ現象であり、本来は近い表現を持つべきです。これが「偽の負例」問題です。

類似度の順序関係の破壊

さらに深刻な問題は、類似度の順序関係が保存されないことです。理想的には、

$$ \text{sim}(A, B) > \text{sim}(A, C) $$

が表現空間でも保持されるべきです。しかし、ハード対照学習ではAに対してBもCも同じ「負例」として扱われるため、学習後の表現空間で $\|\bm{z}_A – \bm{z}_B\|$ と $\|\bm{z}_A – \bm{z}_C\|$ の大小関係が保証されません。

テンポラルレベルの問題

時系列内の時刻間にも同じ問題があります。TS2Vecのような手法では、同じ時刻の拡張が正例、異なる時刻が負例です。しかし、時刻 $t$ と時刻 $t+1$ は非常に似ているのに負例として扱われます。一方、時刻 $t$ と時刻 $t + 1000$ はまったく異なるパターンかもしれません。「1ステップ離れた負例」と「1000ステップ離れた負例」を同等に扱うのは明らかに不自然です。

これらの問題を解決するために、SoftCLTは「どの程度似ているか」を連続値で表現するソフトラベルを導入します。

SoftCLTの定式化

2つの粒度のソフト対照学習

SoftCLTの核心は、インスタンスレベルテンポラルレベルの2つの粒度でソフト対照学習を同時に行うことです。

インスタンスレベル: 系列 $\bm{x}_i$ と系列 $\bm{x}_j$ の全体的な類似度 $w_{ij}^{\text{inst}} \in [0, 1]$ を連続値で定義し、InfoNCE損失の重みとして使用します。

テンポラルレベル: 系列 $\bm{x}_i$ 内の時刻 $t$ と時刻 $s$ の類似度 $w_{ts}^{\text{temp}} \in [0, 1]$ を連続値で定義し、同様に重みとして使用します。

全体の損失は、

$$ \mathcal{L} = \lambda_{\text{inst}} \mathcal{L}_{\text{inst}} + \lambda_{\text{temp}} \mathcal{L}_{\text{temp}} $$

です。ここで $\lambda_{\text{inst}}$ と $\lambda_{\text{temp}}$ は各項の重みです。

ソフトInfoNCE損失の導出

通常のInfoNCE損失では、正例に対してクロスエントロピーを計算します。ソフト版では、ソフトラベル分布 $\bm{q}$ とモデルが出力する予測分布 $\bm{p}$ のクロスエントロピーを計算します。

まず、サンプル $i$ に対するソフトラベル分布を定義します。ミニバッチ内の $N$ 個のサンプルについて、

$$ q_{ij} = \frac{w_{ij}}{\sum_{k \neq i} w_{ik}} $$

ここで $w_{ij}$ はサンプル $i$ と $j$ の間のソフト類似度です。$q_{ij}$ は「サンプル $j$ がサンプル $i$ の正例である確率」を表す分布と解釈できます。

次に、モデルの予測分布を定義します。

$$ p_{ij} = \frac{\exp(\text{sim}(\bm{z}_i, \bm{z}_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(\bm{z}_i, \bm{z}_k) / \tau)} $$

ソフトInfoNCE損失は、この2つの分布のクロスエントロピーです。

$$ \mathcal{L}_i^{\text{soft}} = -\sum_{j \neq i} q_{ij} \log p_{ij} $$

これを全サンプルについて平均したものが損失となります。

$$ \mathcal{L}^{\text{soft}} = \frac{1}{N} \sum_{i=1}^N \mathcal{L}_i^{\text{soft}} $$

通常のInfoNCE損失との違いを確認しましょう。通常のInfoNCE損失では $q_{ij}$ はone-hot(正例のみ1、他は0)ですが、SoftCLTでは連続値です。つまり、ソフトInfoNCE損失はラベルスムージングされたクロスエントロピーと同じ構造を持ちます。

この式からわかるように、ソフトラベル $q_{ij}$ が大きいペアは「より強く引き寄せる」力が働き、小さいペアは「弱い力」で処理されます。完全な負例($w_{ij} = 0$)は引き寄せる力がゼロとなります。

インスタンスレベルのソフトラベル

インスタンスレベルでは、DTW(Dynamic Time Warping)距離をベースにソフト類似度を計算します。

DTW距離 $d_{\text{DTW}}(\bm{x}_i, \bm{x}_j)$ は、2つの時系列の「形状の類似度」を、時間軸の伸縮を許容しながら測る距離です。これを類似度に変換するために、RBF(Radial Basis Function)カーネルを適用します。

$$ w_{ij}^{\text{inst}} = \exp\left(-\frac{d_{\text{DTW}}(\bm{x}_i, \bm{x}_j)^2}{2\sigma^2}\right) $$

$\sigma$ はバンド幅パラメータで、ミニバッチ内のDTW距離の中央値に設定するのが一般的です(メディアンヒューリスティック)。

$$ \sigma = \text{median}\{d_{\text{DTW}}(\bm{x}_i, \bm{x}_j) : i \neq j\} $$

RBFカーネルの性質により、DTW距離が小さい(形状が似ている)ペアほど $w_{ij}^{\text{inst}} \approx 1$(強い正例)、距離が大きいペアほど $w_{ij}^{\text{inst}} \approx 0$(強い負例)となります。

テンポラルレベルのソフトラベル

テンポラルレベルでは、時刻間の距離に基づいてソフト類似度を定義します。同じ系列内の時刻 $t$ と $s$ に対して、

$$ w_{ts}^{\text{temp}} = \exp\left(-\frac{|t – s|^2}{2\sigma_t^2}\right) $$

$\sigma_t$ は時間方向のバンド幅パラメータです。この定義により、時間的に近い時刻ほど高い類似度(「ソフトな正例」)として扱われ、離れた時刻ほど低い類似度(「ソフトな負例」)となります。

時間距離1の時刻ペアは $w \approx 0.95$、時間距離10の時刻ペアは $w \approx 0.61$、時間距離100の時刻ペアは $w \approx 0.007$ のように、類似度が滑らかに減衰します。

インスタンスレベルとテンポラルレベルの両方でソフトラベルが定義できたところで、実際にPythonで実装し、ハード対照学習との違いを実験で確認しましょう。

Pythonによるソフト対照学習の実装

合成データの生成と実験

以下のコードでは、5種類のパターンを持つ合成時系列を生成し、ハード対照学習とソフト対照学習の表現空間を比較します。

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist

np.random.seed(42)

# --- 合成時系列データの生成 ---
def generate_timeseries(n_per_class=50, length=64):
    """5クラスの時系列データを生成(クラス間に連続的な類似度あり)"""
    t = np.linspace(0, 4 * np.pi, length)
    data, labels = [], []
    for _ in range(n_per_class):
        # クラス0: 低周波正弦波
        data.append(np.sin(t) + np.random.randn(length) * 0.1)
        labels.append(0)
        # クラス1: やや高い周波数(クラス0に似ている)
        data.append(np.sin(1.3 * t) + np.random.randn(length) * 0.1)
        labels.append(1)
        # クラス2: 高周波正弦波(クラス0,1とはやや異なる)
        data.append(np.sin(2.5 * t) + np.random.randn(length) * 0.1)
        labels.append(2)
        # クラス3: 上昇トレンド + ノイズ
        data.append(np.linspace(-1, 1, length) + np.random.randn(length) * 0.15)
        labels.append(3)
        # クラス4: パルス信号
        pulse = np.random.randn(length) * 0.1
        pulse[length//3:length//3+8] += 2.5
        data.append(pulse)
        labels.append(4)
    return np.array(data), np.array(labels)

X, y = generate_timeseries(n_per_class=40, length=64)
N = len(X)

# --- 簡易DTW距離の計算 ---
def simple_dtw(x, y):
    """簡易DTW距離(完全DTWのO(n^2)実装)"""
    n, m = len(x), len(y)
    D = np.full((n + 1, m + 1), np.inf)
    D[0, 0] = 0
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = (x[i-1] - y[j-1]) ** 2
            D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
    return np.sqrt(D[n, m])

# 計算量削減のため、ユークリッド距離で代用(大規模データの場合)
print("ペアワイズ距離を計算中...")
dist_matrix = cdist(X.reshape(N, -1), X.reshape(N, -1), metric='euclidean')

# --- ソフトラベルの計算 ---
sigma = np.median(dist_matrix[np.triu_indices(N, k=1)])
soft_labels = np.exp(-dist_matrix ** 2 / (2 * sigma ** 2))
np.fill_diagonal(soft_labels, 0)

# 行ごとに正規化して確率分布にする
soft_probs = soft_labels / soft_labels.sum(axis=1, keepdims=True)

このコードでは、5クラスの合成時系列を生成しています。クラス0(低周波正弦波)とクラス1(やや高い周波数の正弦波)は意図的に似せてあり、ハード対照学習ではこれらを「負例」として引き離してしまう問題が発生します。ユークリッド距離にRBFカーネルを適用してソフトラベルを生成しています。

ソフトラベルの分布を可視化

# --- ソフトラベルの可視化 ---
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# (a) ソフト類似度行列
im = axes[0].imshow(soft_labels[:50, :50], cmap='viridis', aspect='auto')
axes[0].set_title('Soft Similarity Matrix (first 50 samples)')
axes[0].set_xlabel('Sample index')
axes[0].set_ylabel('Sample index')
plt.colorbar(im, ax=axes[0], fraction=0.046)

# (b) ハードラベル行列(同一クラス=1, 異クラス=0)
hard_labels = (y[:50, None] == y[None, :50]).astype(float)
np.fill_diagonal(hard_labels, 0)
axes[1].imshow(hard_labels, cmap='viridis', aspect='auto')
axes[1].set_title('Hard Labels (same class = 1)')
axes[1].set_xlabel('Sample index')
axes[1].set_ylabel('Sample index')

# (c) サンプル0に対するソフトラベルのヒストグラム
sample_soft = soft_labels[0]
colors = ['#00d4ff', '#ff6b6b', '#ffd93d', '#6bff6b', '#ff6bff']
for c in range(5):
    mask = (y == c)
    axes[2].hist(sample_soft[mask], bins=20, alpha=0.6,
                 label=f'Class {c}', color=colors[c])
axes[2].set_title('Soft Similarity from Sample 0 (Class 0)')
axes[2].set_xlabel('Soft similarity')
axes[2].set_ylabel('Count')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('softclt_labels.png', dpi=150, bbox_inches='tight')
plt.show()

3つの可視化から、ソフトラベルとハードラベルの根本的な違いが明確に見えます。左のソフト類似度行列では、クラス0とクラス1のブロック間にも高い類似度(明るい色)が残っており、「似ているが同一クラスではない」関係が連続的に表現されています。中央のハードラベル行列では、同じクラスのペアのみが1で、他は完全に0です。クラス0とクラス1の間の類似性は完全に無視されています。右のヒストグラムは、サンプル0(クラス0)から見た各クラスとのソフト類似度分布です。クラス1の類似度がクラス0に次いで高く、物理的な類似性が定量化されていることが確認できます。

ハード vs ソフト対照学習の表現空間の比較

# --- 簡易的な対照学習エンコーダ ---
class SimpleEncoder:
    def __init__(self, input_dim, latent_dim=2):
        self.W = np.random.randn(input_dim, latent_dim) * 0.01
        self.b = np.zeros(latent_dim)

    def encode(self, X):
        z = X @ self.W + self.b
        z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-8)
        return z

    def train_hard(self, X, labels, n_epochs=200, lr=0.1, tau=0.5):
        """ハード対照学習"""
        losses = []
        for epoch in range(n_epochs):
            z = self.encode(X)
            sim = z @ z.T / tau
            N = len(X)
            loss = 0
            grad_z = np.zeros_like(z)
            for i in range(N):
                pos_mask = (labels == labels[i])
                pos_mask[i] = False
                neg_mask = ~pos_mask
                neg_mask[i] = False
                if not np.any(pos_mask):
                    continue
                exp_sim = np.exp(sim[i] - np.max(sim[i]))
                exp_sim[i] = 0
                total = np.sum(exp_sim)
                for j in np.where(pos_mask)[0]:
                    loss -= np.log(exp_sim[j] / total + 1e-8)
                    p = exp_sim / total
                    grad_i = (p.copy())
                    grad_i[j] -= 1
                    grad_z[i] += grad_i @ z / tau / np.sum(pos_mask)
            loss /= N
            losses.append(loss)
            dW = X.T @ grad_z / N
            self.W -= lr * dW
        return losses

    def train_soft(self, X, soft_probs, n_epochs=200, lr=0.1, tau=0.5):
        """ソフト対照学習"""
        losses = []
        for epoch in range(n_epochs):
            z = self.encode(X)
            sim = z @ z.T / tau
            N = len(X)
            loss = 0
            grad_z = np.zeros_like(z)
            for i in range(N):
                exp_sim = np.exp(sim[i] - np.max(sim[i]))
                exp_sim[i] = 0
                total = np.sum(exp_sim) + 1e-8
                p = exp_sim / total
                q = soft_probs[i]
                for j in range(N):
                    if i == j or q[j] < 1e-8:
                        continue
                    loss -= q[j] * np.log(p[j] + 1e-8)
                grad_i = (p - q) @ z / tau
                grad_z[i] += grad_i
            loss /= N
            losses.append(loss)
            dW = X.T @ grad_z / N
            self.W -= lr * dW
        return losses

# --- 学習と比較 ---
enc_hard = SimpleEncoder(64, 2)
enc_soft = SimpleEncoder(64, 2)
enc_soft.W = enc_hard.W.copy()
enc_soft.b = enc_hard.b.copy()

print("ハード対照学習を実行中...")
losses_hard = enc_hard.train_hard(X, y, n_epochs=150, lr=0.05, tau=0.5)
print("ソフト対照学習を実行中...")
losses_soft = enc_soft.train_soft(X, soft_probs, n_epochs=150, lr=0.05, tau=0.5)

z_hard = enc_hard.encode(X)
z_soft = enc_soft.encode(X)

# --- 可視化 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = ['#00d4ff', '#ff6b6b', '#ffd93d', '#6bff6b', '#ff6bff']
class_names = ['Low-freq sine', 'Mid-freq sine', 'High-freq sine', 'Trend', 'Pulse']

for c in range(5):
    mask = (y == c)
    axes[0].scatter(z_hard[mask, 0], z_hard[mask, 1],
                    c=colors[c], label=class_names[c], alpha=0.6, s=20)
axes[0].set_title('Hard Contrastive Learning')
axes[0].legend(fontsize=7)
axes[0].grid(True, alpha=0.3)

for c in range(5):
    mask = (y == c)
    axes[1].scatter(z_soft[mask, 0], z_soft[mask, 1],
                    c=colors[c], label=class_names[c], alpha=0.6, s=20)
axes[1].set_title('Soft Contrastive Learning (SoftCLT)')
axes[1].legend(fontsize=7)
axes[1].grid(True, alpha=0.3)

axes[2].plot(losses_hard, label='Hard CL', color='#ff6b6b')
axes[2].plot(losses_soft, label='Soft CL', color='#00d4ff')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Loss')
axes[2].set_title('Training Loss')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('softclt_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

3つのプロットから、ハード対照学習とソフト対照学習の表現空間の違いが明確に見えます。左のハード対照学習では、クラス0(低周波正弦波)とクラス1(中周波正弦波)が離れた位置に配置されており、類似した波形パターンが表現空間で引き離されてしまっています。中央のソフト対照学習では、クラス0とクラス1が近い位置に配置されつつも適度に分離されており、類似度の順序関係が保存されています。つまり、「似ているものは近く、異なるものは遠く」という連続的なグラデーションが実現されています。右の学習曲線を見ると、ソフト対照学習の方がスムーズに収束しており、学習の安定性も向上していることがわかります。

テンポラルレベルのソフトラベル可視化

# --- テンポラルレベルのソフトラベル ---
T = 64  # 系列長
sigma_t = 10.0  # 時間方向のバンド幅

# 時間距離行列
time_dist = np.abs(np.arange(T)[:, None] - np.arange(T)[None, :])
temp_soft_labels = np.exp(-time_dist ** 2 / (2 * sigma_t ** 2))
np.fill_diagonal(temp_soft_labels, 0)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# テンポラルソフトラベル行列
im = axes[0].imshow(temp_soft_labels, cmap='viridis', aspect='auto')
axes[0].set_title('Temporal Soft Labels')
axes[0].set_xlabel('Time step')
axes[0].set_ylabel('Time step')
plt.colorbar(im, ax=axes[0], fraction=0.046)

# 特定時刻からのソフトラベル
for t0 in [0, 16, 32, 48]:
    axes[1].plot(temp_soft_labels[t0], label=f't={t0}', alpha=0.8)
axes[1].set_xlabel('Time step')
axes[1].set_ylabel('Soft similarity')
axes[1].set_title(f'Temporal Soft Similarity (σ_t={sigma_t})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('softclt_temporal.png', dpi=150, bbox_inches='tight')
plt.show()

テンポラルソフトラベルの可視化から、時間的に近い時刻ほど高い類似度が割り当てられ、離れた時刻ほど低い類似度になることが確認できます。左の行列は対角線付近が明るく(類似度が高く)、対角線から離れるにつれて暗くなる滑らかなグラデーションを示しています。右のグラフは各基準時刻からの類似度プロファイルで、ガウス関数の形状をしています。$\sigma_t = 10$ の場合、約20ステップ離れた時刻でソフト類似度がほぼ0になります。ハード対照学習では時間距離1の時刻ペアも時間距離60の時刻ペアも等しく「負例」として扱われますが、SoftCLTでは時間距離に応じた連続的な重み付けが行われます。

SoftCLTのメリットと既存手法への適用

plug-and-play性

SoftCLTの大きな利点は、既存の対照学習フレームワークに損失関数の差し替えだけで適用できることです。エンコーダのアーキテクチャ、拡張手法、学習スケジュールはそのままで、InfoNCE損失をソフトInfoNCE損失に置き換えるだけです。

論文では、TS2Vec、TNC、TS-TCC、CoSTの4つの手法にSoftCLTを適用し、全てのケースで性能が向上したことが報告されています。これは、ソフトラベルがモデルに依存しない汎用的な改善手法であることを示しています。

計算コスト

ソフトラベルの計算にはペアワイズ距離が必要で、$O(N^2)$ の計算量がかかります。しかし、これは事前計算が可能であり、学習ループ中のオーバーヘッドはInfoNCE損失の計算に比べて無視できる程度です。DTW距離の計算が律速になる場合は、ユークリッド距離やFASTDTW(近似DTW)で代用することも可能です。

温度パラメータとの関係

ソフトInfoNCE損失の温度 $\tau$ は、通常のInfoNCE損失とは異なる最適値を持ちます。ハード対照学習では $\tau$ が小さいほど「鋭い」区別が行われますが、ソフト対照学習ではソフトラベル自体がすでに類似度のグラデーションを含んでいるため、$\tau$ をやや大きめに設定しても良好な結果が得られます。論文では $\tau = 0.5$ が推奨されています。

まとめ

本記事では、SoftCLT(ICLR 2024 Spotlight)が提案するソフト対照学習について解説しました。

  • ハード対照学習の問題: 時系列の連続的な類似構造を無視し、「偽の負例」問題と類似度順序の破壊が発生する
  • ソフトInfoNCE損失: 正例/負例のハードラベルを連続値のソフトラベルで置き換え、クロスエントロピーベースの損失で学習する
  • 2つの粒度: インスタンスレベル(DTWベース)とテンポラルレベル(時間距離ベース)の両方でソフト対照学習を行う
  • plug-and-play性: 既存の対照学習フレームワークに損失関数の差し替えだけで適用可能

テレメトリ検索においては、SoftCLTにより「似ている度合い」の連続スコアリングが可能になり、検索結果のランキング品質が向上します。

次のステップとして、以下の記事も参考にしてください。

画像なし
TimeSiam — Siameseネットワークで時系列の時間的相関を学習する
Siameseネットワークによる時系列の類似性学習を解説しています
画像なし
TOTEM — VQ-VAEで時系列を離散トークンに変換する
連続的な類似度学習とは異なる離散トークン化のアプローチを解説しています