CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PyTorch の GRU / LSTM を検算してみる

以前のエントリで扱った Simple RNN の検算は、個人的になかなか良い勉強になった。

blog.amedama.jp

そこで、今回は Simple RNN の改良版となる GRU (Gated Recurrent Unit) と LSTM (Long Short Term Memory) についても検算してみる。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.5.2
BuildVersion:   20G95
$ python -V
Python 3.9.6
$ pip list | grep torch
torch                    1.9.0

もくじ

下準備

下準備として、あらかじめ PyTorch をインストールしておく。

$ pip install torch

続いて Python のインタプリタを起動する。

$ python

そして、PyTorch のパッケージをインポートする。

>>> import torch
>>> from torch import nn

GRU を検算する

Simple RNN は、仕組みが単純な一方で隠れ状態が入力によって無条件に更新されてしまう。 そのため、隠れ状態に昔の情報が残りにくいことから、長期的な記憶を保つことが難しいという問題があった。 GRU では、それをゲートという仕組みを導入することで改善を試みている。

まずは次のように GRU クラスをインスタンス化する。 Simple RNN のときと同じように、モデルの初期状態の重みをそのまま使って検算する。

>>> input_dim = 3  # モデルの入力ベクトルの次元数
>>> hidden_dim = 4  # モデルの出力 (隠れ状態) ベクトルの次元数
>>> model = nn.GRU(input_size=input_dim, hidden_size=hidden_dim)

インスタンス化するときに必要な引数は RNN クラスと変わらない。 つまり、入力と出力のサイズを渡すだけ。

インスタンス化できたら、モデルのパラメータを確認しよう。 どうやら、モデルのパラメータが持っている名前は RNN クラスと同じようだ。 ただし、重みを保持している行列のサイズは増している。

>>> from pprint import pprint
>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[-0.3619, -0.1291, -0.0647],
        [-0.4406, -0.2705, -0.3480],
        [ 0.0360,  0.3222,  0.2494],
        [-0.0738, -0.3214,  0.4445],
        [-0.3551,  0.3078, -0.0846],
        [-0.4367,  0.4282, -0.1521],
        [-0.4895,  0.0713,  0.0217],
        [-0.2439,  0.4704, -0.2078],
        [ 0.0460,  0.2528,  0.3555],
        [-0.3008, -0.0595,  0.0586],
        [-0.3535,  0.2088, -0.2179],
        [ 0.2923,  0.0291,  0.4044]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[ 0.0406,  0.3097, -0.2765, -0.2359],
        [ 0.4449,  0.3376,  0.3715, -0.3207],
        [ 0.0157,  0.0347, -0.0091, -0.0438],
        [ 0.1630,  0.3619,  0.3797, -0.0845],
        [ 0.1729, -0.1405,  0.0844, -0.3560],
        [ 0.0711, -0.3750, -0.0721, -0.4998],
        [-0.4140, -0.1105, -0.1611,  0.1338],
        [-0.0574, -0.1216,  0.2439, -0.2021],
        [ 0.1568,  0.2177,  0.4511,  0.4009],
        [-0.4453, -0.0780, -0.1764,  0.3598],
        [ 0.1704,  0.3918, -0.0727,  0.2112],
        [ 0.3841,  0.0154,  0.2495,  0.1840]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([-0.3642, -0.2804,  0.3874, -0.0016, -0.0540, -0.3060, -0.0446, -0.0145,
         0.1529, -0.4700,  0.3887,  0.1273], requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.0260, -0.0787, -0.3992,  0.4587,  0.3522,  0.0618,  0.0865, -0.2561,
         0.0439, -0.4722,  0.2414, -0.2022], requires_grad=True))]

続いて、ダミーの入力データを用意しよう。 ダミーの入力データの形状は RNN を使った場合と変わらない。

>>> T = 5  # 入力する系列データの長さ
>>> batch_size = 2  # 一度に処理するデータの数
>>> X = torch.randn(T, batch_size, input_dim)  # ダミーの入力データ
>>> X.shape
torch.Size([5, 2, 3])

ダミーの入力データをモデルに与えて出力を得る。

>>> H, hn = model(X)

