CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PyTorch の MultiheadAttention を検算してみる

今回は、言わずと知れた Transformer 1 において、処理の中心的な役割を果たしている (とされる) Multi-Head Attention を扱ってみる。 これは、Scaled Dot Product Attention という処理を改良したもの。

PyTorch には Multi-Head Attention の実装として MultiheadAttention というクラスが用意されている。 今回は、これがどういった処理をしているのかを、検算しながら確かめてみる。

pytorch.org

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.6
BuildVersion:   20G165
$ python -V          
Python 3.9.7
$ pip list | grep -i torch
torch             1.9.1

もくじ

下準備

あらかじめ PyTorch と NumPy をインストールしておく。

$ pip install torch numpy

続いて、Python のインタプリタを起動しておこう。

$ python

PyTorch 関連のモジュールをインポートしておく。

>>> import torch
>>> from torch import nn
>>> from torch.nn import functional as F

検算する

Multi-Head Attention は、Query と Key と Value (以下、Q, K, V) という 3 つのパラメータを入力として受け取る。 それぞれのパラメータは同じ次元数で、返す値は Query と同一の形状になるという特徴がある。 なお、Attention 自体の説明は以下のブログが詳しい。

deeplearning.hatenablog.com

はじめに、返り値の次元数を定義する。 この次元数は、前述のとおり Q の次元数と同じになる。

>>> embed_dim = 4

続いてヘッド数を定義する。 ヘッドというのは、入力をいくつかに分割して処理するそれぞれの Scaled Dot Product Attention のこと。 元の入力データの次元数が \displaystyle{d _ {model}} とすると、各ヘッドに入力されるデータの次元数は  \frac{d_{model}}{h} になる。 ここで  h がヘッド数を表す。 ヘッド数はハイパーパラメータで、多すぎても少なすぎても良くないらしい。 まずは単純なケースとしてヘッド数が 1 の場合を確かめよう。

>>> num_heads = 1

定義した次元数とヘッド数を使って MultiheadAttention をインスタンス化する。 今回は単純にするためバイアス項をなくし、入力データの形状としてバッチが最初にくるようにした。

>>> model = nn.MultiheadAttention(embed_dim=embed_dim,
...                               num_heads=num_heads,
...                               bias=False,
...                               batch_first=True)

つづいて、上記に入力するダミーデータを用意する。 今回は 2 x 5 x 4 という形状のデータにした。 仮に自然言語を想定するなら、最初がバッチ、2 番目が文章の系列長、最後が単語の分散表現の次元数を表すことになる。 それぞれ batch_sizeLembed_dim という変数名で用意している。

>>> batch_size = 2  # 一度に処理するデータの数
>>> L = 5  # 入力する系列データの長さ
>>> X = torch.randn(batch_size, L, embed_dim)  # ダミーの入力データ

今回は Transformer でも用いられている Self Attention を想定するので Q, K, V すべてに同じ X を入れる。

>>> Q = K = V = X

モデルとデータが用意できたので順伝搬させて返り値を得よう。 返り値としては、変換された Q と同一形状のテンソルと、モデルが何に注目したかを表す Attention Weights (Map とも) が返ってくる。

>>> attn_output, attn_weights = model(Q, K, V)

もちろん、上記は初期値の重みとダミーデータを使って得たものなので、中身自体には何の意味があるわけでもない。 しかし、どういった計算を経てこれが得られるのかを確かめる分には十分だ。

以下を見てわかるとおり、出力は入力した Query と同じ形状になっている。

>>> attn_output
tensor([[[-0.1453, -0.1466, -0.0371,  0.3497],
         [-0.1822, -0.1085, -0.0900,  0.3613],
         [-0.0751, -0.1506,  0.0799,  0.2830],
         [-0.1552, -0.1184, -0.0450,  0.3429],
         [-0.2029, -0.0971, -0.1205,  0.3772]],

        [[-0.2694, -0.3017, -0.4528,  0.5128],
         [-0.3347, -0.3734, -0.4705,  0.6975],
         [-0.2867, -0.3250, -0.4454,  0.5810],
         [-0.2401, -0.2732, -0.4402,  0.4353],
         [-0.2539, -0.2964, -0.4241,  0.5060]]], grad_fn=<TransposeBackward0>)
