LSTMの理論と仕組みを徹底解説 — ゲート機構から勾配の流れまで

天気予報モデルが明日の気温を予測するとき、昨日の気温だけでなく、1週間前の寒波や1ヶ月前の季節変動パターンも考慮する必要があります。音声認識では、文の冒頭で発話された主語が、数秒後の動詞の解釈に影響を与えることもあります。このように、時間的に離れた情報を正しく結びつける能力は、系列データを扱うモデルにとって本質的に重要です。

ところが、通常のRNN(再帰型ニューラルネットワーク)は、過去の情報を伝えようとすると勾配が指数的に減衰してしまう「勾配消失問題」を抱えています。30ステップ前の情報すら満足に利用できないケースが珍しくありません。

この問題を解決するために1997年にHochreiterとSchmidhuberが提案したのが、LSTM(Long Short-Term Memory) です。LSTMは「セル状態」と「ゲート機構」という2つのアイデアにより、必要な情報を長期間保持しつつ不要な情報を忘れることができます。

LSTMを理解すると、以下のような応用が可能になります。

  • 時系列予測: 株価、電力需要、気温などの長期依存を含む予測問題
  • 自然言語処理: 文中の長距離の文法的依存関係や共参照の学習
  • 音声認識: 長い発話にわたる音韻パターンのモデリング
  • 異常検知: 正常な時系列パターンからの逸脱の検出

本記事の内容

  • RNNの勾配消失問題の復習と数学的分析
  • LSTMの全体アーキテクチャと各ゲートの役割
  • 忘却ゲート・入力ゲート・出力ゲートの数式と直感
  • セル状態の更新メカニズム
  • 勾配の流れが安定する理由の数学的証明
  • NumPyによるLSTMセルのスクラッチ実装
  • 合成正弦波データでの学習実験

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

バニラRNNの勾配消失問題

RNNの基本構造の復習

LSTMがどのような問題を解決するのかを理解するために、まずバニラRNN(通常のRNN)の構造を振り返りましょう。

バニラRNNは、時刻 $t$ の入力 $\bm{x}_t$ と前の時刻の隠れ状態 $\bm{h}_{t-1}$ から、新しい隠れ状態 $\bm{h}_t$ を計算します。

$$ \bm{h}_t = \tanh(\bm{W}_h \bm{h}_{t-1} + \bm{W}_x \bm{x}_t + \bm{b}_h) $$

ここで $\bm{W}_h \in \mathbb{R}^{d_h \times d_h}$ は隠れ状態間の重み行列、$\bm{W}_x \in \mathbb{R}^{d_h \times d_x}$ は入力から隠れ状態への重み行列、$\bm{b}_h$ はバイアスベクトルです。出力は隠れ状態から次のように計算されます。

$$ \bm{y}_t = \bm{W}_y \bm{h}_t + \bm{b}_y $$

この構造は非常にシンプルですが、長い系列を扱うときに根本的な問題が生じます。

勾配消失の数学的メカニズム

時刻 $T$ における損失 $L_T$ を、時刻 $t$($t \ll T$)のパラメータに対して逆伝播する状況を考えましょう。連鎖律を適用すると、勾配は次のようにヤコビアンの積で表されます。

$$ \frac{\partial L_T}{\partial \bm{h}_t} = \frac{\partial L_T}{\partial \bm{h}_T} \prod_{k=t+1}^{T} \frac{\partial \bm{h}_k}{\partial \bm{h}_{k-1}} $$

各時刻のヤコビアンを具体的に計算します。$\bm{h}_k = \tanh(\bm{W}_h \bm{h}_{k-1} + \bm{W}_x \bm{x}_k + \bm{b}_h)$ を $\bm{h}_{k-1}$ で微分すると、

$$ \frac{\partial \bm{h}_k}{\partial \bm{h}_{k-1}} = \text{diag}\!\left(1 – \bm{h}_k^2\right) \cdot \bm{W}_h $$

が得られます。$\text{diag}(1 – \bm{h}_k^2)$ は $\tanh$ の微分を対角行列にしたもので、各要素は $0$ から $1$ の範囲にあります。

ここで問題の本質が見えてきます。$T – t$ 個のヤコビアンの積は

$$ \prod_{k=t+1}^{T} \frac{\partial \bm{h}_k}{\partial \bm{h}_{k-1}} = \prod_{k=t+1}^{T} \text{diag}\!\left(1 – \bm{h}_k^2\right) \cdot \bm{W}_h $$

です。$\bm{W}_h$ の最大特異値を $\sigma_{\max}$ とすると、行列積のノルムは概ね $(\gamma \cdot \sigma_{\max})^{T-t}$ のオーダーで振る舞います。ここで $\gamma \leq 1$ は $\tanh$ の微分の上限です。

  • $\gamma \cdot \sigma_{\max} < 1$ のとき: 勾配は指数的に $0$ に近づく → 勾配消失
  • $\gamma \cdot \sigma_{\max} > 1$ のとき: 勾配は指数的に発散する → 勾配爆発

