TiRex — xLSTMベースの軽量時系列基盤モデル【NeurIPS 2025】

衛星のオンボードコンピュータには厳しい制約があります。宇宙放射線に耐えるために設計された耐放射線プロセッサは、地上のGPUと比べて数桁遅い処理能力しかありません。メモリも限られ、消費電力も厳密に管理されています。しかし、テレメトリの異常をリアルタイムで検出し、迅速に対処するためには、衛星上で直接時系列パターンを認識する能力が必要です。地上局への通信には数分〜数十分のラグがあり、緊急事態に間に合わない可能性があるからです。

Sundial(2.4B)やTime-MoE(2.4B)のような大規模モデルは高い精度を誇りますが、衛星上で動かすことは不可能です。必要なのは、小さいが賢いモデル — 35Mパラメータ程度で、より大きなモデルに匹敵する性能を持つモデルです。

TiRex(NX-AI, NeurIPS 2025)は、xLSTM(Extended LSTM)をベースにした35Mパラメータの軽量モデルで、この課題に挑みました。驚くべきことに、この35Mのモデルが数百M〜数Bのパラメータを持つTransformerベースのモデルを複数のベンチマークで凌駕します。さらに、In-context Learning(コンテキスト内の時系列パターンから即座に適応)により、ファインチューニングなしで新しいドメインに対応できます。

TiRexを理解することは、以下のような場面で直接役立ちます。

  • 衛星オンボード処理: 35Mパラメータなら耐放射線プロセッサでも推論が可能で、テレメトリ異常のリアルタイム検出が実現します
  • エッジコンピューティング: IoTセンサ、工場の制御装置など、計算資源が限られた環境での時系列分析に適用できます

本記事の内容

  • なぜ軽量モデルが重要か(計算資源制約の実情)
  • xLSTM(Extended LSTM)の基礎(sLSTM, mLSTM)
  • TransformerとxLSTMの計算量比較
  • TiRexのアーキテクチャ
  • In-context Learning能力
  • Pythonによる簡易xLSTMセルの実装と時系列予測実験

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

なぜ軽量モデルが重要か

宇宙の計算資源制約

地上のML研究者が当たり前に使うNVIDIA A100 GPU(80GB HBM2e、312 TFLOPS FP16)は、宇宙環境では使えません。宇宙放射線によるシングルイベントアップセット(SEU: ビット反転)が発生するため、耐放射線設計の特殊なプロセッサが必要です。

プロセッサ FP32性能 メモリ 消費電力 用途
NVIDIA A100 19.5 TFLOPS 80 GB 300 W 地上GPU
Xilinx Kintex Ultrascale RT 〜1 GFLOPS 数MB 10 W 耐放射線FPGA
GR740 (LEON4) 〜0.2 GFLOPS 〜4 MB 3 W 衛星用CPU

地上GPUと比べて5桁以上の性能差があります。24億パラメータのモデルは約10GBのメモリが必要で、衛星コンピュータのメモリの数千倍です。35Mパラメータなら約140MBで、将来的な宇宙用ハードウェアで動作する現実的な範囲に入ります。

パラメータ効率の重要性

モデルの「質」を測る指標として、パラメータあたりの性能が重要です。

$$ \text{パラメータ効率} = \frac{\text{性能(例: MSE の逆数)}}{\text{パラメータ数}} $$

TiRexの35Mパラメータが、Chronos-Large(710M)やTimesFM(200M)を上回るということは、パラメータ効率が10倍以上高いことを意味します。この効率の源泉がxLSTMアーキテクチャです。

計算資源制約の現実を理解したところで、TiRexの核となるxLSTMの仕組みを見ていきましょう。

xLSTM(Extended LSTM)の基礎

従来のLSTMの限界

LSTM(Long Short-Term Memory, 1997)は時系列処理の定番でしたが、Transformerの登場以降、主に2つの理由で劣勢に立たされていました。