>>> attn_output.shape
torch.Size([2, 5, 4])

Attention Weights は、先頭がバッチで、2 番目と 3 番目が系列長 L と同じ形状になる。

>>> attn_weights
tensor([[[0.2358, 0.2097, 0.2365, 0.1587, 0.1592],
         [0.1823, 0.1938, 0.1635, 0.2238, 0.2366],
         [0.2124, 0.1384, 0.3916, 0.1622, 0.0953],
         [0.1892, 0.1843, 0.2213, 0.2131, 0.1921],
         [0.1687, 0.2126, 0.1319, 0.2141, 0.2727]],

        [[0.2454, 0.1563, 0.1772, 0.2531, 0.1680],
         [0.2078, 0.2083, 0.2640, 0.1586, 0.1612],
         [0.2298, 0.2132, 0.1977, 0.1992, 0.1602],
         [0.2408, 0.1406, 0.1461, 0.2921, 0.1804],
         [0.2226, 0.2197, 0.1611, 0.2224, 0.1742]]], grad_fn=<DivBackward0>)
>>> attn_weights.shape
torch.Size([2, 5, 5])

さて、それでは上記の結果がどのように得られるのかを検算で確かめてみよう。

まず、PyTorch のドキュメントを見ると MultiheadAttention の数式は以下のように定義されている。

\displaystyle{

\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O \\

\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

}

Q, K, V それぞれを、パラメータ行列 \displaystyle{W_i} と積をとって Attention に入れたものが、各ヘッドの出力になる。 各ヘッドの出力は連結した上で、さらにパラメータ行列 \displaystyle{W^ O} とかけたものが MultiheadAttention の出力だ。

ここで   \text{Attention} は Scaled Dot Product Attention なので以下になる。

\displaystyle{
\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{d_q}) V 
}

数式がわかったので、まずはモデルが持っているパラメータを確認してみよう。 どうやら in_proj_weightsout_proj.weight というパラメータがあるらしい。 これらは、上記の式において  W_i と [tex: WO] に対応している。

>>> from pprint import pprint
>>> pprint(list(model.named_parameters()))
[('in_proj_weight',
  Parameter containing:
tensor([[-0.5443,  0.3884, -0.1312, -0.1092],
        [ 0.1386, -0.3444,  0.3273,  0.1445],
        [-0.2816,  0.0416, -0.4813,  0.1620],
        [-0.4794, -0.0049, -0.5191, -0.3294],
        [-0.3429,  0.4189, -0.0930,  0.2866],
        [ 0.5036, -0.2311,  0.2426,  0.0193],
        [ 0.5196, -0.0979, -0.4762, -0.3478],
        [-0.3660, -0.3218, -0.2310, -0.2840],
        [-0.4351,  0.1184, -0.3720, -0.2419],
        [-0.2723, -0.5269,  0.2075, -0.4505],
        [ 0.0627,  0.0975,  0.5494, -0.2860],
        [ 0.4284,  0.5447, -0.1266,  0.2931]], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[-0.0758,  0.0238, -0.4159,  0.4350],
        [ 0.1650, -0.2046,  0.4133,  0.2710],
        [ 0.4356, -0.0973, -0.1273,  0.3115],
        [ 0.3645,  0.4667,  0.4714, -0.4997]], requires_grad=True))]

それぞれの行列をモデルから取り出しておこう。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> Wi = model_weights['in_proj_weight']
>>> Wo = model_weights['out_proj.weight']

Wi は Q, K, V それぞれにかける部位が分かれているので取り出す。

>>> Wi_q, Wi_k, Wi_v = Wi.chunk(3)

取り出したら次のようにして Attention に入力する部分を計算する。