出力は入力の系列データに対応する隠れ状態と、最後の隠れ状態になっている。 この形状も RNN と変わらない。 つまり、PyTorch において GRU は単純に名前を変えるだけで RNN から差し替えて使うことができる。

>>> H.shape, hn.shape
(torch.Size([5, 2, 4]), torch.Size([1, 2, 4]))
>>> H[-1]
tensor([[ 0.5352, -0.5132,  0.2607,  0.5642],
        [-0.0264, -0.6124,  0.5123, -0.2023]], grad_fn=<SelectBackward>)
>>> hn
tensor([[[ 0.5352, -0.5132,  0.2607,  0.5642],
         [-0.0264, -0.6124,  0.5123, -0.2023]]], grad_fn=<StackBackward>)

それでは、ここからは実際に検算に入ろう。 PyTorch で使われている GRU の数式は以下のドキュメントで確認できる。

pytorch.org

数式は以下のとおり。 Simple RNN のときは 1 つだった式が 4 つに増えている。 なお、最終的に求めたいのは一番下にある「入力  x_t に対応した隠れ状態  h_t」になる。

\displaystyle{
        \begin{array}{ll}
            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
        \end{array}
}

ここで  \sigma はシグモイド関数を表す。  r_t z_t n_t は、活性化関数の違いはあるものの、基本的にはいずれも  W_i x_t + b_i + W_h h_{(t-1)} + b_h の形になっていることがわかる。

数式が確認できたところでモデルのパラメータから重みを取り出していこう。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_i = model_weights['weight_ih_l0']
>>> W_h = model_weights['weight_hh_l0']
>>> b_i = model_weights['bias_ih_l0']
>>> b_h = model_weights['bias_hh_l0']

上記は、部分ごとに  r_t 用と  z_t 用と  n_t 用に分かれている。 本来は一気に行列計算した上で後から取り出すわけだけど、今回は数式をなぞるために先に取り出しておこう。

>>> W_ir = W_i[:hidden_dim]
>>> W_iz = W_i[hidden_dim: hidden_dim * 2]
>>> W_in = W_i[hidden_dim * 2:]
>>> 
>>> W_hr = W_h[:hidden_dim]
>>> W_hz = W_h[hidden_dim: hidden_dim * 2]
>>> W_hn = W_h[hidden_dim * 2:]
>>> 
>>> b_ir = b_i[:hidden_dim]
>>> b_iz = b_i[hidden_dim: hidden_dim * 2]
>>> b_in = b_i[hidden_dim * 2:]
>>> 
>>> b_hr = b_h[:hidden_dim]
>>> b_hz = b_h[hidden_dim: hidden_dim * 2]
>>> b_hn = b_h[hidden_dim * 2:]

あとは定義どおりに計算していく。

まずは t = 0 の状態から。 つまり、X[0] に対応する隠れ状態を計算してみよう。 t = 0 かつ、初期の隠れ状態を渡していないので  W_h h_{(t-1)} の項が存在しない。

>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[0].T).T + b_ir + b_hr)
>>> z_t = torch.sigmoid(torch.matmul(W_iz, X[0].T).T + b_iz + b_hz)
>>> n_t = torch.tanh(torch.matmul(W_in, X[0].T).T + b_in + r_t * b_hn)
>>> h_t = (1 - z_t) * n_t

確認すると、モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。

>>> H[0]
tensor([[-0.1412, -0.2934,  0.3071, -0.2858],
        [ 0.1785, -0.1226,  0.2666,  0.0485]], grad_fn=<SelectBackward>)
>>> h_t
tensor([[-0.1412, -0.2934,  0.3071, -0.2858],
        [ 0.1785, -0.1226,  0.2666,  0.0485]])

次は t = 1 に対する計算を取り上げつつ、それぞれの式が意味するところを考えてみる。

まず、以下の  r_t はリセットゲート (reset gate) と呼ばれている。 リセットゲートの式は、活性化関数がシグモイド関数なので、成分は 0 ~ 1 の範囲になる。

>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[1].T).T + b_ir + torch.matmul(W_hr, H[0].T).T + b_hr)