並列化の困難: LSTMは逐次処理(時刻 $t$ の出力が時刻 $t-1$ の状態に依存)であり、長い系列の学習時にGPUの並列性を活かせません。

メモリ容量の限界: セル状態 $\bm{c}_t \in \mathbb{R}^d$ はベクトル($d$ 個のスカラー)であり、長い系列の情報を $d$ 個の値に圧縮する必要があります。Transformerの注意機構は系列長に応じて $O(L \times d)$ のメモリを使えるため、長期依存の保持で有利です。

xLSTM(Beck et al., 2024)は、これらの限界を2つの拡張 — sLSTMmLSTM — で克服します。

sLSTM — スカラーメモリ + 指数ゲーティング

sLSTM(scalar LSTM)は、従来のLSTMのゲート機構を指数ゲーティングに拡張します。

従来のLSTMの忘却ゲートと入力ゲートは、シグモイド関数で0〜1の範囲に制限されていました。

$$ f_t = \sigma(\bm{W}_f \bm{x}_t + \bm{U}_f \bm{h}_{t-1} + \bm{b}_f) \in (0, 1)^d $$

$$ i_t = \sigma(\bm{W}_i \bm{x}_t + \bm{U}_i \bm{h}_{t-1} + \bm{b}_i) \in (0, 1)^d $$

sLSTMでは、これを指数関数に置き換えます。

$$ f_t = \exp(\bm{W}_f \bm{x}_t + \bm{U}_f \bm{h}_{t-1} + \bm{b}_f) $$

$$ i_t = \exp(\bm{W}_i \bm{x}_t + \bm{U}_i \bm{h}_{t-1} + \bm{b}_i) $$

指数ゲーティングにより、ゲート値が1を超える値を取れるようになります。忘却ゲート $f_t > 1$ の場合、セル状態が増幅されます。これは「重要な情報をより強く保持する」効果を持ちます。

ただし、指数関数は値が発散しやすいため、正規化状態 $n_t$ を導入して安定化します。

$$ c_t = f_t \cdot c_{t-1} + i_t \cdot \tilde{c}_t $$

$$ n_t = f_t \cdot n_{t-1} + i_t $$

$$ h_t = \frac{c_t}{n_t} $$

$n_t$ はセル状態の「累積スケール」を追跡し、出力 $h_t$ は正規化された値となります。これにより、指数ゲーティングの表現力を維持しつつ数値安定性を確保します。

mLSTM — 行列メモリ + 共分散更新

mLSTM(matrix LSTM)は、セル状態をスカラーベクトルから行列に拡張します。

従来のLSTM: $\bm{c}_t \in \mathbb{R}^d$(ベクトル、$d$ 個の値)

mLSTM: $\bm{C}_t \in \mathbb{R}^{d \times d}$(行列、$d^2$ 個の値)

セル状態が行列になることで、メモリ容量が $d$ 倍に増加します。$d = 256$ の場合、ベクトルLSTMが256個の値を記憶できるのに対し、mLSTMは65,536個の値を記憶できます。

mLSTMのセル状態更新は以下のように定義されます。

$$ \bm{C}_t = f_t \cdot \bm{C}_{t-1} + i_t \cdot \bm{v}_t \bm{k}_t^\top $$

ここで $\bm{v}_t = \bm{W}_v \bm{x}_t$(値ベクトル)、$\bm{k}_t = \bm{W}_k \bm{x}_t$(キーベクトル)です。$\bm{v}_t \bm{k}_t^\top$ は外積で、ランク1の行列です。

出力は、クエリベクトル $\bm{q}_t = \bm{W}_q \bm{x}_t$ で行列メモリを検索して得られます。

$$ \bm{h}_t = \frac{\bm{C}_t \bm{q}_t}{\max(\bm{n}_t^\top \bm{q}_t, 1)} $$

この構造はTransformerのAttentionと深い関係があります。Transformerの線形Attention(softmaxを除いたAttention)は以下のように書けます。

