LSTMは勾配消失問題を解決する画期的なアーキテクチャですが、3つのゲートと4つの重み行列を持つため、パラメータ数が多く計算コストも高いという特徴があります。2014年にCho らは、「ゲートの数を減らしてもっとシンプルにできないか」という問いに答える形でGRU(Gated Recurrent Unit)を提案しました。
GRUの基本的な発想は明快です。LSTMの忘却ゲートと入力ゲートを1つの更新ゲートに統合し、さらにセル状態と隠れ状態を統合して、全体をよりコンパクトにします。この簡略化によりパラメータ数は約25%削減されますが、多くのタスクでLSTMと同等の性能を発揮します。
GRUを理解すると、以下のような場面で適切な選択ができるようになります。
- 計算リソースが限られた環境: エッジデバイスや組込みシステムでのリアルタイム推論
- 高速なプロトタイピング: LSTMより学習が速いため、実験サイクルを短縮
- 時系列予測タスク: 気温、株価、センサデータなどの予測問題で、LSTMとの使い分け
- モデル設計の判断: 問題の性質に応じてLSTMとGRUのどちらを選ぶべきかの判断
本記事の内容
- LSTMの構造の復習と「もっとシンプルにできないか」の動機
- GRUのアーキテクチャ(リセットゲートと更新ゲート)
- LSTMとGRUの対応関係と構造的な違い
- パラメータ数の定量的な比較
- NumPyによるGRUセルのスクラッチ実装
- LSTMとGRUの性能比較実験
- 使い分けの実践的な指針
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
- LSTMの理論と仕組みを徹底解説 — ゲート機構とセル状態の理論
- RNN(再帰型ニューラルネットワーク)の基礎を解説 — RNNの構造とBPTT
- 活性化関数の全体像 — シグモイドとtanh
LSTMの復習と簡略化の動機
LSTMの構造を振り返る
GRUの設計意図を理解するために、まずLSTMの計算式を振り返りましょう。
$$ \begin{align} \bm{f}_t &= \sigma(\bm{W}_f [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_f) & \text{(忘却ゲート)} \\ \bm{i}_t &= \sigma(\bm{W}_i [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_i) & \text{(入力ゲート)} \\ \tilde{\bm{c}}_t &= \tanh(\bm{W}_c [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_c) & \text{(候補セル状態)} \\ \bm{c}_t &= \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t & \text{(セル状態更新)} \\ \bm{o}_t &= \sigma(\bm{W}_o [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_o) & \text{(出力ゲート)} \\ \bm{h}_t &= \bm{o}_t \odot \tanh(\bm{c}_t) & \text{(隠れ状態)} \end{align} $$
LSTMは4つの重み行列($\bm{W}_f$, $\bm{W}_i$, $\bm{W}_c$, $\bm{W}_o$)と4つのバイアス、さらに2つの状態ベクトル(セル状態 $\bm{c}_t$ と隠れ状態 $\bm{h}_t$)を持ちます。
「もっとシンプルにできないか」
LSTMの構造を注意深く観察すると、いくつかの冗長性に気づきます。
観察1: 忘却ゲートと入力ゲートの相補性
セル状態の更新式 $\bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$ を見ると、忘却ゲート $\bm{f}_t$ と入力ゲート $\bm{i}_t$ は相補的な役割を持っています。忘却ゲートが「古い情報をどれだけ残すか」を決め、入力ゲートが「新しい情報をどれだけ入れるか」を決めます。
もし「古い情報を忘れた分だけ新しい情報を入れる」と仮定すれば、$\bm{i}_t = 1 – \bm{f}_t$ と制約できます。すると2つのゲートを1つに統合できます。
観察2: セル状態と隠れ状態の冗長性
LSTMではセル状態 $\bm{c}_t$ と隠れ状態 $\bm{h}_t$ を別々に管理していますが、出力ゲートを通じて $\bm{h}_t = \bm{o}_t \odot \tanh(\bm{c}_t)$ と変換しています。もしセル状態と隠れ状態を1つの状態に統合すれば、出力ゲートも不要になります。
観察3: 多くのタスクで全てのゲートが必要ではない
Greff et al. (2015) はLSTMの各構成要素の重要性を大規模に比較実験し、忘却ゲートと出力ゲートが最も重要だが、入力ゲートを忘却ゲートと結合しても性能がほとんど低下しないことを示しました。
これらの観察を踏まえて設計されたのがGRUです。GRUはLSTMのエッセンスを保ちながら、より少ないパラメータで同等の表現力を目指します。
GRUのアーキテクチャ
2つのゲート
GRUは、LSTMの3つのゲート(忘却・入力・出力)を2つのゲートに集約しています。
| ゲート | 記号 | 役割 | LSTMとの対応 |
|---|---|---|---|
| リセットゲート | $\bm{r}_t$ | 過去の隠れ状態のどの部分を候補状態の計算に使うか | 出力ゲートに近い |
| 更新ゲート | $\bm{z}_t$ | 古い状態と新しい候補状態の混合比を決める | 忘却ゲート + 入力ゲートの統合 |
水と絵の具を混ぜる場面を想像してください。更新ゲートは「水(古い状態)と絵の具(新しい状態)をどの比率で混ぜるか」を決めます。リセットゲートは「絵の具を作るとき、前の色をどれだけ参考にするか」を決めます。
リセットゲート
リセットゲートは、候補隠れ状態を計算する際に、前の隠れ状態のどの部分を「使う」かを制御します。
$$ \begin{equation} \bm{r}_t = \sigma(\bm{W}_r [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_r) \end{equation} $$
$\bm{r}_t$ の各要素が $0$ に近いとき、対応する前の隠れ状態の次元は「リセット」されます。つまり、候補隠れ状態を計算する際にその次元の過去の情報は無視されます。$1$ に近いとき、過去の情報はそのまま利用されます。
リセットゲートが全ての要素で $0$ に近い値を取ると、候補状態は現在の入力のみから計算されます。これにより、GRUは過去の文脈を「忘れて」現在の入力に集中することができます。たとえば、文の区切りでトピックが大きく変わるときに、前のトピックの情報をリセットするのに有効です。
候補隠れ状態
リセットゲートを使って、候補隠れ状態 $\tilde{\bm{h}}_t$ を計算します。
$$ \begin{equation} \tilde{\bm{h}}_t = \tanh(\bm{W}_h [\bm{r}_t \odot \bm{h}_{t-1}, \bm{x}_t] + \bm{b}_h) \end{equation} $$
ここで重要なのは、$\bm{r}_t \odot \bm{h}_{t-1}$ の部分です。LSTMの候補セル状態 $\tilde{\bm{c}}_t = \tanh(\bm{W}_c [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_c)$ では $\bm{h}_{t-1}$ がそのまま使われますが、GRUではリセットゲートでフィルタリングされた $\bm{r}_t \odot \bm{h}_{t-1}$ が使われます。
この違いにより、GRUは候補状態を計算する段階で既に「過去のどの情報を参考にするか」を制御できます。リセットゲートが $\bm{0}$ に近いとき、候補状態は $\tilde{\bm{h}}_t \approx \tanh(\bm{W}_h [\bm{0}, \bm{x}_t] + \bm{b}_h)$ となり、現在の入力だけで新しい状態が計算されます。
更新ゲート
更新ゲートは、LSTMの忘却ゲートと入力ゲートの役割を1つに統合したものです。
$$ \begin{equation} \bm{z}_t = \sigma(\bm{W}_z [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_z) \end{equation} $$
隠れ状態の更新
更新ゲートを使って、最終的な隠れ状態 $\bm{h}_t$ を計算します。
$$ \begin{equation} \bm{h}_t = (1 – \bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t \end{equation} $$
この式はGRUの最も重要な式であり、以下のような線形補間(linear interpolation)になっています。
- $\bm{z}_t \approx \bm{0}$ のとき: $\bm{h}_t \approx \bm{h}_{t-1}$(前の状態をそのまま保持)
- $\bm{z}_t \approx \bm{1}$ のとき: $\bm{h}_t \approx \tilde{\bm{h}}_t$(候補状態で完全に更新)
- $\bm{z}_t \approx 0.5$ のとき: 前の状態と候補状態の平均
この線形補間構造が、GRUの核心的な設計です。$(1 – \bm{z}_t)$ と $\bm{z}_t$ の合計が常に $1$ になるため、情報の「総量」が保存されます。LSTMでは忘却ゲートと入力ゲートが独立しているため、$\bm{f}_t + \bm{i}_t$ が $1$ になる保証はありませんが、GRUでは構造的に保証されています。
GRUの全体像
GRUの1ステップの計算をまとめると、次の4つの式で表されます。
$$ \begin{align} \bm{r}_t &= \sigma(\bm{W}_r [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_r) & \text{(リセットゲート)} \\ \bm{z}_t &= \sigma(\bm{W}_z [\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_z) & \text{(更新ゲート)} \\ \tilde{\bm{h}}_t &= \tanh(\bm{W}_h [\bm{r}_t \odot \bm{h}_{t-1}, \bm{x}_t] + \bm{b}_h) & \text{(候補隠れ状態)} \\ \bm{h}_t &= (1 – \bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t & \text{(隠れ状態の更新)} \end{align} $$
LSTMの6つの式と比較して、GRUは4つの式で表されます。セル状態 $\bm{c}_t$ と出力ゲート $\bm{o}_t$ が不要になり、構造が大幅に簡略化されていることがわかります。
ここまででGRUの数式が明らかになりました。次に、LSTMとGRUの構造を詳しく比較して、両者の対応関係を明確にしましょう。
LSTMとGRUの対応関係
構造的な対応
LSTMとGRUの各構成要素がどう対応するかを整理します。
| 項目 | LSTM | GRU | 備考 |
|---|---|---|---|
| 状態変数 | $\bm{c}_t$ (セル状態) + $\bm{h}_t$ (隠れ状態) | $\bm{h}_t$ (隠れ状態のみ) | GRUは1つの状態に統合 |
| 忘却機構 | 忘却ゲート $\bm{f}_t$ | $1 – \bm{z}_t$ | GRUでは更新ゲートの補数 |
| 入力機構 | 入力ゲート $\bm{i}_t$ | $\bm{z}_t$ | GRUでは忘却と連動 |
| 出力機構 | 出力ゲート $\bm{o}_t$ | なし | GRUでは状態がそのまま出力 |
| 過去の制御 | なし(直接参照) | リセットゲート $\bm{r}_t$ | 候補状態計算時に過去を制御 |
| 状態更新 | 加算: $\bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$ | 補間: $(1-\bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t$ | GRUは線形補間で総量保存 |
忘却と入力の統合
LSTMのセル状態更新式とGRUの隠れ状態更新式を並べて比較すると、設計思想の違いが明確になります。
LSTM: $\bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$
ここで $\bm{f}_t$ と $\bm{i}_t$ は独立に計算されます。極端な場合、$\bm{f}_t \approx 1$ かつ $\bm{i}_t \approx 1$ も可能で、このとき古い情報を全て保持しつつ新しい情報も全て追加するため、セル状態の大きさが際限なく成長する可能性があります。
GRU: $\bm{h}_t = (1 – \bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t$
$(1 – \bm{z}_t) + \bm{z}_t = 1$ が常に成り立つため、古い情報と新しい情報の比率の合計は常に $1$ です。これは凸結合であり、隠れ状態のノルムが発散しにくいという利点があります。
リセットゲートと出力ゲートの対応
LSTMの出力ゲート $\bm{o}_t$ は「セル状態のどの部分を出力するか」を制御します。一方、GRUのリセットゲート $\bm{r}_t$ は「候補状態の計算に過去のどの部分を使うか」を制御します。
両者は異なるタイミングで働きますが、「過去の情報へのアクセスを制限する」という点では共通しています。ただし、LSTMの出力ゲートはセル状態から隠れ状態への変換時に作用するのに対し、GRUのリセットゲートは候補状態の生成時に作用するという違いがあります。
勾配の流れの比較
GRUの隠れ状態の更新式 $\bm{h}_t = (1 – \bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t$ において、$\bm{h}_t$ の $\bm{h}_{t-1}$ に対する勾配の主要項は
$$ \frac{\partial \bm{h}_t}{\partial \bm{h}_{t-1}} \approx \text{diag}(1 – \bm{z}_t) $$
です($\tilde{\bm{h}}_t$ の $\bm{h}_{t-1}$ への依存からの項を無視した近似)。長期にわたる勾配は
$$ \prod_{k=t+1}^{T} (1 – \bm{z}_k) $$
のオーダーとなります。これはLSTMのセル状態を通じた勾配 $\prod \bm{f}_k$ と同じ構造です。$\bm{z}_k \approx 0$ のとき(つまり「更新しない」とき)に勾配がほぼ $1$ で伝播します。
LSTMでは忘却ゲート $\bm{f}_t \approx 1$ のとき、GRUでは更新ゲート $\bm{z}_t \approx 0$ のときに、勾配が安定的に伝播するという、ちょうど逆の対応関係になっていることに注意してください。
ここまでで、GRUとLSTMの構造的な違いと共通点が明らかになりました。次に、パラメータ数を定量的に比較して、GRUの計算効率上の利点を確認しましょう。
パラメータ数の比較
数式での比較
入力次元を $d_x$、隠れ次元を $d_h$ として、各モデルのパラメータ数を計算します。
バニラRNN:
重み行列が $\bm{W}_h \in \mathbb{R}^{d_h \times d_h}$、$\bm{W}_x \in \mathbb{R}^{d_h \times d_x}$ と結合して考えると $d_h(d_h + d_x)$、バイアスが $d_h$ で、
$$ N_{\text{RNN}} = d_h(d_h + d_x + 1) $$
LSTM:
4つの重み行列とバイアスで、
$$ N_{\text{LSTM}} = 4 \, d_h(d_h + d_x + 1) $$
GRU:
3つの重み行列(リセットゲート $\bm{W}_r$、更新ゲート $\bm{W}_z$、候補状態 $\bm{W}_h$)とバイアスで、
$$ N_{\text{GRU}} = 3 \, d_h(d_h + d_x + 1) $$
GRUのパラメータ数はLSTMの75%(= 3/4)です。
具体的な数値例
$d_x = 10$、$d_h = 64$ の場合を計算してみましょう。
$$ \begin{align} N_{\text{RNN}} &= 64 \times (64 + 10 + 1) = 4{,}800 \\ N_{\text{LSTM}} &= 4 \times 4{,}800 = 19{,}200 \\ N_{\text{GRU}} &= 3 \times 4{,}800 = 14{,}400 \end{align} $$
GRUはLSTMと比較して $19{,}200 – 14{,}400 = 4{,}800$ パラメータ少なく、これは約25%の削減に相当します。$d_h = 256$ にスケールアップすると、
$$ \begin{align} N_{\text{LSTM}} &= 4 \times 256 \times (256 + 10 + 1) = 273{,}408 \\ N_{\text{GRU}} &= 3 \times 256 \times (256 + 10 + 1) = 205{,}056 \end{align} $$
差は $68{,}352$ パラメータとなり、隠れ次元が大きくなるほどGRUの計算効率の利点が顕著になります。
計算時間の比較
パラメータ数の差は、推論速度と学習速度に直接影響します。1タイムステップあたりの行列積の回数はLSTMが4回($\bm{W}_f$, $\bm{W}_i$, $\bm{W}_c$, $\bm{W}_o$)、GRUが3回($\bm{W}_r$, $\bm{W}_z$, $\bm{W}_h$)です。実測では、GRUはLSTMと比べて20〜30%程度高速であることが多くの研究で報告されています(Chung et al., 2014; Jozefowicz et al., 2015)。
パラメータ数と計算効率の利点がわかったところで、次にGRUをPythonでスクラッチ実装してみましょう。
NumPyによるGRUセルのスクラッチ実装
実装の方針
LSTMの記事と同様に、NumPyのみでGRUセルの順伝播と逆伝播を実装します。LSTMと比較して式が少ない分、コードもシンプルになることを確認してください。
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))
def sigmoid_deriv(s):
"""シグモイドの出力 s からその微分を計算"""
return s * (1 - s)
def tanh_deriv(t):
"""tanhの出力 t からその微分を計算"""
return 1 - t ** 2
class GRUCell:
def __init__(self, input_dim, hidden_dim):
self.d_x = input_dim
self.d_h = hidden_dim
scale = 0.1
# 重み行列: [h, x] の結合入力に対する重み
self.W_z = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
self.W_r = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
self.W_h = np.random.randn(hidden_dim, hidden_dim + input_dim) * scale
# バイアス
self.b_z = np.zeros(hidden_dim)
self.b_r = np.zeros(hidden_dim)
self.b_h = np.zeros(hidden_dim)
def forward(self, x, h_prev):
"""順伝播: 1タイムステップ分の計算"""
concat = np.concatenate([h_prev, x])
# ゲートの計算
z = sigmoid(self.W_z @ concat + self.b_z) # 更新ゲート
r = sigmoid(self.W_r @ concat + self.b_r) # リセットゲート
# 候補隠れ状態
concat_r = np.concatenate([r * h_prev, x])
h_tilde = np.tanh(self.W_h @ concat_r + self.b_h)
# 隠れ状態の更新(線形補間)
h = (1 - z) * h_prev + z * h_tilde
cache = (concat, concat_r, z, r, h_tilde, h_prev, h)
return h, cache
def backward(self, dh_next, cache):
"""逆伝播: 1タイムステップ分の勾配計算"""
concat, concat_r, z, r, h_tilde, h_prev, h = cache
# 隠れ状態の更新式からの勾配
dh_tilde = dh_next * z
dz = dh_next * (h_tilde - h_prev)
dh_prev_direct = dh_next * (1 - z)
# 候補隠れ状態の逆伝播
dh_tilde_raw = dh_tilde * tanh_deriv(h_tilde)
self.dW_h = np.outer(dh_tilde_raw, concat_r)
self.db_h = dh_tilde_raw
d_concat_r = self.W_h.T @ dh_tilde_raw
dr_h = d_concat_r[:self.d_h]
dx_from_h = d_concat_r[self.d_h:]
# リセットゲートの勾配
dr = dr_h * h_prev
dh_prev_from_r = dr_h * r
# 更新ゲートの逆伝播
dz_raw = dz * sigmoid_deriv(z)
self.dW_z = np.outer(dz_raw, concat)
self.db_z = dz_raw
# リセットゲートの逆伝播
dr_raw = dr * sigmoid_deriv(r)
self.dW_r = np.outer(dr_raw, concat)
self.db_r = dr_raw
# 入力と隠れ状態への勾配を集約
d_concat_z = self.W_z.T @ dz_raw
d_concat_rg = self.W_r.T @ dr_raw
dh_prev = (dh_prev_direct + dh_prev_from_r +
d_concat_z[:self.d_h] + d_concat_rg[:self.d_h])
dx = (dx_from_h + d_concat_z[self.d_h:] + d_concat_rg[self.d_h:])
return dh_prev, dx
LSTMの LSTMCell と比較して、いくつかの顕著な違いがあります。まず、forward メソッドで扱う状態がセル状態 $\bm{c}$ なしの隠れ状態 $\bm{h}$ のみです。また、ゲートが3つから2つに減り、計算量が削減されています。さらに、リセットゲート $\bm{r}$ が候補状態の計算に組み込まれている点がLSTMとの構造的な違いです。
続いて、GRUを時系列方向に展開するネットワーククラスを実装します。
class GRUNetwork:
def __init__(self, input_dim, hidden_dim, output_dim, lr=0.001):
self.cell = GRUCell(input_dim, hidden_dim)
self.W_y = np.random.randn(output_dim, hidden_dim) * 0.1
self.b_y = np.zeros(output_dim)
self.d_h = hidden_dim
self.lr = lr
def forward_sequence(self, xs):
"""系列全体の順伝播"""
T = len(xs)
h = np.zeros(self.d_h)
hs, caches = [], []
for t in range(T):
h, cache = self.cell.forward(xs[t], h)
hs.append(h)
caches.append(cache)
ys = [self.W_y @ h + self.b_y for h in hs]
return hs, caches, ys
def train_step(self, xs, targets):
"""1回の学習ステップ"""
T = len(xs)
hs, caches, ys = self.forward_sequence(xs)
# 損失の計算(MSE)
loss = sum(np.sum((ys[t] - targets[t])**2) for t in range(T)) / T
# 勾配の初期化
dW_z = np.zeros_like(self.cell.W_z)
dW_r = np.zeros_like(self.cell.W_r)
dW_h = np.zeros_like(self.cell.W_h)
db_z = np.zeros_like(self.cell.b_z)
db_r = np.zeros_like(self.cell.b_r)
db_h = np.zeros_like(self.cell.b_h)
dW_y = np.zeros_like(self.W_y)
db_y = np.zeros_like(self.b_y)
dh_next = np.zeros(self.d_h)
for t in reversed(range(T)):
dy = 2 * (ys[t] - targets[t]) / T
dW_y += np.outer(dy, hs[t])
db_y += dy
dh = self.W_y.T @ dy + dh_next
dh_next, _ = self.cell.backward(dh, caches[t])
dW_z += self.cell.dW_z
dW_r += self.cell.dW_r
dW_h += self.cell.dW_h
db_z += self.cell.db_z
db_r += self.cell.db_r
db_h += self.cell.db_h
# 勾配クリッピング
for grad in [dW_z, dW_r, dW_h, dW_y, db_z, db_r, db_h, db_y]:
np.clip(grad, -5, 5, out=grad)
# パラメータ更新
self.cell.W_z -= self.lr * dW_z
self.cell.W_r -= self.lr * dW_r
self.cell.W_h -= self.lr * dW_h
self.cell.b_z -= self.lr * db_z
self.cell.b_r -= self.lr * db_r
self.cell.b_h -= self.lr * db_h
self.W_y -= self.lr * dW_y
self.b_y -= self.lr * db_y
return loss
LSTMの LSTMNetwork と比較すると、GRUの GRUNetwork はセル状態の管理が不要な分、コードが明らかに簡潔です。forward_sequence ではセル状態 $\bm{c}$ のリストを管理する必要がなく、train_step でもセル状態への勾配 dc_next を追跡する必要がありません。
実装が完成したので、次にLSTMとGRUの性能を同じ条件で比較する実験を行いましょう。
LSTMとGRUの性能比較実験
実験設定
LSTMとGRUの性能を公平に比較するために、同じ合成データ、同じ隠れ次元、同じ学習率で両モデルを学習させます。
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# データ生成
t = np.linspace(0, 20 * np.pi, 2000)
data = np.sin(t) + 0.1 * np.sin(3 * t)
# 訓練データ作成
seq_len = 20
X, Y = [], []
for i in range(len(data) - seq_len):
X.append(data[i:i+seq_len].reshape(-1, 1))
Y.append(data[i+1:i+seq_len+1].reshape(-1, 1))
X = X[:500]
Y = Y[:500]
# 両モデルの初期化(同じ隠れ次元・学習率)
hidden_dim = 32
lr = 0.005
n_epochs = 100
# LSTMとGRUを並列に学習
from time import time
# LSTM(前のセクションのクラスを使用)
lstm = LSTMNetwork(input_dim=1, hidden_dim=hidden_dim, output_dim=1, lr=lr)
gru = GRUNetwork(input_dim=1, hidden_dim=hidden_dim, output_dim=1, lr=lr)
lstm_losses, gru_losses = [], []
lstm_time, gru_time = 0, 0
for epoch in range(n_epochs):
lstm_epoch_loss, gru_epoch_loss = 0, 0
for j in range(len(X)):
t0 = time()
l_loss = lstm.train_step(X[j], Y[j])
lstm_time += time() - t0
t0 = time()
g_loss = gru.train_step(X[j], Y[j])
gru_time += time() - t0
lstm_epoch_loss += l_loss
gru_epoch_loss += g_loss
lstm_losses.append(lstm_epoch_loss / len(X))
gru_losses.append(gru_epoch_loss / len(X))
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}: LSTM Loss={lstm_losses[-1]:.6f}, "
f"GRU Loss={gru_losses[-1]:.6f}")
print(f"\n学習時間: LSTM={lstm_time:.2f}s, GRU={gru_time:.2f}s")
print(f"GRU/LSTM速度比: {gru_time/lstm_time:.2f}")
損失の比較プロット
plt.figure(figsize=(10, 5))
plt.plot(lstm_losses, label='LSTM', color='steelblue', linewidth=2)
plt.plot(gru_losses, label='GRU', color='coral', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('LSTM vs GRU: Training Loss Comparison')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
このグラフから、LSTMとGRUの損失がほぼ同じペースで減少していくことが確認できます。多くの場合、GRUはLSTMと同等の最終損失に到達します。学習の初期段階ではGRUの方がわずかに速く収束する傾向が見られることもあり、これはパラメータ数が少ない分、各パラメータに対する勾配信号が相対的に大きくなるためと考えられます。
予測精度の比較
# テストデータでの予測比較
test_start = 600
test_len = 200
test_data = data[test_start:test_start + test_len + seq_len]
# LSTM予測
lstm_preds = []
h_l = np.zeros(hidden_dim)
c_l = np.zeros(hidden_dim)
for i in range(seq_len):
x_input = np.array([test_data[i]])
h_l, c_l, _ = lstm.cell.forward(x_input, h_l, c_l)
for i in range(test_len):
y_pred = lstm.W_y @ h_l + lstm.b_y
lstm_preds.append(y_pred[0])
x_input = np.array([test_data[seq_len + i]])
h_l, c_l, _ = lstm.cell.forward(x_input, h_l, c_l)
# GRU予測
gru_preds = []
h_g = np.zeros(hidden_dim)
for i in range(seq_len):
x_input = np.array([test_data[i]])
h_g, _ = gru.cell.forward(x_input, h_g)
for i in range(test_len):
y_pred = gru.W_y @ h_g + gru.b_y
gru_preds.append(y_pred[0])
x_input = np.array([test_data[seq_len + i]])
h_g, _ = gru.cell.forward(x_input, h_g)
# 可視化
plt.figure(figsize=(12, 5))
time_axis = np.arange(test_len)
plt.plot(time_axis, test_data[seq_len:seq_len+test_len],
label='Ground Truth', color='black', linewidth=2)
plt.plot(time_axis, lstm_preds,
label='LSTM', color='steelblue', linewidth=1.5, linestyle='--')
plt.plot(time_axis, gru_preds,
label='GRU', color='coral', linewidth=1.5, linestyle='-.')
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.title('LSTM vs GRU: Prediction Comparison')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 定量的な比較
lstm_mse = np.mean((np.array(lstm_preds) - test_data[seq_len:seq_len+test_len])**2)
gru_mse = np.mean((np.array(gru_preds) - test_data[seq_len:seq_len+test_len])**2)
print(f"LSTM Test MSE: {lstm_mse:.6f}")
print(f"GRU Test MSE: {gru_mse:.6f}")
予測結果のグラフから、LSTMとGRUの両方が正弦波のパターンを良好に追跡していることがわかります。定量的なMSEの差はわずかであり、この程度のタスクでは両者の性能に実質的な差がないことが確認できます。一方で、GRUの学習時間はLSTMより短くなっているはずです。これはパラメータ数が25%少ない分、行列積の回数が削減されているためです。
パラメータ数の実測
# パラメータ数の比較
lstm_params = (lstm.cell.W_f.size + lstm.cell.W_i.size +
lstm.cell.W_c.size + lstm.cell.W_o.size +
lstm.cell.b_f.size + lstm.cell.b_i.size +
lstm.cell.b_c.size + lstm.cell.b_o.size +
lstm.W_y.size + lstm.b_y.size)
gru_params = (gru.cell.W_z.size + gru.cell.W_r.size +
gru.cell.W_h.size +
gru.cell.b_z.size + gru.cell.b_r.size +
gru.cell.b_h.size +
gru.W_y.size + gru.b_y.size)
print(f"LSTM パラメータ数: {lstm_params:,}")
print(f"GRU パラメータ数: {gru_params:,}")
print(f"削減率: {(1 - gru_params/lstm_params)*100:.1f}%")
出力を確認すると、GRUのパラメータ数がLSTMの約75%であることが実測でも確認できます。この25%の削減が、計算速度の向上とメモリ使用量の削減につながっています。
使い分けの指針
どちらを選ぶべきか
LSTMとGRUの使い分けについて、実践的な指針をまとめます。
GRUが適している場面:
- データセットが小さい場合: パラメータ数が少ないため過学習しにくく、限られたデータで効率的に学習できます
- 計算リソースが限られている場合: エッジデバイスやリアルタイム推論など、速度が重要な場面でGRUの計算効率が活きます
- 素早いプロトタイピング: 学習が速いため、実験サイクルを短縮できます
- シンプルな時系列パターン: 複雑な長期依存関係がそれほど重要でないタスク
LSTMが適している場面:
- 複雑な長期依存性を含むタスク: セル状態による明示的な長期記憶が有効な場合。特に、忘却と入力を独立に制御する必要がある場合
- 大規模データセット: データが十分にある場合、LSTMの追加パラメータがより豊かな表現力につながる可能性があります
- 入出力の情報流が非対称な場合: 出力ゲートにより、セル状態の一部だけを出力する柔軟性が必要な場面
- 既存の成功事例が多い分野: 音声認識や機械翻訳など、LSTMで多くのベンチマークが確立されている分野
一般的なガイドライン
実際には、明確な優劣をつけるのは難しく、タスクごとに両者を試して比較するのが最善のアプローチです。Chung et al. (2014) やJozefowicz et al. (2015) の大規模な比較実験でも、一方が常に優れているという結論は出ていません。
ただし、まずGRUから試すというのは合理的な戦略です。パラメータ数が少なく学習が速いため、ベースラインを素早く構築できます。GRUで十分な性能が得られなければLSTMに切り替える、というアプローチが効率的です。
近年のトレンド
2017年以降、多くの時系列タスクでTransformerベースのモデルがRNN系のモデル(LSTMとGRU)を上回る性能を示しています。特に長い系列や並列計算が重要な場面では、Transformerが有力な選択肢です。ただし、短い系列やストリーミングデータの処理など、RNN系のアーキテクチャが依然として有効な場面もあります。
まとめ
本記事では、GRUの理論とLSTMとの対応関係を解説し、NumPyによるスクラッチ実装と比較実験を行いました。
- GRUの設計思想: LSTMの忘却ゲートと入力ゲートを更新ゲートに統合し、セル状態と隠れ状態を一体化
- 2つのゲート: リセットゲート(過去の情報のフィルタリング)と更新ゲート(古い情報と新しい情報の線形補間)
- パラメータ数: GRUはLSTMの約75%のパラメータで、約20〜30%高速な学習が可能
- 性能: 多くのタスクでLSTMと同等の性能を発揮。どちらが優れているかはタスク依存
- 使い分け: リソースが限られた環境やプロトタイピングではGRU、複雑な長期依存性にはLSTMが適するが、まずGRUから試すのが効率的
- 勾配の安定性: $(1 – \bm{z}_t) \odot \bm{h}_{t-1}$ の構造により、LSTMと同様に長期の勾配伝播が安定
LSTMとGRUは、どちらもRNNの勾配消失問題に対する有効な解決策ですが、根本的に逐次処理であるという制約は残ります。系列の全時点を並列に処理し、任意の2時点間の依存関係を直接的に捉えるTransformerアーキテクチャは、この制約を根本的に解消します。
次のステップとして、以下の記事も参考にしてください。