勾配爆発はクリッピングで対処可能ですが、勾配消失は本質的な問題です。30ステップ前の情報の勾配が $10^{-10}$ のオーダーまで減衰すれば、パラメータの更新は事実上不可能になります。

勾配消失が引き起こす実際の問題

たとえば「The cat, which was sitting on the mat in the living room, was …」という文を処理するとき、最後の動詞の活用を決定するには文頭の主語「cat」の情報が必要です。しかしバニラRNNでは、主語から動詞までの多くの中間ステップを経る間に、主語の情報に関する勾配が消失してしまい、正しい文法関係を学習できません。

この根本的な限界を突破するために、勾配が長い時間にわたって安定的に伝播するアーキテクチャが必要です。それがLSTMです。

LSTMの全体アーキテクチャ

コンベアベルトとしてのセル状態

LSTMの核心的なアイデアは、隠れ状態 $\bm{h}_t$ とは別にセル状態(cell state)$\bm{c}_t$ を導入することです。

工場のコンベアベルトを想像してください。ベルトの上を製品(情報)が流れていきます。途中の作業ステーションでは、不良品を取り除いたり(忘却)、新しい部品を載せたり(入力)、完成品を出荷したり(出力)します。しかしベルト自体は一定速度でまっすぐ流れ続けます。

セル状態 $\bm{c}_t$ はこのコンベアベルトに相当します。情報はセル状態の上をほぼそのまま流れていき、ゲートという「作業ステーション」が必要な加工だけを行います。この構造により、情報が長距離にわたって保存されやすくなります。

3つのゲートの概要

LSTMは3種類のゲートでセル状態への情報の流れを制御します。

ゲート 記号 役割 日常のアナロジー
忘却ゲート $\bm{f}_t$ 過去のセル状態のどの部分を残すか 本棚の整理: 読み終わった本を処分
入力ゲート $\bm{i}_t$ 新しい情報のどの部分を書き込むか 新しい本を本棚に追加
出力ゲート $\bm{o}_t$ セル状態のどの部分を外部に出すか 今必要な本だけを取り出して読む

各ゲートは $0$ から $1$ の値を出力するシグモイド関数を使い、要素ごとの乗算(アダマール積 $\odot$)で情報の「通す量」をきめ細かく制御します。$0$ に近ければ「遮断」、$1$ に近ければ「全通し」です。

それでは、各ゲートの仕組みを一つずつ見ていきましょう。

忘却ゲート — 何を捨てるかを決める

直感的な理解

忘却ゲートは、セル状態に保存されている過去の情報のうち、どの部分を忘れるかを決定します。

たとえば、言語モデルが「彼女はフランス出身で、…(多くの文が続く)… 彼はドイツで生まれた」という文を処理しているとします。「彼は」が現れた時点で、以前の主語「彼女」に関する情報は不要になるかもしれません。忘却ゲートは、このような古い情報を適切に「忘れる」役割を担います。

数式

忘却ゲートの値 $\bm{f}_t$ は、現在の入力 $\bm{x}_t$ と前の隠れ状態 $\bm{h}_{t-1}$ から次のように計算されます。

$$ \begin{equation} \bm{f}_t = \sigma(\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_f) \end{equation} $$

ここで $\sigma$ はシグモイド関数、$[\bm{h}_{t-1}, \bm{x}_t]$ は隠れ状態と入力の結合(concatenation)、$\bm{W}_f \in \mathbb{R}^{d_h \times (d_h + d_x)}$ は忘却ゲートの重み行列、$\bm{b}_f$ はバイアスです。

シグモイド関数の出力は各要素が $[0, 1]$ の範囲にあります。$\bm{f}_t$ の $j$ 番目の要素 $f_{t,j}$ が $0$ に近ければ、セル状態の $j$ 番目の要素は忘却されます。$1$ に近ければ、そのまま保持されます。

なお、忘却ゲートの概念はLSTMの原論文(Hochreiter & Schmidhuber, 1997)には含まれておらず、2000年にGersらによって追加されました。この追加により、LSTMの性能は大幅に向上しました。忘却ゲートがないと、セル状態は情報を蓄積し続けるだけで、不要な情報を除去できないためです。

忘却ゲートが「何を捨てるか」を決めたら、次は「何を新しく記憶するか」を決める必要があります。それが入力ゲートの役割です。

入力ゲート — 何を記憶するかを決める

直感的な理解

入力ゲートは、新しく入ってきた情報のうち、どの部分をセル状態に書き込むかを制御します。

ノートにメモを取る場面を考えてみましょう。講義で話される全ての内容をノートに書き留めるわけではありません。重要なポイントだけを選んで記録します。入力ゲートは、この「重要度の判断」に相当します。

数式