リセットゲートは、後ほど新しい隠れ状態の候補を作るときに、一つ前の隠れ状態と積を取る。 それによって、次の隠れ状態に、一つ前の隠れ状態をどれくらい反映するか制御する役目を担っている。

ゲートは成分の値が 0 のときに「閉じている」、1 のときに「開いている」と表現するらしい。 もちろん、ゲートの値は人間が明示的に与えるのではなく、学習するデータによって最適化される。

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
        \end{array}
}

以下の  n_t は、元の論文には名前付きで登場しないものの、PyTorch の中ではニューゲート (new gate) と呼ばれているようだ 1。 これは、言うなれば次の隠れ状態の候補となるもの。 式は RNN の隠れ状態を作るときのものに近いけど、みると一つ前の隠れ状態に先ほどのリセットゲートがかけられている。 これによって、次の隠れ状態に一つ前の隠れ状態をどれくらい混ぜるか、つまり影響を与えるかを制御している。 たとえば、リセットゲートの成分がすべてゼロなら、次の隠れ状態の候補を作るときに、一つ前の隠れ状態をまったく考慮しないことになる。

>>> n_t = torch.tanh(torch.matmul(W_in, X[1].T).T + b_in + r_t * (torch.matmul(W_hn, H[0].T).T + b_hn))

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
        \end{array}
}

次に、以下の  z_t はアップデートゲート (update gate) と呼ばれている。 このゲートは、次の隠れ状態を作るときに、どれくらい一つ前の隠れ状態を引き継ぐかを制御している。

>>> z_t = torch.sigmoid(torch.matmul(W_iz, X[1].T).T + b_iz + torch.matmul(W_hz, H[0].T).T + b_hz)

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
        \end{array}
}

最後に、以下で次の隠れ状態を求めている。 式では、先ほど計算したニューゲートとアップデートゲートが登場している。 次の隠れ状態は、基本的にニューゲートと一つ前の隠れ状態が混ぜられていることがわかる。 そして、混ぜる比率をアップデートゲートが制御している。 もしアップデートゲートの成分がすべてゼロなら、一つ前の隠れ状態はまったく考慮されず、すべてニューゲートのものになる。

>>> h_t = (1 - z_t) * n_t + z_t * H[0]

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
        \end{array}
}

計算した隠れ状態を、最初に得られたものと比較してみよう。

>>> H[1]
tensor([[-0.0039, -0.3424,  0.4580, -0.2490],
        [ 0.3869, -0.4714,  0.1700,  0.4022]], grad_fn=<SelectBackward>)
>>> h_t
tensor([[-0.0039, -0.3424,  0.4580, -0.2490],
        [ 0.3869, -0.4714,  0.1700,  0.4022]], grad_fn=<AddBackward0>)

モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。

LSTM を検算する

続いては LTSM についても同様に検算してみる。

LSTM では、Simple RNN や GRU で扱っていた隠れ状態が「長期記憶」と「短期記憶」に分かれている。 これによって、長いスパンで記憶しておく必要のある情報と、特定のタイミングでのみ必要な情報を扱いやすくしているらしい。 ちなみに LSTM は前述の GRU よりも歴史のあるアーキテクチャで、GRU は LSTM の特殊形と捉えることもできるようだ。

LSTM も、PyTorch ではクラスの名前を LSTM に変更するだけで使うことができる。

>>> model = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim)

モデルに含まれるパラメータを確認してみよう。 パラメータの名前は同じだけど、先ほどの GRU よりも、さらに行列のサイズが増えている。

