GRU(Gated Recurrent Unit)の構造と実装を解説

RNN(Recurrent Neural Network)の大きな課題であった勾配消失問題を解決するために登場したのがLSTM(Long Short-Term Memory)です。しかしLSTMは3つのゲートとセル状態を持つ複雑な構造であり、パラメータ数が多くなる傾向があります。GRU(Gated Recurrent Unit)は、Cho et al.(2014)により提案された、LSTMの設計思想を継承しつつ構造を簡略化したモデルです。

GRUはLSTMと比較してゲート数が2つ(更新ゲート・リセットゲート)に減少し、セル状態を持たないためパラメータ数が少なく、計算コストも低くなります。一方で、多くのタスクにおいてLSTMと同等の性能を示すことが知られています。本記事では、GRUの設計動機から数式の導出、BPTTによる学習、そしてPythonによるスクラッチ実装まで丁寧に解説します。

本記事の内容

  • GRUの設計動機とLSTMとの関係
  • GRUの各ゲート(更新ゲート・リセットゲート)の数式と役割
  • GRU順伝播の全数式
  • BPTTによる逆伝播と勾配の流れの分析
  • LSTMとGRUの使い分けの指針
  • Pythonによるスクラッチ実装と時系列予測タスクでの性能比較

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

GRUの設計動機

LSTMの復習

LSTMは、セル状態 $\bm{c}_t$ と隠れ状態 $\bm{h}_t$ の2つの状態を持ち、忘却ゲート $\bm{f}_t$、入力ゲート $\bm{i}_t$、出力ゲート $\bm{o}_t$ の3つのゲートで情報の流れを制御します。

$$ \begin{align} \bm{f}_t &= \sigma(\bm{W}_f \bm{x}_t + \bm{U}_f \bm{h}_{t-1} + \bm{b}_f) \\ \bm{i}_t &= \sigma(\bm{W}_i \bm{x}_t + \bm{U}_i \bm{h}_{t-1} + \bm{b}_i) \\ \bm{o}_t &= \sigma(\bm{W}_o \bm{x}_t + \bm{U}_o \bm{h}_{t-1} + \bm{b}_o) \\ \tilde{\bm{c}}_t &= \tanh(\bm{W}_c \bm{x}_t + \bm{U}_c \bm{h}_{t-1} + \bm{b}_c) \\ \bm{c}_t &= \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t \\ \bm{h}_t &= \bm{o}_t \odot \tanh(\bm{c}_t) \end{align} $$

LSTMは強力ですが、パラメータ行列が $\bm{W}_f, \bm{W}_i, \bm{W}_o, \bm{W}_c, \bm{U}_f, \bm{U}_i, \bm{U}_o, \bm{U}_c$ の8個(+バイアス4個)と多く、学習コストが高くなります。

GRUのアイデア

GRUは以下の2つの簡略化を行います。

  1. セル状態とhidden stateの統合: LSTMでは $\bm{c}_t$ と $\bm{h}_t$ を分けていましたが、GRUでは隠れ状態 $\bm{h}_t$ のみを使います。
  2. 忘却ゲートと入力ゲートの統合: LSTMでは $\bm{f}_t$ と $\bm{i}_t$ が独立でしたが、GRUでは更新ゲート $\bm{z}_t$ 1つで「古い情報をどれだけ忘れるか」と「新しい情報をどれだけ取り入れるか」を同時に制御します。具体的には、$\bm{z}_t$ の割合で新しい候補を採用し、$1 – \bm{z}_t$ の割合で前の隠れ状態を保持します。

これにより、パラメータ数が約75%に削減されます。

GRUの構造

各ゲートの定義

GRUの順伝播は以下の4つの式で定義されます。時刻 $t$ における入力を $\bm{x}_t \in \mathbb{R}^d$、前時刻の隠れ状態を $\bm{h}_{t-1} \in \mathbb{R}^h$ とします。

更新ゲート(Update Gate) $\bm{z}_t \in \mathbb{R}^h$:

$$ \bm{z}_t = \sigma(\bm{W}_z \bm{x}_t + \bm{U}_z \bm{h}_{t-1} + \bm{b}_z) $$

ここで $\sigma$ はシグモイド関数 $\sigma(x) = 1/(1 + e^{-x})$ です。$\bm{W}_z \in \mathbb{R}^{h \times d}$、$\bm{U}_z \in \mathbb{R}^{h \times h}$、$\bm{b}_z \in \mathbb{R}^h$ は学習パラメータです。