入力ゲートは2段階の処理を行います。

ステップ1: まず、入力ゲートの値 $\bm{i}_t$ が「どの要素を更新するか」を決めます。

$$ \begin{equation} \bm{i}_t = \sigma(\bm{W}_i [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_i) \end{equation} $$

ステップ2: 次に、候補セル状態 $\tilde{\bm{c}}_t$ が「どんな新しい情報を書き込む候補にするか」を計算します。

$$ \begin{equation} \tilde{\bm{c}}_t = \tanh(\bm{W}_c [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_c) \end{equation} $$

$\tanh$ を使う理由は2つあります。第一に、出力範囲が $[-1, 1]$ に制限されるため、セル状態に加算する値が有界になり、数値的な安定性が保たれます。第二に、正と負の両方の値を取れるため、セル状態の各次元を増加させることも減少させることもできます。

入力ゲート $\bm{i}_t$ と候補セル状態 $\tilde{\bm{c}}_t$ のアダマール積 $\bm{i}_t \odot \tilde{\bm{c}}_t$ が、実際にセル状態に書き込まれる新しい情報となります。

ここまでで「何を忘れるか」(忘却ゲート)と「何を覚えるか」(入力ゲート)が決まりました。次に、これらを組み合わせてセル状態を実際に更新する仕組みを見ていきましょう。

セル状態の更新 — 忘却と記憶の融合

更新式

忘却ゲートと入力ゲートの結果を合わせると、セル状態の更新は次のように表されます。

$$ \begin{equation} \bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t \end{equation} $$

この式は、LSTMの最も重要な式です。2つの項の意味を確認しましょう。

  • 第1項 $\bm{f}_t \odot \bm{c}_{t-1}$: 忘却ゲートでフィルタリングされた過去のセル状態。$\bm{f}_t$ の各要素が $1$ に近ければ過去の情報はそのまま保持され、$0$ に近ければ忘却されます。
  • 第2項 $\bm{i}_t \odot \tilde{\bm{c}}_t$: 入力ゲートでフィルタリングされた新しい情報。$\bm{i}_t$ の各要素が $1$ に近ければ候補情報が全て書き込まれ、$0$ に近ければ書き込まれません。

バニラRNNとの対比

バニラRNNの更新式 $\bm{h}_t = \tanh(\bm{W}_h \bm{h}_{t-1} + \bm{W}_x \bm{x}_t + \bm{b})$ と比較すると、LSTMの更新式の特徴が明確になります。

バニラRNNでは、隠れ状態が毎時刻 $\tanh$ と行列積で完全に「上書き」されます。一方LSTMでは、セル状態は加算($+$)で更新されます。この加算的な構造が、勾配の伝播において決定的な違いを生みます。この点については、後のセクションで数学的に詳しく解析します。

具体的な数値例

セル状態のメカニズムを具体的な数値で確認してみましょう。$d_h = 3$ の場合を考えます。

前のセル状態が $\bm{c}_{t-1} = [0.8, -0.3, 0.5]^\top$ だったとします。ここで、忘却ゲートが $\bm{f}_t = [0.9, 0.1, 0.7]^\top$、入力ゲートが $\bm{i}_t = [0.2, 0.8, 0.1]^\top$、候補セル状態が $\tilde{\bm{c}}_t = [0.5, 0.6, -0.4]^\top$ だったとすると、

忘却ゲートの効果を計算します。

$$ \bm{f}_t \odot \bm{c}_{t-1} = [0.9 \times 0.8, \; 0.1 \times (-0.3), \; 0.7 \times 0.5]^\top = [0.72, \; -0.03, \; 0.35]^\top $$

入力ゲートの効果を計算します。

$$ \bm{i}_t \odot \tilde{\bm{c}}_t = [0.2 \times 0.5, \; 0.8 \times 0.6, \; 0.1 \times (-0.4)]^\top = [0.10, \; 0.48, \; -0.04]^\top $$

両者を加算して新しいセル状態を得ます。

$$ \bm{c}_t = [0.72 + 0.10, \; -0.03 + 0.48, \; 0.35 + (-0.04)]^\top = [0.82, \; 0.45, \; 0.31]^\top $$

この例から、次のことが読み取れます。

  • 第1要素: 忘却ゲートが $0.9$ なので過去の情報 $0.8$ をほぼ保持し、新しい情報は少量だけ追加
  • 第2要素: 忘却ゲートが $0.1$ なので過去の情報 $-0.3$ をほぼ忘却し、新しい情報 $0.48$ で大きく書き換え
  • 第3要素: 忘却ゲートが $0.7$ で適度に保持しつつ、入力ゲートが $0.1$ で新しい情報はわずかに追加

このように、LSTMは要素ごとに「忘れる量」と「覚える量」を独立に制御でき、きめ細かい情報管理を実現しています。

セル状態が更新されたら、最後にこの情報のうち「どの部分を外部に出力するか」を決定する必要があります。

出力ゲート — 何を出力するかを決める

直感的な理解

出力ゲートは、更新されたセル状態のうちどの部分を現時刻の出力(隠れ状態)として使うかを制御します。

図書館で調べ物をしている場面を想像してください。図書館には膨大な蔵書(セル状態)がありますが、今の調べ物に必要な本だけを取り出して机に広げます。出力ゲートは、「蔵書全体の中から今必要な情報だけを取り出す」作業に相当します。

数式

出力ゲートの値 $\bm{o}_t$ と、最終的な隠れ状態 $\bm{h}_t$ は次のように計算されます。

$$ \begin{equation} \bm{o}_t = \sigma(\bm{W}_o [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_o) \end{equation} $$

$$ \begin{equation} \bm{h}_t = \bm{o}_t \odot \tanh(\bm{c}_t) \end{equation} $$

$\tanh(\bm{c}_t)$ でセル状態を $[-1, 1]$ にスケーリングしてから、出力ゲート $\bm{o}_t$ でフィルタリングしています。$\tanh$ を適用する理由は、セル状態は理論上任意の値をとりうるため、そのまま出力すると数値が大きくなりすぎる可能性があるからです。

LSTMの全体像: 計算の流れのまとめ

ここまでの各ゲートとセル状態の更新をまとめると、LSTMの1ステップの計算は以下の6つの式で表されます。

$$ \begin{align} \bm{f}_t &= \sigma(\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_f) & \text{(忘却ゲート)} \\ \bm{i}_t &= \sigma(\bm{W}_i [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_i) & \text{(入力ゲート)} \\ \tilde{\bm{c}}_t &= \tanh(\bm{W}_c [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_c) & \text{(候補セル状態)} \\ \bm{c}_t &= \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t & \text{(セル状態の更新)} \\ \bm{o}_t &= \sigma(\bm{W}_o [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_o) & \text{(出力ゲート)} \\ \bm{h}_t &= \bm{o}_t \odot \tanh(\bm{c}_t) & \text{(隠れ状態)} \end{align} $$

入力 $\bm{x}_t \in \mathbb{R}^{d_x}$ と前の隠れ状態 $\bm{h}_{t-1} \in \mathbb{R}^{d_h}$ を受け取り、セル状態 $\bm{c}_t \in \mathbb{R}^{d_h}$ と隠れ状態 $\bm{h}_t \in \mathbb{R}^{d_h}$ を出力します。

LSTMの構造が明らかになったところで、次の重要な疑問に答えましょう。なぜこの構造で勾配消失問題が解決されるのでしょうか?

パラメータ数の分析

LSTMのパラメータ数

LSTMの学習すべきパラメータの総数を数えてみましょう。入力次元を $d_x$、隠れ状態の次元を $d_h$ とします。

LSTMには4つの重み行列があります(忘却ゲート $\bm{W}_f$、入力ゲート $\bm{W}_i$、候補セル状態 $\bm{W}_c$、出力ゲート $\bm{W}_o$)。各行列のサイズは $d_h \times (d_h + d_x)$ です。さらに、4つのバイアスベクトル(各 $d_h$ 次元)があります。

$$ \text{パラメータ数} = 4 \times \left[ d_h \times (d_h + d_x) + d_h \right] = 4 d_h (d_h + d_x + 1) $$

たとえば $d_x = 10$、$d_h = 64$ の場合、

$$ 4 \times 64 \times (64 + 10 + 1) = 4 \times 64 \times 75 = 19{,}200 $$

バニラRNNのパラメータ数 $d_h(d_h + d_x) + d_h = d_h(d_h + d_x + 1)$ と比較すると、LSTMは約4倍のパラメータを持ちます。これはゲートが3つ(忘却・入力・出力)あり、さらに候補セル状態の計算があるためです。

パラメータ数が多いということは計算コストも高くなりますが、その代わりに長期依存性を学習できるという大きなメリットが得られます。このトレードオフがLSTMとバニラRNNの本質的な違いです。

では、なぜ4倍ものパラメータを追加することで、勾配消失が解消されるのでしょうか。次のセクションで数学的に解析します。

勾配の流れが安定する理由

セル状態を通じた勾配伝播

LSTMの最大の利点は、セル状態を通じた勾配伝播が安定することです。この性質を数学的に示しましょう。

セル状態の更新式 $\bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$ において、$\bm{c}_t$ の $\bm{c}_{t-1}$ に対する勾配を計算します。

$$ \frac{\partial \bm{c}_t}{\partial \bm{c}_{t-1}} = \text{diag}(\bm{f}_t) + \frac{\partial (\bm{i}_t \odot \tilde{\bm{c}}_t)}{\partial \bm{c}_{t-1}} + \frac{\partial (\bm{f}_t \odot \bm{c}_{t-1})}{\partial \bm{c}_{t-1}} – \text{diag}(\bm{f}_t) $$

ここで、$\bm{f}_t$、$\bm{i}_t$、$\tilde{\bm{c}}_t$ は $\bm{h}_{t-1}$ を通じて間接的に $\bm{c}_{t-1}$ に依存しますが、その影響は二次的です。主要な勾配経路はセル状態を直接通るパスであり、この経路の勾配は単純に $\text{diag}(\bm{f}_t)$ となります。

$$ \frac{\partial \bm{c}_t}{\partial \bm{c}_{t-1}} \approx \text{diag}(\bm{f}_t) $$

長期にわたる勾配の累積

時刻 $T$ から時刻 $t$ までセル状態を通じて勾配を逆伝播すると、

$$ \frac{\partial \bm{c}_T}{\partial \bm{c}_t} \approx \prod_{k=t+1}^{T} \text{diag}(\bm{f}_k) = \text{diag}\!\left(\prod_{k=t+1}^{T} \bm{f}_k\right) $$

これが対角行列の積になっている点がバニラRNNと決定的に異なります。バニラRNNのヤコビアン積 $\prod \text{diag}(1-\bm{h}_k^2) \cdot \bm{W}_h$ は一般の行列積であり、特異値の累積的な効果で指数的に発散・減衰します。

しかしLSTMでは、セル状態の勾配は忘却ゲートの値のスカラー的な積です。$j$ 番目の要素に着目すると、

$$ \frac{\partial c_{T,j}}{\partial c_{t,j}} \approx \prod_{k=t+1}^{T} f_{k,j} $$

忘却ゲートの値 $f_{k,j} \in [0, 1]$ が $1$ に近い値を保っていれば、この積は $1$ に近い値を維持します。

なぜこれが勾配消失を防ぐのか

バニラRNNとの比較を整理しましょう。

バニラRNN: 勾配は $\prod (\gamma \cdot \sigma_{\max})^{T-t}$ のオーダーで振る舞い、$\gamma \cdot \sigma_{\max} \neq 1$ であれば指数的に消失または爆発します。$\gamma \cdot \sigma_{\max} = 1$ を正確に維持するのは実質的に不可能です。

LSTM: 忘却ゲート $\bm{f}_t$ は学習可能であり、ネットワークが「この情報は保持すべき」と判断した次元については $f_{t,j} \approx 1$ を学習します。すると $\prod f_{k,j} \approx 1$ となり、勾配がほぼ減衰なく伝播します。

重要な点は、忘却ゲートの値が学習によって適応的に決まるということです。全ての情報を無条件に保持するのではなく、ネットワークが必要と判断した情報だけの勾配を安定的に伝播させます。これは「定数的な勾配フロー」(constant error carousel)と呼ばれ、LSTMの原論文で中核概念として提示されました。

忘却ゲートバイアスの初期化

実践的には、忘却ゲートのバイアス $\bm{b}_f$ を正の値(例えば $1$ や $2$)で初期化することが推奨されます(Jozefowicz et al., 2015)。これにより、学習初期の忘却ゲートの値がシグモイド関数によって $0.5$ より大きくなり、情報の保持が促進されます。バイアスが $0$ で初期化されると、忘却ゲートは初期値として $\sigma(0) = 0.5$ を出力し、各ステップでセル状態が半減してしまいます。10ステップで $0.5^{10} \approx 0.001$ まで減衰してしまうため、長期依存性の学習が困難になります。

ここまでLSTMの理論的な仕組みを詳しく見てきました。次に、これらの数式をPythonで実装して、理論が実際に機能することを確認しましょう。

NumPyによるLSTMセルのスクラッチ実装

実装の方針

LSTMセルの順伝播と逆伝播をNumPyでスクラッチ実装します。外部の深層学習フレームワークを使わずに一から実装することで、各ゲートの計算がどのように行われるか、また勾配がどのように伝播するかを体感できます。

まず、シグモイド関数と tanh 関数のユーティリティを定義します。

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))

def sigmoid_deriv(s):
    """シグモイドの出力 s からその微分を計算"""
    return s * (1 - s)

def tanh_deriv(t):
    """tanhの出力 t からその微分を計算"""
    return 1 - t ** 2

このコードでは、sigmoid で数値安定性のために np.clip を使っています。sigmoid_derivtanh_deriv は、活性化関数の出力から直接微分を計算する効率的な方法です。

次に、LSTMセルクラスを実装します。

class LSTMCell:
    def __init__(self, input_dim, hidden_dim):
        self.d_x = input_dim
        self.d_h = hidden_dim
        scale = 0.1
        # 重み行列: [h, x] の結合入力に対する重み
        self.W_f = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
        self.W_i = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
        self.W_c = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
        self.W_o = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
        # バイアス
        self.b_f = np.ones(hidden_dim)       # 忘却ゲートバイアスを1で初期化
        self.b_i = np.zeros(hidden_dim)
        self.b_c = np.zeros(hidden_dim)
        self.b_o = np.zeros(hidden_dim)

    def forward(self, x, h_prev, c_prev):
        """順伝播: 1タイムステップ分の計算"""
        # 入力と隠れ状態の結合
        concat = np.concatenate([h_prev, x])

        # 各ゲートの計算
        f = sigmoid(self.W_f @ concat + self.b_f)     # 忘却ゲート
        i = sigmoid(self.W_i @ concat + self.b_i)     # 入力ゲート
        c_tilde = np.tanh(self.W_c @ concat + self.b_c)  # 候補セル状態
        o = sigmoid(self.W_o @ concat + self.b_o)     # 出力ゲート

        # セル状態の更新
        c = f * c_prev + i * c_tilde

        # 隠れ状態の計算
        h = o * np.tanh(c)

        # 逆伝播用にキャッシュを保存
        cache = (concat, f, i, c_tilde, o, c_prev, c, h)
        return h, c, cache

    def backward(self, dh_next, dc_next, cache):
        """逆伝播: 1タイムステップ分の勾配計算"""
        concat, f, i, c_tilde, o, c_prev, c, h = cache

        # 出力ゲートからの勾配
        tanh_c = np.tanh(c)
        do = dh_next * tanh_c
        dc = dh_next * o * tanh_deriv(tanh_c) + dc_next

        # 各ゲートへの勾配
        df = dc * c_prev
        dc_prev = dc * f
        di = dc * c_tilde
        dc_tilde = dc * i

        # 活性化関数の逆伝播
        df_raw = df * sigmoid_deriv(f)
        di_raw = di * sigmoid_deriv(i)
        dc_tilde_raw = dc_tilde * tanh_deriv(c_tilde)
        do_raw = do * sigmoid_deriv(o)

        # 重みとバイアスの勾配
        self.dW_f = np.outer(df_raw, concat)
        self.dW_i = np.outer(di_raw, concat)
        self.dW_c = np.outer(dc_tilde_raw, concat)
        self.dW_o = np.outer(do_raw, concat)
        self.db_f = df_raw
        self.db_i = di_raw
        self.db_c = dc_tilde_raw
        self.db_o = do_raw

        # 入力と隠れ状態への勾配
        d_concat = (self.W_f.T @ df_raw + self.W_i.T @ di_raw +
                    self.W_c.T @ dc_tilde_raw + self.W_o.T @ do_raw)
        dh_prev = d_concat[:self.d_h]
        dx = d_concat[self.d_h:]

        return dh_prev, dc_prev, dx

このクラスでは、forward メソッドが上で導出した6つの式を忠実に実装しています。backward メソッドでは、隠れ状態への勾配 dh_next とセル状態への勾配 dc_next の両方を受け取り、各パラメータの勾配を計算します。特に dc_prev = dc * f の部分が、忘却ゲートを通じた勾配の直接的な伝播に対応しています。

ここで注目すべきは、backward メソッド内の dc_prev = dc * f という行です。これはまさに先ほど数式で示した $\partial \bm{c}_t / \partial \bm{c}_{t-1} \approx \text{diag}(\bm{f}_t)$ を実装したものであり、忘却ゲートが勾配の流れを制御している様子がコードレベルでも確認できます。

続いて、このLSTMセルを時系列に沿って展開するクラスを実装しましょう。

class LSTMNetwork:
    def __init__(self, input_dim, hidden_dim, output_dim, lr=0.001):
        self.cell = LSTMCell(input_dim, hidden_dim)
        self.W_y = np.random.randn(output_dim, hidden_dim) * 0.1
        self.b_y = np.zeros(output_dim)
        self.d_h = hidden_dim
        self.lr = lr

    def forward_sequence(self, xs):
        """系列全体の順伝播"""
        T = len(xs)
        h = np.zeros(self.d_h)
        c = np.zeros(self.d_h)
        hs, cs, caches = [], [], []

        for t in range(T):
            h, c, cache = self.cell.forward(xs[t], h, c)
            hs.append(h)
            cs.append(c)
            caches.append(cache)

        # 出力層
        ys = [self.W_y @ h + self.b_y for h in hs]
        return hs, cs, caches, ys

    def train_step(self, xs, targets):
        """1回の学習ステップ(順伝播 + 逆伝播 + パラメータ更新)"""
        T = len(xs)
        hs, cs, caches, ys = self.forward_sequence(xs)

        # 損失の計算(MSE)
        loss = sum(np.sum((ys[t] - targets[t])**2) for t in range(T)) / T

        # 逆伝播
        dW_y = np.zeros_like(self.W_y)
        db_y = np.zeros_like(self.b_y)
        dW_f = np.zeros_like(self.cell.W_f)
        dW_i = np.zeros_like(self.cell.W_i)
        dW_c = np.zeros_like(self.cell.W_c)
        dW_o = np.zeros_like(self.cell.W_o)
        db_f = np.zeros_like(self.cell.b_f)
        db_i = np.zeros_like(self.cell.b_i)
        db_c = np.zeros_like(self.cell.b_c)
        db_o = np.zeros_like(self.cell.b_o)

        dh_next = np.zeros(self.d_h)
        dc_next = np.zeros(self.d_h)

        for t in reversed(range(T)):
            # 出力層の勾配
            dy = 2 * (ys[t] - targets[t]) / T
            dW_y += np.outer(dy, hs[t])
            db_y += dy
            dh = self.W_y.T @ dy + dh_next

            # LSTMセルの逆伝播
            dh_next, dc_next, _ = self.cell.backward(dh, dc_next, caches[t])

            # 勾配の蓄積
            dW_f += self.cell.dW_f
            dW_i += self.cell.dW_i
            dW_c += self.cell.dW_c
            dW_o += self.cell.dW_o
            db_f += self.cell.db_f
            db_i += self.cell.db_i
            db_c += self.cell.db_c
            db_o += self.cell.db_o

        # 勾配クリッピング
        for grad in [dW_f, dW_i, dW_c, dW_o, dW_y, db_f, db_i, db_c, db_o, db_y]:
            np.clip(grad, -5, 5, out=grad)

        # パラメータ更新(SGD)
        self.cell.W_f -= self.lr * dW_f
        self.cell.W_i -= self.lr * dW_i
        self.cell.W_c -= self.lr * dW_c
        self.cell.W_o -= self.lr * dW_o
        self.cell.b_f -= self.lr * db_f
        self.cell.b_i -= self.lr * db_i
        self.cell.b_c -= self.lr * db_c
        self.cell.b_o -= self.lr * db_o
        self.W_y -= self.lr * dW_y
        self.b_y -= self.lr * db_y

        return loss

LSTMNetwork クラスは、LSTMCell を時系列方向に展開し、出力層を追加した完全なネットワークです。train_step メソッドでは、BPTT(Backpropagation Through Time)を実装しています。時間を逆方向にたどりながら勾配を蓄積し、最後にパラメータを更新します。勾配クリッピング(値を $[-5, 5]$ に制限)は勾配爆発を防ぐための標準的なテクニックです。

これで LSTMの実装が完成しました。次に、実際のデータを使って学習実験を行い、LSTMが長期依存性を学習できることを確認しましょう。

合成正弦波データでの学習実験

実験の目的と設定

LSTMが周期的なパターンを学習できることを確認するために、合成正弦波データを使った予測タスクを行います。入力として過去の正弦波の値を与え、次のステップの値を予測させます。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 合成正弦波データの生成
t = np.linspace(0, 20 * np.pi, 2000)
data = np.sin(t) + 0.1 * np.sin(3 * t)  # 複合正弦波

# 訓練データの作成(系列長20の窓でスライド)
seq_len = 20
X, Y = [], []
for i in range(len(data) - seq_len):
    X.append(data[i:i+seq_len].reshape(-1, 1))
    Y.append(data[i+1:i+seq_len+1].reshape(-1, 1))
X = X[:500]  # 訓練用に500サンプル使用
Y = Y[:500]

print(f"訓練サンプル数: {len(X)}")
print(f"系列長: {seq_len}")
print(f"入力次元: {X[0][0].shape}")

複合正弦波 $\sin(t) + 0.1 \sin(3t)$ を生成し、長さ20の窓で区切って訓練データとしています。基本波と高調波が混在したデータを使うことで、LSTMが複数の周波数成分を学習できるかを検証します。

学習ループの実行

# LSTMネットワークの初期化と学習
model = LSTMNetwork(input_dim=1, hidden_dim=32, output_dim=1, lr=0.005)

losses = []
n_epochs = 100

for epoch in range(n_epochs):
    epoch_loss = 0
    for j in range(len(X)):
        loss = model.train_step(X[j], Y[j])
        epoch_loss += loss
    epoch_loss /= len(X)
    losses.append(epoch_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.6f}")

損失の推移の可視化

plt.figure(figsize=(10, 4))
plt.plot(losses, color='steelblue', linewidth=1.5)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('LSTM Training Loss on Composite Sine Wave')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

このグラフから、損失が学習とともに単調に減少していく様子が確認できます。最初の数エポックで急激に損失が下がり、その後は緩やかに収束していくのが典型的なパターンです。LSTMが正弦波のパターンを徐々に学習していることがわかります。

予測結果の可視化

学習済みモデルを使って、訓練データの一部と未知データに対する予測を行います。

# テストデータでの予測
test_start = 600
test_len = 200
test_data = data[test_start:test_start + test_len + seq_len]

# 自己回帰予測
predictions = []
h = np.zeros(model.d_h)
c = np.zeros(model.d_h)

# 最初のseq_len個で状態を初期化
for i in range(seq_len):
    x_input = np.array([test_data[i]])
    h, c, _ = model.cell.forward(x_input, h, c)

# 予測
for i in range(test_len):
    y_pred = model.W_y @ h + model.b_y
    predictions.append(y_pred[0])
    x_input = np.array([test_data[seq_len + i]])
    h, c, _ = model.cell.forward(x_input, h, c)

# 可視化
plt.figure(figsize=(12, 5))
time_axis = np.arange(test_len)
plt.plot(time_axis, test_data[seq_len:seq_len+test_len],
         label='Ground Truth', color='steelblue', linewidth=2)
plt.plot(time_axis, predictions,
         label='LSTM Prediction', color='coral', linewidth=2, linestyle='--')
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.title('LSTM Prediction vs Ground Truth (Composite Sine Wave)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

グラフから、LSTMの予測(破線)が実際の値(実線)によく一致していることが確認できます。複合正弦波の2つの周波数成分(基本波と3倍高調波)の両方を捉えており、振幅と位相が正確に再現されています。これは、LSTMが系列長20の窓を通じて、周期的なパターンの長期依存性を学習できていることを示しています。

忘却ゲートの値の可視化

LSTMが実際にどのようにゲートを制御しているかを確認するために、忘却ゲートの値を可視化します。

# 忘却ゲートの値を収集
gate_values = {'forget': [], 'input': [], 'output': []}
h = np.zeros(model.d_h)
c = np.zeros(model.d_h)

vis_data = data[:100]
for i in range(len(vis_data)):
    x_input = np.array([vis_data[i]])
    concat = np.concatenate([h, x_input])
    f = sigmoid(model.cell.W_f @ concat + model.cell.b_f)
    ig = sigmoid(model.cell.W_i @ concat + model.cell.b_i)
    o = sigmoid(model.cell.W_o @ concat + model.cell.b_o)
    gate_values['forget'].append(f.copy())
    gate_values['input'].append(ig.copy())
    gate_values['output'].append(o.copy())
    h, c, _ = model.cell.forward(x_input, h, c)

# ゲート値のヒートマップ
fig, axes = plt.subplots(3, 1, figsize=(12, 8))
gate_names = ['Forget Gate', 'Input Gate', 'Output Gate']
gate_keys = ['forget', 'input', 'output']
cmaps = ['Blues', 'Greens', 'Oranges']

for ax, name, key, cmap in zip(axes, gate_names, gate_keys, cmaps):
    gate_matrix = np.array(gate_values[key]).T  # (hidden_dim, time)
    im = ax.imshow(gate_matrix[:8], aspect='auto', cmap=cmap, vmin=0, vmax=1)
    ax.set_ylabel(name)
    ax.set_xlabel('Time Step')
    plt.colorbar(im, ax=ax)

plt.suptitle('LSTM Gate Activations Over Time', fontsize=14)
plt.tight_layout()
plt.show()

このヒートマップでは、3つのゲートの値が時間とともにどう変化するかを表示しています(隠れ次元の最初の8次元のみ表示)。忘却ゲートのヒートマップで多くの次元が明るい色($1$ に近い値)を示していれば、セル状態の情報が保持されていることを意味します。一方、入力ゲートと出力ゲートは時刻によって値が変動し、必要なときだけ情報を書き込んだり読み出したりしている様子が見て取れます。ゲートの値が正弦波の位相に合わせて周期的に変動していれば、LSTMが入力データの周期構造を学習していることの証拠です。

まとめ

本記事では、LSTMの理論的な仕組みをゲート機構の数式から丁寧に解説しました。

  • バニラRNNの限界: ヤコビアンの積が指数的に減衰・発散するため、長期依存性の学習が困難
  • LSTMの3つのゲート: 忘却ゲート(何を捨てるか)、入力ゲート(何を覚えるか)、出力ゲート(何を出力するか)が協調してセル状態を制御
  • セル状態の加算的更新: $\bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$ という加算構造が、勾配の安定伝播を実現
  • 勾配の安定性: セル状態を通じた勾配は $\prod f_{k,j}$ のスカラー積で表され、$f_{k,j} \approx 1$ なら勾配がほぼ減衰しない
  • 忘却ゲートバイアスの初期化: 正の値(例: 1)で初期化することで、学習初期の情報保持を促進
  • スクラッチ実装: NumPyでLSTMセルの順伝播・逆伝播を実装し、正弦波の予測に成功

LSTMは長期依存性の問題を大幅に緩和しましたが、4つの重み行列を持つためパラメータ数が多く、計算コストが高いという課題があります。この課題に対して「ゲートの数を減らせないか」という発想で提案されたのがGRU(Gated Recurrent Unit)です。GRUはLSTMの忘却ゲートと入力ゲートを統合し、よりシンプルな構造で同等の性能を実現します。

次のステップとして、以下の記事も参考にしてください。