>>> QW = torch.matmul(Q, Wi_q.T)
>>> KW = torch.matmul(K, Wi_k.T)
>>> VW = torch.matmul(V, Wi_v.T)

数式で対応するのは、以下の Attention に渡す前の部分。

\displaystyle{

 \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

}

なお、今回は Self Attention なので、次のように計算しても結果は変わらない。

>>> QW, KW, VW = torch.matmul(Q, Wi.T).chunk(3, dim=-1)

続いては Scaled Dot Product Attention の中の処理に入る。 まずは \displaystyle{
Q K^ T
} を計算する。

>>> KW_t = KW.transpose(-2, -1)
>>> QK_t = torch.bmm(QW, KW_t)

次にスケールを調整する。 数式だと \displaystyle{
\frac{Q K^ T}{d _ q}
} の部分。

>>> import math
>>> QK_t_scaled = QK_t / math.sqrt(embed_dim)

Softmax を通して足して 1 になるようにする。 \displaystyle{
 \text{softmax}(\frac{Q K^ T}{d _ q})
} の部分。 これが Attention Weights と呼ばれるもの。

>>> attn_weights_ = F.softmax(QK_t_scaled, dim=-1)

モデルから得られた値と比較すると、ちゃんと一致していることがわかる。

>>> attn_weights
tensor([[[0.2358, 0.2097, 0.2365, 0.1587, 0.1592],
         [0.1823, 0.1938, 0.1635, 0.2238, 0.2366],
         [0.2124, 0.1384, 0.3916, 0.1622, 0.0953],
         [0.1892, 0.1843, 0.2213, 0.2131, 0.1921],
         [0.1687, 0.2126, 0.1319, 0.2141, 0.2727]],

        [[0.2454, 0.1563, 0.1772, 0.2531, 0.1680],
         [0.2078, 0.2083, 0.2640, 0.1586, 0.1612],
         [0.2298, 0.2132, 0.1977, 0.1992, 0.1602],
         [0.2408, 0.1406, 0.1461, 0.2921, 0.1804],
         [0.2226, 0.2197, 0.1611, 0.2224, 0.1742]]], grad_fn=<DivBackward0>)
>>> attn_weights_
tensor([[[0.2358, 0.2097, 0.2365, 0.1587, 0.1592],
         [0.1823, 0.1938, 0.1635, 0.2238, 0.2366],
         [0.2124, 0.1384, 0.3916, 0.1622, 0.0953],
         [0.1892, 0.1843, 0.2213, 0.2131, 0.1921],
         [0.1687, 0.2126, 0.1319, 0.2141, 0.2727]],

        [[0.2454, 0.1563, 0.1772, 0.2531, 0.1680],
         [0.2078, 0.2083, 0.2640, 0.1586, 0.1612],
         [0.2298, 0.2132, 0.1977, 0.1992, 0.1602],
         [0.2408, 0.1406, 0.1461, 0.2921, 0.1804],
         [0.2226, 0.2197, 0.1611, 0.2224, 0.1742]]])

続いては Attention Weights を重みとした、V の重みつき和を得る。 これで \displaystyle{
\text{head} _ i = \text{Attention}(QW _ i^ Q, KW _ i^ K, VW _ i^ V)
} に対応するヘッドの出力が得られた。

>>> AV = torch.matmul(attn_weights_, VW)

今回、ヘッド数が 1 なので連結処理は必要ない。 あとはヘッドの出力を \displaystyle{
W^ O
} とかけるだけ。

>>> attn_output_ = torch.matmul(AV, Wo.T)

最初にモデルから得られた値と比較すると、ちゃんと一致していることがわかる。

>>> attn_output
tensor([[[-0.1453, -0.1466, -0.0371,  0.3497],
         [-0.1822, -0.1085, -0.0900,  0.3613],
         [-0.0751, -0.1506,  0.0799,  0.2830],
         [-0.1552, -0.1184, -0.0450,  0.3429],
         [-0.2029, -0.0971, -0.1205,  0.3772]],

        [[-0.2694, -0.3017, -0.4528,  0.5128],
         [-0.3347, -0.3734, -0.4705,  0.6975],
         [-0.2867, -0.3250, -0.4454,  0.5810],
         [-0.2401, -0.2732, -0.4402,  0.4353],
         [-0.2539, -0.2964, -0.4241,  0.5060]]], grad_fn=<TransposeBackward0>)