更新ゲートは、前の隠れ状態 $\bm{h}_{t-1}$ をどの程度保持するかを制御します。$\bm{z}_t$ が1に近い場合は前の状態をそのまま保持し(長期記憶の保持)、0に近い場合は新しい候補で上書きします。

リセットゲート(Reset Gate) $\bm{r}_t \in \mathbb{R}^h$:

$$ \bm{r}_t = \sigma(\bm{W}_r \bm{x}_t + \bm{U}_r \bm{h}_{t-1} + \bm{b}_r) $$

$\bm{W}_r \in \mathbb{R}^{h \times d}$、$\bm{U}_r \in \mathbb{R}^{h \times h}$、$\bm{b}_r \in \mathbb{R}^h$ は学習パラメータです。

リセットゲートは、候補隠れ状態を計算する際に、前の隠れ状態 $\bm{h}_{t-1}$ の情報をどの程度リセット(忘却)するかを制御します。$\bm{r}_t$ が0に近い場合、前の隠れ状態は無視され、候補はほぼ入力 $\bm{x}_t$ のみから計算されます。

候補隠れ状態(Candidate Hidden State) $\tilde{\bm{h}}_t \in \mathbb{R}^h$:

$$ \tilde{\bm{h}}_t = \tanh(\bm{W}_h \bm{x}_t + \bm{U}_h (\bm{r}_t \odot \bm{h}_{t-1}) + \bm{b}_h) $$

$\bm{W}_h \in \mathbb{R}^{h \times d}$、$\bm{U}_h \in \mathbb{R}^{h \times h}$、$\bm{b}_h \in \mathbb{R}^h$ は学習パラメータです。$\odot$ はアダマール積(要素ごとの積)を表します。

ここでリセットゲート $\bm{r}_t$ が作用しています。$\bm{r}_t \odot \bm{h}_{t-1}$ により、前の隠れ状態の一部(または全部)がマスクされます。

隠れ状態の更新:

$$ \bm{h}_t = \bm{z}_t \odot \bm{h}_{t-1} + (1 – \bm{z}_t) \odot \tilde{\bm{h}}_t $$

この式がGRUの核心です。更新ゲート $\bm{z}_t$ が $\bm{h}_{t-1}$(過去の情報)と $\tilde{\bm{h}}_t$(新しい候補)の線形補間を行います。

更新式の直感的理解

隠れ状態の更新式を成分ごとに書くと、各 $j = 1, \dots, h$ に対して次のようになります。

$$ h_{t,j} = z_{t,j} \cdot h_{t-1,j} + (1 – z_{t,j}) \cdot \tilde{h}_{t,j} $$

$z_{t,j} = 1$ のとき $h_{t,j} = h_{t-1,j}$ となり、$j$ 番目の次元は前の時刻からそのまま引き継がれます。逆に $z_{t,j} = 0$ のとき $h_{t,j} = \tilde{h}_{t,j}$ となり、完全に新しい値に更新されます。

これは、LSTMの忘却ゲートと入力ゲートの役割を1つのゲートに統合していることに対応します。LSTMでは $\bm{f}_t$ と $\bm{i}_t$ が独立に学習されるため、$\bm{f}_t + \bm{i}_t \neq 1$ となることがあります。GRUでは $\bm{z}_t + (1 – \bm{z}_t) = 1$ が常に成り立つため、情報の流入と流出が自動的にバランスします。

パラメータ数の比較

入力次元 $d$、隠れ次元 $h$ の場合のパラメータ数を比較します。

LSTM: ゲートが4つ(忘却、入力、出力、候補セル)あり、それぞれに $\bm{W} \in \mathbb{R}^{h \times d}$、$\bm{U} \in \mathbb{R}^{h \times h}$、$\bm{b} \in \mathbb{R}^h$ があるため、

$$ N_{\text{LSTM}} = 4 \times (hd + h^2 + h) = 4h(d + h + 1) $$

GRU: ゲートが3つ(更新、リセット、候補隠れ状態)あるため、

$$ N_{\text{GRU}} = 3 \times (hd + h^2 + h) = 3h(d + h + 1) $$

よって、GRUのパラメータ数はLSTMの $3/4 = 75\%$ です。例えば $d = 64$、$h = 128$ の場合、