$$ \text{Attention}_t = \frac{\sum_{s=1}^t \bm{v}_s \bm{k}_s^\top \bm{q}_t}{\sum_{s=1}^t \bm{k}_s^\top \bm{q}_t} $$

mLSTMのセル状態更新で $f_t = 1$ とすると、

$$ \bm{C}_t = \sum_{s=1}^t i_s \cdot \bm{v}_s \bm{k}_s^\top $$

これはまさに線形Attentionの「キー・バリューのメモリ蓄積」と同じ構造です。mLSTMは忘却ゲートで情報の重要度を時間的に制御しながら、Attention的な検索を行うアーキテクチャと解釈できます。

計算量の比較

xLSTMとTransformerの計算量を比較します。系列長 $L$、隠れ層次元 $d$ に対して、

アーキテクチャ 計算量(1層) メモリ 長系列での効率
Transformer(Full Attention) $O(L^2 d)$ $O(L^2 + Ld)$ 系列長に二乗で増加
sLSTM $O(Ld)$ $O(d)$ 系列長に線形
mLSTM $O(Ld^2)$ $O(d^2)$ 系列長に線形

xLSTMは系列長に対して線形の計算量であり、$L > d$ の場合(時系列では一般的)にTransformerよりも効率的です。$L = 4096, d = 256$ の場合、

  • Transformer: $O(4096^2 \times 256) \approx 4.3 \times 10^9$
  • mLSTM: $O(4096 \times 256^2) \approx 2.7 \times 10^8$

約16倍の効率差があります。

xLSTMの基礎を理解したところで、TiRexがこれを時系列基盤モデルにどう適用するか見ていきましょう。

TiRexのアーキテクチャ

全体構造

TiRexの設計は非常にシンプルです。

1. パッチ化入力: 時系列をパッチサイズ $P$ で分割し、線形層で埋め込みに変換します。

$$ \bm{h}_i^{(0)} = \bm{p}_i \bm{W}_{\text{embed}} + \bm{b}_{\text{embed}} $$

2. xLSTMブロック: 複数のmLSTMブロックを積み重ねます。各ブロックは以下の構成です。

$$ \bm{h}^{(l’)} = \bm{h}^{(l)} + \text{mLSTM}(\text{LN}(\bm{h}^{(l)})) $$

$$ \bm{h}^{(l+1)} = \bm{h}^{(l’)} + \text{FFN}(\text{LN}(\bm{h}^{(l’)})) $$

Transformerブロックと同じ残差接続 + LayerNorm構造ですが、Self-Attentionの代わりにmLSTMが使われます。

3. 予測ヘッド: 最後のパッチ位置の隠れ状態から次のパッチの値を予測します。

$$ \hat{\bm{p}}_{N+1} = \bm{h}_N^{(L)} \bm{W}_{\text{pred}} + \bm{b}_{\text{pred}} $$

35Mパラメータの内訳

コンポーネント パラメータ数
パッチ埋め込み 〜0.5M
mLSTMブロック(12層) 〜32M
FFN 〜2M
予測ヘッド 〜0.5M
合計 〜35M

mLSTMブロックがパラメータの大部分を占めます。各ブロック内のmLSTM層は $d \times d$ のキー・バリュー・クエリ射影行列と、ゲーティング用の重み行列を持ちます。

In-context Learning

TiRexの注目すべき能力はIn-context Learning(ICL)です。これは、入力系列のコンテキスト内にあるパターンから即座に適応する能力です。

Transformerベースのモデル(GPT等)がICLを示すことは知られていますが、TiRexはxLSTMベースでもICLが可能であることを実証しました。

具体的には、以下のような動作をします。

  1. 入力として長い時系列コンテキスト(例: 1024時刻分)を与える
  2. モデルはコンテキスト内のパターン(周期、トレンド、季節性等)を行列メモリ $\bm{C}_t$ に蓄積する
  3. コンテキストの末尾で、蓄積されたパターン知識に基づいて予測を行う

