10万トークンの小説を要約したい。数百万塩基のDNA配列から遺伝子の機能領域を見つけたい。リアルタイムで音声を文字起こししたい — こうした「長い系列を効率的に処理したい」という要求は、大規模言語モデルの時代にますます切実になっています。
現在のTransformerはSelf-Attentionの計算量が系列長 $n$ に対して $O(n^2)$ であるため、系列長を2倍にすると計算量は4倍に膨れ上がります。GPT-4やClaude 3のようなモデルがコンテキストウィンドウを10万トークン以上に拡張している今、この二乗コストは深刻なボトルネックです。
この問題に対して、まったく異なるアプローチで挑んだのがState Space Models(SSM)です。SSMは制御理論で何十年も使われてきた状態空間方程式を深層学習に導入し、系列長に対して $O(n)$ の計算量で長距離依存性を捉えます。特に2022年に登場したS4(Structured State Spaces for Sequence Modeling)はHiPPO行列による長距離記憶で注目を集め、2023年末のMambaは入力依存の選択機構を加えることでTransformerに匹敵する性能を達成しました。
SSMの理論を理解すると、以下のような分野に応用できます。
- ゲノム配列解析: 数百万塩基のDNA配列を線形時間でモデル化し、遺伝子の構造予測や変異解析に利用する
- 長文書処理: 書籍全体や法律文書を一度に読み込んで要約・検索する
- リアルタイム推論: RNNのように1トークンずつ逐次処理でき、音声認識やロボット制御に使える
- 時系列予測: 長期的な傾向を捉えた気象データや金融データの予測
本記事の内容
- Transformerの $O(n^2)$ 問題とSSMが解決するもの
- 制御理論に由来する状態空間方程式の基礎
- 連続→離散化(ゼロ次ホールド、バイリニア変換)の数学
- 離散SSMの畳み込み表現とFFTによる高速化
- S4の革新:HiPPO行列と対角化による効率的な長距離記憶
- Mambaの革新:Selective State Space(入力依存のフィルタリング)
- PyTorchでの離散SSMとMamba風Selective SSMの実装
- Mamba vs Transformerの比較と使い分け
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
Transformerのスケーラビリティ問題
$O(n^2)$ の壁
SSMの動機を理解するために、まずTransformerの計算量問題を整理しましょう。
Self-Attentionの計算式を振り返ります。
$$ \text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \text{softmax}\left(\frac{\bm{Q}\bm{K}^\top}{\sqrt{d_k}}\right)\bm{V} $$
$\bm{Q}, \bm{K} \in \mathbb{R}^{n \times d_k}$ のとき、$\bm{Q}\bm{K}^\top$ の結果は $n \times n$ の注意行列になります。この行列を構成するためだけで $O(n^2 d_k)$ の計算量と $O(n^2)$ のメモリが必要です。
具体的な数値で見ると、深刻さがわかります。系列長512ならば注意行列は約26万要素(1MB程度)ですが、系列長100,000にすると100億要素(約40GB)になります。GPUのメモリ容量を簡単に超えてしまいます。
これまでのアプローチとその限界
この問題に対して、いくつかの解決策が提案されてきました。
Sparse Attention(Longformer, BigBird)はアテンションパターンをスパースにして $O(n\sqrt{n})$ や $O(n)$ に落としますが、「どのトークンに注目するか」のパターンを事前に固定するため、タスクに応じた柔軟な情報選択が難しい場面があります。
Linear Attention は softmax を除去してカーネルトリックで $O(n)$ を達成しますが、softmax の鋭いピークを再現しにくく、精度がやや劣化する傾向があります。
Flash Attention はGPUメモリ階層を活用して定数倍を大幅に削減しますが、漸近的な計算量 $O(n^2)$ 自体は変わりません。
SSMはこれらとは根本的に異なるアプローチを取ります。「全トークンペアの関係を計算する」というAttentionのパラダイム自体を捨て、代わりに隠れ状態を介して情報を線形再帰的に伝播するモデルを構築します。これにより、学習時は畳み込みで $O(n \log n)$、推論時は再帰で $O(1)$(1ステップあたり)という効率性を実現しました。
では、SSMの基盤となる状態空間方程式がどのような仕組みなのかを見ていきましょう。
状態空間モデルの基礎
制御理論から深層学習へ
State Space Model(状態空間モデル)は、もともと1960年代にルドルフ・カルマンが制御理論で体系化した概念です。ロケットの姿勢制御、自動車のクルーズコントロール、電気回路のフィルタ設計 — これらすべてで状態空間表現が使われています。
直感的に言えば、状態空間モデルは「見えない内部状態を通じて、入力を出力にマッピングする」仕組みです。水槽に水を注ぐ場面を想像してみてください。蛇口からの流量(入力)が決まっても、水槽の水位(内部状態)がわからなければ、溢れるかどうか(出力)は予測できません。水位という「隠れた状態」が、入力と出力をつなぐ仲介役になっています。
深層学習のSSMもまったく同じ発想です。入力トークンの系列 $u(t)$ が与えられたとき、隠れ状態 $\bm{x}(t)$ を介して出力 $y(t)$ を計算します。
連続時間の状態空間方程式
連続時間のSSMは、次の2つの方程式で定義されます。
$$ \frac{d\bm{x}(t)}{dt} = \bm{A}\bm{x}(t) + \bm{B}u(t) \tag{1} $$
$$ y(t) = \bm{C}\bm{x}(t) + Du(t) \tag{2} $$
各変数の意味を整理します。
- $u(t) \in \mathbb{R}$: 時刻 $t$ での入力信号(スカラー)
- $\bm{x}(t) \in \mathbb{R}^N$: 時刻 $t$ での隠れ状態($N$ 次元ベクトル)
- $y(t) \in \mathbb{R}$: 時刻 $t$ での出力信号(スカラー)
- $\bm{A} \in \mathbb{R}^{N \times N}$: 状態遷移行列(系の動特性を決める)
- $\bm{B} \in \mathbb{R}^{N \times 1}$: 入力行列(入力が状態にどう影響するか)
- $\bm{C} \in \mathbb{R}^{1 \times N}$: 出力行列(状態から出力をどう読み出すか)
- $D \in \mathbb{R}$: 直接伝達項(スキップ接続に相当、多くの場合0と仮定)
式(1)は「隠れ状態の時間変化は、現在の状態と入力の線形結合で決まる」と読めます。式(2)は「出力は隠れ状態の線形射影で得られる」と読めます。
ここで重要な点は、行列 $\bm{A}$ が系の記憶特性を決定するということです。$\bm{A}$ の固有値が負の実部を持てば系は安定し、過去の入力は指数的に減衰して忘れられます。固有値が0に近い成分があれば、過去の情報を長く保持できます。この $\bm{A}$ をどう設計するかがSSMの核心であり、後述するS4のHiPPO行列が画期的だった理由です。
RNN・CNNとの関係
SSMは面白い二面性を持っています。
再帰モード(RNNに対応): 式(1)(2)を時間ステップごとに逐次計算すれば、RNNのような再帰的処理になります。1ステップあたりの計算量は $O(N)$ で、推論時に高速です。
畳み込みモード(CNNに対応): 後で示すように、SSMの出力は入力と特定のカーネルの畳み込みとして表現できます。これを利用すればFFTで $O(n \log n)$ の並列学習が可能です。
つまりSSMは、学習時はCNNのように並列化でき、推論時はRNNのように逐次処理できるという、両方の良いところを兼ね備えたモデルです。この二面性がSSMの大きな強みです。
しかし、ニューラルネットワークで使うには連続時間の方程式をそのまま扱うことはできません。離散的なトークン列を処理するために、連続時間のSSMを離散化する必要があります。
連続時間SSMの離散化
なぜ離散化が必要か
テキストや音声のようなデジタルデータは、離散的な時間ステップ $k = 0, 1, 2, \ldots$ で与えられます。微分方程式 $d\bm{x}/dt = \bm{A}\bm{x} + \bm{B}u$ は連続時間の世界の話なので、離散トークン列を処理するためには「連続→離散」の変換が必要です。
離散化のイメージとしては、滑らかな曲線を一定間隔でサンプリングして折れ線グラフに変換する操作を思い浮かべてください。サンプリング間隔 $\Delta$ が小さいほど元の曲線に忠実ですが、計算コストは増えます。この $\Delta$(ステップサイズ)がSSMの重要なハイパーパラメータになります。
ゼロ次ホールド(ZOH)離散化
最も広く使われる離散化手法がゼロ次ホールド(Zero-Order Hold, ZOH)です。これは「サンプル間で入力は一定値を保つ」と仮定する方法です。
連続時間の状態方程式 $\frac{d\bm{x}}{dt} = \bm{A}\bm{x} + \bm{B}u$ の解は、初期値 $\bm{x}(t_0)$ と入力 $u(\tau)$ を使って次のように書けます。
$$ \bm{x}(t) = e^{\bm{A}(t – t_0)}\bm{x}(t_0) + \int_{t_0}^{t} e^{\bm{A}(t – \tau)}\bm{B}u(\tau)\,d\tau \tag{3} $$
ここで $t_0 = k\Delta$, $t = (k+1)\Delta$ とし、この区間で $u(\tau) = u_k$(一定)と仮定します。まず、行列指数関数の性質 $e^{\bm{A}(t-\tau)} = e^{\bm{A}((k+1)\Delta – \tau)}$ を使って積分を計算します。
変数置換 $s = (k+1)\Delta – \tau$ とすると $d\tau = -ds$ であり、$\tau = k\Delta$ のとき $s = \Delta$、$\tau = (k+1)\Delta$ のとき $s = 0$ です。積分の向きを反転させると次を得ます。
$$ \int_{k\Delta}^{(k+1)\Delta} e^{\bm{A}((k+1)\Delta – \tau)}\bm{B}u_k\,d\tau = \left(\int_0^{\Delta} e^{\bm{A}s}\,ds\right)\bm{B}u_k \tag{4} $$
$\bm{A}$ が正則(逆行列を持つ)ならば、行列指数関数の積分公式を適用できます。
$$ \int_0^{\Delta} e^{\bm{A}s}\,ds = \bm{A}^{-1}\left(e^{\bm{A}\Delta} – \bm{I}\right) \tag{5} $$
式(3)に式(4)(5)を代入すると、離散時間の状態方程式が得られます。
$$ \bm{x}_{k+1} = \overline{\bm{A}}\,\bm{x}_k + \overline{\bm{B}}\,u_k \tag{6} $$
$$ y_k = \bm{C}\bm{x}_k + Du_k \tag{7} $$
ここで、離散化されたパラメータは次のとおりです。
$$ \overline{\bm{A}} = e^{\bm{A}\Delta}, \quad \overline{\bm{B}} = \bm{A}^{-1}(e^{\bm{A}\Delta} – \bm{I})\bm{B} = (\overline{\bm{A}} – \bm{I})\bm{A}^{-1}\bm{B} \tag{8} $$
$\bm{C}$ は離散化の影響を受けず、そのまま使用します。直接伝達項 $D$ も同様です。
バイリニア変換(Tustin法)
ZOHの代替としてよく使われるのがバイリニア変換(Tustin法)です。連続時間の伝達関数における変換 $s = \frac{2}{\Delta}\frac{z-1}{z+1}$ に対応し、周波数特性の保存に優れています。
バイリニア変換では、微分を次のように近似します。
$$ \frac{d\bm{x}}{dt} \approx \frac{\bm{x}_{k+1} – \bm{x}_k}{\Delta} $$
状態方程式に代入し、入力を2時刻の平均 $\frac{u_{k+1} + u_k}{2}$ で近似すると次を得ます。
$$ \frac{\bm{x}_{k+1} – \bm{x}_k}{\Delta} = \bm{A}\frac{\bm{x}_{k+1} + \bm{x}_k}{2} + \bm{B}\frac{u_{k+1} + u_k}{2} $$
$\bm{x}_{k+1}$ について解くために、左辺を整理して $\bm{x}_{k+1}$ と $\bm{x}_k$ を分離します。
$$ \bm{x}_{k+1} – \bm{x}_k = \frac{\Delta}{2}\bm{A}(\bm{x}_{k+1} + \bm{x}_k) + \frac{\Delta}{2}\bm{B}(u_{k+1} + u_k) $$
$\bm{x}_{k+1}$ を含む項を左辺に集めると次のようになります。
$$ \left(\bm{I} – \frac{\Delta}{2}\bm{A}\right)\bm{x}_{k+1} = \left(\bm{I} + \frac{\Delta}{2}\bm{A}\right)\bm{x}_k + \frac{\Delta}{2}\bm{B}(u_{k+1} + u_k) $$
両辺に $\left(\bm{I} – \frac{\Delta}{2}\bm{A}\right)^{-1}$ を左からかけると、離散化パラメータが得られます。
$$ \overline{\bm{A}} = \left(\bm{I} – \frac{\Delta}{2}\bm{A}\right)^{-1}\left(\bm{I} + \frac{\Delta}{2}\bm{A}\right) \tag{9} $$
$$ \overline{\bm{B}} = \left(\bm{I} – \frac{\Delta}{2}\bm{A}\right)^{-1}\Delta\bm{B} \tag{10} $$
バイリニア変換のメリットは2つあります。第1に、行列指数関数 $e^{\bm{A}\Delta}$ の計算が不要で、逆行列の計算だけで済みます。第2に、連続系が安定($\bm{A}$ の固有値が負の実部)ならば離散系も必ず安定になるという安定性の保存が保証されます。S4の原論文ではこのバイリニア変換が採用されました。
ZOHとバイリニア変換のどちらを使っても、離散SSMの形は同じです。違いは $\overline{\bm{A}}$ と $\overline{\bm{B}}$ の具体的な計算方法だけです。以降は、離散化の方法に依存しない一般的な離散SSMの性質を見ていきましょう。
離散SSMの畳み込み表現
再帰式を展開する
離散SSMの再帰式(6)(7)を、$\bm{x}_0 = \bm{0}$(ゼロ初期状態)として展開してみましょう。
$k = 0$ のとき:
$$ \bm{x}_1 = \overline{\bm{A}}\,\bm{x}_0 + \overline{\bm{B}}\,u_0 = \overline{\bm{B}}\,u_0 $$
$$ y_0 = \bm{C}\bm{x}_0 = 0 $$
$k = 1$ のとき、$\bm{x}_1$ を代入すると:
$$ \bm{x}_2 = \overline{\bm{A}}\overline{\bm{B}}\,u_0 + \overline{\bm{B}}\,u_1 $$
$$ y_1 = \bm{C}\overline{\bm{B}}\,u_0 $$
$k = 2$ のとき、$\bm{x}_2$ を代入すると:
$$ y_2 = \bm{C}\overline{\bm{A}}\overline{\bm{B}}\,u_0 + \bm{C}\overline{\bm{B}}\,u_1 $$
一般に、直接伝達項 $D = 0$ とすれば、出力 $y_k$ は次のように書けます。
$$ y_k = \sum_{j=0}^{k-1} \bm{C}\overline{\bm{A}}^{j}\overline{\bm{B}}\,u_{k-1-j} \tag{11} $$
畳み込みカーネル
式(11)をよく見ると、これは入力系列 $\{u_0, u_1, \ldots\}$ とカーネル $\overline{\bm{K}}$ の畳み込みにほかなりません。カーネルを次のように定義します。
$$ \overline{K}_j = \bm{C}\overline{\bm{A}}^{j}\overline{\bm{B}}, \quad j = 0, 1, \ldots, L-1 \tag{12} $$
ここで $L$ は系列長です。するとSSMの出力は次の畳み込みで表現できます。
$$ y = \overline{\bm{K}} * u \tag{13} $$
これは1次元畳み込み(因果的畳み込み)そのものです。カーネル $\overline{\bm{K}}$ さえ事前に計算しておけば、FFT(高速フーリエ変換)を使って $O(L \log L)$ で出力を計算できます。
$$ y = \mathcal{F}^{-1}[\mathcal{F}(\overline{\bm{K}}) \odot \mathcal{F}(u)] \tag{14} $$
ここで $\mathcal{F}$ はFFT、$\odot$ は要素ごとの積を表します。
二重性のまとめ
SSMの本質的な強みは、次の二重性にあります。
| モード | 計算量 | 適した場面 |
|---|---|---|
| 再帰モード(式(6)(7)) | 1ステップ $O(N)$、全体 $O(LN)$ | 推論(1トークンずつ生成) |
| 畳み込みモード(式(14)) | $O(L \log L)$ | 学習(系列全体を並列処理) |
Transformerは学習も推論も $O(L^2)$(KVキャッシュを使えば推論は $O(L)$ だがメモリは $O(L)$ 必要)であるのに対し、SSMは学習 $O(L \log L)$、推論時の1ステップ $O(N)$(隠れ状態のサイズのみに依存)という大きな優位性があります。
ただし、ここで一つ問題が残っています。カーネル $\overline{K}_j = \bm{C}\overline{\bm{A}}^j\overline{\bm{B}}$ の計算には $\overline{\bm{A}}$ の $j$ 乗が必要で、$j$ が大きくなると数値的に不安定になり得ます。さらに、ランダムな $\bm{A}$ では長距離依存性を捉えることができません — $\overline{\bm{A}}^j$ が指数的に減衰してしまうからです。この問題をエレガントに解決したのがS4のHiPPO行列です。
S4の革新:HiPPO行列による長距離記憶
なぜランダムな $\bm{A}$ ではうまくいかないのか
ここまでの議論では、状態遷移行列 $\bm{A}$ は任意の行列でした。しかし、ランダムに初期化した $\bm{A}$ を使ってSSMを学習させると、うまくいきません。
理由は明快です。離散系での出力は $y_k = \sum_j \bm{C}\overline{\bm{A}}^j \overline{\bm{B}}\,u_{k-1-j}$ であり、$j$ ステップ前の入力に対する応答は $\overline{\bm{A}}^j$ のノルムに比例します。ランダムな行列の固有値は一般に単位円の内側に分布するため、$\overline{\bm{A}}^j$ は $j$ が大きくなると指数的に減衰します。つまり、数百ステップ前の入力はほとんど忘れ去られてしまうのです。
これはRNN/LSTMが長距離依存性を学習しにくいのと本質的に同じ問題(勾配消失)です。SSMで長距離記憶を実現するには、$\bm{A}$ を意図的に設計する必要があります。
HiPPO:高次多項式射影演算子
S4の著者であるGu et al.は、HiPPO(High-order Polynomial Projection Operators)という数学的フレームワークを用いて $\bm{A}$ を設計しました。
HiPPOの発想は非常にエレガントです。「過去の入力信号 $u(\tau)$($\tau \leq t$)を、直交多項式の係数として圧縮して記憶する」という発想です。
たとえば、ある関数を3次多項式 $a_0 + a_1 x + a_2 x^2 + a_3 x^3$ で近似することを考えてください。4つの係数 $\{a_0, a_1, a_2, a_3\}$ さえ覚えていれば、元の関数の大まかな形を復元できます。HiPPOは、この「関数を多項式係数に圧縮する」操作を微分方程式として定式化し、その結果として $\bm{A}$ 行列が自然に導かれることを示しました。
具体的には、時刻 $t$ までの入力関数 $u(\tau)$($0 \leq \tau \leq t$)を、区間 $[0, t]$ 上の $N$ 個のルジャンドル多項式で最適に近似するとき、近似係数 $\bm{x}(t) \in \mathbb{R}^N$ の時間発展は次の微分方程式に従います。
$$ \frac{d\bm{x}(t)}{dt} = \bm{A}_{\text{HiPPO}}\bm{x}(t) + \bm{B}_{\text{HiPPO}}u(t) \tag{15} $$
HiPPO-LegS行列の定義
最も広く使われるのはHiPPO-LegS(Legendre Scaled)行列で、次のように定義されます。
$$ (\bm{A}_{\text{HiPPO}})_{nk} = \begin{cases} -(2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ -(n+1) & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \tag{16} $$
$$ (\bm{B}_{\text{HiPPO}})_n = (2n+1)^{1/2} \tag{17} $$
この行列は下三角行列であり、対角要素は $-(n+1)$ です。たとえば $N = 4$ のとき、具体的な形は次のようになります。
$$ \bm{A}_{\text{HiPPO}} = -\begin{pmatrix} 1 & 0 & 0 & 0 \\ \sqrt{3} & 2 & 0 & 0 \\ \sqrt{5} & \sqrt{15} & 3 & 0 \\ \sqrt{7} & \sqrt{21} & \sqrt{35} & 4 \end{pmatrix} $$
この行列の固有値は $\{-1, -2, -3, -4\}$(対角要素そのもの)です。全て負ですが、$-1$ のように0に近い固有値があるため、対応する成分はゆっくりと減衰し、遠い過去の情報も保持できます。これがHiPPO行列が長距離記憶を実現できる数学的理由です。
なぜHiPPOが優れているのか
HiPPO行列の本質的な優位性は、「記憶の最適な圧縮」を保証している点にあります。
通常のRNNは固定サイズの隠れベクトルに過去の情報を圧縮しますが、どのように圧縮するかは学習に任されています。一方、HiPPO-SSMでは「過去の入力関数のルジャンドル多項式近似として最適な係数を常に保持する」ことが数学的に証明されています。
$N$ 次元の隠れ状態で $N$ 個の多項式係数を保持するので、$N$ を増やせばより精密な記憶が可能です。そして重要なことに、近似の精度は時間に依存しません — 10ステップ前の入力も1000ステップ前の入力も同じ精度で記憶されます。これは「最近の情報を優先して古い情報を忘れる」RNNとは根本的に異なる特性です。
対角化による効率化(DPLR構造)
HiPPO行列は $N \times N$ の密な下三角行列です。カーネル $\overline{K}_j = \bm{C}\overline{\bm{A}}^j\overline{\bm{B}}$ を素朴に計算すると、各 $j$ について $O(N^2)$ の行列ベクトル積が必要で、カーネル全体では $O(LN^2)$ の計算量がかかります。$N$ が数百の場合、これは無視できないコストです。
S4はこの問題をDPLR(Diagonal Plus Low-Rank)構造で解決しました。HiPPO行列は次のように分解できます。
$$ \bm{A}_{\text{HiPPO}} = \bm{\Lambda} + \bm{p}\bm{q}^* \tag{18} $$
ここで $\bm{\Lambda}$ は対角行列、$\bm{p}, \bm{q} \in \mathbb{C}^N$ はベクトルです。${}^*$ は共役転置を表します。
「対角行列 + ランク1行列」の構造を持つため、Woodbury行列恒等式を適用すると、離散化で必要な逆行列を効率的に計算できます。さらに、カーネルの計算を周波数領域で行うことで、全体の計算量を $O(N + L)$ に削減できます。
具体的には、カーネルの $z$ 変換を閉じた形で計算します。カーネル生成関数を次のように定義します。
$$ \hat{K}(z) = \sum_{j=0}^{L-1} \overline{K}_j z^j = \sum_{j=0}^{L-1} \bm{C}\overline{\bm{A}}^j\overline{\bm{B}}\,z^j = \bm{C}(\bm{I} – \overline{\bm{A}}z)^{-1}\overline{\bm{B}} \tag{19} $$
$\overline{\bm{A}}$ がDPLR構造を持つとき、$(\bm{I} – \overline{\bm{A}}z)^{-1}$ はWoodbury恒等式で効率的に計算でき、$z$ を単位根 $\omega_L^j = e^{2\pi i j/L}$($j = 0, \ldots, L-1$)に代入すれば、$L$ 個のカーネル値がそれぞれ $O(N)$ で得られます。合計で $O(NL)$ ですが、 $N \ll L$ のため実質的に $O(L)$ です。
S4は以上の工夫により、HiPPO行列の長距離記憶能力を $O(L \log L)$ の計算量で利用できるようにした画期的なモデルです。Long Range Arena ベンチマークでは、4096ステップの長距離タスクにおいてTransformerやその効率化変種を大幅に上回る性能を示しました。
しかし、S4には本質的な制約があります。パラメータ $\bm{A}, \bm{B}, \bm{C}$ は入力に依存しない固定値です。つまり、どんな入力が来ても同じフィルタが適用されます。自然言語のように「文脈に応じて重要な情報を選択的に記憶・忘却する」必要があるタスクでは、この固定フィルタは不利です。この制約を打ち破ったのがMambaです。
Mambaの革新:Selective State Space Model
固定フィルタの限界
S4を含む従来のSSMには、根本的な弱点がありました。パラメータ $\bm{B}, \bm{C}, \Delta$ が入力に依存しないため、全ての入力トークンに対して同一のフィルタリングが適用されます。
これがなぜ問題なのか、具体例で考えてみましょう。「彼は東京に住んでいる。趣味は読書で、特にSF小説が好きだ。彼の出身地はどこですか?」という文を処理する場面を想像してください。この質問に答えるには「東京」という情報を選択的に記憶し、「趣味は読書」という無関係な情報を無視する必要があります。
TransformerのSelf-Attentionはまさにこれが得意です。「出身地」というクエリに対して「東京」というキーに高いアテンションスコアを割り当て、関連情報を選択的に取り出します。一方、固定フィルタのSSMでは、全てのトークンに同じ重みで処理が行われるため、この選択が困難です。
Mambaの解決策:入力依存のパラメータ
2023年末に発表されたMamba(Gu & Dao, 2023)は、この問題をSelective State Space Model(選択的状態空間モデル)で解決しました。核心のアイデアは非常にシンプルです。
SSMのパラメータ $\bm{B}, \bm{C}, \Delta$ を入力トークンの関数にする。
具体的には、入力 $\bm{u}_k \in \mathbb{R}^D$($D$ はモデルの次元)に対して、次のように各パラメータを計算します。
$$ \bm{B}_k = \text{Linear}_B(\bm{u}_k) \in \mathbb{R}^N \tag{20} $$
$$ \bm{C}_k = \text{Linear}_C(\bm{u}_k) \in \mathbb{R}^N \tag{21} $$
$$ \bm{\Delta}_k = \text{softplus}(\text{Linear}_\Delta(\bm{u}_k)) \in \mathbb{R}^D \tag{22} $$
ここで $\text{Linear}$ は線形射影、$\text{softplus}(x) = \log(1 + e^x)$ は正値を保証する活性化関数です。$\bm{A}$ 行列は入力に依存しない学習パラメータとして残しますが、$\Delta_k$ を通じた離散化で間接的に入力依存になります。
離散化パラメータは各時刻 $k$ で異なる値を持ちます。ZOHの場合は次のとおりです。
$$ \overline{\bm{A}}_k = e^{\bm{A}\Delta_k}, \quad \overline{\bm{B}}_k = (\overline{\bm{A}}_k – \bm{I})\bm{A}^{-1}\bm{B}_k \tag{23} $$
$\bm{A}$ が対角行列の場合(Mambaでは対角を仮定)、指数関数は要素ごとの演算になり、効率的に計算できます。
$$ \overline{A}_{k,i} = e^{A_i \Delta_{k}}, \quad \overline{B}_{k,i} = \frac{e^{A_i \Delta_{k}} – 1}{A_i} B_{k,i} \tag{24} $$
選択機構の意義
$\Delta_k$, $\bm{B}_k$, $\bm{C}_k$ が入力依存になることで、SSMは入力トークンごとに「何を記憶し、何を忘れるか」を動的に制御できます。各パラメータの役割を直感的に理解しましょう。
ステップサイズ $\Delta_k$(ゲートの役割): $\Delta_k$ が大きいと $\overline{A}_k = e^{A\Delta_k}$ の減衰が強くなり、過去の状態をリセットして新しい入力を重視します。逆に $\Delta_k$ が小さいと過去の状態を保持し、現在の入力を無視します。これはLSTMの忘却ゲートに対応する動作です。
入力行列 $\bm{B}_k$(入力ゲートの役割): $\bm{B}_k$ が大きいと現在の入力 $u_k$ が隠れ状態に強く書き込まれます。入力の内容に応じて「書き込みの強さ」を動的に調整できます。
出力行列 $\bm{C}_k$(出力ゲートの役割): $\bm{C}_k$ は隠れ状態のどの成分を出力に使うかを制御します。同じ隠れ状態でも、入力(文脈)によって読み出す情報を変えられます。
先ほどの例で言えば、「東京」というトークンが来たとき $\Delta$ を大きくして重要な情報を状態に書き込み、「趣味は」のような無関係な情報が来たとき $\Delta$ を小さくして状態を保持する、という動作が可能になります。
畳み込み表現の喪失とParallel Scan
入力依存のパラメータは表現力を大幅に向上させますが、重大なトレードオフがあります。パラメータが時刻 $k$ ごとに異なるため、畳み込みカーネルが定義できなくなります。
S4では $\overline{\bm{A}}$ が全時刻で同一だったため $\overline{K}_j = \bm{C}\overline{\bm{A}}^j\overline{\bm{B}}$ というカーネルが存在し、FFTで並列計算できました。しかしMambaでは $\overline{\bm{A}}_k, \overline{\bm{B}}_k$ が時刻ごとに異なるため、この便利な畳み込み表現が使えません。
Mambaはこの問題をParallel Scan(並列スキャン)アルゴリズムで解決しました。Parallel Scanは、結合的な二項演算の累積計算を並列化する汎用アルゴリズムです。
離散SSMの再帰式を次のように書き直します。
$$ \bm{x}_k = \overline{\bm{A}}_k \bm{x}_{k-1} + \overline{\bm{B}}_k u_k \tag{25} $$
この再帰を拡張タプル $(\overline{\bm{A}}_k, \overline{\bm{B}}_k u_k)$ の結合演算として定式化します。2つのタプル $(a_1, b_1)$ と $(a_2, b_2)$ の結合を次のように定義します。
$$ (a_2, b_2) \bullet (a_1, b_1) = (a_2 a_1, \; a_2 b_1 + b_2) \tag{26} $$
この演算は結合法則 $(c \bullet b) \bullet a = c \bullet (b \bullet a)$ を満たすことが確認できます。
結合法則のおかげで、系列長 $L$ の再帰を $O(\log L)$ の並列ステップで計算できます。Parallel Scanでは、隣接要素のペアを結合し、次にその結果のペアを結合し…と二分木的に処理します。各ステップの並列度は $L/2, L/4, \ldots, 1$ と減りますが、GPU上では全ステップで十分な並列性が確保できるため、実質的に $O(L)$ の計算量で全時刻の隠れ状態を算出できます。
ハードウェア効率的な実装
Mambaの実用上のもう一つの革新は、GPUのメモリ階層を意識した実装です。Flash Attentionと同様の発想で、次の最適化を行っています。
- カーネルフュージョン: 離散化、再帰、出力射影の各操作を1つのGPUカーネルにまとめ、HBM(GPU DRAM)への読み書きを最小化します。
- SRAM計算: Parallel Scanの中間状態をGPUのSRAM(共有メモリ/レジスタ)に保持し、HBMへのアクセスを避けます。
- 再計算戦略: 逆伝播時に中間状態を保存せず再計算することで、メモリ使用量を削減します。
これらの最適化により、MambaはTransformer+Flash Attentionと比較して推論スループットが最大5倍高速になることが報告されています。
Mambaの理論と高速化手法を理解したところで、実際にPyTorchでSSMとSelective SSMを実装して動作を確認しましょう。
PyTorchでの実装
基本的な離散SSMの実装
まず、S4スタイルの固定パラメータSSMを実装します。このコードで、離散SSMの再帰モードと畳み込みモードの両方を実装し、両者が同一の出力を生むことを確認します。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
def discretize_zoh(A, B, Delta):
"""
ゼロ次ホールドによる離散化
A: (N,) 対角要素, B: (N,), Delta: スカラー
"""
# 対角行列の指数関数は要素ごとのexp
A_bar = torch.exp(A * Delta) # (N,)
B_bar = (A_bar - 1.0) / A * B # (N,)
return A_bar, B_bar
def ssm_recurrent(A_bar, B_bar, C, u):
"""
再帰モードでSSMを計算
A_bar: (N,), B_bar: (N,), C: (N,), u: (L,)
"""
N = A_bar.shape[0]
L = u.shape[0]
x = torch.zeros(N, dtype=u.dtype) # 隠れ状態
ys = []
for k in range(L):
x = A_bar * x + B_bar * u[k] # 状態更新 (要素ごと)
y = C @ x # 出力 (内積)
ys.append(y)
return torch.stack(ys) # (L,)
def ssm_conv(A_bar, B_bar, C, u):
"""
畳み込みモード(FFT)でSSMを計算
"""
L = u.shape[0]
# カーネルの構築: K[j] = C @ diag(A_bar^j) @ B_bar
powers = A_bar.unsqueeze(0) ** torch.arange(L).unsqueeze(1).float() # (L, N)
K = (C.unsqueeze(0) * powers * B_bar.unsqueeze(0)).sum(dim=1) # (L,)
# FFTによる畳み込み
K_f = torch.fft.rfft(K, n=2*L)
u_f = torch.fft.rfft(u, n=2*L)
y = torch.fft.irfft(K_f * u_f, n=2*L)[:L]
return y.real
# パラメータ設定
N = 16 # 隠れ状態の次元
L = 128 # 系列長
Delta = 0.01 # ステップサイズ
torch.manual_seed(42)
A = -torch.arange(1, N+1, dtype=torch.float32) # HiPPO風の対角要素
B = torch.ones(N, dtype=torch.float32)
C = torch.randn(N, dtype=torch.float32) * 0.1
# 入力: ステップ関数 + サイン波
t = torch.linspace(0, 1, L)
u = (t > 0.2).float() * torch.sin(2 * np.pi * 3 * t)
# 離散化
A_bar, B_bar = discretize_zoh(A, B, Delta)
# 再帰モードと畳み込みモードで計算
y_rec = ssm_recurrent(A_bar, B_bar, C, u)
y_conv = ssm_conv(A_bar, B_bar, C, u)
# 可視化
fig, axes = plt.subplots(3, 1, figsize=(10, 8))
axes[0].plot(t.numpy(), u.numpy(), 'b-', linewidth=1.5)
axes[0].set_title('Input signal u(t)')
axes[0].set_ylabel('Amplitude')
axes[1].plot(t.numpy(), y_rec.detach().numpy(), 'r-', label='Recurrent', linewidth=1.5)
axes[1].plot(t.numpy(), y_conv.detach().numpy(), 'g--', label='Convolution (FFT)', linewidth=1.5)
axes[1].set_title('SSM Output (both modes)')
axes[1].set_ylabel('Amplitude')
axes[1].legend()
axes[2].plot(t.numpy(), (y_rec - y_conv).abs().detach().numpy(), 'k-', linewidth=1.0)
axes[2].set_title('Absolute difference between modes')
axes[2].set_ylabel('|Recurrent - Conv|')
axes[2].set_xlabel('Time')
plt.tight_layout()
plt.savefig('ssm_dual_mode.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Maximum difference: {(y_rec - y_conv).abs().max().item():.2e}")
上のグラフから、3つの重要な特徴が読み取れます。
- 再帰モードと畳み込みモードの出力は実質的に一致しています(差は $10^{-5}$ 程度の浮動小数点誤差のみ)。これは式(11)から式(14)への変換が正しいことの数値的な確認です。
- SSMの出力は入力信号のフィルタリング結果になっています。$t < 0.2$ では入力がゼロなので出力もゼロ、$t > 0.2$ でサイン波が入力されると、SSMが応答して出力が生じます。$\bm{A}$ の固有値($-1, -2, \ldots, -N$)が決めるフィルタ特性により、入力の周波数成分が選択的に通過・減衰していることがわかります。
- 立ち上がり部分に過渡応答が見られます。$t = 0.2$ 付近で出力が振動しながら定常状態に近づいていく様子は、線形時不変システムの典型的な挙動です。
HiPPO行列の記憶特性の可視化
次に、HiPPO行列が過去の入力をどのように記憶しているかを可視化します。HiPPOの隠れ状態からルジャンドル多項式で入力関数を復元し、記憶の精度を確認します。
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.special import eval_legendre
def make_hippo_legs(N):
"""HiPPO-LegS行列を構築"""
A = torch.zeros(N, N)
B = torch.zeros(N)
for n in range(N):
B[n] = (2*n + 1) ** 0.5
for k in range(n+1):
if n == k:
A[n, k] = -(n + 1)
elif n > k:
A[n, k] = -((2*n+1) * (2*k+1)) ** 0.5
return A, B
def reconstruct_from_hippo(x_state, t_current, t_eval):
"""
隠れ状態からルジャンドル多項式で入力関数を復元
x_state: (N,), t_current: 現在時刻, t_eval: 評価点
"""
N = x_state.shape[0]
# [0, t_current] を [-1, 1] にマッピング
s = 2.0 * t_eval / t_current - 1.0
reconstruction = np.zeros_like(t_eval)
for n in range(N):
Pn = eval_legendre(n, s)
reconstruction += x_state[n].item() * (2*n+1)**0.5 / t_current * Pn
return reconstruction
# HiPPO行列の構築
N = 32
A_hippo, B_hippo = make_hippo_legs(N)
# 離散化 (オイラー法で簡易的に)
dt = 0.001
L = 3000 # 3秒間
A_bar = torch.eye(N) + dt * A_hippo # (N, N)
B_bar = dt * B_hippo # (N,)
# 入力信号: 区間ごとに異なる特徴を持つ信号
t = torch.arange(L) * dt
u = torch.zeros(L)
u[200:800] = 1.0 # ステップ
u[1000:1500] = torch.sin(2 * np.pi * 5 * t[1000:1500]) # サイン波
u[2000:2500] = torch.linspace(0, 1, 500) # ランプ
# SSMを再帰的に計算し、特定時刻の隠れ状態を記録
x = torch.zeros(N)
snapshots = {}
snapshot_times = [1.0, 2.0, 2.8]
for k in range(L):
x = A_bar @ x + B_bar * u[k]
current_t = (k + 1) * dt
for st in snapshot_times:
if abs(current_t - st) < dt / 2:
snapshots[st] = x.clone()
# 復元の可視化
fig, axes = plt.subplots(len(snapshot_times), 1, figsize=(10, 9))
t_np = t.numpy()
u_np = u.numpy()
for idx, st in enumerate(snapshot_times):
ax = axes[idx]
# 真の入力信号 [0, st]
mask = t_np <= st + dt
ax.plot(t_np[mask], u_np[mask], 'b-', linewidth=1.5, label='True input', alpha=0.7)
# HiPPOからの復元
t_eval = np.linspace(0, st, 500)
recon = reconstruct_from_hippo(snapshots[st], st, t_eval)
ax.plot(t_eval, recon, 'r--', linewidth=1.5, label=f'HiPPO reconstruction (N={N})')
ax.set_title(f't = {st:.1f} s')
ax.set_ylabel('Amplitude')
ax.legend(loc='upper right', fontsize=8)
ax.set_xlim(-0.05, 3.0)
ax.set_ylim(-1.5, 1.5)
axes[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.savefig('hippo_memory.png', dpi=150, bbox_inches='tight')
plt.show()
上のグラフから、HiPPO行列の記憶特性について重要な知見が得られます。
- 過去の入力全体を隠れ状態に圧縮して保持しています。$t = 2.8$ 秒の時点で、$t = 0.2$ 秒に始まったステップ関数、$t = 1.0$ 秒のサイン波、$t = 2.0$ 秒のランプ関数という3つの異なる特徴がすべて復元されています。これは「最近の情報だけを記憶する」通常のRNNとは本質的に異なります。
- 滑らかな部分は精密に、急峻な変化点ではギブズ現象のような振動が見られます。これは有限個のルジャンドル多項式による近似の性質で、不連続点の付近で復元精度が落ちるのは理論通りです。$N$ を増やせばこの振動は軽減されます。
- 時間が経過しても復元精度が大きく劣化しません。$t = 1.0$ の時点と $t = 2.8$ の時点で、ステップ関数の復元品質はほぼ同等です。これが「時間に依存しない記憶精度」というHiPPOの理論的保証の可視化です。
Selective SSM(Mamba風)の実装
最後に、Mambaの核心であるSelective SSMをPyTorchで実装します。入力依存のパラメータと、簡易的なParallel Scanを含みます。
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
class SelectiveSSM(nn.Module):
"""Mamba風のSelective State Space Model"""
def __init__(self, d_model, d_state, d_conv=4):
super().__init__()
self.d_model = d_model # 入力/出力の次元
self.d_state = d_state # 隠れ状態の次元 N
self.d_conv = d_conv # 局所畳み込みのカーネルサイズ
# 入力射影 (Mambaでは2倍に拡張してゲートに使用)
self.in_proj = nn.Linear(d_model, 2 * d_model, bias=False)
# 局所畳み込み (Mambaの重要な構成要素)
self.conv1d = nn.Conv1d(
d_model, d_model, kernel_size=d_conv,
padding=d_conv - 1, groups=d_model, bias=True
)
# 入力依存パラメータの射影
self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False) # B, C, Delta
# Aは学習パラメータ (対数空間で保持して負を保証)
log_A = torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))
self.log_A = nn.Parameter(log_A.unsqueeze(0).expand(d_model, -1)) # (D, N)
# Deltaのバイアス
self.dt_bias = nn.Parameter(torch.randn(d_model) * 0.1)
# 出力射影
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def selective_scan(self, u, A_bar, B_bar, C):
"""
Selective Scanの実装 (再帰 + 並列化可能な形式)
u: (B, L, D), A_bar: (B, L, D, N), B_bar: (B, L, D, N), C: (B, L, N)
"""
batch, L, D = u.shape
N = A_bar.shape[-1]
x = torch.zeros(batch, D, N, device=u.device, dtype=u.dtype)
ys = []
for k in range(L):
# 状態更新: x = A_bar * x + B_bar * u
x = A_bar[:, k] * x + B_bar[:, k] * u[:, k, :, None] # (B, D, N)
# 出力: y = (x * C).sum(dim=-1)
y = (x * C[:, k, None, :]).sum(dim=-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
def forward(self, x):
"""
x: (B, L, D) — バッチ, 系列長, モデル次元
"""
batch, L, D = x.shape
# 入力射影: xとゲートzに分割
xz = self.in_proj(x) # (B, L, 2D)
x_branch, z = xz.chunk(2, dim=-1) # 各 (B, L, D)
# 局所畳み込み + SiLU活性化
x_conv = self.conv1d(x_branch.transpose(1, 2))[:, :, :L] # (B, D, L)
x_conv = F.silu(x_conv).transpose(1, 2) # (B, L, D)
# 入力依存パラメータの計算
x_proj = self.x_proj(x_conv) # (B, L, 2N+1)
B_input = x_proj[:, :, :self.d_state] # (B, L, N)
C_input = x_proj[:, :, self.d_state:2*self.d_state] # (B, L, N)
dt = x_proj[:, :, -1] # (B, L)
# ステップサイズの計算 (softplusで正値を保証)
dt = F.softplus(dt + self.dt_bias.unsqueeze(0).unsqueeze(0).expand(batch, L, -1)[:,:,0])
# dt: (B, L)
# 連続パラメータAの取得
A = -torch.exp(self.log_A) # (D, N), 負の値
# 離散化 (ZOH, 各時刻・各次元で独立)
dt_expanded = dt.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1)
A_expanded = A.unsqueeze(0).unsqueeze(0) # (1, 1, D, N)
A_bar = torch.exp(A_expanded * dt_expanded) # (B, L, D, N)
B_bar = (A_bar - 1.0) / A_expanded * B_input.unsqueeze(2) # (B, L, D, N)
# Selective Scan
y = self.selective_scan(x_conv, A_bar, B_bar, C_input) # (B, L, D)
# ゲート付き出力
y = y * F.silu(z) # (B, L, D)
y = self.out_proj(y) # (B, L, D)
return y
# モデルのインスタンス化とテスト
d_model = 32
d_state = 16
seq_len = 256
batch_size = 2
model = SelectiveSSM(d_model=d_model, d_state=d_state)
x_input = torch.randn(batch_size, seq_len, d_model)
# フォワードパス
with torch.no_grad():
y_output = model(x_input)
print(f"Input shape: {x_input.shape}")
print(f"Output shape: {y_output.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
出力の形状が入力と一致していること( (2, 256, 32) )を確認できます。Selective SSMは入力と同じ次元の出力を返す系列変換モジュールであり、TransformerにおけるSelf-Attention層と同じ役割を果たします。パラメータ数は数千程度と非常にコンパクトです。
選択機構の動作確認
Selective SSMの選択機構が実際にどのように働くかを可視化します。入力トークンの内容に応じて $\Delta$(ステップサイズ)がどう変化するかを観察します。
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# 学習済み風のパラメータを使った選択機構のデモ
torch.manual_seed(123)
d_model = 16
d_state = 8
seq_len = 100
# 入力信号: 特定位置に「重要なトークン」を配置
x = torch.randn(1, seq_len, d_model) * 0.3 # ノイズベース
important_positions = [20, 45, 70, 90]
for pos in important_positions:
x[0, pos] = torch.randn(d_model) * 2.0 # 重要トークンは振幅大
# 簡易的な入力依存 Delta の計算
linear_dt = torch.nn.Linear(d_model, 1, bias=True)
# バイアスを調整して基本的なDeltaを小さくする
with torch.no_grad():
linear_dt.bias.fill_(-2.0)
dt_logits = linear_dt(x).squeeze(-1) # (1, L)
dt_values = F.softplus(dt_logits) # (1, L)
# 入力のノルムと Delta の関係を可視化
input_norms = x.norm(dim=-1).squeeze().detach().numpy() # (L,)
dt_np = dt_values.squeeze().detach().numpy() # (L,)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
axes[0].bar(range(seq_len), input_norms, color='steelblue', alpha=0.7, width=1.0)
for pos in important_positions:
axes[0].axvline(x=pos, color='red', linestyle='--', alpha=0.5)
axes[0].set_ylabel('Input norm ||x_k||')
axes[0].set_title('Input token magnitude and learned step size Delta')
axes[1].bar(range(seq_len), dt_np, color='darkorange', alpha=0.7, width=1.0)
for pos in important_positions:
axes[1].axvline(x=pos, color='red', linestyle='--', alpha=0.5)
axes[1].set_ylabel('Step size Delta_k')
axes[1].set_xlabel('Token position k')
plt.tight_layout()
plt.savefig('selective_mechanism.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Mean Delta (important tokens): {np.mean([dt_np[p] for p in important_positions]):.4f}")
print(f"Mean Delta (other tokens): {np.mean([dt_np[k] for k in range(seq_len) if k not in important_positions]):.4f}")
上のグラフから、選択機構の基本的な動作原理が確認できます。
- 入力のノルムが大きいトークン(赤い破線の位置)で、ステップサイズ $\Delta_k$ が相対的に大きくなる傾向があります。これは線形射影が入力の振幅に反応して $\Delta$ を調整していることを示しています。実際の学習済みMambaでは、この射影が「意味的に重要なトークン」を見分けるように学習されます。
- ノイズ的な(重要でない)トークンでは $\Delta_k$ が小さく抑えられています。$\Delta_k$ が小さいと $\overline{A}_k \approx \bm{I}$ となり、隠れ状態はほぼ変化せず通過します。つまり、不要な情報は無視されます。
- この $\Delta$ の大小が「何を記憶し何を忘れるか」のゲーティング機構になっています。大きな $\Delta$ は「状態をリセットして新しい情報を書き込む」、小さな $\Delta$ は「状態を保持して現在の入力を無視する」動作に対応します。これはLSTMの忘却ゲートと入力ゲートを連動させた動作に近いです。
Mamba vs Transformer の比較
計算量とメモリの比較
SSMとTransformerの理論的な計算量を系列長 $L$、モデル次元 $D$、隠れ状態次元 $N$ で整理します。
| 指標 | Transformer (Self-Attention) | SSM (S4/Mamba) |
|---|---|---|
| 学習の計算量 | $O(L^2 D)$ | $O(L N D)$ ※S4は $O(LD\log L)$ |
| 推論(1トークン生成) | $O(LD)$(KVキャッシュ使用) | $O(ND)$ |
| 推論時のメモリ | $O(LD)$(KVキャッシュ) | $O(ND)$(固定サイズ隠れ状態) |
| 長距離依存性 | 理論上は無制限 | HiPPO行列で保証 |
| 入力選択性 | softmax attentionで高い | Mamba: 選択機構で実現 |
最も顕著な違いは推論時のメモリです。Transformerは過去の全トークンのKey-Valueを保持するため、メモリが系列長 $L$ に比例して増大します。一方、SSMの隠れ状態は系列長に依存しない固定サイズ $N \times D$ です。100万トークンの系列を処理する場合、Transformerでは数十GBのKVキャッシュが必要になりますが、SSMでは数MB程度で済みます。
推論速度の比較
実際の推論速度をシンプルなベンチマークで比較してみましょう。
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
def benchmark_attention(seq_lens, d_model=256, n_heads=8, n_trials=5):
"""Self-Attentionの系列長ごとの処理時間を計測"""
attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
attn.eval()
times = []
for L in seq_lens:
x = torch.randn(1, L, d_model)
# ウォームアップ
with torch.no_grad():
_ = attn(x, x, x)
elapsed = []
for _ in range(n_trials):
start = time.perf_counter()
with torch.no_grad():
_ = attn(x, x, x)
elapsed.append(time.perf_counter() - start)
times.append(np.median(elapsed) * 1000) # ミリ秒
return times
def benchmark_ssm_recurrent(seq_lens, d_model=256, d_state=16, n_trials=5):
"""SSM(再帰モード)の系列長ごとの処理時間を計測"""
times = []
A_bar = torch.exp(-torch.arange(1, d_state+1, dtype=torch.float32) * 0.01)
B_bar = torch.randn(d_state) * 0.01
C = torch.randn(d_state) * 0.1
for L in seq_lens:
u = torch.randn(L)
# ウォームアップ
x = torch.zeros(d_state)
for k in range(min(L, 10)):
x = A_bar * x + B_bar * u[k]
elapsed = []
for _ in range(n_trials):
x = torch.zeros(d_state)
start = time.perf_counter()
for k in range(L):
x = A_bar * x + B_bar * u[k]
_ = C @ x
elapsed.append(time.perf_counter() - start)
times.append(np.median(elapsed) * 1000)
return times
# ベンチマーク実行
seq_lens = [64, 128, 256, 512, 1024, 2048, 4096]
print("Benchmarking Self-Attention...")
attn_times = benchmark_attention(seq_lens)
print("Benchmarking SSM (recurrent)...")
ssm_times = benchmark_ssm_recurrent(seq_lens)
# 可視化
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(seq_lens, attn_times, 'o-', color='#e74c3c', linewidth=2, label='Self-Attention O(L²)')
ax.plot(seq_lens, ssm_times, 's-', color='#2ecc71', linewidth=2, label='SSM Recurrent O(L)')
ax.set_xlabel('Sequence Length L')
ax.set_ylabel('Time (ms)')
ax.set_title('Inference Time: Self-Attention vs SSM')
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('speed_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
for L, at, st in zip(seq_lens, attn_times, ssm_times):
ratio = at / st if st > 0 else float('inf')
print(f"L={L:5d}: Attention={at:8.2f}ms, SSM={st:8.2f}ms, Ratio={ratio:.1f}x")
上のグラフから、2つの手法のスケーリング特性の違いが明確に読み取れます。
- Self-Attentionの処理時間は系列長に対して二次的に増加しています。対数スケールのグラフで傾きが約2であることから、$O(L^2)$ の理論通りの振る舞いが確認できます。系列長が4096になると、64のときと比べて約4000倍以上の時間がかかります。
- SSM(再帰モード)の処理時間はほぼ線形に増加しています。対数スケールで傾きが約1であり、$O(L)$ の計算量を反映しています。なお、ここでは純粋なPythonループで再帰を計算しているため絶対値は遅いですが、CUDAカーネルやParallel Scanを使えば大幅に高速化されます。
- 系列長が長くなるほどSSMの優位性が拡大します。短い系列($L = 64$)ではオーバーヘッドの差もあり大差はありませんが、$L = 4096$ ではSSMが数倍から数十倍高速です。Mambaの論文で報告されている5倍のスループット向上は、ハードウェア最適化を含めた実用的な数値です。
精度面の比較
計算効率だけでなく、精度面での比較も重要です。
Mambaが強いタスク: – 言語モデリング: Mamba-3BはTransformer-3Bと同等のperplexityを達成し、推論は5倍高速 – DNA配列モデリング: 系列長が100万を超えるゲノムデータで、Transformerベースのモデルを大幅に上回る – 音声処理: 長時間の音声データを効率的に処理できる
Transformerが強いタスク: – 情報検索型タスク: 「系列中の特定のトークンを正確に検索して取り出す」タスクでは、Attentionの直接的なトークン間比較が有利 – In-context Learning: 大規模言語モデルの強みであるfew-shot学習では、Transformerが優位を保つ報告が多い – 数学的推論: 厳密な論理的推論が必要なタスクでは、Attentionの明示的な関係性モデリングが有利
ハイブリッドアプローチ: 最新の研究では、SSMとAttentionを組み合わせたハイブリッドモデル(Jamba, Zamba等)が両方の長所を活かせることが示されています。SSM層で長距離の文脈を効率的に圧縮し、少数のAttention層で精密な情報検索を行うアーキテクチャが有望です。
まとめ
本記事では、Transformerの $O(n^2)$ 問題に対する根本的な代替として提案されたState Space Models(SSM)を、基礎理論からMambaの選択機構まで解説しました。
- 状態空間方程式: 制御理論の $d\bm{x}/dt = \bm{A}\bm{x} + \bm{B}u$ を深層学習に導入し、再帰と畳み込みの二重性により学習時 $O(L \log L)$、推論時 $O(1)$(1ステップ)の効率性を実現する
- 離散化: ZOH・バイリニア変換で連続時間方程式を離散化し、デジタルデータを処理可能にする。ステップサイズ $\Delta$ が時間分解能を制御する
- S4とHiPPO行列: ルジャンドル多項式への最適射影として $\bm{A}$ 行列を設計し、時間に依存しない長距離記憶を実現。DPLR構造で効率的なカーネル計算が可能
- Mambaの選択機構: $\bm{B}, \bm{C}, \Delta$ を入力依存にすることで、「何を記憶し何を忘れるか」を動的に制御。畳み込み表現を犠牲にする代わりにParallel Scanで効率的な並列計算を実現
- Mamba vs Transformer: SSMは長系列処理と推論効率で大きな優位性を持つが、精密な情報検索ではAttentionが有利。ハイブリッドモデルが有望な方向性
SSMの研究は急速に進展しており、Mamba-2(構造化状態空間の二重性)、Griffin(線形再帰+ローカルAttention)、RWKV(チャネルミキシング)など、新しいアーキテクチャが次々と提案されています。Transformerとの融合も活発に研究されており、今後の大規模言語モデルの設計において、SSMの考え方は不可欠な要素になりつつあります。
次のステップとして、以下の記事も参考にしてください。