$$ \begin{align} N_{\text{LSTM}} &= 4 \times 128 \times (64 + 128 + 1) = 4 \times 128 \times 193 = 98{,}816 \\ N_{\text{GRU}} &= 3 \times 128 \times (64 + 128 + 1) = 3 \times 128 \times 193 = 74{,}112 \end{align} $$

約25%のパラメータ削減となります。

GRUのBPTT(Backpropagation Through Time)

損失関数と勾配の概要

系列長 $T$ の入力に対する損失関数を $L = \sum_{t=1}^{T} L_t$ とします。各時刻 $t$ の損失 $L_t$ は隠れ状態 $\bm{h}_t$ から計算される出力 $\bm{y}_t$ に依存します。GRUの学習では、損失 $L$ を各パラメータ $\bm{W}_z, \bm{U}_z, \bm{b}_z, \bm{W}_r, \bm{U}_r, \bm{b}_r, \bm{W}_h, \bm{U}_h, \bm{b}_h$ で微分する必要があります。

BPTTの鍵は、隠れ状態に関する勾配 $\delta_t = \frac{\partial L}{\partial \bm{h}_t}$ を時刻 $t$ から $t-1$ へ伝播することです。

勾配の逆伝播

時刻 $t$ での勾配 $\delta_t$ が与えられたとき、$\bm{h}_t = \bm{z}_t \odot \bm{h}_{t-1} + (1 – \bm{z}_t) \odot \tilde{\bm{h}}_t$ の各変数に対する勾配を計算します。

まず、中間変数の勾配を定義します。

$$ \begin{align} \frac{\partial L}{\partial \tilde{\bm{h}}_t} &= \delta_t \odot (1 – \bm{z}_t) \\ \frac{\partial L}{\partial \bm{z}_t} &= \delta_t \odot (\bm{h}_{t-1} – \tilde{\bm{h}}_t) \end{align} $$

1行目は、更新式 $\bm{h}_t = \bm{z}_t \odot \bm{h}_{t-1} + (1 – \bm{z}_t) \odot \tilde{\bm{h}}_t$ を $\tilde{\bm{h}}_t$ で偏微分した結果です。2行目は同じ式を $\bm{z}_t$ で偏微分した結果で、$\bm{h}_{t-1}$ と $\tilde{\bm{h}}_t$ の差が現れます。

次に、シグモイド関数と tanh の微分を使って、ゲートの線形変換前の勾配を求めます。シグモイドの微分は $\sigma'(x) = \sigma(x)(1 – \sigma(x))$ であり、tanhの微分は $\tanh'(x) = 1 – \tanh^2(x)$ です。

更新ゲートの線形変換前の勾配:

$$ \frac{\partial L}{\partial \bm{a}_z} = \frac{\partial L}{\partial \bm{z}_t} \odot \bm{z}_t \odot (1 – \bm{z}_t) $$

ここで $\bm{a}_z = \bm{W}_z \bm{x}_t + \bm{U}_z \bm{h}_{t-1} + \bm{b}_z$ です。

候補隠れ状態のtanh前の勾配:

$$ \frac{\partial L}{\partial \bm{a}_h} = \frac{\partial L}{\partial \tilde{\bm{h}}_t} \odot (1 – \tilde{\bm{h}}_t^2) $$

ここで $\bm{a}_h = \bm{W}_h \bm{x}_t + \bm{U}_h (\bm{r}_t \odot \bm{h}_{t-1}) + \bm{b}_h$ です。

リセットゲートの勾配は、候補隠れ状態を経由して伝播します。

$$ \frac{\partial L}{\partial \bm{r}_t} = \left(\bm{U}_h^\top \frac{\partial L}{\partial \bm{a}_h}\right) \odot \bm{h}_{t-1} $$

リセットゲートの線形変換前の勾配:

$$ \frac{\partial L}{\partial \bm{a}_r} = \frac{\partial L}{\partial \bm{r}_t} \odot \bm{r}_t \odot (1 – \bm{r}_t) $$

隠れ状態の勾配伝播

$\bm{h}_{t-1}$ に関する勾配は、3つの経路から寄与を受けます。

  1. 更新式からの直接経路: $\bm{h}_t$ の更新式 $\bm{h}_t = \bm{z}_t \odot \bm{h}_{t-1} + \dots$ から、$\delta_t \odot \bm{z}_t$
  2. 更新ゲート経由: $\bm{z}_t$ が $\bm{h}_{t-1}$ に依存するため
  3. リセットゲート・候補隠れ状態経由: $\bm{r}_t$ と $\tilde{\bm{h}}_t$ が $\bm{h}_{t-1}$ に依存するため