>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[ 0.3498, -0.0745,  0.0339],
        [-0.0537, -0.4582, -0.0305],
        [-0.1209, -0.1292,  0.0014],
        [-0.4880,  0.4027,  0.2235],
        [-0.3940, -0.4997, -0.4360],
        [ 0.4677, -0.2913,  0.3161],
        [-0.4162, -0.4060, -0.0483],
        [ 0.0281,  0.0586, -0.4602],
        [ 0.0145,  0.3151, -0.0132],
        [ 0.2642,  0.0724, -0.1972],
        [-0.1406,  0.2249, -0.0125],
        [-0.1339, -0.1570, -0.4393],
        [-0.1411, -0.1534,  0.4226],
        [-0.3554,  0.0628,  0.3336],
        [-0.3037, -0.4630, -0.0022],
        [-0.4711,  0.4282,  0.4648]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[ 0.1409,  0.2027,  0.4179,  0.2062],
        [ 0.0182,  0.1814, -0.0826,  0.0193],
        [-0.3766, -0.4391,  0.0336, -0.0875],
        [-0.3921,  0.0581,  0.3184, -0.4362],
        [ 0.0616, -0.0611, -0.0350,  0.2251],
        [-0.1458, -0.2994, -0.4362, -0.0643],
        [ 0.1637, -0.1193,  0.4780, -0.0938],
        [-0.0130,  0.1613,  0.2988, -0.2142],
        [-0.1978,  0.3739, -0.4704,  0.3770],
        [ 0.4956, -0.3259,  0.0976,  0.1588],
        [ 0.2641, -0.2511, -0.3984,  0.2107],
        [ 0.4604,  0.1646, -0.0299,  0.4243],
        [ 0.4658, -0.1663, -0.0066, -0.2386],
        [ 0.2184,  0.3376, -0.2343,  0.2853],
        [-0.2000, -0.4610, -0.2787, -0.2990],
        [ 0.3782, -0.1738, -0.1492, -0.2577]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([-0.1566,  0.4039,  0.2361,  0.1422,  0.1875,  0.0293, -0.2778,  0.4168,
        -0.4732,  0.0960,  0.1191,  0.1664,  0.1017,  0.1526,  0.4041,  0.0643],
       requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.2511,  0.2747, -0.0801, -0.1251,  0.0565, -0.3207,  0.0877,  0.2105,
        -0.3742, -0.3953, -0.3199, -0.1545, -0.1276, -0.4406, -0.3679,  0.4121],
       requires_grad=True))]

モデルにダミーデータを与えてみよう。 このとき、LSTM では返り値が RNNGRU よりも増えている。

>>> H, (hn, cn) = model(X)

上記で、HhnRNNGRU と同じ隠れ状態を表している。 ただし、LSTM においては隠れ状態が「短期記憶」に対応する。

>>> H[-1]
tensor([[-0.2198, -0.1965, -0.0670, -0.5722],
        [-0.1991, -0.0771, -0.1617, -0.0441]], grad_fn=<SelectBackward>)
>>> 
>>> hn
tensor([[[-0.2198, -0.1965, -0.0670, -0.5722],
         [-0.1991, -0.0771, -0.1617, -0.0441]]], grad_fn=<StackBackward>)

返り値で増えているのは、前述した「長期記憶」になる。 詳しくは後述するけど、LSTM の「短期記憶」はこの「長期記憶」から抜き出して作る。

>>> cn
tensor([[[-0.3814, -0.5555, -0.1285, -0.9944],
         [-0.6473, -0.2559, -0.2589, -0.1025]]], grad_fn=<StackBackward>)

使う上で理解すべき概念の説明が終わったところで、検算に移る。 PyTorch で使われている LSTM の数式は以下のドキュメントで確認できる。

pytorch.org

数式は次のとおり。 GRU のときよりも、さらに増えている。

\displaystyle{
        \begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
            h_t = o_t \odot \tanh(c_t) \\
        \end{array}
}

上記で  \odot はアダマール積を表している。

モデルからパラメータを取り出そう。 先ほどと同じように、数式をなぞるために行列から必要な箇所を取り出して名前をつけていく。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_i = model_weights['weight_ih_l0']
>>> W_h = model_weights['weight_hh_l0']
>>> b_i = model_weights['bias_ih_l0']
>>> b_h = model_weights['bias_hh_l0']
>>> 
>>> W_ii = W_i[:hidden_dim]
>>> W_if = W_i[hidden_dim: hidden_dim * 2]
>>> W_ig = W_i[hidden_dim * 2: hidden_dim * 3]
>>> W_io = W_i[hidden_dim * 3:]
>>> 
>>> W_hi = W_h[:hidden_dim]
>>> W_hf = W_h[hidden_dim: hidden_dim * 2]
>>> W_hg = W_h[hidden_dim * 2: hidden_dim * 3]
>>> W_ho = W_h[hidden_dim * 3:]
>>> 
>>> b_ii = b_i[:hidden_dim]
>>> b_if = b_i[hidden_dim: hidden_dim * 2]
>>> b_ig = b_i[hidden_dim * 2: hidden_dim * 3]
>>> b_io = b_i[hidden_dim * 3:]
>>> 
>>> b_hi = b_h[:hidden_dim]
>>> b_hf = b_h[hidden_dim: hidden_dim * 2]
>>> b_hg = b_h[hidden_dim * 2: hidden_dim * 3]
>>> b_ho = b_h[hidden_dim * 3:]

