RNN(再帰型ニューラルネットワーク)は系列データを扱う強力なアーキテクチャですが、長い系列において 勾配消失問題 が深刻な課題となることを前の記事で見ました。この問題を根本的に解決するために提案されたのが LSTM(Long Short-Term Memory) です。
LSTMは1997年にHochreiterとSchmidhuberによって提案され、その後Gersらによる忘却ゲートの追加(2000年)を経て、現在広く使われる形になりました。LSTMの核心的なアイデアは、勾配が一定のまま伝播できる経路(CEC: Constant Error Carousel) をネットワーク内に明示的に設けることです。これにより、数百〜数千ステップの長期依存性を学習することが可能になります。
本記事の内容
- RNNの勾配消失問題の復習とLSTMの設計思想
- LSTMの4つの構成要素(忘却ゲート、入力ゲート、セル状態、出力ゲート)の数式
- 各ゲートの役割の直感的説明
- LSTMの順伝播の全数式
- 勾配が消失しにくい理由の数学的説明(CEC)
- Peephole接続の変種
- numpyによるスクラッチ実装とテキスト生成タスク
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
RNNの勾配消失問題の復習
問題の再確認
RNNの基礎で見たように、BPTTでの勾配計算にはヤコビ行列の積が現れます。
$$ \frac{\partial \bm{h}_\tau}{\partial \bm{h}_t} = \prod_{k=t+1}^{\tau} \frac{\partial \bm{h}_k}{\partial \bm{h}_{k-1}} = \prod_{k=t+1}^{\tau} \text{diag}(f'(\bm{a}_k)) \, \bm{W}_h $$
$\tanh$ の微分は $|f'(x)| = |1 – \tanh^2(x)| \leq 1$ であり、$\bm{W}_h$ のスペクトル半径が1未満の場合、この積は $\tau – t$ に対して指数的にゼロに近づきます。
$$ \left\| \frac{\partial \bm{h}_\tau}{\partial \bm{h}_t} \right\| \leq \| \bm{W}_h \|^{\tau – t} \cdot \prod_{k=t+1}^{\tau} \| f'(\bm{a}_k) \| \to 0 \quad (\tau – t \to \infty) $$
この結果、遠い過去の情報が勾配に反映されず、長期依存性を学習できない という問題が生じます。
問題の根本原因
勾配消失の根本原因は、隠れ状態 $\bm{h}_t$ の更新が
$$ \bm{h}_t = \tanh(\bm{W}_h \bm{h}_{t-1} + \bm{W}_x \bm{x}_t + \bm{b}_h) $$
という形であり、前の隠れ状態 $\bm{h}_{t-1}$ が非線形変換を経て完全に上書きされる ことにあります。毎ステップで $\tanh$ が適用されるため、勾配は必ず $|f’| \leq 1$ の係数で減衰します。
LSTMの設計思想
一定勾配の経路を作る
LSTMの設計思想は明快です。勾配が1に近い値で伝播できる経路 をネットワーク内に明示的に設けます。
通常のRNNでは $\partial \bm{h}_t / \partial \bm{h}_{t-1}$ が $\tanh$ の微分と $\bm{W}_h$ に依存しましたが、LSTMでは セル状態 $\bm{C}_t$ という新しい変数を導入し、その更新を
$$ \bm{C}_t = \bm{f}_t \odot \bm{C}_{t-1} + \bm{i}_t \odot \tilde{\bm{C}}_t $$
の形にします($\odot$ はアダマール積)。この式の核心は、$\bm{C}_{t-1}$ が 加法的に $\bm{C}_t$ に伝播することです。$\bm{f}_t$(忘却ゲート)が1に近ければ
$$ \frac{\partial \bm{C}_t}{\partial \bm{C}_{t-1}} = \text{diag}(\bm{f}_t) \approx \bm{I} $$
となり、勾配がほぼ減衰せずに伝播します。これが CEC(Constant Error Carousel) と呼ばれるメカニズムです。
ゲートによる情報制御
「全ての情報を保持し続ける」のではなく、何を記憶し、何を忘れ、何を出力するか を学習的に制御するのがゲート機構です。LSTMは以下の3つのゲートを持ちます。
- 忘却ゲート $\bm{f}_t$: セル状態のどの部分を忘れるか
- 入力ゲート $\bm{i}_t$: 新しい情報のどの部分をセルに書き込むか
- 出力ゲート $\bm{o}_t$: セル状態のどの部分を外部に出力するか
LSTMの全数式
忘却ゲート(Forget Gate)
$$ \bm{f}_t = \sigma(\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_f) $$
ここで $\sigma$ はシグモイド関数 $\sigma(x) = 1/(1 + e^{-x})$ であり、出力は $[0, 1]$ の範囲です。$[\bm{h}_{t-1}, \bm{x}_t]$ は前の隠れ状態と現在の入力の結合ベクトルです。
役割: $\bm{f}_t$ の各要素は、対応するセル状態の要素をどの程度「忘れるか」を制御します。$f_t^{(j)} \approx 1$ なら保持、$f_t^{(j)} \approx 0$ なら忘却です。
展開して書くと
$$ \bm{f}_t = \sigma(\bm{W}_{fh} \bm{h}_{t-1} + \bm{W}_{fx} \bm{x}_t + \bm{b}_f) $$
です。ここで $\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] = \bm{W}_{fh} \bm{h}_{t-1} + \bm{W}_{fx} \bm{x}_t$ は結合ベクトルとの積を分解したものです。
入力ゲート(Input Gate)
$$ \bm{i}_t = \sigma(\bm{W}_i [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_i) $$
役割: 新しい候補情報のどの部分をセルに書き込むかを制御します。
候補セル状態(Candidate Cell State)
$$ \tilde{\bm{C}}_t = \tanh(\bm{W}_C [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_C) $$
役割: 新たにセル状態に追加される候補値です。$\tanh$ により $[-1, 1]$ の範囲に正規化されます。
セル状態の更新
$$ \bm{C}_t = \bm{f}_t \odot \bm{C}_{t-1} + \bm{i}_t \odot \tilde{\bm{C}}_t $$
役割: これがLSTMの核心です。
- $\bm{f}_t \odot \bm{C}_{t-1}$: 前のセル状態のうち、忘却ゲートが通す部分
- $\bm{i}_t \odot \tilde{\bm{C}}_t$: 入力ゲートが通す新しい情報
2つの項の加算によりセル状態が更新されます。
出力ゲート(Output Gate)
$$ \bm{o}_t = \sigma(\bm{W}_o [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_o) $$
隠れ状態の計算
$$ \bm{h}_t = \bm{o}_t \odot \tanh(\bm{C}_t) $$
役割: セル状態を $\tanh$ で正規化し、出力ゲートで制御した結果が、外部に出力される隠れ状態 $\bm{h}_t$ です。
パラメータの一覧
LSTMのパラメータをまとめます。入力次元を $d$、隠れ次元を $n$ とすると
| パラメータ | サイズ | 用途 |
|---|---|---|
| $\bm{W}_f$ | $n \times (n + d)$ | 忘却ゲートの重み |
| $\bm{W}_i$ | $n \times (n + d)$ | 入力ゲートの重み |
| $\bm{W}_C$ | $n \times (n + d)$ | 候補セル状態の重み |
| $\bm{W}_o$ | $n \times (n + d)$ | 出力ゲートの重み |
| $\bm{b}_f, \bm{b}_i, \bm{b}_C, \bm{b}_o$ | 各 $n \times 1$ | 各ゲートのバイアス |
パラメータの総数は $4n(n + d) + 4n = 4n(n + d + 1)$ です。通常のRNN($n(n + d) + n$ パラメータ)の約4倍です。
勾配が消失しにくい理由の数学的説明
CEC(Constant Error Carousel)
LSTMの核心であるセル状態の更新式
$$ \bm{C}_t = \bm{f}_t \odot \bm{C}_{t-1} + \bm{i}_t \odot \tilde{\bm{C}}_t $$
に対して、$\bm{C}_{t-1}$ に関する偏微分を計算します。
$$ \frac{\partial \bm{C}_t}{\partial \bm{C}_{t-1}} = \text{diag}(\bm{f}_t) + \text{diag}(\bm{C}_{t-1}) \frac{\partial \bm{f}_t}{\partial \bm{C}_{t-1}} + \text{diag}(\tilde{\bm{C}}_t) \frac{\partial \bm{i}_t}{\partial \bm{C}_{t-1}} + \text{diag}(\bm{i}_t) \frac{\partial \tilde{\bm{C}}_t}{\partial \bm{C}_{t-1}} $$
ここで重要なのは、第1項 $\text{diag}(\bm{f}_t)$ です。$\bm{f}_t$, $\bm{i}_t$, $\tilde{\bm{C}}_t$ は $\bm{h}_{t-1}$ を介して $\bm{C}_{t-1}$ に依存しますが、この依存は $\bm{h}_{t-1} = \bm{o}_{t-1} \odot \tanh(\bm{C}_{t-1})$ を通じた間接的なものです。主要項は $\text{diag}(\bm{f}_t)$ であり、近似的に
$$ \frac{\partial \bm{C}_t}{\partial \bm{C}_{t-1}} \approx \text{diag}(\bm{f}_t) $$
と書けます。
$T$ ステップ遡った勾配は
$$ \frac{\partial \bm{C}_T}{\partial \bm{C}_t} = \prod_{k=t+1}^{T} \frac{\partial \bm{C}_k}{\partial \bm{C}_{k-1}} \approx \prod_{k=t+1}^{T} \text{diag}(\bm{f}_k) = \text{diag}\left( \prod_{k=t+1}^{T} \bm{f}_k \right) $$
通常のRNNでは $\|\bm{W}_h\|^{T-t}$ の形で指数的に変化しましたが、LSTMでは $\prod_{k=t+1}^{T} f_k^{(j)}$ の形になります。
通常のRNNとの比較
| 通常のRNN | LSTM | |
|---|---|---|
| 勾配の伝播 | $\prod \text{diag}(f'(\bm{a}_k)) \bm{W}_h$ | $\prod \text{diag}(\bm{f}_k) \approx \text{diag}(\bm{f}_t \odot \cdots \odot \bm{f}_T)$ |
| 減衰の要因 | $\|f’\| \leq 1$ と $\bm{W}_h$ の固有値 | $\bm{f}_k \in (0, 1)$ のみ |
| 制御可能性 | 固定($\tanh$ と $\bm{W}_h$ で決定) | 学習可能($\bm{f}_k$ が入力に応じて変化) |
通常のRNNでは勾配の減衰が $\bm{W}_h$ の固有値で一様に決まりますが、LSTMでは忘却ゲート $\bm{f}_k$ が入力に応じてダイナミックに変化します。長期記憶が必要な場合は $\bm{f}_k \approx 1$ を学習し、不要な情報は $\bm{f}_k \approx 0$ で忘却できます。
なぜ「加法的」更新が重要か
通常のRNNの更新は 乗法的 です。
$$ \bm{h}_t = f(\bm{W}_h \bm{h}_{t-1} + \cdots) $$
非線形関数 $f$ を通すため、$\bm{h}_{t-1}$ の情報が圧縮・変形されます。
LSTMのセル状態の更新は 加法的 です。
$$ \bm{C}_t = \bm{f}_t \odot \bm{C}_{t-1} + \bm{i}_t \odot \tilde{\bm{C}}_t $$
$\bm{C}_{t-1}$ は非線形関数を通らず、ゲート $\bm{f}_t$ による要素ごとのスケーリングのみです。この加法的構造が、勾配の安定した伝播を保証します。
これは残差結合(ResNet)の $\bm{x} + F(\bm{x})$ と同じ精神です。恒等写像に近い経路を確保することで、深いネットワークの勾配伝播を改善します。
各ゲートの直感的な理解
忘却ゲート $\bm{f}_t$: 何を忘れるか
たとえば自然言語処理で、文の主語が変わったとき、忘却ゲートは以前の主語の情報を忘れるように $f_t \approx 0$ を出力します。逆に、文脈が続いている間は $f_t \approx 1$ で情報を保持します。
$$ \text{“太郎は走った。”} \xrightarrow{f \approx 1} \text{“彼は速かった。”} \xrightarrow{f \approx 0} \text{“花子は歩いた。”} $$
入力ゲート $\bm{i}_t$: 何を書き込むか
新しい重要な情報が入力されたとき $i_t \approx 1$ で書き込み、ノイズ的な入力は $i_t \approx 0$ で無視します。
出力ゲート $\bm{o}_t$: 何を出力するか
セル状態には保持しているが、現在の出力には不要な情報は $o_t \approx 0$ で隠します。必要になったタイミングで $o_t \approx 1$ で出力します。
Peephole接続
標準のLSTMでは、ゲートは $\bm{h}_{t-1}$ と $\bm{x}_t$ のみを入力として受け取ります。Peephole接続(Gers & Schmidhuber, 2000)では、ゲートがセル状態 $\bm{C}$ も直接参照できるようにします。
$$ \begin{align} \bm{f}_t &= \sigma(\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] + \bm{W}_{pf} \odot \bm{C}_{t-1} + \bm{b}_f) \\ \bm{i}_t &= \sigma(\bm{W}_i [\bm{h}_{t-1}, \bm{x}_t] + \bm{W}_{pi} \odot \bm{C}_{t-1} + \bm{b}_i) \\ \bm{o}_t &= \sigma(\bm{W}_o [\bm{h}_{t-1}, \bm{x}_t] + \bm{W}_{po} \odot \bm{C}_t + \bm{b}_o) \end{align} $$
忘却ゲートと入力ゲートは $\bm{C}_{t-1}$(更新前のセル状態)を、出力ゲートは $\bm{C}_t$(更新後のセル状態)を参照します。$\bm{W}_{pf}, \bm{W}_{pi}, \bm{W}_{po}$ は対角重み(要素ごとの重み)です。
Peephole接続により、ゲートはセル状態の現在値に基づいてより精密な制御を行えますが、パラメータ数が増加します。実用上は、標準LSTMで十分な性能が得られることが多いです。
GRUとの比較
LSTM以外のゲート付きRNNとして、GRU(Gated Recurrent Unit)(Cho et al., 2014)があります。GRUはLSTMを簡略化したもので、セル状態と隠れ状態を統合し、ゲート数を2つ(リセットゲートと更新ゲート)に削減しています。
$$ \begin{align} \bm{z}_t &= \sigma(\bm{W}_z [\bm{h}_{t-1}, \bm{x}_t]) \quad \text{(更新ゲート)} \\ \bm{r}_t &= \sigma(\bm{W}_r [\bm{h}_{t-1}, \bm{x}_t]) \quad \text{(リセットゲート)} \\ \tilde{\bm{h}}_t &= \tanh(\bm{W} [\bm{r}_t \odot \bm{h}_{t-1}, \bm{x}_t]) \\ \bm{h}_t &= (1 – \bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t \end{align} $$
GRUはLSTMより少ないパラメータで同等の性能を示すことが多く、計算効率が良いという利点があります。
Pythonでの実装: LSTMのスクラッチ実装
numpyによるLSTM実装
import numpy as np
import matplotlib.pyplot as plt
class LSTM:
"""LSTMのスクラッチ実装(numpy)"""
def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01):
self.hidden_size = hidden_size
self.input_size = input_size
self.output_size = output_size
self.lr = learning_rate
# 結合入力のサイズ
concat_size = hidden_size + input_size
# 重みの初期化(Xavier初期化)
scale = np.sqrt(2.0 / concat_size)
# 忘却ゲートの重み
self.Wf = np.random.randn(hidden_size, concat_size) * scale
self.bf = np.ones((hidden_size, 1)) # 忘却ゲートのバイアスは1で初期化
# 入力ゲートの重み
self.Wi = np.random.randn(hidden_size, concat_size) * scale
self.bi = np.zeros((hidden_size, 1))
# 候補セル状態の重み
self.Wc = np.random.randn(hidden_size, concat_size) * scale
self.bc = np.zeros((hidden_size, 1))
# 出力ゲートの重み
self.Wo = np.random.randn(hidden_size, concat_size) * scale
self.bo = np.zeros((hidden_size, 1))
# 出力層の重み
self.Wy = np.random.randn(output_size, hidden_size) * np.sqrt(2.0 / hidden_size)
self.by = np.zeros((output_size, 1))
def sigmoid(self, x):
"""数値安定なシグモイド"""
return np.where(x >= 0,
1 / (1 + np.exp(-x)),
np.exp(x) / (1 + np.exp(x)))
def forward(self, inputs):
"""順伝播: inputs は文字インデックスのリスト"""
T = len(inputs)
self.T = T
self.inputs = inputs
# 各時刻の中間変数を保存
self.concat_inputs = {}
self.f_gates = {}
self.i_gates = {}
self.c_candidates = {}
self.o_gates = {}
self.cell_states = {-1: np.zeros((self.hidden_size, 1))}
self.hidden_states = {-1: np.zeros((self.hidden_size, 1))}
self.outputs = {}
for t in range(T):
# 入力のone-hotエンコーディング
x_t = np.zeros((self.input_size, 1))
x_t[inputs[t]] = 1
# 結合入力 [h_{t-1}, x_t]
concat = np.vstack([self.hidden_states[t-1], x_t])
self.concat_inputs[t] = concat
# 忘却ゲート
self.f_gates[t] = self.sigmoid(self.Wf @ concat + self.bf)
# 入力ゲート
self.i_gates[t] = self.sigmoid(self.Wi @ concat + self.bi)
# 候補セル状態
self.c_candidates[t] = np.tanh(self.Wc @ concat + self.bc)
# セル状態の更新
self.cell_states[t] = (self.f_gates[t] * self.cell_states[t-1]
+ self.i_gates[t] * self.c_candidates[t])
# 出力ゲート
self.o_gates[t] = self.sigmoid(self.Wo @ concat + self.bo)
# 隠れ状態
self.hidden_states[t] = self.o_gates[t] * np.tanh(self.cell_states[t])
# 出力(ソフトマックス)
logits = self.Wy @ self.hidden_states[t] + self.by
exp_logits = np.exp(logits - np.max(logits)) # 数値安定性
self.outputs[t] = exp_logits / np.sum(exp_logits)
return self.outputs
def backward(self, targets):
"""逆伝播(BPTT)"""
T = self.T
# 勾配の初期化
dWf = np.zeros_like(self.Wf)
dWi = np.zeros_like(self.Wi)
dWc = np.zeros_like(self.Wc)
dWo = np.zeros_like(self.Wo)
dbf = np.zeros_like(self.bf)
dbi = np.zeros_like(self.bi)
dbc = np.zeros_like(self.bc)
dbo = np.zeros_like(self.bo)
dWy = np.zeros_like(self.Wy)
dby = np.zeros_like(self.by)
# 次の時刻から伝播してくる勾配
dh_next = np.zeros((self.hidden_size, 1))
dc_next = np.zeros((self.hidden_size, 1))
loss = 0
for t in reversed(range(T)):
# 出力層の勾配(交差エントロピー損失)
dy = self.outputs[t].copy()
dy[targets[t]] -= 1 # ソフトマックス+交差エントロピーの勾配
loss -= np.log(self.outputs[t][targets[t]] + 1e-8)
dWy += dy @ self.hidden_states[t].T
dby += dy
# 隠れ状態に対する勾配
dh = self.Wy.T @ dy + dh_next
# 出力ゲートの勾配
tanh_c = np.tanh(self.cell_states[t])
do = dh * tanh_c
do_raw = do * self.o_gates[t] * (1 - self.o_gates[t])
# セル状態に対する勾配
dc = (dh * self.o_gates[t] * (1 - tanh_c ** 2)) + dc_next
# 忘却ゲートの勾配
df = dc * self.cell_states[t-1]
df_raw = df * self.f_gates[t] * (1 - self.f_gates[t])
# 入力ゲートの勾配
di = dc * self.c_candidates[t]
di_raw = di * self.i_gates[t] * (1 - self.i_gates[t])
# 候補セル状態の勾配
dc_cand = dc * self.i_gates[t]
dc_cand_raw = dc_cand * (1 - self.c_candidates[t] ** 2)
# 重み・バイアスの勾配を累積
concat = self.concat_inputs[t]
dWf += df_raw @ concat.T
dWi += di_raw @ concat.T
dWc += dc_cand_raw @ concat.T
dWo += do_raw @ concat.T
dbf += df_raw
dbi += di_raw
dbc += dc_cand_raw
dbo += do_raw
# 前の時刻への勾配
dconcat = (self.Wf.T @ df_raw + self.Wi.T @ di_raw
+ self.Wc.T @ dc_cand_raw + self.Wo.T @ do_raw)
dh_next = dconcat[:self.hidden_size]
dc_next = dc * self.f_gates[t]
# 勾配クリッピング
for grad in [dWf, dWi, dWc, dWo, dbf, dbi, dbc, dbo, dWy, dby]:
np.clip(grad, -5, 5, out=grad)
# パラメータ更新
self.Wf -= self.lr * dWf
self.Wi -= self.lr * dWi
self.Wc -= self.lr * dWc
self.Wo -= self.lr * dWo
self.bf -= self.lr * dbf
self.bi -= self.lr * dbi
self.bc -= self.lr * dbc
self.bo -= self.lr * dbo
self.Wy -= self.lr * dWy
self.by -= self.lr * dby
return loss.item()
def generate(self, seed_char, length, char_to_idx, idx_to_char, temperature=1.0):
"""テキスト生成"""
h = np.zeros((self.hidden_size, 1))
c = np.zeros((self.hidden_size, 1))
idx = char_to_idx[seed_char]
result = [seed_char]
for _ in range(length):
x = np.zeros((self.input_size, 1))
x[idx] = 1
concat = np.vstack([h, x])
# 順伝播
f = self.sigmoid(self.Wf @ concat + self.bf)
i = self.sigmoid(self.Wi @ concat + self.bi)
c_cand = np.tanh(self.Wc @ concat + self.bc)
c = f * c + i * c_cand
o = self.sigmoid(self.Wo @ concat + self.bo)
h = o * np.tanh(c)
# 出力
logits = self.Wy @ h + self.by
logits = logits / temperature
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / np.sum(exp_logits)
idx = np.random.choice(self.input_size, p=probs.flatten())
result.append(idx_to_char[idx])
return ''.join(result)
# --- テキスト生成タスク ---
# サンプルテキスト(学習データ)
text = """the quick brown fox jumps over the lazy dog
a stitch in time saves nine
all that glitters is not gold
actions speak louder than words
knowledge is power and power is knowledge
the pen is mightier than the sword
fortune favors the bold and brave
to be or not to be that is the question
early bird catches the worm every morning
practice makes perfect in every endeavor"""
# 文字のインデックス化
chars = sorted(set(text))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)
print(f"語彙サイズ: {vocab_size}")
print(f"テキスト長: {len(text)}")
# 学習データの準備
data = [char_to_idx[ch] for ch in text]
# LSTMモデルの構築
hidden_size = 64
lstm = LSTM(input_size=vocab_size, hidden_size=hidden_size,
output_size=vocab_size, learning_rate=0.005)
# 学習
n_epochs = 200
seq_length = 30 # 1回の学習で使う系列長
losses = []
for epoch in range(n_epochs):
epoch_loss = 0
n_batches = 0
# テキストをseq_length単位で処理
for start in range(0, len(data) - seq_length - 1, seq_length):
inputs = data[start:start + seq_length]
targets = data[start + 1:start + seq_length + 1]
# 順伝播
lstm.forward(inputs)
# 逆伝播
loss = lstm.backward(targets)
epoch_loss += loss
n_batches += 1
avg_loss = epoch_loss / n_batches
losses.append(avg_loss)
if (epoch + 1) % 40 == 0:
print(f"\nEpoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")
generated = lstm.generate('t', 80, char_to_idx, idx_to_char, temperature=0.8)
print(f"Generated: {generated}")
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(16, 5))
# 学習曲線
axes[0].plot(losses, linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Average Loss', fontsize=12)
axes[0].set_title('LSTM Training Loss (Character-level Language Model)', fontsize=14)
axes[0].grid(True, alpha=0.3)
# ゲートの値の可視化(最後のエピソードの忘却ゲート値)
sample_input = data[:seq_length]
lstm.forward(sample_input)
f_gate_values = np.array([lstm.f_gates[t].flatten()[:8] for t in range(seq_length)])
im = axes[1].imshow(f_gate_values.T, aspect='auto', cmap='RdYlGn', vmin=0, vmax=1)
axes[1].set_xlabel('Time Step', fontsize=12)
axes[1].set_ylabel('Hidden Unit', fontsize=12)
axes[1].set_title('Forget Gate Values (First 8 Units)', fontsize=14)
plt.colorbar(im, ax=axes[1])
# 入力文字をx軸ラベルとして表示
chars_label = [idx_to_char[idx] for idx in sample_input]
axes[1].set_xticks(range(0, seq_length, 5))
axes[1].set_xticklabels([chars_label[i] for i in range(0, seq_length, 5)])
plt.tight_layout()
plt.show()
# 最終的なテキスト生成
print("\n--- 学習済みLSTMによるテキスト生成 ---")
for seed in ['t', 'a', 'k']:
generated = lstm.generate(seed, 100, char_to_idx, idx_to_char, temperature=0.7)
print(f"\nSeed='{seed}': {generated}")
勾配伝播の比較実験
RNNとLSTMの勾配伝播の違いを定量的に比較する実験です。
import numpy as np
import matplotlib.pyplot as plt
def compute_gradient_flow_rnn(hidden_size, seq_len, n_trials=50):
"""通常のRNNの勾配流を計算"""
norms = np.zeros(seq_len)
for _ in range(n_trials):
# ランダムな重み行列
Wh = np.random.randn(hidden_size, hidden_size) * 0.5 / np.sqrt(hidden_size)
# ランダムな隠れ状態
h = np.random.randn(hidden_size, 1) * 0.1
jacobian_prod = np.eye(hidden_size)
for t in range(seq_len):
# h_t = tanh(Wh @ h_{t-1} + ...)
a = Wh @ h
h = np.tanh(a)
# ヤコビ行列: diag(1 - h^2) @ Wh
J = np.diag((1 - h ** 2).flatten()) @ Wh
jacobian_prod = J @ jacobian_prod
norms[t] += np.linalg.norm(jacobian_prod)
return norms / n_trials
def compute_gradient_flow_lstm(hidden_size, seq_len, n_trials=50):
"""LSTMの勾配流を計算(セル状態経由)"""
norms = np.zeros(seq_len)
for _ in range(n_trials):
# 忘却ゲートの値(0.9前後に設定、実際の学習済みLSTMを想定)
forget_gates = 0.85 + 0.1 * np.random.rand(seq_len, hidden_size)
gradient_prod = np.ones(hidden_size)
for t in range(seq_len):
# ∂C_t/∂C_{t-1} ≈ diag(f_t)
gradient_prod *= forget_gates[t]
norms[t] += np.linalg.norm(gradient_prod)
return norms / n_trials
# 実験の実行
hidden_size = 32
seq_len = 100
rnn_norms = compute_gradient_flow_rnn(hidden_size, seq_len)
lstm_norms = compute_gradient_flow_lstm(hidden_size, seq_len)
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# 線形スケール
axes[0].plot(range(1, seq_len+1), rnn_norms, label='Simple RNN', linewidth=2)
axes[0].plot(range(1, seq_len+1), lstm_norms, label='LSTM (Cell State)', linewidth=2)
axes[0].set_xlabel('Time Steps Back', fontsize=12)
axes[0].set_ylabel('Gradient Norm', fontsize=12)
axes[0].set_title('Gradient Flow: RNN vs LSTM (Linear Scale)', fontsize=14)
axes[0].legend(fontsize=12)
axes[0].grid(True, alpha=0.3)
# 対数スケール
axes[1].semilogy(range(1, seq_len+1), rnn_norms + 1e-20, label='Simple RNN', linewidth=2)
axes[1].semilogy(range(1, seq_len+1), lstm_norms + 1e-20, label='LSTM (Cell State)', linewidth=2)
axes[1].set_xlabel('Time Steps Back', fontsize=12)
axes[1].set_ylabel('Gradient Norm (log scale)', fontsize=12)
axes[1].set_title('Gradient Flow: RNN vs LSTM (Log Scale)', fontsize=14)
axes[1].legend(fontsize=12)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 数値結果
print(f"100ステップ後の勾配ノルム:")
print(f" RNN: {rnn_norms[-1]:.2e}")
print(f" LSTM: {lstm_norms[-1]:.2e}")
print(f" 比率: {lstm_norms[-1] / (rnn_norms[-1] + 1e-30):.2e}")
結果の考察
テキスト生成タスク: LSTMは文字レベルの言語モデルとして、短い訓練テキストからパターンを学習します。学習が進むにつれて、英語の単語構造に近い文字列が生成されるようになります。
忘却ゲートの値の可視化: 忘却ゲートの値は入力文字に応じて動的に変化します。空白文字(単語境界)の直後ではゲート値が低くなり、新しい文脈への切り替えが起きることが観察されます。
勾配伝播の比較: 通常のRNNでは勾配ノルムが数十ステップで急速にゼロに近づきますが、LSTMのセル状態経由の勾配は100ステップ経っても一定程度の大きさを保ちます。これがLSTMの長期依存性学習能力の源泉です。
忘却ゲートのバイアス初期化
実装上の重要なテクニックとして、忘却ゲートのバイアスを1(またはそれ以上)で初期化 することが知られています(Jozefowicz et al., 2015)。
$$ \bm{b}_f \leftarrow \bm{1} $$
これにより、学習初期の忘却ゲートの出力が $\sigma(1) \approx 0.73$ と大きくなり、セル状態の情報が保持されやすくなります。バイアスがゼロの場合は $\sigma(0) = 0.5$ で、情報が半分ずつ失われるため、長期依存性の学習が困難になります。
まとめ
本記事では、LSTM(Long Short-Term Memory)の構造と理論を解説しました。
- LSTMはRNNの勾配消失問題を解決するために、セル状態 $\bm{C}_t$ と3つの ゲート機構(忘却・入力・出力)を導入しました
- セル状態の更新 $\bm{C}_t = \bm{f}_t \odot \bm{C}_{t-1} + \bm{i}_t \odot \tilde{\bm{C}}_t$ は 加法的 であり、勾配が $\text{diag}(\bm{f}_t)$ の形でほぼ減衰なく伝播します(CEC: Constant Error Carousel)
- 通常のRNNでは勾配が $\|\bm{W}_h\|^{T-t}$ の形で指数的に減衰しますが、LSTMでは $\prod \bm{f}_k$ の形であり、忘却ゲートが1に近い値を学習すれば長期記憶が可能です
- Peephole接続 はゲートがセル状態を直接参照できる変種です
- 忘却ゲートのバイアスを1で初期化するのが実用上重要なテクニックです
次のステップとして、以下の記事も参考にしてください。