まとめると、

$$ \begin{align} \delta_{t-1} &= \delta_t \odot \bm{z}_t \quad \text{(直接経路)} \\ &\quad + \bm{U}_z^\top \frac{\partial L}{\partial \bm{a}_z} \quad \text{(更新ゲート経由)} \\ &\quad + \bm{U}_r^\top \frac{\partial L}{\partial \bm{a}_r} \quad \text{(リセットゲート経由)} \\ &\quad + \left(\bm{U}_h^\top \frac{\partial L}{\partial \bm{a}_h}\right) \odot \bm{r}_t \quad \text{(候補隠れ状態経由)} \end{align} $$

勾配の流れの分析

GRUにおける隠れ状態の勾配伝播で最も重要な項は、第1項の $\delta_t \odot \bm{z}_t$ です。この項は行列の乗算を含まず、要素ごとの積のみです。

更新ゲート $\bm{z}_t$ が1に近い場合、勾配は $\delta_{t-1} \approx \delta_t$ と、ほぼ減衰なく前の時刻に伝播します。これはLSTMにおけるセル状態の勾配が忘却ゲート $\bm{f}_t$ を通じて伝播する仕組みと類似しています。

$$ \text{LSTM}: \frac{\partial \bm{c}_t}{\partial \bm{c}_{t-1}} = \bm{f}_t, \qquad \text{GRU}: \frac{\partial \bm{h}_t}{\partial \bm{h}_{t-1}} \supset \bm{z}_t $$

この「勾配のハイウェイ」の存在により、GRUはバニラRNNで深刻だった勾配消失問題を回避できます。長い系列でも、更新ゲートが1に近い値を取れば、遠い過去の情報を隠れ状態に保持し続けられます。

パラメータの勾配

各パラメータに対する勾配は、以下のように計算されます。全時刻の寄与を合計します。

$$ \begin{align} \frac{\partial L}{\partial \bm{W}_z} &= \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_z^{(t)}} \bm{x}_t^\top \\ \frac{\partial L}{\partial \bm{U}_z} &= \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_z^{(t)}} \bm{h}_{t-1}^\top \\ \frac{\partial L}{\partial \bm{b}_z} &= \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_z^{(t)}} \end{align} $$

$\bm{W}_r, \bm{U}_r, \bm{b}_r$ と $\bm{W}_h, \bm{U}_h, \bm{b}_h$ に対する勾配も同様の構造です。

$$ \begin{align} \frac{\partial L}{\partial \bm{W}_r} &= \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_r^{(t)}} \bm{x}_t^\top, \quad \frac{\partial L}{\partial \bm{U}_r} = \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_r^{(t)}} \bm{h}_{t-1}^\top, \quad \frac{\partial L}{\partial \bm{b}_r} = \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_r^{(t)}} \\ \frac{\partial L}{\partial \bm{W}_h} &= \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_h^{(t)}} \bm{x}_t^\top, \quad \frac{\partial L}{\partial \bm{U}_h} = \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_h^{(t)}} (\bm{r}_t \odot \bm{h}_{t-1})^\top, \quad \frac{\partial L}{\partial \bm{b}_h} = \sum_{t=1}^{T} \frac{\partial L}{\partial \bm{a}_h^{(t)}} \end{align} $$

$\bm{U}_h$ の勾配で $\bm{h}_{t-1}$ ではなく $\bm{r}_t \odot \bm{h}_{t-1}$ が現れる点に注意してください。これはリセットゲートが $\bm{h}_{t-1}$ を変調した後の値が $\bm{U}_h$ に入力されるためです。

LSTMとGRUの比較と使い分け

構造の比較

LSTM GRU
ゲート数 3(忘却・入力・出力) 2(更新・リセット)
状態変数 2($\bm{c}_t$, $\bm{h}_t$) 1($\bm{h}_t$)
パラメータ数 $4h(d+h+1)$ $3h(d+h+1)$
出力ゲート あり なし
忘却と入力 独立 連動($\bm{z}_t$ と $1-\bm{z}_t$)

使い分けの指針

GRUが有利な場面: – データセットが小さくパラメータ数を抑えたい場合 – 推論速度を重視する場合(モバイル・組み込みなど) – 比較的短い系列長のタスク