ファインチューニングが不要であるため、新しいドメイン(例: 未知の衛星のテレメトリ)に対しても、数秒分のコンテキストを入力するだけで即座に適応します。これは衛星運用において極めて実用的です — 新しい衛星が軌道に投入された直後から、数軌道分のデータをコンテキストとして与えるだけで異常検知が可能になります。

mLSTMのメモリとICLの関係

ICLが可能な理由は、mLSTMの行列メモリの特性にあります。

セル状態 $\bm{C}_t = \sum_{s=1}^t f_s^{(t)} \cdot i_s \cdot \bm{v}_s \bm{k}_s^\top$ は、過去のキー・バリューペアの重み付き和です。ここで $f_s^{(t)} = \prod_{\tau=s+1}^t f_\tau$ は時刻 $s$ の情報が時刻 $t$ まで残る減衰率です。

コンテキスト内のパターンが繰り返し現れると、対応するキー・バリューペアが行列メモリに蓄積され、忘却ゲートの減衰にもかかわらず強いシグナルとして残ります。予測時にクエリ $\bm{q}_t$ で行列メモリを検索すると、蓄積されたパターンに対応するバリューが出力されます。

これはTransformerのAttentionが過去のキー・バリューを明示的にキャッシュする(KVキャッシュ)のと対比されます。mLSTMは行列メモリに情報を圧縮的に蓄積するため、メモリ使用量が系列長に依存しません($O(d^2)$ 固定)。

TiRexのアーキテクチャとICL能力を理解したところで、Pythonで簡易的なmLSTMセルを実装し、時系列予測の実験を行いましょう。

Pythonによる実装

簡易mLSTMセルの実装

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

class mLSTMCell:
    """簡易mLSTMセル(行列メモリ + 指数ゲーティング)"""

    def __init__(self, input_dim, hidden_dim):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        scale = 0.1

        # キー・バリュー・クエリ射影
        self.W_k = np.random.randn(input_dim, hidden_dim) * scale
        self.W_v = np.random.randn(input_dim, hidden_dim) * scale
        self.W_q = np.random.randn(input_dim, hidden_dim) * scale

        # ゲーティング
        self.W_f = np.random.randn(input_dim, 1) * scale  # 忘却ゲート(スカラー)
        self.b_f = np.ones(1) * 2.0  # 初期バイアス(忘却ゲートを高く保つ)
        self.W_i = np.random.randn(input_dim, 1) * scale  # 入力ゲート(スカラー)

    def forward_sequence(self, x_seq):
        """系列全体を処理"""
        T = len(x_seq)
        d = self.hidden_dim

        # 行列メモリの初期化
        C = np.zeros((d, d))
        n = np.zeros(d)
        outputs = []
        gate_history = {'f': [], 'i': []}

        for t in range(T):
            x = x_seq[t]

            # キー・バリュー・クエリ
            k = x @ self.W_k  # (d,)
            v = x @ self.W_v  # (d,)
            q = x @ self.W_q  # (d,)

            # 指数ゲーティング(数値安定化のためクリッピング)
            f_logit = np.clip(x @ self.W_f + self.b_f, -10, 10)
            i_logit = np.clip(x @ self.W_i, -10, 10)
            f = np.exp(f_logit).item()
            i = np.exp(i_logit).item()

            # 行列メモリ更新
            C = f * C + i * np.outer(v, k)
            n = f * n + i * k

            # 出力(正規化検索)
            denominator = np.maximum(np.abs(n @ q), 1.0)
            h = (C @ q) / denominator

            outputs.append(h)
            gate_history['f'].append(f)
            gate_history['i'].append(i)

        return np.array(outputs), gate_history