とりあえず、t = 0 の時点の隠れ状態 (短期記憶) を数式のとおりに計算してみよう。 t = 0 かつ、初期の隠れ状態と長期記憶を渡していないので存在しない項がある点に注意する。

>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[0].T).T + b_ii + b_hi)
>>> f_t = torch.sigmoid(torch.matmul(W_if, X[0].T).T + b_if + b_hf)
>>> g_t = torch.tanh(torch.matmul(W_ig, X[0].T).T + b_ig + b_hg)
>>> o_t = torch.sigmoid(torch.matmul(W_io, X[0].T).T + b_io + b_ho)
>>> c_t = i_t * g_t
>>> h_t = o_t * torch.tanh(c_t)

計算した値と、モデルから返ってきた隠れ状態を比較してみよう。

>>> H[0]
tensor([[-0.1018, -0.0494, -0.0653,  0.1273],
        [-0.0617, -0.1598,  0.0546, -0.3024]], grad_fn=<SelectBackward>)
>>> 
>>> h_t
tensor([[-0.1018, -0.0494, -0.0653,  0.1273],
        [-0.0617, -0.1598,  0.0546, -0.3024]])

ちゃんと一致している。

続いては数式の意味を確認しながら t = 1 も計算してみよう。 計算する上で、一つ前の長期記憶が必要になるので c_0 という名前で記録しておく。

>>> c_0 = c_t

まず計算するのは、入力ゲート (input gate) で、新しい入力  x_t を、どれくらい次の長期記憶に反映するかを司っている。

>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[1].T).T + b_ii + torch.matmul(W_hi, H[0].T).T + b_hi)

次に計算しているのは忘却ゲート (forget gate) で、一つ前の長期記憶を、次にどれだけ引き継ぐかを担っている。

>>> f_t = torch.sigmoid(torch.matmul(W_if, X[1].T).T + b_if + torch.matmul(W_hf, H[0].T).T + b_hf)

以下の式は、論文では名前がついていないけど、PyTorch ではセルゲート (cell gate) と呼んでいる。 これは Simple RNN で隠れ状態を計算していた式と同じ。 入力と、一つ前の隠れ状態 (短期記憶) を混ぜている。

>>> g_t = torch.tanh(torch.matmul(W_ig, X[1].T).T + b_ig + torch.matmul(W_hg, H[0].T).T + b_hg)

以下は出力ゲート (output gate) で、長期記憶から短期記憶をどのように抜き出すかを司っている。

>>> o_t = torch.sigmoid(torch.matmul(W_io, X[1].T).T + b_io + torch.matmul(W_ho, H[0].T).T + b_ho)

以下で、一つ前の長期記憶と出力ゲートを混ぜて、次の長期記憶を作っている。 どんな風に混ぜるかは、忘却ゲートと入力ゲートの値に依存する。

>>> c_t = f_t * c_0 + i_t * g_t

そして、最後に長期記憶から出力ゲートを使って短期記憶を抜き出している。

>>> h_t = o_t * torch.tanh(c_t)

隠れ状態を比べてみると、ちゃんと値が一致していることがわかる。

>>> H[1]
tensor([[-0.1224, -0.1573,  0.0294,  0.0794],
        [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<SelectBackward>)
>>> 
>>> h_t
tensor([[-0.1224, -0.1573,  0.0294,  0.0794],
        [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<MulBackward0>)

いじょう。

参考

arxiv.org

(PDF) Long Short-term Memory

youtu.be

youtu.be


  1. 役目的にはゲートではないので何だか変な気もする