LSTMが有利な場面: – 非常に長い系列(例: 数千ステップ以上)を扱う場合。出力ゲートがあるため、隠れ状態の公開をより精密に制御できます – セル状態と隠れ状態の分離が表現力向上に寄与するような複雑なタスク – 大量のデータがあり、パラメータ数の増加が過学習を招かない場合

実際には、タスクごとに両者を試して性能を比較するのが最善です。理論的にどちらが優れるかは自明ではなく、経験的に同等の性能を示すことが多いとされています。

Pythonでの実装

GRUのスクラッチ実装

numpyのみを使って、GRUセルと学習ループを実装します。

import numpy as np
import matplotlib.pyplot as plt

# ----------------------------
# シグモイド関数とtanh関数
# ----------------------------
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))

def tanh(x):
    return np.tanh(x)

# ----------------------------
# GRUクラス(スクラッチ実装)
# ----------------------------
class GRU:
    """GRU(Gated Recurrent Unit)のスクラッチ実装"""

    def __init__(self, input_dim, hidden_dim, output_dim, lr=0.01):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.lr = lr

        # Xavier初期化
        scale_ih = np.sqrt(2.0 / (input_dim + hidden_dim))
        scale_hh = np.sqrt(2.0 / (hidden_dim + hidden_dim))
        scale_ho = np.sqrt(2.0 / (hidden_dim + output_dim))

        # 更新ゲートのパラメータ
        self.Wz = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Uz = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bz = np.zeros((hidden_dim, 1))

        # リセットゲートのパラメータ
        self.Wr = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Ur = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.br = np.zeros((hidden_dim, 1))

        # 候補隠れ状態のパラメータ
        self.Wh = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Uh = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bh = np.zeros((hidden_dim, 1))

        # 出力層のパラメータ
        self.Wy = np.random.randn(output_dim, hidden_dim) * scale_ho
        self.by = np.zeros((output_dim, 1))

    def forward(self, xs):
        """順伝播: xs は (T, input_dim, 1) のリストまたは配列"""
        T = len(xs)
        self.xs = xs
        self.zs = {}   # 更新ゲート
        self.rs = {}   # リセットゲート
        self.h_tildes = {}  # 候補隠れ状態
        self.hs = {-1: np.zeros((self.hidden_dim, 1))}  # 隠れ状態
        self.ys = {}   # 出力

        for t in range(T):
            x_t = xs[t]
            h_prev = self.hs[t - 1]

            # 更新ゲート
            self.zs[t] = sigmoid(self.Wz @ x_t + self.Uz @ h_prev + self.bz)

            # リセットゲート
            self.rs[t] = sigmoid(self.Wr @ x_t + self.Ur @ h_prev + self.br)

            # 候補隠れ状態
            self.h_tildes[t] = tanh(
                self.Wh @ x_t + self.Uh @ (self.rs[t] * h_prev) + self.bh
            )

            # 隠れ状態の更新
            self.hs[t] = self.zs[t] * h_prev + (1 - self.zs[t]) * self.h_tildes[t]

            # 出力
            self.ys[t] = self.Wy @ self.hs[t] + self.by

        return self.ys

    def backward(self, targets):
        """逆伝播(BPTT)"""
        T = len(targets)

        # パラメータ勾配の初期化
        dWz = np.zeros_like(self.Wz)
        dUz = np.zeros_like(self.Uz)
        dbz = np.zeros_like(self.bz)
        dWr = np.zeros_like(self.Wr)
        dUr = np.zeros_like(self.Ur)
        dbr = np.zeros_like(self.br)
        dWh = np.zeros_like(self.Wh)
        dUh = np.zeros_like(self.Uh)
        dbh = np.zeros_like(self.bh)
        dWy = np.zeros_like(self.Wy)
        dby = np.zeros_like(self.by)

        dh_next = np.zeros((self.hidden_dim, 1))
        loss = 0.0

        for t in reversed(range(T)):
            # 出力層の勾配(MSE損失)
            dy = self.ys[t] - targets[t]
            loss += 0.5 * np.sum(dy ** 2)

            dWy += dy @ self.hs[t].T
            dby += dy

            # 隠れ状態に対する勾配
            dh = self.Wy.T @ dy + dh_next

            # 更新ゲートの勾配
            dz = dh * (self.hs[t - 1] - self.h_tildes[t])
            dz_raw = dz * self.zs[t] * (1 - self.zs[t])

            # 候補隠れ状態の勾配
            dh_tilde = dh * (1 - self.zs[t])
            dh_tilde_raw = dh_tilde * (1 - self.h_tildes[t] ** 2)

            # リセットゲートの勾配
            dr = (self.Uh.T @ dh_tilde_raw) * self.hs[t - 1]
            dr_raw = dr * self.rs[t] * (1 - self.rs[t])

            # パラメータ勾配の累積
            dWz += dz_raw @ self.xs[t].T
            dUz += dz_raw @ self.hs[t - 1].T
            dbz += dz_raw

            dWr += dr_raw @ self.xs[t].T
            dUr += dr_raw @ self.hs[t - 1].T
            dbr += dr_raw

            dWh += dh_tilde_raw @ self.xs[t].T
            dUh += dh_tilde_raw @ (self.rs[t] * self.hs[t - 1]).T
            dbh += dh_tilde_raw

            # 前時刻への隠れ状態勾配の伝播
            dh_next = (dh * self.zs[t]                        # 直接経路
                      + self.Uz.T @ dz_raw                    # 更新ゲート経由
                      + self.Ur.T @ dr_raw                    # リセットゲート経由
                      + (self.Uh.T @ dh_tilde_raw) * self.rs[t])  # 候補経由

        # 勾配クリッピング
        for grad in [dWz, dUz, dbz, dWr, dUr, dbr, dWh, dUh, dbh, dWy, dby]:
            np.clip(grad, -5, 5, out=grad)

        # パラメータ更新
        self.Wz -= self.lr * dWz
        self.Uz -= self.lr * dUz
        self.bz -= self.lr * dbz
        self.Wr -= self.lr * dWr
        self.Ur -= self.lr * dUr
        self.br -= self.lr * dbr
        self.Wh -= self.lr * dWh
        self.Uh -= self.lr * dUh
        self.bh -= self.lr * dbh
        self.Wy -= self.lr * dWy
        self.by -= self.lr * dby

        return loss