class SimpleTiRex:
    """簡易TiRexモデル(1層mLSTM + 線形予測ヘッド)"""

    def __init__(self, patch_size=16, hidden_dim=32, lr=0.005):
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.lr = lr

        # パッチ埋め込み
        self.W_embed = np.random.randn(patch_size, hidden_dim) * 0.1
        self.b_embed = np.zeros(hidden_dim)

        # mLSTM
        self.mlstm = mLSTMCell(hidden_dim, hidden_dim)

        # 予測ヘッド
        self.W_pred = np.random.randn(hidden_dim, patch_size) * 0.1
        self.b_pred = np.zeros(patch_size)

    def patchify(self, series):
        """時系列をパッチに分割"""
        n_patches = len(series) // self.patch_size
        patches = series[:n_patches * self.patch_size].reshape(n_patches, self.patch_size)
        return patches

    def forward(self, series):
        """順伝播: パッチ化 → 埋め込み → mLSTM → 予測"""
        patches = self.patchify(series)
        # パッチ埋め込み
        h_embed = np.tanh(patches @ self.W_embed + self.b_embed)
        # mLSTM処理
        h_out, gates = self.mlstm.forward_sequence(h_embed)
        # 予測(各時刻で次パッチを予測)
        preds = h_out @ self.W_pred + self.b_pred
        return preds, patches, gates


# --- 合成テレメトリデータ ---
def generate_telemetry_with_anomaly(length=1024, period=64):
    """軌道周期 + 異常を含む合成テレメトリ"""
    t = np.arange(length, dtype=float)
    signal = np.sin(2 * np.pi * t / period) + 0.3 * np.sin(2 * np.pi * t / (period * 3))
    noise = np.random.randn(length) * 0.1
    series = signal + noise
    # 異常挿入
    series[600:620] += 3.0  # パルス異常
    series[800:850] *= 0.1  # 振幅低下異常
    return series

series = generate_telemetry_with_anomaly(length=1024, period=64)
model = SimpleTiRex(patch_size=16, hidden_dim=32)

# --- 順伝播と可視化 ---
preds, patches, gates = model.forward(series)

fig, axes = plt.subplots(3, 1, figsize=(16, 10))

# (a) 元の時系列 + 異常区間
axes[0].plot(series, color='#00d4ff', linewidth=0.8)
axes[0].axvspan(600, 620, alpha=0.3, color='red', label='Pulse anomaly')
axes[0].axvspan(800, 850, alpha=0.3, color='orange', label='Amplitude drop')
axes[0].set_title('Input Telemetry with Anomalies')
axes[0].set_ylabel('Value')
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.2)

# (b) ゲート値の推移
t_gates = np.arange(len(gates['f'])) * model.patch_size
axes[1].plot(t_gates, gates['f'], label='Forget gate (f)', color='#ff6b6b', alpha=0.7)
axes[1].plot(t_gates, gates['i'], label='Input gate (i)', color='#00d4ff', alpha=0.7)
axes[1].axvspan(600, 620, alpha=0.2, color='red')
axes[1].axvspan(800, 850, alpha=0.2, color='orange')
axes[1].set_title('mLSTM Gate Values (Exponential Gating)')
axes[1].set_ylabel('Gate value')
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.2)
axes[1].set_yscale('log')

# (c) 予測誤差
recon_error = np.mean((preds[:-1] - patches[1:]) ** 2, axis=1)
t_error = np.arange(len(recon_error)) * model.patch_size
axes[2].plot(t_error, recon_error, color='#ffd93d', linewidth=1)
axes[2].axvspan(600, 620, alpha=0.2, color='red')
axes[2].axvspan(800, 850, alpha=0.2, color='orange')
axes[2].set_title('Next-Patch Prediction Error (anomaly detection signal)')
axes[2].set_xlabel('Time step')
axes[2].set_ylabel('MSE')
axes[2].grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig('tirex_gates.png', dpi=150, bbox_inches='tight')
plt.show()