>>> attn_output_
tensor([[[-0.1453, -0.1466, -0.0371,  0.3497],
         [-0.1822, -0.1085, -0.0900,  0.3613],
         [-0.0751, -0.1506,  0.0799,  0.2830],
         [-0.1552, -0.1184, -0.0450,  0.3429],
         [-0.2029, -0.0971, -0.1205,  0.3772]],

        [[-0.2694, -0.3017, -0.4528,  0.5128],
         [-0.3347, -0.3734, -0.4705,  0.6975],
         [-0.2867, -0.3250, -0.4454,  0.5810],
         [-0.2401, -0.2732, -0.4402,  0.4353],
         [-0.2539, -0.2964, -0.4241,  0.5060]]])

これで、出力と Attention Weights の両方について、モデルから得られたものと計算した値が一致した。

ヘッド数を増やしてみる

さて、先ほどの検算ではヘッド数が 1 の単純なパターンを試した。 続いてはヘッド数を 2 に増やして同様のことをやってみよう。 ちゃんとできるだろうか。

まずはヘッド数を 2 に増やした上で MultiheadAttention をインスタンス化し直す。

>>> num_heads = 2
>>> model = nn.MultiheadAttention(embed_dim=embed_dim,
...                               num_heads=num_heads,
...                               bias=False,
...                               batch_first=True)

ダミーデータはそのまま流用して、また返り値を得よう。

>>> attn_output, attn_weights = model(Q, K, V)

モデルからパラメータを取り出す。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> Wi = model_weights['in_proj_weight']
>>> Wo = model_weights['out_proj.weight']

今回、ヘッド数が 1 から 2 に増えるので、ヘッドに入力されるデータの次元数は半分になる。 そこを計算しておこう。

>>> embed_dim_per_head = embed_dim // num_heads
>>> embed_dim_per_head
2

まずは \displaystyle{
W _ i
} の方のパラメータ行列を先ほどと同じように取り出そう。

>>> Wi_q, Wi_k, Wi_v = Wi.chunk(3)

上記の行列は、ヘッド毎に使う部分が分かれている。 そこで、0 番目のヘッド用と 1 番目のヘッド用に分割して取り出す。

>>> Wi0_q, Wi1_q = Wi_q.chunk(num_heads)
>>> Wi0_k, Wi1_k = Wi_k.chunk(num_heads)
>>> Wi0_v, Wi1_v = Wi_v.chunk(num_heads)

Scaled Dot Product Attention の計算はやることが多いので一旦関数にまとめてしまおう。

>>> def scaled_dot_product_self_attention(X, Wi_q, Wi_k, Wi_v):
...     QW = torch.matmul(Q, Wi_q.T)
...     KW = torch.matmul(K, Wi_k.T)
...     VW = torch.matmul(V, Wi_v.T)
...     KW_t = KW.transpose(-2, -1)
...     QK_t = torch.bmm(QW, KW_t)
...     import math
...     QK_t_scaled = QK_t / math.sqrt(embed_dim_per_head)
...     attn_weights_ = F.softmax(QK_t_scaled, dim=-1)
...     AV = torch.matmul(attn_weights_, VW)
...     return AV, attn_weights_
... 

次のようにヘッドごとの計算結果を得る。

>>> AV0, attn_weights0_ = scaled_dot_product_self_attention(Q, Wi0_q, Wi0_k, Wi0_v)
>>> AV1, attn_weights1_ = scaled_dot_product_self_attention(Q, Wi1_q, Wi1_k, Wi1_v)

ヘッドごとの計算結果を連結する。 数式では \displaystyle{
\text{Concat}(\text{head} _ 1, \dots, \text{head} _ h)
} に対応する。