LSTMのスクラッチ実装(比較用)

class LSTM:
    """LSTM(Long Short-Term Memory)のスクラッチ実装(比較用)"""

    def __init__(self, input_dim, hidden_dim, output_dim, lr=0.01):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.lr = lr

        scale_ih = np.sqrt(2.0 / (input_dim + hidden_dim))
        scale_hh = np.sqrt(2.0 / (hidden_dim + hidden_dim))
        scale_ho = np.sqrt(2.0 / (hidden_dim + output_dim))

        # 忘却ゲート
        self.Wf = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Uf = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bf = np.ones((hidden_dim, 1))  # 忘却ゲートのバイアスは1で初期化

        # 入力ゲート
        self.Wi = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Ui = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bi = np.zeros((hidden_dim, 1))

        # 出力ゲート
        self.Wo = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Uo = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bo = np.zeros((hidden_dim, 1))

        # 候補セル
        self.Wc = np.random.randn(hidden_dim, input_dim) * scale_ih
        self.Uc = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.bc = np.zeros((hidden_dim, 1))

        # 出力層
        self.Wy = np.random.randn(output_dim, hidden_dim) * scale_ho
        self.by = np.zeros((output_dim, 1))

    def forward(self, xs):
        T = len(xs)
        self.xs = xs
        self.fs, self.ins, self.os, self.c_tildes = {}, {}, {}, {}
        self.cs = {-1: np.zeros((self.hidden_dim, 1))}
        self.hs = {-1: np.zeros((self.hidden_dim, 1))}
        self.ys = {}

        for t in range(T):
            x_t = xs[t]
            h_prev = self.hs[t - 1]
            c_prev = self.cs[t - 1]

            self.fs[t] = sigmoid(self.Wf @ x_t + self.Uf @ h_prev + self.bf)
            self.ins[t] = sigmoid(self.Wi @ x_t + self.Ui @ h_prev + self.bi)
            self.os[t] = sigmoid(self.Wo @ x_t + self.Uo @ h_prev + self.bo)
            self.c_tildes[t] = tanh(self.Wc @ x_t + self.Uc @ h_prev + self.bc)

            self.cs[t] = self.fs[t] * c_prev + self.ins[t] * self.c_tildes[t]
            self.hs[t] = self.os[t] * tanh(self.cs[t])
            self.ys[t] = self.Wy @ self.hs[t] + self.by

        return self.ys

    def backward(self, targets):
        T = len(targets)

        dWf = np.zeros_like(self.Wf); dUf = np.zeros_like(self.Uf); dbf = np.zeros_like(self.bf)
        dWi = np.zeros_like(self.Wi); dUi = np.zeros_like(self.Ui); dbi = np.zeros_like(self.bi)
        dWo = np.zeros_like(self.Wo); dUo = np.zeros_like(self.Uo); dbo = np.zeros_like(self.bo)
        dWc = np.zeros_like(self.Wc); dUc = np.zeros_like(self.Uc); dbc = np.zeros_like(self.bc)
        dWy = np.zeros_like(self.Wy); dby = np.zeros_like(self.by)

        dh_next = np.zeros((self.hidden_dim, 1))
        dc_next = np.zeros((self.hidden_dim, 1))
        loss = 0.0

        for t in reversed(range(T)):
            dy = self.ys[t] - targets[t]
            loss += 0.5 * np.sum(dy ** 2)

            dWy += dy @ self.hs[t].T
            dby += dy

            dh = self.Wy.T @ dy + dh_next

            do = dh * tanh(self.cs[t])
            do_raw = do * self.os[t] * (1 - self.os[t])

            dc = dh * self.os[t] * (1 - tanh(self.cs[t]) ** 2) + dc_next

            df = dc * self.cs[t - 1]
            df_raw = df * self.fs[t] * (1 - self.fs[t])

            di = dc * self.c_tildes[t]
            di_raw = di * self.ins[t] * (1 - self.ins[t])

            dc_tilde = dc * self.ins[t]
            dc_tilde_raw = dc_tilde * (1 - self.c_tildes[t] ** 2)

            dWf += df_raw @ self.xs[t].T; dUf += df_raw @ self.hs[t-1].T; dbf += df_raw
            dWi += di_raw @ self.xs[t].T; dUi += di_raw @ self.hs[t-1].T; dbi += di_raw
            dWo += do_raw @ self.xs[t].T; dUo += do_raw @ self.hs[t-1].T; dbo += do_raw
            dWc += dc_tilde_raw @ self.xs[t].T; dUc += dc_tilde_raw @ self.hs[t-1].T; dbc += dc_tilde_raw

            dc_next = dc * self.fs[t]
            dh_next = (self.Uf.T @ df_raw + self.Ui.T @ di_raw
                      + self.Uo.T @ do_raw + self.Uc.T @ dc_tilde_raw)

        for grad in [dWf, dUf, dbf, dWi, dUi, dbi, dWo, dUo, dbo,
                     dWc, dUc, dbc, dWy, dby]:
            np.clip(grad, -5, 5, out=grad)

        self.Wf -= self.lr * dWf; self.Uf -= self.lr * dUf; self.bf -= self.lr * dbf
        self.Wi -= self.lr * dWi; self.Ui -= self.lr * dUi; self.bi -= self.lr * dbi
        self.Wo -= self.lr * dWo; self.Uo -= self.lr * dUo; self.bo -= self.lr * dbo
        self.Wc -= self.lr * dWc; self.Uc -= self.lr * dUc; self.bc -= self.lr * dbc
        self.Wy -= self.lr * dWy; self.by -= self.lr * dby

        return loss