3段の可視化から、mLSTMベースの時系列処理の特性が読み取れます。上段は入力テレメトリで、時刻600-620にパルス異常、800-850に振幅低下異常が挿入されています。中段のゲート値は対数スケールで表示されており、指数ゲーティングにより値が1を大きく超える(忘却ゲートで情報を増幅)区間と、1未満の区間が交互に現れています。異常区間付近でゲート値のパターンが変化していることが確認でき、mLSTMが入力の統計的変化に反応していることを示しています。下段の予測誤差は、学習前のランダムモデルの出力ですが、異常区間で予測誤差が変化する傾向が見られ、学習後にはこれが異常検知の有効なシグナルとなります。

In-context Learningの実験

# --- In-context Learning の実験 ---
# コンテキスト長を変えて予測精度を比較

def compute_icl_performance(series, model, context_lengths):
    """コンテキスト長と予測精度の関係"""
    results = []
    patch_size = model.patch_size

    for ctx_len in context_lengths:
        # コンテキスト部分
        ctx = series[:ctx_len]
        # モデルを通す(行列メモリにパターンを蓄積)
        preds, patches, _ = model.forward(ctx)
        # 最後のパッチの予測誤差
        if len(preds) > 1:
            error = np.mean((preds[-1] - patches[-1]) ** 2)
            # ベースライン: コンテキストなし(最初のパッチの予測)
            baseline_error = np.mean((preds[0] - patches[0]) ** 2)
        else:
            error = np.nan
            baseline_error = np.nan
        results.append({
            'context_length': ctx_len,
            'prediction_error': error,
            'baseline_error': baseline_error
        })
    return results

context_lengths = [64, 128, 256, 384, 512, 640, 768, 896, 1024]
# 周期的パターン(異常なし)
clean_series = np.sin(2 * np.pi * np.arange(1024) / 64) + \
               0.3 * np.sin(2 * np.pi * np.arange(1024) / (64*3)) + \
               np.random.randn(1024) * 0.1

results = compute_icl_performance(clean_series, model, context_lengths)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# (a) コンテキスト長 vs 予測誤差
ctx = [r['context_length'] for r in results]
errors = [r['prediction_error'] for r in results]
baselines = [r['baseline_error'] for r in results]

axes[0].plot(ctx, errors, 'o-', color='#00d4ff', label='Last patch error')
axes[0].plot(ctx, baselines, 's--', color='#ff6b6b', label='First patch error (no context)')
axes[0].set_xlabel('Context length (time steps)')
axes[0].set_ylabel('Prediction MSE')
axes[0].set_title('In-context Learning: More Context → Better Prediction')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# (b) 行列メモリの「容量」可視化
# 異なるコンテキスト長での行列メモリのフロベニウスノルム
memory_norms = []
for ctx_len in context_lengths:
    ctx = clean_series[:ctx_len]
    patches_ctx = model.patchify(ctx)
    h_embed = np.tanh(patches_ctx @ model.W_embed + model.b_embed)
    _, gate_hist = model.mlstm.forward_sequence(h_embed)
    # 行列メモリを再構成して最終状態のノルムを計算
    d = model.hidden_dim
    C = np.zeros((d, d))
    n = np.zeros(d)
    for t_idx in range(len(h_embed)):
        x = h_embed[t_idx]
        k = x @ model.mlstm.W_k
        v = x @ model.mlstm.W_v
        f = gate_hist['f'][t_idx]
        i = gate_hist['i'][t_idx]
        C = f * C + i * np.outer(v, k)
    memory_norms.append(np.linalg.norm(C, 'fro'))

axes[1].plot(context_lengths, memory_norms, 'o-', color='#ffd93d')
axes[1].set_xlabel('Context length (time steps)')
axes[1].set_ylabel('Matrix memory Frobenius norm')
axes[1].set_title('Matrix Memory Growth with Context')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('tirex_icl.png', dpi=150, bbox_inches='tight')
plt.show()