>>> AV_concat = torch.cat([AV0, AV1], dim=-1)

連結したらあとはさっきと同じ。

>>> attn_output_ = torch.matmul(AV_concat, Wo.T)

結果を確認すると、ちゃんとモデルの出力と一致している。

>>> attn_output
tensor([[[ 0.0637, -0.1713,  0.1646,  0.1171],
         [ 0.2022, -0.2627,  0.1944,  0.0561],
         [ 0.1051, -0.1979,  0.1749,  0.0978],
         [ 0.1367, -0.2148,  0.1865,  0.0759],
         [ 0.1614, -0.2310,  0.1923,  0.0644]],

        [[-0.3877, -0.1922, -0.1522,  0.6143],
         [-0.4006, -0.1432, -0.1958,  0.5701],
         [-0.3928, -0.2885, -0.1149,  0.7077],
         [-0.3871, -0.1448, -0.1641,  0.5711],
         [-0.4203, -0.1682, -0.1782,  0.6106]]], grad_fn=<TransposeBackward0>)
>>> attn_output_
tensor([[[ 0.0637, -0.1713,  0.1646,  0.1171],
         [ 0.2022, -0.2627,  0.1944,  0.0561],
         [ 0.1051, -0.1979,  0.1749,  0.0978],
         [ 0.1367, -0.2148,  0.1865,  0.0759],
         [ 0.1614, -0.2310,  0.1923,  0.0644]],

        [[-0.3877, -0.1922, -0.1522,  0.6143],
         [-0.4006, -0.1432, -0.1958,  0.5701],
         [-0.3928, -0.2885, -0.1149,  0.7077],
         [-0.3871, -0.1448, -0.1641,  0.5711],
         [-0.4203, -0.1682, -0.1782,  0.6106]]])

ちなみに Attention Weights はどうなっているかというと、すべてのヘッドの平均を計算しているようだ。

>>> attn_weights_ = (attn_weights0_ + attn_weights1_) / num_heads

平均をとったものが、モデルの出力と一致している。

>>> attn_weights
tensor([[[0.2021, 0.1750, 0.1170, 0.2130, 0.2929],
         [0.2060, 0.1956, 0.2399, 0.1871, 0.1713],
         [0.2028, 0.1883, 0.1492, 0.2120, 0.2476],
         [0.1966, 0.2012, 0.1703, 0.2065, 0.2254],
         [0.1972, 0.2058, 0.1914, 0.2004, 0.2052]],

        [[0.1916, 0.2228, 0.1979, 0.1893, 0.1984],
         [0.1783, 0.2872, 0.1692, 0.1739, 0.1914],
         [0.1891, 0.0948, 0.2491, 0.2291, 0.2379],
         [0.1943, 0.2597, 0.1843, 0.1796, 0.1822],
         [0.2045, 0.1544, 0.2107, 0.2196, 0.2110]]], grad_fn=<DivBackward0>)
>>> attn_weights_
tensor([[[0.2021, 0.1750, 0.1170, 0.2130, 0.2929],
         [0.2060, 0.1956, 0.2399, 0.1871, 0.1713],
         [0.2028, 0.1883, 0.1492, 0.2120, 0.2476],
         [0.1966, 0.2012, 0.1703, 0.2065, 0.2254],
         [0.1972, 0.2058, 0.1914, 0.2004, 0.2052]],

        [[0.1916, 0.2228, 0.1979, 0.1893, 0.1984],
         [0.1783, 0.2872, 0.1692, 0.1739, 0.1914],
         [0.1891, 0.0948, 0.2491, 0.2291, 0.2379],
         [0.1943, 0.2597, 0.1843, 0.1796, 0.1822],
         [0.2045, 0.1544, 0.2107, 0.2196, 0.2110]]])

これで、複数のヘッドがあるパターンについてもモデルと計算した値が一致することがわかった。

まとめ

今回は Transformer の中心的な処理である Multi-Head Attention について、PyTorch の実装を例に検算してみた。