時系列予測タスクでの性能比較

正弦波の予測タスクを使って、GRUとLSTMの学習曲線とパラメータ数を比較します。

# ----------------------------
# データ生成: 正弦波の予測
# ----------------------------
np.random.seed(42)

# sin波にノイズを加えたデータを生成
T_total = 500
t_axis = np.linspace(0, 10 * np.pi, T_total)
data = np.sin(t_axis) + 0.1 * np.random.randn(T_total)

# 入力系列と教師系列の作成
seq_len = 20  # 過去20ステップから次の1ステップを予測
X_sequences = []
Y_sequences = []
for i in range(T_total - seq_len):
    X_sequences.append(data[i:i + seq_len])
    Y_sequences.append(data[i + seq_len])

X_sequences = np.array(X_sequences)
Y_sequences = np.array(Y_sequences)

# 訓練・テスト分割
n_train = 350
X_train = X_sequences[:n_train]
Y_train = Y_sequences[:n_train]
X_test = X_sequences[n_train:]
Y_test = Y_sequences[n_train:]

# ----------------------------
# 学習ループ
# ----------------------------
input_dim = 1
hidden_dim = 16
output_dim = 1
n_epochs = 50

gru_model = GRU(input_dim, hidden_dim, output_dim, lr=0.001)
lstm_model = LSTM(input_dim, hidden_dim, output_dim, lr=0.001)

gru_losses = []
lstm_losses = []