In-context Learningの実験から2つの重要な特性が確認できます。左のグラフでは、コンテキスト長が増えるにつれて最後のパッチの予測誤差が変化する様子を示しています。赤い点線(最初のパッチの誤差 = コンテキストなしのベースライン)と比較して、コンテキストが長いほど行列メモリにより多くのパターン情報が蓄積され、予測に活用できることを示しています。これがIn-context Learningの本質です — ファインチューニングなしで、入力系列のパターンから即座に適応します。右のグラフは行列メモリのフロベニウスノルムがコンテキスト長とともに成長する様子を示しており、mLSTMの行列メモリが入力パターンの情報を継続的に蓄積していることが確認できます。ただし、忘却ゲートの減衰効果により、古い情報は徐々に減衰し、メモリが無限に増加することはありません。

Transformer vs xLSTM の計算量比較

# --- 計算量の比較 ---
seq_lengths = [64, 128, 256, 512, 1024, 2048, 4096]
d = 256  # 隠れ層次元

transformer_flops = [L**2 * d for L in seq_lengths]
mlstm_flops = [L * d**2 for L in seq_lengths]
slstm_flops = [L * d for L in seq_lengths]

fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.loglog(seq_lengths, transformer_flops, 'o-', color='#ff6b6b',
          label=f'Transformer O(L²d), d={d}', linewidth=2, markersize=6)
ax.loglog(seq_lengths, mlstm_flops, 's-', color='#00d4ff',
          label=f'mLSTM O(Ld²), d={d}', linewidth=2, markersize=6)
ax.loglog(seq_lengths, slstm_flops, '^-', color='#ffd93d',
          label=f'sLSTM O(Ld), d={d}', linewidth=2, markersize=6)

# 交差点
crossover = d  # L = d でTransformerとmLSTMの計算量が等しくなる
ax.axvline(x=crossover, color='gray', linestyle='--', alpha=0.5)
ax.annotate(f'L=d={d}', xy=(crossover, crossover**2 * d), fontsize=9,
            xytext=(crossover*1.5, crossover**2 * d * 2),
            arrowprops=dict(arrowstyle='->', color='gray'))

ax.set_xlabel('Sequence length L')
ax.set_ylabel('FLOPs (per layer)')
ax.set_title('Computational Cost: Transformer vs xLSTM')
ax.legend()
ax.grid(True, alpha=0.3, which='both')
plt.tight_layout()
plt.savefig('tirex_flops.png', dpi=150, bbox_inches='tight')
plt.show()

# 具体的な倍率を表示
for L in [512, 1024, 4096]:
    ratio = (L**2 * d) / (L * d**2)
    print(f"L={L}: Transformer/mLSTM = {ratio:.1f}x")

計算量の比較から、系列長に応じたアーキテクチャ選択の指針が明確になります。赤線のTransformer($O(L^2 d)$)は系列長に対して二乗で増加するのに対し、青線のmLSTM($O(Ld^2)$)は線形にしか増加しません。$L = d$(ここでは256)で両者の計算量が等しくなり、$L > d$ の領域ではmLSTMが効率的です。$L = 4096$ では、mLSTMはTransformerの16倍効率的であり、これが35Mパラメータのモデルで大規模Transformerモデルに対抗できる理由の一つです。時系列は $L \gg d$ となることが多いため、xLSTMアーキテクチャは時系列に本質的に適していると言えます。

まとめ

本記事では、TiRex(NeurIPS 2025)が提案するxLSTMベースの軽量時系列基盤モデルについて解説しました。

  • xLSTMの2つの拡張: sLSTM(指数ゲーティング)とmLSTM(行列メモリ)により、従来のLSTMの限界を克服。特にmLSTMは線形Attentionとの数学的対応を持つ
  • 35Mパラメータの高効率: 710MのChronosや200MのTimesFMを複数ベンチマークで上回るパラメータ効率を達成
  • In-context Learning: 行列メモリにコンテキスト内のパターンを蓄積し、ファインチューニングなしで新ドメインに適応
  • 線形計算量: 系列長に対して線形の計算量で、長い時系列の処理に本質的に適している
  • 衛星運用への適用可能性: 35Mパラメータは計算資源制約のある宇宙環境で動作する現実的なモデルサイズ

次のステップとして、以下の記事も参考にしてください。