以前のエントリで扱った Simple RNN の検算は、個人的になかなか良い勉強になった。
そこで、今回は Simple RNN の改良版となる GRU (Gated Recurrent Unit) と LSTM (Long Short Term Memory) についても検算してみる。
使った環境は次のとおり。
$ sw_vers ProductName: macOS ProductVersion: 11.5.2 BuildVersion: 20G95 $ python -V Python 3.9.6 $ pip list | grep torch torch 1.9.0
もくじ
下準備
下準備として、あらかじめ PyTorch をインストールしておく。
$ pip install torch
続いて Python のインタプリタを起動する。
$ python
そして、PyTorch のパッケージをインポートする。
>>> import torch >>> from torch import nn
GRU を検算する
Simple RNN は、仕組みが単純な一方で隠れ状態が入力によって無条件に更新されてしまう。 そのため、隠れ状態に昔の情報が残りにくいことから、長期的な記憶を保つことが難しいという問題があった。 GRU では、それをゲートという仕組みを導入することで改善を試みている。
まずは次のように GRU
クラスをインスタンス化する。
Simple RNN のときと同じように、モデルの初期状態の重みをそのまま使って検算する。
>>> input_dim = 3 # モデルの入力ベクトルの次元数 >>> hidden_dim = 4 # モデルの出力 (隠れ状態) ベクトルの次元数 >>> model = nn.GRU(input_size=input_dim, hidden_size=hidden_dim)
インスタンス化するときに必要な引数は RNN
クラスと変わらない。
つまり、入力と出力のサイズを渡すだけ。
インスタンス化できたら、モデルのパラメータを確認しよう。
どうやら、モデルのパラメータが持っている名前は RNN
クラスと同じようだ。
ただし、重みを保持している行列のサイズは増している。
>>> from pprint import pprint >>> pprint(list(model.named_parameters())) [('weight_ih_l0', Parameter containing: tensor([[-0.3619, -0.1291, -0.0647], [-0.4406, -0.2705, -0.3480], [ 0.0360, 0.3222, 0.2494], [-0.0738, -0.3214, 0.4445], [-0.3551, 0.3078, -0.0846], [-0.4367, 0.4282, -0.1521], [-0.4895, 0.0713, 0.0217], [-0.2439, 0.4704, -0.2078], [ 0.0460, 0.2528, 0.3555], [-0.3008, -0.0595, 0.0586], [-0.3535, 0.2088, -0.2179], [ 0.2923, 0.0291, 0.4044]], requires_grad=True)), ('weight_hh_l0', Parameter containing: tensor([[ 0.0406, 0.3097, -0.2765, -0.2359], [ 0.4449, 0.3376, 0.3715, -0.3207], [ 0.0157, 0.0347, -0.0091, -0.0438], [ 0.1630, 0.3619, 0.3797, -0.0845], [ 0.1729, -0.1405, 0.0844, -0.3560], [ 0.0711, -0.3750, -0.0721, -0.4998], [-0.4140, -0.1105, -0.1611, 0.1338], [-0.0574, -0.1216, 0.2439, -0.2021], [ 0.1568, 0.2177, 0.4511, 0.4009], [-0.4453, -0.0780, -0.1764, 0.3598], [ 0.1704, 0.3918, -0.0727, 0.2112], [ 0.3841, 0.0154, 0.2495, 0.1840]], requires_grad=True)), ('bias_ih_l0', Parameter containing: tensor([-0.3642, -0.2804, 0.3874, -0.0016, -0.0540, -0.3060, -0.0446, -0.0145, 0.1529, -0.4700, 0.3887, 0.1273], requires_grad=True)), ('bias_hh_l0', Parameter containing: tensor([-0.0260, -0.0787, -0.3992, 0.4587, 0.3522, 0.0618, 0.0865, -0.2561, 0.0439, -0.4722, 0.2414, -0.2022], requires_grad=True))]
続いて、ダミーの入力データを用意しよう。
ダミーの入力データの形状は RNN
を使った場合と変わらない。
>>> T = 5 # 入力する系列データの長さ >>> batch_size = 2 # 一度に処理するデータの数 >>> X = torch.randn(T, batch_size, input_dim) # ダミーの入力データ >>> X.shape torch.Size([5, 2, 3])
ダミーの入力データをモデルに与えて出力を得る。
>>> H, hn = model(X)
出力は入力の系列データに対応する隠れ状態と、最後の隠れ状態になっている。
この形状も RNN
と変わらない。
つまり、PyTorch において GRU
は単純に名前を変えるだけで RNN
から差し替えて使うことができる。
>>> H.shape, hn.shape (torch.Size([5, 2, 4]), torch.Size([1, 2, 4])) >>> H[-1] tensor([[ 0.5352, -0.5132, 0.2607, 0.5642], [-0.0264, -0.6124, 0.5123, -0.2023]], grad_fn=<SelectBackward>) >>> hn tensor([[[ 0.5352, -0.5132, 0.2607, 0.5642], [-0.0264, -0.6124, 0.5123, -0.2023]]], grad_fn=<StackBackward>)
それでは、ここからは実際に検算に入ろう。 PyTorch で使われている GRU の数式は以下のドキュメントで確認できる。
数式は以下のとおり。 Simple RNN のときは 1 つだった式が 4 つに増えている。 なお、最終的に求めたいのは一番下にある「入力 に対応した隠れ状態 」になる。
ここで はシグモイド関数を表す。 と と は、活性化関数の違いはあるものの、基本的にはいずれも の形になっていることがわかる。
数式が確認できたところでモデルのパラメータから重みを取り出していこう。
>>> model_weights = {name: param.data for name, param ... in model.named_parameters()} >>> >>> W_i = model_weights['weight_ih_l0'] >>> W_h = model_weights['weight_hh_l0'] >>> b_i = model_weights['bias_ih_l0'] >>> b_h = model_weights['bias_hh_l0']
上記は、部分ごとに 用と 用と 用に分かれている。 本来は一気に行列計算した上で後から取り出すわけだけど、今回は数式をなぞるために先に取り出しておこう。
>>> W_ir = W_i[:hidden_dim] >>> W_iz = W_i[hidden_dim: hidden_dim * 2] >>> W_in = W_i[hidden_dim * 2:] >>> >>> W_hr = W_h[:hidden_dim] >>> W_hz = W_h[hidden_dim: hidden_dim * 2] >>> W_hn = W_h[hidden_dim * 2:] >>> >>> b_ir = b_i[:hidden_dim] >>> b_iz = b_i[hidden_dim: hidden_dim * 2] >>> b_in = b_i[hidden_dim * 2:] >>> >>> b_hr = b_h[:hidden_dim] >>> b_hz = b_h[hidden_dim: hidden_dim * 2] >>> b_hn = b_h[hidden_dim * 2:]
あとは定義どおりに計算していく。
まずは t = 0
の状態から。
つまり、X[0]
に対応する隠れ状態を計算してみよう。
t = 0
かつ、初期の隠れ状態を渡していないので の項が存在しない。
>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[0].T).T + b_ir + b_hr) >>> z_t = torch.sigmoid(torch.matmul(W_iz, X[0].T).T + b_iz + b_hz) >>> n_t = torch.tanh(torch.matmul(W_in, X[0].T).T + b_in + r_t * b_hn) >>> h_t = (1 - z_t) * n_t
確認すると、モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。
>>> H[0] tensor([[-0.1412, -0.2934, 0.3071, -0.2858], [ 0.1785, -0.1226, 0.2666, 0.0485]], grad_fn=<SelectBackward>) >>> h_t tensor([[-0.1412, -0.2934, 0.3071, -0.2858], [ 0.1785, -0.1226, 0.2666, 0.0485]])
次は t = 1
に対する計算を取り上げつつ、それぞれの式が意味するところを考えてみる。
まず、以下の はリセットゲート (reset gate) と呼ばれている。
リセットゲートの式は、活性化関数がシグモイド関数なので、成分は 0 ~ 1
の範囲になる。
>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[1].T).T + b_ir + torch.matmul(W_hr, H[0].T).T + b_hr)
リセットゲートは、後ほど新しい隠れ状態の候補を作るときに、一つ前の隠れ状態と積を取る。 それによって、次の隠れ状態に、一つ前の隠れ状態をどれくらい反映するか制御する役目を担っている。
ゲートは成分の値が 0
のときに「閉じている」、1
のときに「開いている」と表現するらしい。
もちろん、ゲートの値は人間が明示的に与えるのではなく、学習するデータによって最適化される。
数式で対応しているのは、この部分。
以下の は、元の論文には名前付きで登場しないものの、PyTorch の中ではニューゲート (new gate) と呼ばれているようだ 1。
これは、言うなれば次の隠れ状態の候補となるもの。
式は RNN
の隠れ状態を作るときのものに近いけど、みると一つ前の隠れ状態に先ほどのリセットゲートがかけられている。
これによって、次の隠れ状態に一つ前の隠れ状態をどれくらい混ぜるか、つまり影響を与えるかを制御している。
たとえば、リセットゲートの成分がすべてゼロなら、次の隠れ状態の候補を作るときに、一つ前の隠れ状態をまったく考慮しないことになる。
>>> n_t = torch.tanh(torch.matmul(W_in, X[1].T).T + b_in + r_t * (torch.matmul(W_hn, H[0].T).T + b_hn))
数式で対応しているのは、この部分。
次に、以下の はアップデートゲート (update gate) と呼ばれている。 このゲートは、次の隠れ状態を作るときに、どれくらい一つ前の隠れ状態を引き継ぐかを制御している。
>>> z_t = torch.sigmoid(torch.matmul(W_iz, X[1].T).T + b_iz + torch.matmul(W_hz, H[0].T).T + b_hz)
数式で対応しているのは、この部分。
最後に、以下で次の隠れ状態を求めている。 式では、先ほど計算したニューゲートとアップデートゲートが登場している。 次の隠れ状態は、基本的にニューゲートと一つ前の隠れ状態が混ぜられていることがわかる。 そして、混ぜる比率をアップデートゲートが制御している。 もしアップデートゲートの成分がすべてゼロなら、一つ前の隠れ状態はまったく考慮されず、すべてニューゲートのものになる。
>>> h_t = (1 - z_t) * n_t + z_t * H[0]
数式で対応しているのは、この部分。
計算した隠れ状態を、最初に得られたものと比較してみよう。
>>> H[1] tensor([[-0.0039, -0.3424, 0.4580, -0.2490], [ 0.3869, -0.4714, 0.1700, 0.4022]], grad_fn=<SelectBackward>) >>> h_t tensor([[-0.0039, -0.3424, 0.4580, -0.2490], [ 0.3869, -0.4714, 0.1700, 0.4022]], grad_fn=<AddBackward0>)
モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。
LSTM を検算する
続いては LTSM についても同様に検算してみる。
LSTM では、Simple RNN や GRU で扱っていた隠れ状態が「長期記憶」と「短期記憶」に分かれている。 これによって、長いスパンで記憶しておく必要のある情報と、特定のタイミングでのみ必要な情報を扱いやすくしているらしい。 ちなみに LSTM は前述の GRU よりも歴史のあるアーキテクチャで、GRU は LSTM の特殊形と捉えることもできるようだ。
LSTM も、PyTorch ではクラスの名前を LSTM
に変更するだけで使うことができる。
>>> model = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim)
モデルに含まれるパラメータを確認してみよう。
パラメータの名前は同じだけど、先ほどの GRU
よりも、さらに行列のサイズが増えている。
>>> pprint(list(model.named_parameters())) [('weight_ih_l0', Parameter containing: tensor([[ 0.3498, -0.0745, 0.0339], [-0.0537, -0.4582, -0.0305], [-0.1209, -0.1292, 0.0014], [-0.4880, 0.4027, 0.2235], [-0.3940, -0.4997, -0.4360], [ 0.4677, -0.2913, 0.3161], [-0.4162, -0.4060, -0.0483], [ 0.0281, 0.0586, -0.4602], [ 0.0145, 0.3151, -0.0132], [ 0.2642, 0.0724, -0.1972], [-0.1406, 0.2249, -0.0125], [-0.1339, -0.1570, -0.4393], [-0.1411, -0.1534, 0.4226], [-0.3554, 0.0628, 0.3336], [-0.3037, -0.4630, -0.0022], [-0.4711, 0.4282, 0.4648]], requires_grad=True)), ('weight_hh_l0', Parameter containing: tensor([[ 0.1409, 0.2027, 0.4179, 0.2062], [ 0.0182, 0.1814, -0.0826, 0.0193], [-0.3766, -0.4391, 0.0336, -0.0875], [-0.3921, 0.0581, 0.3184, -0.4362], [ 0.0616, -0.0611, -0.0350, 0.2251], [-0.1458, -0.2994, -0.4362, -0.0643], [ 0.1637, -0.1193, 0.4780, -0.0938], [-0.0130, 0.1613, 0.2988, -0.2142], [-0.1978, 0.3739, -0.4704, 0.3770], [ 0.4956, -0.3259, 0.0976, 0.1588], [ 0.2641, -0.2511, -0.3984, 0.2107], [ 0.4604, 0.1646, -0.0299, 0.4243], [ 0.4658, -0.1663, -0.0066, -0.2386], [ 0.2184, 0.3376, -0.2343, 0.2853], [-0.2000, -0.4610, -0.2787, -0.2990], [ 0.3782, -0.1738, -0.1492, -0.2577]], requires_grad=True)), ('bias_ih_l0', Parameter containing: tensor([-0.1566, 0.4039, 0.2361, 0.1422, 0.1875, 0.0293, -0.2778, 0.4168, -0.4732, 0.0960, 0.1191, 0.1664, 0.1017, 0.1526, 0.4041, 0.0643], requires_grad=True)), ('bias_hh_l0', Parameter containing: tensor([-0.2511, 0.2747, -0.0801, -0.1251, 0.0565, -0.3207, 0.0877, 0.2105, -0.3742, -0.3953, -0.3199, -0.1545, -0.1276, -0.4406, -0.3679, 0.4121], requires_grad=True))]
モデルにダミーデータを与えてみよう。
このとき、LSTM
では返り値が RNN
や GRU
よりも増えている。
>>> H, (hn, cn) = model(X)
上記で、H
と hn
は RNN
や GRU
と同じ隠れ状態を表している。
ただし、LSTM においては隠れ状態が「短期記憶」に対応する。
>>> H[-1] tensor([[-0.2198, -0.1965, -0.0670, -0.5722], [-0.1991, -0.0771, -0.1617, -0.0441]], grad_fn=<SelectBackward>) >>> >>> hn tensor([[[-0.2198, -0.1965, -0.0670, -0.5722], [-0.1991, -0.0771, -0.1617, -0.0441]]], grad_fn=<StackBackward>)
返り値で増えているのは、前述した「長期記憶」になる。 詳しくは後述するけど、LSTM の「短期記憶」はこの「長期記憶」から抜き出して作る。
>>> cn tensor([[[-0.3814, -0.5555, -0.1285, -0.9944], [-0.6473, -0.2559, -0.2589, -0.1025]]], grad_fn=<StackBackward>)
使う上で理解すべき概念の説明が終わったところで、検算に移る。 PyTorch で使われている LSTM の数式は以下のドキュメントで確認できる。
数式は次のとおり。 GRU のときよりも、さらに増えている。
上記で はアダマール積を表している。
モデルからパラメータを取り出そう。 先ほどと同じように、数式をなぞるために行列から必要な箇所を取り出して名前をつけていく。
>>> model_weights = {name: param.data for name, param ... in model.named_parameters()} >>> >>> W_i = model_weights['weight_ih_l0'] >>> W_h = model_weights['weight_hh_l0'] >>> b_i = model_weights['bias_ih_l0'] >>> b_h = model_weights['bias_hh_l0'] >>> >>> W_ii = W_i[:hidden_dim] >>> W_if = W_i[hidden_dim: hidden_dim * 2] >>> W_ig = W_i[hidden_dim * 2: hidden_dim * 3] >>> W_io = W_i[hidden_dim * 3:] >>> >>> W_hi = W_h[:hidden_dim] >>> W_hf = W_h[hidden_dim: hidden_dim * 2] >>> W_hg = W_h[hidden_dim * 2: hidden_dim * 3] >>> W_ho = W_h[hidden_dim * 3:] >>> >>> b_ii = b_i[:hidden_dim] >>> b_if = b_i[hidden_dim: hidden_dim * 2] >>> b_ig = b_i[hidden_dim * 2: hidden_dim * 3] >>> b_io = b_i[hidden_dim * 3:] >>> >>> b_hi = b_h[:hidden_dim] >>> b_hf = b_h[hidden_dim: hidden_dim * 2] >>> b_hg = b_h[hidden_dim * 2: hidden_dim * 3] >>> b_ho = b_h[hidden_dim * 3:]
とりあえず、t = 0
の時点の隠れ状態 (短期記憶) を数式のとおりに計算してみよう。
t = 0
かつ、初期の隠れ状態と長期記憶を渡していないので存在しない項がある点に注意する。
>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[0].T).T + b_ii + b_hi) >>> f_t = torch.sigmoid(torch.matmul(W_if, X[0].T).T + b_if + b_hf) >>> g_t = torch.tanh(torch.matmul(W_ig, X[0].T).T + b_ig + b_hg) >>> o_t = torch.sigmoid(torch.matmul(W_io, X[0].T).T + b_io + b_ho) >>> c_t = i_t * g_t >>> h_t = o_t * torch.tanh(c_t)
計算した値と、モデルから返ってきた隠れ状態を比較してみよう。
>>> H[0] tensor([[-0.1018, -0.0494, -0.0653, 0.1273], [-0.0617, -0.1598, 0.0546, -0.3024]], grad_fn=<SelectBackward>) >>> >>> h_t tensor([[-0.1018, -0.0494, -0.0653, 0.1273], [-0.0617, -0.1598, 0.0546, -0.3024]])
ちゃんと一致している。
続いては数式の意味を確認しながら t = 1
も計算してみよう。
計算する上で、一つ前の長期記憶が必要になるので c_0
という名前で記録しておく。
>>> c_0 = c_t
まず計算するのは、入力ゲート (input gate) で、新しい入力 を、どれくらい次の長期記憶に反映するかを司っている。
>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[1].T).T + b_ii + torch.matmul(W_hi, H[0].T).T + b_hi)
次に計算しているのは忘却ゲート (forget gate) で、一つ前の長期記憶を、次にどれだけ引き継ぐかを担っている。
>>> f_t = torch.sigmoid(torch.matmul(W_if, X[1].T).T + b_if + torch.matmul(W_hf, H[0].T).T + b_hf)
以下の式は、論文では名前がついていないけど、PyTorch ではセルゲート (cell gate) と呼んでいる。 これは Simple RNN で隠れ状態を計算していた式と同じ。 入力と、一つ前の隠れ状態 (短期記憶) を混ぜている。
>>> g_t = torch.tanh(torch.matmul(W_ig, X[1].T).T + b_ig + torch.matmul(W_hg, H[0].T).T + b_hg)
以下は出力ゲート (output gate) で、長期記憶から短期記憶をどのように抜き出すかを司っている。
>>> o_t = torch.sigmoid(torch.matmul(W_io, X[1].T).T + b_io + torch.matmul(W_ho, H[0].T).T + b_ho)
以下で、一つ前の長期記憶と出力ゲートを混ぜて、次の長期記憶を作っている。 どんな風に混ぜるかは、忘却ゲートと入力ゲートの値に依存する。
>>> c_t = f_t * c_0 + i_t * g_t
そして、最後に長期記憶から出力ゲートを使って短期記憶を抜き出している。
>>> h_t = o_t * torch.tanh(c_t)
隠れ状態を比べてみると、ちゃんと値が一致していることがわかる。
>>> H[1] tensor([[-0.1224, -0.1573, 0.0294, 0.0794], [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<SelectBackward>) >>> >>> h_t tensor([[-0.1224, -0.1573, 0.0294, 0.0794], [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<MulBackward0>)
いじょう。
参考
-
役目的にはゲートではないので何だか変な気もする↩