for epoch in range(n_epochs):
    gru_epoch_loss = 0.0
    lstm_epoch_loss = 0.0

    # ミニバッチなしで各系列ごとに学習
    for i in range(n_train):
        # 入力を (T, 1, 1) の形式に変換
        xs = [data_pt.reshape(1, 1) for data_pt in X_train[i]]
        target = [Y_train[i].reshape(1, 1)]

        # GRU
        gru_model.forward(xs)
        # 最終時刻の出力のみで損失を計算
        gru_targets = {t: np.zeros((1, 1)) for t in range(seq_len)}
        gru_targets[seq_len - 1] = target[0]
        gru_loss = gru_model.backward(gru_targets)
        gru_epoch_loss += float(np.sum((gru_model.ys[seq_len - 1] - target[0]) ** 2))

        # LSTM
        lstm_model.forward(xs)
        lstm_targets = {t: np.zeros((1, 1)) for t in range(seq_len)}
        lstm_targets[seq_len - 1] = target[0]
        lstm_loss = lstm_model.backward(lstm_targets)
        lstm_epoch_loss += float(np.sum((lstm_model.ys[seq_len - 1] - target[0]) ** 2))

    gru_losses.append(gru_epoch_loss / n_train)
    lstm_losses.append(lstm_epoch_loss / n_train)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}  "
              f"GRU Loss: {gru_losses[-1]:.6f}  "
              f"LSTM Loss: {lstm_losses[-1]:.6f}")

# ----------------------------
# パラメータ数の確認
# ----------------------------
n_gru_params = (3 * (hidden_dim * input_dim + hidden_dim * hidden_dim + hidden_dim)
                + output_dim * hidden_dim + output_dim)
n_lstm_params = (4 * (hidden_dim * input_dim + hidden_dim * hidden_dim + hidden_dim)
                 + output_dim * hidden_dim + output_dim)
print(f"\nGRU  パラメータ数: {n_gru_params}")
print(f"LSTM パラメータ数: {n_lstm_params}")
print(f"比率 (GRU/LSTM): {n_gru_params / n_lstm_params:.2%}")

# ----------------------------
# 学習曲線の可視化
# ----------------------------
plt.figure(figsize=(10, 5))
plt.plot(range(1, n_epochs + 1), gru_losses, label="GRU", linewidth=2)
plt.plot(range(1, n_epochs + 1), lstm_losses, label="LSTM", linewidth=2, linestyle="--")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Training Loss: GRU vs LSTM")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ----------------------------
# テストデータでの予測結果の可視化
# ----------------------------
gru_preds = []
lstm_preds = []

for i in range(len(X_test)):
    xs = [data_pt.reshape(1, 1) for data_pt in X_test[i]]

    gru_model.forward(xs)
    gru_preds.append(float(gru_model.ys[seq_len - 1]))

    lstm_model.forward(xs)
    lstm_preds.append(float(lstm_model.ys[seq_len - 1]))

plt.figure(figsize=(12, 5))
plt.plot(Y_test, label="Ground Truth", linewidth=2, color="black")
plt.plot(gru_preds, label="GRU Prediction", linewidth=1.5, alpha=0.8)
plt.plot(lstm_preds, label="LSTM Prediction", linewidth=1.5, alpha=0.8, linestyle="--")
plt.xlabel("Test Sample Index")
plt.ylabel("Value")
plt.title("Time Series Prediction: GRU vs LSTM")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

上記のコードを実行すると、GRUとLSTMの学習曲線が比較できます。GRUはパラメータ数が約75%であるにもかかわらず、この程度のタスクではLSTMと同等の学習速度と精度を示すことが確認できます。

まとめ

本記事では、GRU(Gated Recurrent Unit)について解説しました。

  • GRUはLSTMを簡略化したモデルで、更新ゲートとリセットゲートの2つのゲートで動作します
  • 更新ゲート $\bm{z}_t$ は、忘却と入力の2つの役割を $\bm{z}_t$ と $1 – \bm{z}_t$ の線形補間で統一的に制御します
  • セル状態を廃止し、隠れ状態のみで情報を伝搬することで、パラメータ数をLSTMの75%に削減しています
  • BPTTにおいて、$\delta_{t-1} \approx \bm{z}_t \odot \delta_t$ という直接経路により、勾配消失を回避します
  • 多くのタスクでLSTMと同等の性能を発揮し、計算効率の面で有利です

次のステップとして、以下の記事も参考にしてください。