CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: Luigi の DateIntervalParameter について

バッチ処理に特化した Python のデータパイプライン構築用のフレームワークに Luigi がある。 今回は、特定の時系列的な範囲を Task が受け取るのに使える DateIntervalParameter というパラメータを紹介する。 これは、たとえば一週間とか一ヶ月あるいは特定の日付から日付といった範囲で、何らかの集計をする処理を書くときに便利に使える。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.6
BuildVersion:   20G165
$ python -V        
Python 3.9.7
$ pip list | grep -i luigi
luigi                    3.0.3

もくじ

下準備

あらかじめ、Luigi をインストールしておく。

$ pip install luigi

DateIntervalParameter について

早速だけど以下にサンプルコードを示す。

以下では ExampleTask というタスクを定義している。 このタスクは dt_interval という名前で DateIntervalParameter 型のパラメータを受け取る。 タスクが実行されると DateIntervalParameter#dates() メソッドを呼んで、範囲に含まれる日付を標準出力に書き出す。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import luigi


class ExampleTask(luigi.Task):
    # 期間を指定できるパラメータ
    dt_interval = luigi.DateIntervalParameter()

    def run(self):
        for date in self.dt_interval.dates():
            # 各日付に適用する擬似的な処理
            print(f'Processing: {date}')

    def complete(self):
        # 動作確認用に output() メソッドを定義しないで常にタスクが実行されるようにする
        return False

上記を実行してみよう。 まずは "2021-W01" という文字列を渡してみる。 これは 2021 年の第 01 週を表している。

$ python -m luigi \
    --local-scheduler \
    --module example \
    ExampleTask \
    --dt-interval "2021-W01"

... (snip) ...

Processing: 2021-01-04
Processing: 2021-01-05
Processing: 2021-01-06
Processing: 2021-01-07
Processing: 2021-01-08
Processing: 2021-01-09
Processing: 2021-01-10

... (snip)

上記を見ると 2021-01-04 から 2021-01-10 が、指定した 2021-W01 に含まれる日付ということがわかる。

同様に "2021-01" という文字列を渡してみよう。 これは 2021 年の 1 月を表している。

$ python -m luigi \
    --local-scheduler \
    --module example \
    ExampleTask \
    --dt-interval "2021-01"

... (snip) ...

Processing: 2021-01-01
Processing: 2021-01-02
Processing: 2021-01-03
... (snip) ...
Processing: 2021-01-29
Processing: 2021-01-30
Processing: 2021-01-31

... (snip)

数が多いので省略しているけど、2021-01-01 から 2021-01-31 が範囲に含まれることがわかる。

また、ISO 8601 形式の日付をハイフン (-) でつなぐと、任意の日付の範囲が指定できる。 たとえば 2021-09-01 から 2021-09-07 を指定してみよう。

$ python -m luigi \
    --local-scheduler \
    --module example \
    ExampleTask \
    --dt-interval "2021-09-01-2021-09-07"

... (snip) ...

Processing: 2021-09-01
Processing: 2021-09-02
Processing: 2021-09-03
Processing: 2021-09-04
Processing: 2021-09-05
Processing: 2021-09-06

... (snip)

末尾の日付は含まれずに、2021-09-01 から 2021-09-06 が範囲に含まれていることがわかる。

このように DateIntervalParameter を使うと、特定の日付の範囲を受け取る処理が書きやすい。 典型的には、開始日と終了日を個別に取っていたような処理を置き換えることができる。

動作原理

使い方はわかったので、この DateIntervalParameter というパラメータが、どのように実現されているのか見ていこう。

以下のサンプルコードでは、受け取った dt_interval の型を表示している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import luigi


class ExampleTask(luigi.Task):
    dt_interval = luigi.DateIntervalParameter()

    def run(self):
        # dt_interval の型を表示する
        print(f'*** Type of dt_interval: {type(self.dt_interval)} ***')

    def complete(self):
        return False

上記に "2021-W01" という文字列を渡すと luigi.date_interval.Week という型になっていることが確認できる。

$ python -m luigi \
    --local-scheduler \
    --module example \
    ExampleTask \
    --dt-interval "2021-W01"

... (snip) ...

*** Type of dt_interval: <class 'luigi.date_interval.Week'> ***

... (snip)

同様に、"2021-01" を渡したときでは luigi.date_interval.Month になる。

$ python -m luigi \
    --local-scheduler \
    --module example \
    ExampleTask \
    --dt-interval "2021-01"

... (snip) ...

*** Type of dt_interval: <class 'luigi.date_interval.Month'> ***

... (snip)

以下、実行については省略しつつ、以下のように対応している。

  • <年-月-日>
    • luigi.date_interval.Date
  • <年-W週>
    • luigi.date_interval.Week
  • <年-月>
    • luigi.date_interval.Month
  • <年-月-日>-<年-月-日>
    • luigi.date_interval.Custom

ここからは Python のインタプリタを使って確認していこう。

$ python

luigi.date_interval をインポートする。

>>> from luigi import date_interval

たとえば luigi.date_interval.Week をインスタンス化してみよう。 このクラスには年と週数を渡す必要がある。

>>> from pprint import pprint
>>> week_interval = date_interval.Week(2021, 1)

ちなみに、実行時と同じように文字列を使ってインスタンス化するときは parse() メソッドを使えば良い。

>>> date_interval.Week.parse('2021-W01')
2021-W01

このオブジェクトには、先ほどのサンプルコードでも登場したように dates() というメソッドがある。 このメソッドは、指定された期間に含まれる日付を返す。

>>> pprint(week_interval.dates())
[datetime.date(2021, 1, 4),
 datetime.date(2021, 1, 5),
 datetime.date(2021, 1, 6),
 datetime.date(2021, 1, 7),
 datetime.date(2021, 1, 8),
 datetime.date(2021, 1, 9),
 datetime.date(2021, 1, 10)]

また、next() メソッドを使うと次の期間が、prev() メソッドを使うと前の期間が得られる。

>>> week_interval.next()
2021-W02
>>> week_interval.prev()
2020-W53

これらのメソッドは、Week 以外にも MonthCustom などでそれぞれ実装されている。

ちなみに DateIntervalParameter 自体は、受け取った文字列をそれぞれのクラスの parse() に順番に渡す実装になっている。

luigi.readthedocs.io

いじょう。

Python: Luigi の RangeDaily 系の使い方と注意点について

Python の Luigi はバッチ処理に特化したデータパイプライン構築用のフレームワーク。 バッチ処理に特化しているとあって、定期的に実行する系のユーティリティも色々と用意されている。 今回は、その中でも特定の期間に実行すべきバッチ処理をまとめて扱うことのできる、RangeDaily を代表としたクラス群について書いてみる。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.6
BuildVersion:   20G165
$ python -V
Python 3.9.7
$ pip list | grep -i luigi    
luigi                    3.0.3

もくじ

下準備

まずは Luigi をインストールしておく。

$ pip install luigi

RangeDaily について

RangeDaily は、日付を受け取って何らかの処理をする Task をまとめて扱うことのできるラッパータスク。 これを使うと、一日ずつ手動でタスクを実行する代わりに、開始期間 (と終了期間) を指定して一気に処理が実行できる。 とはいえ、文章で説明してもあんまりイメージできないと思うのでサンプルコードを使って説明していく。

以下のサンプルコードでは MyDailyTask というタスクを定義している。 このタスクは date というパラメータ名で日付を受け取る。 実行すると、/tmp 以下に受け取った日付を名前に含んだファイルを作成する。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import luigi


class MyDailyTask(luigi.Task):
    """日付を受け取って、名前に日付を含むファイルを生成するタスク"""
    date = luigi.DateParameter()

    def run(self):
        with self.output().open(mode='w') as fp:
            print('Hello, World!', file=fp)

    def output(self):
        path = '/tmp/luigi-{date:%Y-%m-%d}'.format(date=self.date)
        return luigi.LocalTarget(path)

上記を example.py という名前で保存しよう。

通常なら、上記のタスクを実行するには次のようにする。 Luigi のスケジューラに、タスクを表したクラスと、実行したい日付を --date パラメータで指定する。

$ python -m luigi \
    --local-scheduler \
    --module example \
    MyDailyTask \
    --date 2021-09-01

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 1 tasks of which:
* 1 ran successfully:
    - 1 MyDailyTask(date=2021-09-01)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

実行すると、次のようにファイルが作られる。

$ ls /tmp/luigi-2021-09-01  
/tmp/luigi-2021-09-01
$ cat /tmp/luigi-2021-09-01 
Hello, World!

上記は一日ずつ処理する場合の例になる。

続いては、今回の主題である RangeDaily を使ってみよう。 RangeDaily を使うと、この日付以降をまとめて実行する、といったことができる。 以下は --start オプションを使って 2021-09-01 以降を一気に実行する場合の例。 実行するタスクとしては RangeDaily を指定して、--of パラメータで自分で定義したタスクを指定する。

$ python -m luigi \
    --local-scheduler \
    RangeDaily \
    --module example \
    --of MyDailyTask \
    --start 2021-09-01

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 19 tasks of which:
* 19 ran successfully:
    - 18 MyDailyTask(date=2021-09-02...2021-09-19)
    - 1 RangeDaily(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

上記の実行結果を見ると、指定した日付以降を一括で実行できていることがわかる。 なお、これを実行しているシステムの日付は 2021-09-19 である。

$ date "+%Y-%m-%d"
2021-09-19

ディレクトリを確認すると 2021-09-01 から 2021-09-19 の範囲でファイルができている。 なお、2021-09-01 については最初に単発で実行したもの。

$ ls /tmp/luigi-* | head -n 1
/tmp/luigi-2021-09-01
$ ls /tmp/luigi-* | tail -n 1
/tmp/luigi-2021-09-19

試しにもう一度同じコマンドを実行してみよう。 すると、いずれのタスクも実行されないことがわかる。 これは、既にタスクが成果物としているファイルが存在するため。

$ python -m luigi \
    --local-scheduler \
    RangeDaily \
    --module example \
    --of MyDailyTask \
    --start 2021-09-01

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 1 tasks of which:
* 1 complete ones were encountered:
    - 1 RangeDaily(...)

Did not run any tasks
This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

この特性はバッチ処理を扱う上でなかなか便利だったりする。 というのも、何らかの理由で成果物のファイルが日付として歯抜けになっている場合、自動で存在しない日付を検出して実行してくれる。 また、日付の最後も特に指定しない限り今日になるので、バッチ処理として仕込むコマンドに RangeDaily を指定することもできる。

注意点について

続いては RangeDaily の注意点について。 端的に言うと、現時刻から遡れる日数とタスク数に上限が設定されている。

たとえば開始の日付として年初を指定して実行してみよう。

$ python -m luigi \
    --local-scheduler \
    RangeDaily \
    --module example \
    --of MyDailyTask \
    --start 2021-01-01

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 51 tasks of which:
* 51 ran successfully:
    - 50 MyDailyTask(date=2021-06-12...2021-07-31)
    - 1 RangeDaily(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

上記の実行結果 (Luigi Execution Summary) を見て違和感を覚えたかもしれない。 というのも、作成されたファイルの日付が 2021-06-12 から始まっているため。

$ ls /tmp/luigi-* | sort | head -n 5
/tmp/luigi-2021-06-12
/tmp/luigi-2021-06-13
/tmp/luigi-2021-06-14
/tmp/luigi-2021-06-15
/tmp/luigi-2021-06-16

また、末尾についても 2021-07-31 で終わっている。

$ ls /tmp/luigi-* | sort | tail -n 25 | head -n 10
/tmp/luigi-2021-07-26
/tmp/luigi-2021-07-27
/tmp/luigi-2021-07-28
/tmp/luigi-2021-07-29
/tmp/luigi-2021-07-30
/tmp/luigi-2021-07-31
/tmp/luigi-2021-09-01
/tmp/luigi-2021-09-02
/tmp/luigi-2021-09-03
/tmp/luigi-2021-09-04

本来なら 2021-01-01 から歯抜けなくファイルができてほしかった。 どうして、こんなことが起こるのか。

この問題はドキュメントを見るとわかる。

luigi.readthedocs.io

遡れる日数の上限 (--days-back)

まず、RangeDailyBase という、RangeDaily の基底クラスには days_back というパラメータがある。 このパラメータは、現在の日時から遡る日付の上限を定めていて、デフォルトで 100 に設定されている。

確認すると、今日から 100 日前は 2021-06-11 だった。 つまり、その日付を含む過去は遡って実行されないようになっている。

$ brew install coreutils
$ gdate --iso-8601 --date '100 days ago'
2021-06-11

試しに --days-back パラメータに 365 を指定して、上限を 1 年まで広げてみよう。

$ python -m luigi \
    --local-scheduler \
    RangeDaily \
    --module example \
    --of MyDailyTask \
    --start 2021-01-01 \
    --days-back 365

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 51 tasks of which:
* 51 ran successfully:
    - 50 MyDailyTask(date=2021-01-01...2021-02-19)
    - 1 RangeDaily(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

すると、上記を見てわかるとおり年初から実行されるようになった。

実行できるタスク数の上限 (--task-limit)

同様に、一度に実行できるタスクの数にも上限がある。 こちらは --task-limit というパラメータで指定できる。 デフォルトでは 50 に設定されている。

試しに、タスク数の上限も 365 に引きあげて実行してみよう。

$ python -m luigi \
    --local-scheduler \
    RangeDaily \
    --module example \
    --of MyDailyTask \
    --start 2021-01-01 \
    --days-back 365 \
    --task-limit 365

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 144 tasks of which:
* 144 ran successfully:
    - 143 MyDailyTask(date=2021-02-20,2021-02-21,2021-02-22,...)
    - 1 RangeDaily(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

上記を見てわかるとおり、これまでの上限だった 50 を越えてタスクが実行されている。

ファイルの数を検算しても、年初から今日までの日数と一致した。

$ ls /tmp/luigi-* | wc -l
     262
$ gdate --iso-8601 --date '2021-01-01 261 days'
2021-09-19

確認がおわったら、一旦作られたファイルをすべて削除しておこう。

$ rm /tmp/luigi-*

一ヶ月毎の処理 (RangeMonthly)

ちなみに、タイトルに RangeDaily と書いたように、日次のバッチ以外を扱うためのクラスも用意されている。

以下のサンプルコードでは特定の月に対して実行することを想定した MyMonthlyTask というタスクを定義している。 パラメータの名前は date だけど、型が luigi.DateParameter ではなく luigi.MonthParameter になっている。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import luigi


class MyMonthlyTask(luigi.Task):
    """月を受け取って、名前に月を含むファイルを生成するタスク"""
    date = luigi.MonthParameter()

    def run(self):
        with self.output().open(mode='w') as fp:
            print('Hello, World!', file=fp)

    def output(self):
        path = '/tmp/luigi-{month:%Y-%m}'.format(month=self.date)
        return luigi.LocalTarget(path)

先ほどと同じように example.py という名前で保存しておこう。

上記を RangeMonthly 経由で実行する。

$ python -m luigi \
    --local-scheduler \
    RangeMonthly \
    --module example \
    --of MyMonthlyTask \
    --start 2021-01

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 9 tasks of which:
* 9 ran successfully:
    - 8 MyMonthlyTask(date=2021-01...2021-08)
    - 1 RangeMonthly(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

上記を見ると、今月を含まない形でタスクが実行されていることがわかる。 これは典型的には日付が揃わないうちに集計処理を実行することは少ないことが関係しているんだろう。

先ほどと同じように、確認できたらファイルを一旦きれいにしておく。

$ rm /tmp/luigi-*

一時間毎の処理 (RangeHourly)

同じように一時間毎の処理も扱うことができる。 以下では一時間毎に実行することを期待した処理を MyHourlyTask という名前で定義している。 パラメータの名前は date だけど、クラスは luigi.DateHourParameter を用いる。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import luigi


class MyHourlyTask(luigi.Task):
    """時間 (Hour) を受け取って、名前に時間を含むファイルを生成するタスク"""
    date = luigi.DateHourParameter()

    def run(self):
        with self.output().open(mode='w') as fp:
            print('Hello, World!', file=fp)

    def output(self):
        path = '/tmp/luigi-{month:%Y-%m-%dT%H}'.format(month=self.date)
        return luigi.LocalTarget(path)

一時間ごとの処理になるとファイルが増えるので、今回は明示的に終了の時刻も指定しておこう。 以下では --start2021-09-15T00 を、--stop2021-09-16T12 を指定している。 T 以降が時間を表している。

$ python -m luigi \
    --local-scheduler \
    RangeHourly \
    --module example \
    --of MyHourlyTask \
    --start 2021-09-15T00 \
    --stop 2021-09-16T12

... (snip) ...

===== Luigi Execution Summary =====

Scheduled 37 tasks of which:
* 37 ran successfully:
    - 36 MyHourlyTask(date=2021-09-15T00...2021-09-16T11)
    - 1 RangeHourly(...)

This progress looks :) because there were no failed tasks or missing dependencies

===== Luigi Execution Summary =====

結果を見てわかるとおり、--start の時刻を含んで --stop を含まない形でタスクが実行されていることがわかる。

$ ls /tmp/luigi-*  
/tmp/luigi-2021-09-15T00    /tmp/luigi-2021-09-15T12    /tmp/luigi-2021-09-16T00
/tmp/luigi-2021-09-15T01    /tmp/luigi-2021-09-15T13    /tmp/luigi-2021-09-16T01
/tmp/luigi-2021-09-15T02    /tmp/luigi-2021-09-15T14    /tmp/luigi-2021-09-16T02
/tmp/luigi-2021-09-15T03    /tmp/luigi-2021-09-15T15    /tmp/luigi-2021-09-16T03
/tmp/luigi-2021-09-15T04    /tmp/luigi-2021-09-15T16    /tmp/luigi-2021-09-16T04
/tmp/luigi-2021-09-15T05    /tmp/luigi-2021-09-15T17    /tmp/luigi-2021-09-16T05
/tmp/luigi-2021-09-15T06    /tmp/luigi-2021-09-15T18    /tmp/luigi-2021-09-16T06
/tmp/luigi-2021-09-15T07    /tmp/luigi-2021-09-15T19    /tmp/luigi-2021-09-16T07
/tmp/luigi-2021-09-15T08    /tmp/luigi-2021-09-15T20    /tmp/luigi-2021-09-16T08
/tmp/luigi-2021-09-15T09    /tmp/luigi-2021-09-15T21    /tmp/luigi-2021-09-16T09
/tmp/luigi-2021-09-15T10    /tmp/luigi-2021-09-15T22    /tmp/luigi-2021-09-16T10
/tmp/luigi-2021-09-15T11    /tmp/luigi-2021-09-15T23    /tmp/luigi-2021-09-16T11

まとめ

今回は Luigi で特定の期間毎に実行するバッチ処理を、まとめて扱う RangeDaily 系の使い方と注意点について紹介した。 注意点としては、遡る日付や実行するタスクの数に上限がある点や、期間によって開始・終了を含む・含まないが微妙に異なる点が挙げられる。

Python: PyTorch の GRU / LSTM を検算してみる

以前のエントリで扱った Simple RNN の検算は、個人的になかなか良い勉強になった。

blog.amedama.jp

そこで、今回は Simple RNN の改良版となる GRU (Gated Recurrent Unit) と LSTM (Long Short Term Memory) についても検算してみる。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.5.2
BuildVersion:   20G95
$ python -V
Python 3.9.6
$ pip list | grep torch
torch                    1.9.0

もくじ

下準備

下準備として、あらかじめ PyTorch をインストールしておく。

$ pip install torch

続いて Python のインタプリタを起動する。

$ python

そして、PyTorch のパッケージをインポートする。

>>> import torch
>>> from torch import nn

GRU を検算する

Simple RNN は、仕組みが単純な一方で隠れ状態が入力によって無条件に更新されてしまう。 そのため、隠れ状態に昔の情報が残りにくいことから、長期的な記憶を保つことが難しいという問題があった。 GRU では、それをゲートという仕組みを導入することで改善を試みている。

まずは次のように GRU クラスをインスタンス化する。 Simple RNN のときと同じように、モデルの初期状態の重みをそのまま使って検算する。

>>> input_dim = 3  # モデルの入力ベクトルの次元数
>>> hidden_dim = 4  # モデルの出力 (隠れ状態) ベクトルの次元数
>>> model = nn.GRU(input_size=input_dim, hidden_size=hidden_dim)

インスタンス化するときに必要な引数は RNN クラスと変わらない。 つまり、入力と出力のサイズを渡すだけ。

インスタンス化できたら、モデルのパラメータを確認しよう。 どうやら、モデルのパラメータが持っている名前は RNN クラスと同じようだ。 ただし、重みを保持している行列のサイズは増している。

>>> from pprint import pprint
>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[-0.3619, -0.1291, -0.0647],
        [-0.4406, -0.2705, -0.3480],
        [ 0.0360,  0.3222,  0.2494],
        [-0.0738, -0.3214,  0.4445],
        [-0.3551,  0.3078, -0.0846],
        [-0.4367,  0.4282, -0.1521],
        [-0.4895,  0.0713,  0.0217],
        [-0.2439,  0.4704, -0.2078],
        [ 0.0460,  0.2528,  0.3555],
        [-0.3008, -0.0595,  0.0586],
        [-0.3535,  0.2088, -0.2179],
        [ 0.2923,  0.0291,  0.4044]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[ 0.0406,  0.3097, -0.2765, -0.2359],
        [ 0.4449,  0.3376,  0.3715, -0.3207],
        [ 0.0157,  0.0347, -0.0091, -0.0438],
        [ 0.1630,  0.3619,  0.3797, -0.0845],
        [ 0.1729, -0.1405,  0.0844, -0.3560],
        [ 0.0711, -0.3750, -0.0721, -0.4998],
        [-0.4140, -0.1105, -0.1611,  0.1338],
        [-0.0574, -0.1216,  0.2439, -0.2021],
        [ 0.1568,  0.2177,  0.4511,  0.4009],
        [-0.4453, -0.0780, -0.1764,  0.3598],
        [ 0.1704,  0.3918, -0.0727,  0.2112],
        [ 0.3841,  0.0154,  0.2495,  0.1840]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([-0.3642, -0.2804,  0.3874, -0.0016, -0.0540, -0.3060, -0.0446, -0.0145,
         0.1529, -0.4700,  0.3887,  0.1273], requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.0260, -0.0787, -0.3992,  0.4587,  0.3522,  0.0618,  0.0865, -0.2561,
         0.0439, -0.4722,  0.2414, -0.2022], requires_grad=True))]

続いて、ダミーの入力データを用意しよう。 ダミーの入力データの形状は RNN を使った場合と変わらない。

>>> T = 5  # 入力する系列データの長さ
>>> batch_size = 2  # 一度に処理するデータの数
>>> X = torch.randn(T, batch_size, input_dim)  # ダミーの入力データ
>>> X.shape
torch.Size([5, 2, 3])

ダミーの入力データをモデルに与えて出力を得る。

>>> H, hn = model(X)

出力は入力の系列データに対応する隠れ状態と、最後の隠れ状態になっている。 この形状も RNN と変わらない。 つまり、PyTorch において GRU は単純に名前を変えるだけで RNN から差し替えて使うことができる。

>>> H.shape, hn.shape
(torch.Size([5, 2, 4]), torch.Size([1, 2, 4]))
>>> H[-1]
tensor([[ 0.5352, -0.5132,  0.2607,  0.5642],
        [-0.0264, -0.6124,  0.5123, -0.2023]], grad_fn=<SelectBackward>)
>>> hn
tensor([[[ 0.5352, -0.5132,  0.2607,  0.5642],
         [-0.0264, -0.6124,  0.5123, -0.2023]]], grad_fn=<StackBackward>)

それでは、ここからは実際に検算に入ろう。 PyTorch で使われている GRU の数式は以下のドキュメントで確認できる。

pytorch.org

数式は以下のとおり。 Simple RNN のときは 1 つだった式が 4 つに増えている。 なお、最終的に求めたいのは一番下にある「入力  x_t に対応した隠れ状態  h_t」になる。

\displaystyle{
        \begin{array}{ll}
            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
        \end{array}
}

ここで  \sigma はシグモイド関数を表す。  r_t z_t n_t は、活性化関数の違いはあるものの、基本的にはいずれも  W_i x_t + b_i + W_h h_{(t-1)} + b_h の形になっていることがわかる。

数式が確認できたところでモデルのパラメータから重みを取り出していこう。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_i = model_weights['weight_ih_l0']
>>> W_h = model_weights['weight_hh_l0']
>>> b_i = model_weights['bias_ih_l0']
>>> b_h = model_weights['bias_hh_l0']

上記は、部分ごとに  r_t 用と  z_t 用と  n_t 用に分かれている。 本来は一気に行列計算した上で後から取り出すわけだけど、今回は数式をなぞるために先に取り出しておこう。

>>> W_ir = W_i[:hidden_dim]
>>> W_iz = W_i[hidden_dim: hidden_dim * 2]
>>> W_in = W_i[hidden_dim * 2:]
>>> 
>>> W_hr = W_h[:hidden_dim]
>>> W_hz = W_h[hidden_dim: hidden_dim * 2]
>>> W_hn = W_h[hidden_dim * 2:]
>>> 
>>> b_ir = b_i[:hidden_dim]
>>> b_iz = b_i[hidden_dim: hidden_dim * 2]
>>> b_in = b_i[hidden_dim * 2:]
>>> 
>>> b_hr = b_h[:hidden_dim]
>>> b_hz = b_h[hidden_dim: hidden_dim * 2]
>>> b_hn = b_h[hidden_dim * 2:]

あとは定義どおりに計算していく。

まずは t = 0 の状態から。 つまり、X[0] に対応する隠れ状態を計算してみよう。 t = 0 かつ、初期の隠れ状態を渡していないので  W_h h_{(t-1)} の項が存在しない。

>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[0].T).T + b_ir + b_hr)
>>> z_t = torch.sigmoid(torch.matmul(W_iz, X[0].T).T + b_iz + b_hz)
>>> n_t = torch.tanh(torch.matmul(W_in, X[0].T).T + b_in + r_t * b_hn)
>>> h_t = (1 - z_t) * n_t

確認すると、モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。

>>> H[0]
tensor([[-0.1412, -0.2934,  0.3071, -0.2858],
        [ 0.1785, -0.1226,  0.2666,  0.0485]], grad_fn=<SelectBackward>)
>>> h_t
tensor([[-0.1412, -0.2934,  0.3071, -0.2858],
        [ 0.1785, -0.1226,  0.2666,  0.0485]])

次は t = 1 に対する計算を取り上げつつ、それぞれの式が意味するところを考えてみる。

まず、以下の  r_t はリセットゲート (reset gate) と呼ばれている。 リセットゲートの式は、活性化関数がシグモイド関数なので、成分は 0 ~ 1 の範囲になる。

>>> r_t = torch.sigmoid(torch.matmul(W_ir, X[1].T).T + b_ir + torch.matmul(W_hr, H[0].T).T + b_hr)

リセットゲートは、後ほど新しい隠れ状態の候補を作るときに、一つ前の隠れ状態と積を取る。 それによって、次の隠れ状態に、一つ前の隠れ状態をどれくらい反映するか制御する役目を担っている。

ゲートは成分の値が 0 のときに「閉じている」、1 のときに「開いている」と表現するらしい。 もちろん、ゲートの値は人間が明示的に与えるのではなく、学習するデータによって最適化される。

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
        \end{array}
}

以下の  n_t は、元の論文には名前付きで登場しないものの、PyTorch の中ではニューゲート (new gate) と呼ばれているようだ 1。 これは、言うなれば次の隠れ状態の候補となるもの。 式は RNN の隠れ状態を作るときのものに近いけど、みると一つ前の隠れ状態に先ほどのリセットゲートがかけられている。 これによって、次の隠れ状態に一つ前の隠れ状態をどれくらい混ぜるか、つまり影響を与えるかを制御している。 たとえば、リセットゲートの成分がすべてゼロなら、次の隠れ状態の候補を作るときに、一つ前の隠れ状態をまったく考慮しないことになる。

>>> n_t = torch.tanh(torch.matmul(W_in, X[1].T).T + b_in + r_t * (torch.matmul(W_hn, H[0].T).T + b_hn))

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
        \end{array}
}

次に、以下の  z_t はアップデートゲート (update gate) と呼ばれている。 このゲートは、次の隠れ状態を作るときに、どれくらい一つ前の隠れ状態を引き継ぐかを制御している。

>>> z_t = torch.sigmoid(torch.matmul(W_iz, X[1].T).T + b_iz + torch.matmul(W_hz, H[0].T).T + b_hz)

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
        \end{array}
}

最後に、以下で次の隠れ状態を求めている。 式では、先ほど計算したニューゲートとアップデートゲートが登場している。 次の隠れ状態は、基本的にニューゲートと一つ前の隠れ状態が混ぜられていることがわかる。 そして、混ぜる比率をアップデートゲートが制御している。 もしアップデートゲートの成分がすべてゼロなら、一つ前の隠れ状態はまったく考慮されず、すべてニューゲートのものになる。

>>> h_t = (1 - z_t) * n_t + z_t * H[0]

数式で対応しているのは、この部分。

\displaystyle{
        \begin{array}{ll}
            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
        \end{array}
}

計算した隠れ状態を、最初に得られたものと比較してみよう。

>>> H[1]
tensor([[-0.0039, -0.3424,  0.4580, -0.2490],
        [ 0.3869, -0.4714,  0.1700,  0.4022]], grad_fn=<SelectBackward>)
>>> h_t
tensor([[-0.0039, -0.3424,  0.4580, -0.2490],
        [ 0.3869, -0.4714,  0.1700,  0.4022]], grad_fn=<AddBackward0>)

モデルから返ってきた隠れ状態と、検算した値が一致していることがわかる。

LSTM を検算する

続いては LTSM についても同様に検算してみる。

LSTM では、Simple RNN や GRU で扱っていた隠れ状態が「長期記憶」と「短期記憶」に分かれている。 これによって、長いスパンで記憶しておく必要のある情報と、特定のタイミングでのみ必要な情報を扱いやすくしているらしい。 ちなみに LSTM は前述の GRU よりも歴史のあるアーキテクチャで、GRU は LSTM の特殊形と捉えることもできるようだ。

LSTM も、PyTorch ではクラスの名前を LSTM に変更するだけで使うことができる。

>>> model = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim)

モデルに含まれるパラメータを確認してみよう。 パラメータの名前は同じだけど、先ほどの GRU よりも、さらに行列のサイズが増えている。

>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[ 0.3498, -0.0745,  0.0339],
        [-0.0537, -0.4582, -0.0305],
        [-0.1209, -0.1292,  0.0014],
        [-0.4880,  0.4027,  0.2235],
        [-0.3940, -0.4997, -0.4360],
        [ 0.4677, -0.2913,  0.3161],
        [-0.4162, -0.4060, -0.0483],
        [ 0.0281,  0.0586, -0.4602],
        [ 0.0145,  0.3151, -0.0132],
        [ 0.2642,  0.0724, -0.1972],
        [-0.1406,  0.2249, -0.0125],
        [-0.1339, -0.1570, -0.4393],
        [-0.1411, -0.1534,  0.4226],
        [-0.3554,  0.0628,  0.3336],
        [-0.3037, -0.4630, -0.0022],
        [-0.4711,  0.4282,  0.4648]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[ 0.1409,  0.2027,  0.4179,  0.2062],
        [ 0.0182,  0.1814, -0.0826,  0.0193],
        [-0.3766, -0.4391,  0.0336, -0.0875],
        [-0.3921,  0.0581,  0.3184, -0.4362],
        [ 0.0616, -0.0611, -0.0350,  0.2251],
        [-0.1458, -0.2994, -0.4362, -0.0643],
        [ 0.1637, -0.1193,  0.4780, -0.0938],
        [-0.0130,  0.1613,  0.2988, -0.2142],
        [-0.1978,  0.3739, -0.4704,  0.3770],
        [ 0.4956, -0.3259,  0.0976,  0.1588],
        [ 0.2641, -0.2511, -0.3984,  0.2107],
        [ 0.4604,  0.1646, -0.0299,  0.4243],
        [ 0.4658, -0.1663, -0.0066, -0.2386],
        [ 0.2184,  0.3376, -0.2343,  0.2853],
        [-0.2000, -0.4610, -0.2787, -0.2990],
        [ 0.3782, -0.1738, -0.1492, -0.2577]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([-0.1566,  0.4039,  0.2361,  0.1422,  0.1875,  0.0293, -0.2778,  0.4168,
        -0.4732,  0.0960,  0.1191,  0.1664,  0.1017,  0.1526,  0.4041,  0.0643],
       requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.2511,  0.2747, -0.0801, -0.1251,  0.0565, -0.3207,  0.0877,  0.2105,
        -0.3742, -0.3953, -0.3199, -0.1545, -0.1276, -0.4406, -0.3679,  0.4121],
       requires_grad=True))]

モデルにダミーデータを与えてみよう。 このとき、LSTM では返り値が RNNGRU よりも増えている。

>>> H, (hn, cn) = model(X)

上記で、HhnRNNGRU と同じ隠れ状態を表している。 ただし、LSTM においては隠れ状態が「短期記憶」に対応する。

>>> H[-1]
tensor([[-0.2198, -0.1965, -0.0670, -0.5722],
        [-0.1991, -0.0771, -0.1617, -0.0441]], grad_fn=<SelectBackward>)
>>> 
>>> hn
tensor([[[-0.2198, -0.1965, -0.0670, -0.5722],
         [-0.1991, -0.0771, -0.1617, -0.0441]]], grad_fn=<StackBackward>)

返り値で増えているのは、前述した「長期記憶」になる。 詳しくは後述するけど、LSTM の「短期記憶」はこの「長期記憶」から抜き出して作る。

>>> cn
tensor([[[-0.3814, -0.5555, -0.1285, -0.9944],
         [-0.6473, -0.2559, -0.2589, -0.1025]]], grad_fn=<StackBackward>)

使う上で理解すべき概念の説明が終わったところで、検算に移る。 PyTorch で使われている LSTM の数式は以下のドキュメントで確認できる。

pytorch.org

数式は次のとおり。 GRU のときよりも、さらに増えている。

\displaystyle{
        \begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
            h_t = o_t \odot \tanh(c_t) \\
        \end{array}
}

上記で  \odot はアダマール積を表している。

モデルからパラメータを取り出そう。 先ほどと同じように、数式をなぞるために行列から必要な箇所を取り出して名前をつけていく。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_i = model_weights['weight_ih_l0']
>>> W_h = model_weights['weight_hh_l0']
>>> b_i = model_weights['bias_ih_l0']
>>> b_h = model_weights['bias_hh_l0']
>>> 
>>> W_ii = W_i[:hidden_dim]
>>> W_if = W_i[hidden_dim: hidden_dim * 2]
>>> W_ig = W_i[hidden_dim * 2: hidden_dim * 3]
>>> W_io = W_i[hidden_dim * 3:]
>>> 
>>> W_hi = W_h[:hidden_dim]
>>> W_hf = W_h[hidden_dim: hidden_dim * 2]
>>> W_hg = W_h[hidden_dim * 2: hidden_dim * 3]
>>> W_ho = W_h[hidden_dim * 3:]
>>> 
>>> b_ii = b_i[:hidden_dim]
>>> b_if = b_i[hidden_dim: hidden_dim * 2]
>>> b_ig = b_i[hidden_dim * 2: hidden_dim * 3]
>>> b_io = b_i[hidden_dim * 3:]
>>> 
>>> b_hi = b_h[:hidden_dim]
>>> b_hf = b_h[hidden_dim: hidden_dim * 2]
>>> b_hg = b_h[hidden_dim * 2: hidden_dim * 3]
>>> b_ho = b_h[hidden_dim * 3:]

とりあえず、t = 0 の時点の隠れ状態 (短期記憶) を数式のとおりに計算してみよう。 t = 0 かつ、初期の隠れ状態と長期記憶を渡していないので存在しない項がある点に注意する。

>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[0].T).T + b_ii + b_hi)
>>> f_t = torch.sigmoid(torch.matmul(W_if, X[0].T).T + b_if + b_hf)
>>> g_t = torch.tanh(torch.matmul(W_ig, X[0].T).T + b_ig + b_hg)
>>> o_t = torch.sigmoid(torch.matmul(W_io, X[0].T).T + b_io + b_ho)
>>> c_t = i_t * g_t
>>> h_t = o_t * torch.tanh(c_t)

計算した値と、モデルから返ってきた隠れ状態を比較してみよう。

>>> H[0]
tensor([[-0.1018, -0.0494, -0.0653,  0.1273],
        [-0.0617, -0.1598,  0.0546, -0.3024]], grad_fn=<SelectBackward>)
>>> 
>>> h_t
tensor([[-0.1018, -0.0494, -0.0653,  0.1273],
        [-0.0617, -0.1598,  0.0546, -0.3024]])

ちゃんと一致している。

続いては数式の意味を確認しながら t = 1 も計算してみよう。 計算する上で、一つ前の長期記憶が必要になるので c_0 という名前で記録しておく。

>>> c_0 = c_t

まず計算するのは、入力ゲート (input gate) で、新しい入力  x_t を、どれくらい次の長期記憶に反映するかを司っている。

>>> i_t = torch.sigmoid(torch.matmul(W_ii, X[1].T).T + b_ii + torch.matmul(W_hi, H[0].T).T + b_hi)

次に計算しているのは忘却ゲート (forget gate) で、一つ前の長期記憶を、次にどれだけ引き継ぐかを担っている。

>>> f_t = torch.sigmoid(torch.matmul(W_if, X[1].T).T + b_if + torch.matmul(W_hf, H[0].T).T + b_hf)

以下の式は、論文では名前がついていないけど、PyTorch ではセルゲート (cell gate) と呼んでいる。 これは Simple RNN で隠れ状態を計算していた式と同じ。 入力と、一つ前の隠れ状態 (短期記憶) を混ぜている。

>>> g_t = torch.tanh(torch.matmul(W_ig, X[1].T).T + b_ig + torch.matmul(W_hg, H[0].T).T + b_hg)

以下は出力ゲート (output gate) で、長期記憶から短期記憶をどのように抜き出すかを司っている。

>>> o_t = torch.sigmoid(torch.matmul(W_io, X[1].T).T + b_io + torch.matmul(W_ho, H[0].T).T + b_ho)

以下で、一つ前の長期記憶と出力ゲートを混ぜて、次の長期記憶を作っている。 どんな風に混ぜるかは、忘却ゲートと入力ゲートの値に依存する。

>>> c_t = f_t * c_0 + i_t * g_t

そして、最後に長期記憶から出力ゲートを使って短期記憶を抜き出している。

>>> h_t = o_t * torch.tanh(c_t)

隠れ状態を比べてみると、ちゃんと値が一致していることがわかる。

>>> H[1]
tensor([[-0.1224, -0.1573,  0.0294,  0.0794],
        [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<SelectBackward>)
>>> 
>>> h_t
tensor([[-0.1224, -0.1573,  0.0294,  0.0794],
        [-0.1954, -0.1273, -0.0442, -0.3626]], grad_fn=<MulBackward0>)

いじょう。

参考

arxiv.org

(PDF) Long Short-term Memory

youtu.be

youtu.be


  1. 役目的にはゲートではないので何だか変な気もする

Python: PyTorch の RNN を検算してみる

今回は、PyTorch の RNN (Recurrent Neural Network) が内部的にどんな処理をしているのか確認してみる。 なお、ここでいう RNN は、再起的な構造をもったニューラルネットワークの総称ではなく、いわゆる古典的な Simple RNN を指している。

これを書いている人は、ニューラルネットワークが何もわからないので、再帰的な構造があったりすると尚更わからなくなる。 そこで、中身について知っておきたいと考えたのがモチベーションになっている。

使った環境は次のとおり。

$ sw_vers 
ProductName:    macOS
ProductVersion: 11.5.2
BuildVersion:   20G95
$ python -V         
Python 3.9.6
$ pip list | grep torch       
torch                    1.9.0

もくじ

下準備

まずは PyTorch をインストールしておく。

$ pip install torch

そして、Python のインタプリタを起動する。

$ python

起動できたら PyTorch のモジュールをインポートしておく。

>>> import torch
>>> from torch import nn

モデルを用意する

PyTorch には nn モジュール以下に RNN というクラスがある。 このクラスが、ミニバッチに対応した Simple RNN を実装している。 このクラスは、最低限 input_sizehidden_size という引数を指定すればインスタンス化できる。

>>> input_dim = 3  # モデルの入力ベクトルの次元数
>>> hidden_dim = 4  # モデルの出力ベクトルの次元数
>>> model = nn.RNN(input_size=input_dim, hidden_size=hidden_dim)

input_size はモデルに入力するデータの次元数で、hidden_size はモデルが出力するデータの次元数になる。 Simple RNN が出力するデータには隠れ状態 (Hidden State) ベクトルという名前がついていて、それが引数の名前に反映されている。

インスタンス化できたら、モデルに含まれるパラメータを確認してみよう。 これは何も学習していない状態の初期値だけど、ダミーのデータを使って検算する分にはそれで問題ない。

>>> from pprint import pprint
>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[ 0.4349,  0.2858, -0.3802],
        [ 0.3035,  0.4744, -0.4774],
        [ 0.4553,  0.1563, -0.0048],
        [-0.4107, -0.4734,  0.3651]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[-0.4045,  0.4994, -0.3950,  0.3627],
        [-0.4304,  0.2032,  0.2878,  0.0923],
        [ 0.0641, -0.0405, -0.2965, -0.3422],
        [ 0.3323, -0.2716, -0.1380,  0.2079]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([-0.2928,  0.2330,  0.1649, -0.2679], requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.0034, -0.0927,  0.0520, -0.0646], requires_grad=True))]

モデルには 4 つの名前つきパラメータが確認できる。 これらのパラメータが何を意味しているかは、以下のドキュメントをみるとわかる。

pytorch.org

上記には、RNN の具体的な計算式が記載されている。

\displaystyle{
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
}

ここで、 x_t は入力となる系列データにおいて t 番目 (時点) の要素を表していて、 h_tt 番目の隠れ状態ベクトルになる。  h_{(t-1)}t - 1 番目の隠れ状態ベクトルなので、1 つ前の状態の出力を入力として使っていることがわかる。

それ以外は、先ほどのパラメータと次のように対応している。

  •  W_{ih}

    • weight_ih_l0
  •  W_{hh}

    • weight_hh_l0
  •  b_{ih}

    • bias_ih_l0
  •  b_{hh}

    • bias_hh_l0

ダミーデータを用意する

式がわかったところで、検算するための出力を適当に用意したダミーデータを使って得よう。 次のようにランダムな入力データを用意する。

>>> T = 5  # 入力する系列データの長さ
>>> batch_size = 2  # 一度に処理するデータの数
>>> X = torch.randn(T, batch_size, input_dim)  # ダミーの入力データ
>>> X.shape
torch.Size([5, 2, 3])

上記のダミーデータをモデルに入力として与える。 すると、タプルで 2 つの返り値が得られる。

>>> H, hn = model(X)

このうち、タプルの最初の要素は各時点 (0 ~ T) での隠れ状態ベクトルが入っている。 つまり、X[0] に対応した隠れ状態ベクトルが H[0] で、X[1] に対応した隠れ状態ベクトルが H[1] で...ということ。

>>> H.shape
torch.Size([5, 2, 4])
>>> H
tensor([[[-0.0096,  0.3380,  0.4147, -0.5187],
         [-0.5797,  0.0438, -0.3449, -0.0454]],

        [[-0.3769,  0.5505,  0.1542, -0.6927],
         [ 0.1021,  0.4838,  0.0174, -0.5226]],

        [[ 0.5723,  0.8306,  0.5878, -0.9012],
         [-0.5423, -0.3730,  0.1816,  0.0130]],

        [[-0.2641,  0.0466,  0.7226, -0.6048],
         [-0.6680, -0.4764,  0.2837,  0.2118]],

        [[-0.8623, -0.3724, -0.4284,  0.2948],
         [-0.2464,  0.4500, -0.4194, -0.1977]]], grad_fn=<StackBackward>)

そして、タプルで 2 番目に返ってきた値は最後の時点 (T) での隠れ状態ベクトルになる。 ようするに、上記の最後尾と同じもの。

>>> hn
tensor([[[-0.8623, -0.3724, -0.4284,  0.2948],
         [-0.2464,  0.4500, -0.4194, -0.1977]]], grad_fn=<StackBackward>)
>>> H[-1]
tensor([[-0.8623, -0.3724, -0.4284,  0.2948],
        [-0.2464,  0.4500, -0.4194, -0.1977]], grad_fn=<SelectBackward>)

検算する

次に、実際の検算に入る。 まずは、次のようにして各パラメータの Tensor オブジェクトを得る。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_ih = model_weights['weight_ih_l0']
>>> W_hh = model_weights['weight_hh_l0']
>>> b_ih = model_weights['bias_ih_l0']
>>> b_hh = model_weights['bias_hh_l0']

まずは系列データの一番最初の t = 0X[0] に対応する隠れ状態ベクトルから求める。 ターゲットはこれ。

>>> H[0]
tensor([[-0.0096,  0.3380,  0.4147, -0.5187],
        [-0.5797,  0.0438, -0.3449, -0.0454]], grad_fn=<SelectBackward>)

やることは単純で、先ほどの式を PyTorch で表現すれば良い。 なお、t = 0 の時点では  h_{(t-1)} がないので、その項は消える。

>>> torch.tanh(torch.matmul(W_ih, X[0].T).T + b_ih + b_hh)
tensor([[-0.0096,  0.3380,  0.4147, -0.5187],
        [-0.5797,  0.0438, -0.3449, -0.0454]])

Tensor の値が一致していることがわかる。

次は t = 1X[1] に対応する隠れ状態ベクトルを求める。 ターゲットは以下。

>>> H[1]
tensor([[-0.3769,  0.5505,  0.1542, -0.6927],
        [ 0.1021,  0.4838,  0.0174, -0.5226]], grad_fn=<SelectBackward>)

t = 1 では  h_{(t-1)} h_0 になる。 とはいえ項が増えるだけで、やることは先ほどと変わらない。

>>> torch.tanh(torch.matmul(W_ih, X[1].T).T + b_ih + torch.matmul(W_hh, H[0].T).T + b_hh)
tensor([[-0.3769,  0.5505,  0.1542, -0.6927],
        [ 0.1021,  0.4838,  0.0174, -0.5226]], grad_fn=<TanhBackward>)

こちらも値が一致している。

あとは添字が増えるだけなので省略する。

初期 (t = 0) の隠れ状態ベクトルを渡す場合

先ほどの例では、初期 (t = 0) のときに  h_{(t-1)} に相当する隠れ状態ベクトルが存在しなかった。 これは自分で用意して渡すこともできるので、その場合の挙動も確認しておこう。

次のようにしてランダムな値で初期の隠れ状態ベクトルを h0 として用意する。 なお、先頭の次元は Simple RNN を重ねる段数を表している。 というのも、(総称としての) RNN は縦に積み重ねることで性能向上が望める場合があるらしい 1。 PyTorch のRNN も、インスタンス化するときに num_layers という引数で重ねる数が指定できる。 なお、デフォルト値は 1 になっている。

>>> rnn_layers = 1  # Simple RNN を重ねる数 (num_layers の値)
>>> h0 = torch.randn(rnn_layers, batch_size, hidden_dim)
>>> h0.shape
torch.Size([1, 2, 4])

初期の隠れ状態ベクトルをモデルに渡すには、次のように 2 番目の引数として渡せば良い。

>>> H, hn = model(X, h0)

先ほどと同じように検算してみよう。

>>> H[0]
tensor([[-0.1925,  0.6594, -0.2041, -0.4893],
        [ 0.5740,  0.8465, -0.5979, -0.8112]], grad_fn=<SelectBackward>)

といっても、最初の  h_{(t-1)} として h0 を使うだけ。

>>> torch.tanh(torch.matmul(W_ih, X[0].T).T + b_ih + torch.matmul(W_hh, h0[0].T).T + b_hh)
tensor([[-0.1925,  0.6594, -0.2041, -0.4893],
        [ 0.5740,  0.8465, -0.5979, -0.8112]])

残りは変わらない。

>>> H[1]
tensor([[ 0.0929,  0.5283,  0.2951, -0.7210],
        [-0.1403,  0.0510,  0.3764, -0.4921]], grad_fn=<SelectBackward>)
>>> torch.tanh(torch.matmul(W_ih, X[1].T).T + b_ih + torch.matmul(W_hh, H[0].T).T + b_hh)
tensor([[ 0.0929,  0.5283,  0.2951, -0.7210],
        [-0.1403,  0.0510,  0.3764, -0.4921]], grad_fn=<TanhBackward>)

Simple RNN を重ねた場合

先ほど述べたとおり RNN は層を重ねることで性能向上が望める場合がある。 その場合についても確認しておく。

まずは RNN を 2 層重ねたモデルを用意する。

>>> rnn_layers = 2
>>> model = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, num_layers=rnn_layers)

次のようにモデルのパラメータが増えている。 具体的には名前の末尾が l0 になったものと l1 になったものがある。 これはつまりl0 の上に l1 が重なっていることを示す。

>>> pprint(list(model.named_parameters()))
[('weight_ih_l0',
  Parameter containing:
tensor([[-0.3591,  0.0948, -0.0500],
        [ 0.1963, -0.1717, -0.3551],
        [ 0.0313,  0.0495, -0.0878],
        [ 0.3109,  0.3728,  0.2577]], requires_grad=True)),
 ('weight_hh_l0',
  Parameter containing:
tensor([[-0.3050, -0.0269,  0.1772,  0.0081],
        [-0.0770,  0.3563, -0.1209,  0.0126],
        [-0.3534,  0.0264,  0.2649,  0.2235],
        [ 0.3338, -0.0708,  0.4314, -0.0149]], requires_grad=True)),
 ('bias_ih_l0',
  Parameter containing:
tensor([ 0.3767,  0.3653, -0.1024,  0.3425], requires_grad=True)),
 ('bias_hh_l0',
  Parameter containing:
tensor([-0.1083, -0.1802, -0.2972,  0.1099], requires_grad=True)),
 ('weight_ih_l1',
  Parameter containing:
tensor([[ 0.2279, -0.4886,  0.4573,  0.2441],
        [-0.0949, -0.2300,  0.1320, -0.2643],
        [ 0.0720,  0.4727,  0.2005, -0.0784],
        [-0.0784,  0.3208,  0.4977, -0.0190]], requires_grad=True)),
 ('weight_hh_l1',
  Parameter containing:
tensor([[-0.0565,  0.1433,  0.0810,  0.1619],
        [ 0.2734,  0.3270, -0.2813,  0.1076],
        [ 0.2989,  0.0412, -0.1173,  0.1614],
        [-0.0805, -0.1851, -0.1254,  0.0713]], requires_grad=True)),
 ('bias_ih_l1',
  Parameter containing:
tensor([-0.3898, -0.1349, -0.2269, -0.1637], requires_grad=True)),
 ('bias_hh_l1',
  Parameter containing:
tensor([ 0.4969,  0.3327,  0.4548, -0.3809], requires_grad=True))]

それぞれのパラメータの重みを取得しておく。

>>> model_weights = {name: param.data for name, param
...                  in model.named_parameters()}
>>> 
>>> W_ih_l0 = model_weights['weight_ih_l0']
>>> W_hh_l0 = model_weights['weight_hh_l0']
>>> b_ih_l0 = model_weights['bias_ih_l0']
>>> b_hh_l0 = model_weights['bias_hh_l0']
>>> 
>>> W_ih_l1 = model_weights['weight_ih_l1']
>>> W_hh_l1 = model_weights['weight_hh_l1']
>>> b_ih_l1 = model_weights['bias_ih_l1']
>>> b_hh_l1 = model_weights['bias_hh_l1']

入力のダミーデータはそのままに、モデルからあらためて隠れ状態ベクトルを取得する。

>>> H, hn = model(X)

初期状態 (t = 0) をターゲットにする。

>>> H[0]
tensor([[-0.1115, -0.0662,  0.2981, -0.5452],
        [ 0.0896,  0.0750,  0.2003, -0.6533]], grad_fn=<SelectBackward>)

まずは、これまでの要領で隠れ状態ベクトルを得る。 ただし、これはあくまで 1 層目の出力にすぎない。 使っているパラメータの名前も末尾が _l0 になっている。

>>> h0_l0 = torch.tanh(torch.matmul(W_ih_l0, X[0].T).T + b_ih_l0 + b_hh_l0)
>>> h0_l0
tensor([[ 0.0724,  0.3836, -0.3525,  0.4635],
        [ 0.6664,  0.0096, -0.3751,  0.0292]])

続いて、1 層目の出力を 2 層目に入力して計算する。

>>> torch.tanh(torch.matmul(W_ih_l1, h0_l0.T).T + b_ih_l1 + b_hh_l1)
tensor([[-0.1115, -0.0662,  0.2981, -0.5452],
        [ 0.0896,  0.0750,  0.2003, -0.6533]])

これで値が一致した。

そんなかんじで。

参考書籍


  1. 詳しくは「ゼロから作るDeep Learning ❷ ―自然言語処理編」を参照のこと

Python: Google Colaboratory で Cloud TPU を TensorFlow から試してみる

Google Colaboratory では、ランタイムのタイプを変更することで Cloud TPU (Tensor Processing Unit) を利用できる。 Cloud TPU は、Google が開発しているハードウェアアクセラレータの一種。 利用することで、行列計算のパフォーマンス向上が期待できる。 ただ、Cloud TPU は CPU や GPU に比べると扱う上でのクセがそれなりにつよい。 今回は、そんな Cloud TPU を使ってみることにする。

使った環境は次のとおり。

# pip list | grep "^tensorflow "
tensorflow                         2.1.0

もくじ

下準備

あらかじめ、メニューの「ランタイム」から「ランタイムのタイプを変更」を選択して、ハードウェアアクセラレータに TPU を指定しておく。

そして、TensorFlow をインポートしておく。

>>> import tensorflow as tf

TPU に接続する

まず、TPU を利用するには、最初に TPU クラスタへ接続する必要がある。 というのも、TPU のデバイスは実行中のホストで動作しているわけではない。 専用のホストに搭載されていて、それを gRPC 経由で制御するらしい。

はじめに、Google Colaboratory の環境であれば tf.distribute.cluster_resolver.TPUClusterResolver() を引数なしで実行する。 これで、利用可能な TPU クラスタの情報が得られる。

>>> tpu = tf.distribute.cluster_resolver.TPUClusterResolver()

あとは、tf.config.experimental_connect_to_cluster() を使って TPU クラスタに接続する。

>>> tf.config.experimental_connect_to_cluster(tpu)

接続できたら、TPU を初期化する。

>>> tf.tpu.experimental.initialize_tpu_system(tpu)

これで TPU を利用する準備ができた。 tf.config.list_logical_devices() を使って、タイプが TPU のデバイスを調べると、認識しているデバイスの一覧が確認できる。 下記を見て分かるとおり、複数のデバイスが確認できる。

>>> devices = tf.config.list_logical_devices('TPU')
>>> devices
[LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')]

現在利用できる TPU v2 / v3 には、最小構成単位である TPU ボード 1 枚につき 4 つの TPU チップが載っている。 そして、それぞれのチップには 2 つの TPU コアがあるため、合計で 8 つのコアがある。 上記は、各コアがデバイスとして TensorFlow から認識されていることを示している。

これは、GPU であれば筐体に複数枚のグラフィックカードを差していたり、あるいは複数のコアが載った GPU チップを利用している状態に近い。 つまり、TPU のパフォーマンスを最大限に活用しようとすると、必然的に複数のデバイスを使った分散学習をすることになる。

ちなみに、Google Colaboratory で利用できるのは単一の TPU ボードだけっぽい。 Google Cloud 経由で利用する場合には、それ以外に TPU Pod や TPU スライスといった、複数の TPU ボードから成るシステムも利用できる。 その場合も、おそらく見え方としては上記のデバイスが増えるだけなんだろう。

単一のデバイスで演算する

さて、デバイスを認識できるようになったので、早速その中の一つを使って行列演算を試してみよう。

まずは、適当に (2, 3) な形状の行列と (3, 2) な形状の行列を作る。

>>> tf.random.set_seed(42)
>>> x = tf.random.normal(shape=(2, 3))
>>> y = tf.random.normal(shape=(3, 2))

TensorFlow では、tf.device() 関数にデバイスの情報を渡してコンテキストマネージャとして使うと、そのデバイス上で演算を実行できる。 試しに先頭の TPU デバイスを使って行列の積を求めてみよう。

>>> with tf.device(devices[0]):
...      z = tf.matmul(x, y)

次のように、ちゃんと計算できた。

>>> z
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0.5277252 , 4.685486  ],
       [0.8692589 , 0.21500015]], dtype=float32)>

複数のデバイスで演算する

さて、単一のデバイスで計算できることは分かったので、続いては複数のデバイスで分散処理してみよう。

その前に、一旦 TPU の状態を初期化しておく。 TPU で何か新しい処理を始めるときは、初期化しておかないと上手く動作しないことがある。

>>> tf.tpu.experimental.initialize_tpu_system(tpu)

TensorFlow で複数のデバイスを使った分散処理をするときは、tf.distribute.Strategy というオブジェクト (以下、ストラテジオブジェクト) を使うことになる。 このオブジェクトには具体的な実装がいくつかあって、何を使うかによってどのように分散処理を進めるかが決まる。 ただし、TPU を使うときは tf.distribute.TPUStrategy を使うことに決まっているので選択の余地はない。

>>> strategy = tf.distribute.TPUStrategy(tpu)

試しに、先ほどと同じように行列の積を分散処理でやらせてみよう。 そのためには、まず行列の積を計算するためのヘルパー関数を次のように定義しておく。 生の tf.matmul() をそのまま使えないの?と思うけど、どうやら今のところ使えなさそう。

>>> @tf.function
... def matmul_fn(x, y):
...   """行列の積を計算する関数"""
...   z = tf.matmul(x, y)
...   return z
... 

あとは、上記の関数を先ほどのストラテジオブジェクトの run() メソッド経由で呼び出すだけ。

>>> zs = strategy.run(matmul_fn, args=(x, y))

結果を確認してみよう。 PerReplica というオブジェクトで、コアと同じ数の計算結果が得られていることがわかる。 それぞれのコアで同じ計算がされたようだ。

>>> zs
PerReplica:{
  0: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  1: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  2: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  3: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  4: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  5: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  6: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>,
  7: <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.4504242 , -0.07991219],
       [-0.5104828 ,  0.57960224]], dtype=float32)>
}

さて、上記はすべての処理に同じデータを与えているので、結果もすべて同じになっている。 なるほどって感じだけど、これでは複数のデバイスを使っている意味がない。 そこで、続いてはデバイス毎に与えるデータを変えてみよう。

まずは、以下のようにして整数を順番に返す Dataset オブジェクトを作る。

>>> range_dataset = tf.data.Dataset.range(16)
>>> list(range_dataset.as_numpy_iterator())
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

続いて、上記を分散処理に使うデバイスの数に合わせてミニバッチへ分割する。 以下は Strategy#num_replicas_in_sync に 1 をかけているので、各デバイスに 1 つずつサンプルを与える場合の設定。

>>> batch_size = 1 * strategy.num_replicas_in_sync
>>> batch_dataset = range_dataset.batch(batch_size)
>>> list(batch_dataset.as_numpy_iterator())
[array([0, 1, 2, 3, 4, 5, 6, 7]), array([ 8,  9, 10, 11, 12, 13, 14, 15])]

そして、上記の Dataset オブジェクトを Strategy#experimental_distribute_dataset() メソッドに渡す。 すると、DistributedDataset というオブジェクトが得られる。

>>> dist_dataset = strategy.experimental_distribute_dataset(batch_dataset)
>>> dist_dataset
<tensorflow.python.distribute.input_lib.DistributedDataset at 0x7f546167d110>

この DistributedDataset オブジェクトからは、先ほど分散処理の結果として返ってきた PerReplica というオブジェクトが得られる。

>>> ite = iter(dist_dataset)
>>> x = next(ite)
>>> x
PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>,
  2: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>,
  3: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([3])>,
  4: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
  5: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([5])>,
  6: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>,
  7: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([7])>
}

上記の PerReplica オブジェクトを使うと、それぞれのデバイスに対して異なる入力データを与えることができる。 以下の、引数を 2 倍する関数で試してみよう。

>>> @tf.function
... def double_fn(x):
...     """引数を 2 倍する関数"""
...     return x * 2
... 

DistributedDataset オブジェクトから得られる PerReplica オブジェクトを、ストラテジオブジェクト経由で上記の関数に渡す。 すると、以下のように返り値として各要素が 2 倍になった PerReplica オブジェクトが得られることがわかる。

>>> for x in dist_dataset:
...     result = strategy.run(double_fn, args=(x, ))
...     print(result)
PerReplica:{
  0: tf.Tensor([0], shape=(1,), dtype=int64),
  1: tf.Tensor([2], shape=(1,), dtype=int64),
  2: tf.Tensor([4], shape=(1,), dtype=int64),
  3: tf.Tensor([6], shape=(1,), dtype=int64),
  4: tf.Tensor([8], shape=(1,), dtype=int64),
  5: tf.Tensor([10], shape=(1,), dtype=int64),
  6: tf.Tensor([12], shape=(1,), dtype=int64),
  7: tf.Tensor([14], shape=(1,), dtype=int64)
}
PerReplica:{
  0: tf.Tensor([16], shape=(1,), dtype=int64),
  1: tf.Tensor([18], shape=(1,), dtype=int64),
  2: tf.Tensor([20], shape=(1,), dtype=int64),
  3: tf.Tensor([22], shape=(1,), dtype=int64),
  4: tf.Tensor([24], shape=(1,), dtype=int64),
  5: tf.Tensor([26], shape=(1,), dtype=int64),
  6: tf.Tensor([28], shape=(1,), dtype=int64),
  7: tf.Tensor([30], shape=(1,), dtype=int64)
}

上記から、複数のデバイスで、異なる入力データを使った分散処理をできることがわかった。

単一のデバイスで勾配降下法を試す

次は、また単一のデバイスに戻って、自動微分を使った勾配降下法が機能することを確認してみよう。 要するに、ニューラルネットワークが最適化できる本質的な部分の動作を見ておく。

以下のサンプルコードでは、最小化したい関数 objective() を定義している。 そして、それに適当な初期値を与えて、SGD をオプティマイザに最小化している。 実際に損失と勾配を計算してパラメータを更新しているのは training_step() という関数。

# -*- coding: utf-8 -*-

from __future__ import annotations

from pprint import pprint

import tensorflow as tf


def objective(params: tf.Variable) -> tf.Tensor:
    """最小化したい関数"""
    # x_0^2 + x_1^2
    loss = params[0] ** 2 + params[1] ** 2
    return loss


def main():
    # TPU クラスタに接続する
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    devices = tf.config.list_logical_devices('TPU')
    print('TPU devices:', end='')
    pprint(devices)

    # 使用するオプティマイザ
    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-1)

    @tf.function
    def training_step(params: tf.Variable) -> None:
        """勾配降下法を使った最適化の 1 ステップ"""
        with tf.GradientTape() as t:
            # 損失を計算する
            loss = objective(params)
        # 勾配を計算する
        grads = t.gradient(loss, params)
        # パラメータを更新する
        optimizer.apply_gradients([(grads, params)])
        # tf.print(params)  # 少なくとも今の TPU では利用できない...

    # 初期値を用意する
    tensor = tf.constant([1., 4.], dtype=tf.float32)

    # 先頭の TPU デバイスで計算する
    with tf.device(devices[0]):
        # TPU デバイス上に Variable を用意する
        params = tf.Variable(tensor, trainable=True)
        # 最適化のループ
        for _ in range(20):  # 回数は適当
            training_step(params)

    # 結果を出力する
    print(f'{objective(params)} @ {params.numpy()}')


if __name__ == '__main__':
    main()

上記の実行結果は次のとおり。

(snip) ...
TPU devices:[LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')]
0.0022596875205636024 @ [0.01152921 0.04611686]

最適化によって、最終的な objective(params) の結果が小さくなっていることが確認できる。

複数のデバイスで CNN を tf.keras で学習する

次は、これまでのサンプルよりも少し実用性が高めのコードを試す。 具体的には CNN のモデルを tf.keras を使って組んで、CIFAR-10 のデータを学習させてみる。

ポイントとしては、モデルやメトリックなどをストラテジオブジェクトのスコープで組み立てるところ。 こうすると、たとえば Variable オブジェクトは内部的にデバイス間で値が同期できる MirroredVariable になったりするらしい。

CIFAR-10 のデータはメモリに収まるサイズなので、オンメモリのデータから Dataset オブジェクトを生成している。 これが、もしメモリに収まりきらないときは TFRecord フォーマットで GCS に保存する必要がある。

TPU を使う際には、CPU や GPU の環境で動作したコードを転用するのがベストプラクティスらしい。 以下のサンプルコードでは、それがやりやすいように環境毎のストラテジオブジェクトを取得できる detect_strategy() という関数を定義した。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf


def detect_strategy():
    """利用できるハードウェアアクセラレータ毎に適した tf.distribute.Strategy を返す関数"""
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        devices = tf.config.list_logical_devices('TPU')
        if len(devices) > 0:
            # TPU が利用できる
            return tf.distribute.TPUStrategy(tpu)
    except ValueError:
        pass

    devices = tf.config.list_logical_devices('GPU')
    if len(devices) > 0:
        # GPU が利用できる
        return tf.distribute.MirroredStrategy()

    # Default
    return tf.distribute.get_strategy()


def normalize(element):
    """画像データを浮動小数点型にキャストして正規化する処理"""
    image = element['image']
    normalized_image = tf.cast(image, tf.float32) / 255.
    label = element['label']
    return normalized_image, label


def datafeed_pipeline(x, y, batch_size):
    """オンメモリのテンソルからデータを読み出す Dataset パイプライン"""
    mappings = {
        'image': x,
        'label': y,
    }
    ds = tf.data.Dataset.from_tensor_slices(mappings)
    ds = ds.map(normalize)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.cache()
    return ds


def main():
    # データセットをオンメモリに読み込む
    (train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar10.load_data()

    # データセットの仕様
    image_shape = train_x.shape[1:]
    num_classes = 10

    # 乱数シードを設定しておく
    tf.random.set_seed(42)

    # 環境に応じたストラテジオブジェクトを取得する
    strategy = detect_strategy()

    # データ供給のパイプラインを Dataset API で構築する
    device_batch_size = 512  # デバイス単位で見たバッチサイズ
    global_batch_size = strategy.num_replicas_in_sync * device_batch_size
    ds_train = datafeed_pipeline(train_x, train_y, global_batch_size)
    ds_test = datafeed_pipeline(test_x, test_y, global_batch_size)

    with strategy.scope():
        # ストラテジオブジェクトのスコープでモデルを組み立てる
        # これによって内部で使われる Variable の型などが変わる
        model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=image_shape),
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(num_classes, activation='softmax')
        ])
        model.compile(
            loss='sparse_categorical_crossentropy',
            optimizer='adam',
            metrics=['sparse_categorical_accuracy'],
        )

    # モデルの概要
    print(model.summary())

    # モデルを学習させる
    fit_callbacs = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                         patience=5,
                                         mode='min'),
    ]
    model.fit(ds_train,
              epochs=100,
              validation_data=ds_test,
              callbacks=fit_callbacs,
              )

    # テストデータを評価する
    scr, sca = model.evaluate(ds_test)
    print(f'Loss: {scr}, Accuracy: {sca}')


if __name__ == '__main__':
    main()

上記を実行してみよう。 今回は、精度とかは横に置いておく。

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_3 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 4, 4, 128)         73856     
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 64)                131136    
_________________________________________________________________
dense_3 (Dense)              (None, 10)                650       
=================================================================
Total params: 225,034
Trainable params: 225,034
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/100
13/13 [==============================] - 10s 431ms/step - loss: 2.1851 - sparse_categorical_accuracy: 0.1976 - val_loss: 1.9839 - val_sparse_categorical_accuracy: 0.2979
Epoch 2/100
13/13 [==============================] - 1s 88ms/step - loss: 1.9094 - sparse_categorical_accuracy: 0.3103 - val_loss: 1.8427 - val_sparse_categorical_accuracy: 0.3334
Epoch 3/100
13/13 [==============================] - 1s 85ms/step - loss: 1.7798 - sparse_categorical_accuracy: 0.3575 - val_loss: 1.7185 - val_sparse_categorical_accuracy: 0.3849

...(snip)...

Epoch 71/100
13/13 [==============================] - 1s 87ms/step - loss: 0.8315 - sparse_categorical_accuracy: 0.7118 - val_loss: 0.9328 - val_sparse_categorical_accuracy: 0.6744
Epoch 72/100
13/13 [==============================] - 1s 89ms/step - loss: 0.8223 - sparse_categorical_accuracy: 0.7148 - val_loss: 0.9292 - val_sparse_categorical_accuracy: 0.6737
Epoch 73/100
13/13 [==============================] - 1s 88ms/step - loss: 0.8054 - sparse_categorical_accuracy: 0.7224 - val_loss: 0.9340 - val_sparse_categorical_accuracy: 0.6756
3/3 [==============================] - 1s 16ms/step - loss: 0.9340 - sparse_categorical_accuracy: 0.6756
Loss: 0.9339648485183716, Accuracy: 0.675599992275238

ちゃんと動いているようだ。 カスタムトレーニングループを使うときは、また気にするところがあるみたいだけど、今回は取り扱わない。

参考

cloud.google.com

cloud.google.com

cloud.google.com

www.tensorflow.org

www.tensorflow.org

www.tensorflow.org

www.tensorflow.org

Python: Session State API で Streamlit をステートフルにする

これまで Streamlit で書いた Web アプリケーションは、基本的にステートレスだった。 つまり、何らかのイベントが生じてアプリケーションのコードが再評価されると、ウィジェットを除くほとんどすべてのオブジェクトの状態はリセットされていた。 アプリケーションをステートフルにする非公式なスニペットは一部で知られていたが、数行で使い始められるような気軽さはなかった。

そうした中、先日リリースされた Streamlit のバージョン 0.85 には、Session State API という機能が追加された。 この API は、読んで字のごとく Streamlit の Web アプリケーションに限定的ながらステートを持たせることができる機能となっている。

docs.streamlit.io

今回は、追加された Session State API を触ってみることにした。

使った環境は次のとおり。

$ sw_vers                 
ProductName:    macOS
ProductVersion: 11.5.1
BuildVersion:   20G80
$ python -V                   
Python 3.9.6
$ pip list | grep -i streamlit 
streamlit                0.85.0

もくじ

下準備

まずは肝心の Streamlit と、それ以外に可視化で使うデータセットを読み込むために Seaborn をインストールしておく。

$ pip install streamlit seaborn

ボタンを押すとカウンタが増減するサンプルコード

早速だけど、以下にカウンタの値をボタンに連動して増減させるサンプルコードを示す。 Session State API では、session_state という名前の辞書ライクなオブジェクトを扱う。 このオブジェクトに格納したオブジェクト (以下、便宜的にセッション変数と呼ぶ) は、アプリケーションが再評価されても消えずに引き継がれる。 セッション変数の値はウィジェットに追加されたコールバック関数の機能を介して更新する。 以下では st.button()on_change オプションにセッション変数の値を増減させるコールバック関数を登録している。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # セッション変数が存在しないときは初期化する
    # ここでは 'counter' というセッション変数を作っている
    if 'counter' not in st.session_state:
        st.session_state['counter'] = 0

    # セッション変数の状態を表示する
    msg = f"Counter value: {st.session_state['counter']}"
    st.write(msg)

    # ボタンが押されたときに発火するコールバック
    def plus_one_clicks():
        # ボタンが押されたらセッション変数の値を増やす
        st.session_state['counter'] += 1
    # ボタンを作成するときにコールバックを登録しておく
    st.button(label='+1',
              on_click=plus_one_clicks)

    # ボタンが押されたらセッション変数の値を減らすバージョン
    def minus_one_clicks():
        st.session_state['counter'] -= 1
    st.button(label='-1',
              on_click=minus_one_clicks)

    # セッション変数の値をリセットするボタン
    def reset_clicks():
        st.session_state['counter'] = 0
    st.button(label='Reset',
              on_click=reset_clicks)


if __name__ == '__main__':
    main()

上記を保存したら Streamlit 経由で実行しよう。

$ streamlit run example.py

デフォルトでは自動で Web ブラウザが開くはず。 開かない場合には以下でアクセスする。

$ open http://localhost:8501

すると、次のような WebUI が表示される。 ボタンを押すと、それに連動してカウンタの値が増えたり減ったりする。

f:id:momijiame:20210728222510p:plain

これまで、ボタンをクリックするとイベントが生じてアプリケーションが再評価され、オブジェクトは一通りリセットされていた。 しかし、Session State API を使うことで、それが回避できている。

Session State API を使う上での注意点は次のようなものがありそう。

  • (当たり前だけど) 存在しない変数 (辞書のキー) を使おうとすると例外になる
  • ブラウザをリロードすると変数はリセットされる
  • ページを複数のタブで開いたとしても変数は共有されない

データフレームのページネーションを実現するサンプルコード

続いては、もうちょっと実用的な例としてページネーションを実現してみる。 以下のサンプルコードでは、タイタニックデータセットを読み込んで、それを 10 件ずつ表示するものになっている。 表示している場所をセッション変数で管理することでページネーションが実現できる。

# -*- coding: utf-8 -*-

import math

import seaborn as sns
import streamlit as st


@st.cache
def load_dataset():
    """Titanic データセットを読み込む関数"""
    return sns.load_dataset('titanic')


def main():
    # データセットを読み込んで必要なページ数を計算する
    df = load_dataset()
    rows_per_page = 10
    total_pages = math.ceil(len(df) / rows_per_page)

    if 'page' not in st.session_state:
        st.session_state['page'] = 1

    left_col, center_col, right_col = st.beta_columns(3)

    # ページ数の増減ボタン
    with left_col:
        def minus_one_page():
            st.session_state['page'] -= 1
        if st.session_state['page'] > 1:
            st.button(label='<< Prev',
                      on_click=minus_one_page)

    with right_col:
        def plus_one_page():
            st.session_state['page'] += 1
        if st.session_state['page'] < total_pages:
            st.button(label='Next >>',
                      on_click=plus_one_page)

    # 現在のページ番号
    with center_col:
        st.write(f"Page: {st.session_state['page']} / {total_pages}")

    # ページ番号に応じた範囲のデータフレームを表示する
    start_iloc = (st.session_state['page'] - 1) * rows_per_page
    end_iloc = start_iloc + rows_per_page + 1
    st.write(df.iloc[start_iloc:end_iloc])


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ streamlit run example.py

すると、ページ単位でデータフレームの内容が確認できる画面が表示される。

f:id:momijiame:20210728224651p:plain

いじょう。

Python: TFRecord フォーマットについて

TFRecord フォーマットは、TensorFlow がサポートしているデータセットの表現形式の一つ。 このフォーマットは、一言で表すと TensorFlow で扱うデータを Protocol Buffers でシリアライズしたものになっている。 特に、Dataset API との親和性に優れていたり、Cloud TPU を扱う上で実用上はほぼ必須といった特徴がある。 今回は、そんな TFRecord の扱い方について見ていくことにする。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.5
BuildVersion:   20G71
$ python -V
Python 3.9.6
$ pip list | grep -i tensorflow
tensorflow               2.5.0
tensorflow-datasets      4.3.0
tensorflow-estimator     2.5.0
tensorflow-metadata      1.1.0

もくじ

下準備

あらかじめ TensorFlow をインストールしておく。

$ pip install tensorflow tensorflow_datasets

そして、Python のインタプリタを起動する。

$ python

tensorflow パッケージを tf という名前でインポートしておく。

>>> import tensorflow as tf

概要

TFRecord フォーマットを TensorFlow の Python API から扱おうとすると、いくつかのオブジェクト (クラス) が登場する。 ただ、意外とその数が多いので、理解する上でとっつきにくさを生んでいる感じがある。 そこで、まずは一通りトップダウンで説明することにする。

それぞれの関係は、あるオブジェクトが別のオブジェクトを内包するようになっている。 階層構造で表すと、以下のような感じ。 階層構造で上にあるオブジェクトが、下にあるオブジェクトを内包する。

  • tf.Example
    • tf.train.Features
      • tf.train.Feature
        • tf.train.BytesList
        • tf.train.FloatList
        • tf.train.Int64List

tf.Example

tf.Example は、データセットに含まれる特定のサンプル (データポイント) に対応したオブジェクトになっている。 たとえば、教師あり学習のデータセットなら、あるサンプルの説明変数と目的変数のペアがこれに当たるイメージ。 ただ、サンプルに対応しているオブジェクトというだけなので、別に必要なら何を入れても構わない。 たとえば、画像データなら付随するメタデータとして横幅 (Width) と縦幅 (Height) のピクセル数が必要とかはあるはず。

このオブジェクトは単一の tf.train.Features というオブジェクトを内包する。

tf.train.Features

tf.train.Features は、名前から複数の特徴量を束ねるオブジェクトっぽいけど、まあ概ねその理解で正しいと思う。 概ね、というのは前述したとおりメタデータ的なものや説明変数も含まれるため。

このオブジェクトは複数の tf.train.Feature を内包する。

tf.train.Feature

tf.train.Feature は、特定の特徴量ないしメタデータや説明変数に対応したオブジェクト。

このオブジェクトは単一の tf.train.BytesList または tf.train.FloatList または tf.train.Int64List を内包する。

tf.train.BytesList

tf.train.BytesList は、特徴量としてバイト列のリストを扱うために用いるオブジェクト。

このオブジェクトは bytes 型のリストを内包する。 任意のバイト列を扱えるので、何らかのオブジェクトをシリアライズしたものを入れることができる。 詳しくは後述するけど、この特性は割と重要になってくる。 なぜなら、他の tf.train.FloatListtf.train.Int64List は一次元配列しか扱えないため。

tf.train.FloatList

tf.train.FloatList は、特徴量として浮動小数点のリストを扱うために用いるオブジェクト。

このオブジェクトは浮動小数点のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。

tf.train.Int64List

tf.train.Int64List は、特徴量として整数のリストを扱うために用いるオブジェクト。

このオブジェクトは整数のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。

基本的な使い方

一通りのオブジェクトの説明が終わったので、ここからは実際にコードを実行しながら試してみよう。 先ほどの説明とは反対に、ボトムアップでの実行になる。 これは、そうでないとオブジェクトを組み立てられないため。

まず、最もプリミティブなオブジェクトである tf.train.Int64Listtf.train.FloatListtf.train.BytesList から。 これらは前述したとおりバイト列・浮動小数点・整数のリストを内包するオブジェクトになっている。

>>> int64_list = tf.train.Int64List(value=[1, 2, 3])
>>> int64_list
value: 1
value: 2
value: 3

>>> float_list = tf.train.FloatList(value=[1., 2., 3.])
>>> float_list
value: 1.0
value: 2.0
value: 3.0

>>> bytes_list = tf.train.BytesList(value=[b'x', b'y', b'z'])
>>> bytes_list
value: "x"
value: "y"
value: "z"

前述したとおり、value には一次元配列しか渡せないらしい。 渡そうとすると次のようにエラーになる。

>>> import numpy as np
>>> x = np.random.randint(low=0, high=100, size=(3, 2))
>>> tf.train.Int64List(value=x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: only integer scalar arrays can be converted to a scalar index
>>> tf.train.Int64List(value=[[1, 2], [3, 4]])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: [1, 2] has type list, but expected one of: int, long

この仕様だと、画像データとか扱うときに面倒くさくない?と思うはず。 そんなときは、多次元配列を次のようにバイト列にシリアライズしてやれば良い。

>>> serialized_x = tf.io.serialize_tensor(x)
>>> serialized_x
<tf.Tensor: shape=(), dtype=string, numpy=b'\x08\t\x12\x08\x12\x02\x08\x03\x12\x02\x08\x02"0\x08\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00\x006\x00\x00\x00\x00\x00\x00\x00H\x00\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00'>

バイト列になっていれば tf.train.BytesList に入れることができる。

>>> tf.train.BytesList(value=[serialized_x.numpy()])
value: "\010\t\022\010\022\002\010\003\022\002\010\002\"0\010\000\000\000\000\000\000\000^\000\000\000\000\000\000\000\017\000\000\000\000\000\000\0006\000\000\000\000\000\000\000H\000\000\000\000\000\000\000G\000\000\000\000\000\000\000"

なお、もちろん多次元配列は Flatten して、別で保存しておいた shape の情報を使って復元してもかまわない。

続いては tf.train.Feature を使って先ほどの *List オブジェクトをラップする。 型ごとに引数が異なるため、そこだけ注意する。

>>> int64_feature = tf.train.Feature(int64_list=int64_list)
>>> int64_feature
int64_list {
  value: 1
  value: 2
  value: 3
}

>>> float_feature = tf.train.Feature(float_list=float_list)
>>> float_feature
float_list {
  value: 1.0
  value: 2.0
  value: 3.0
}

>>> bytes_feature = tf.train.Feature(bytes_list=bytes_list)
>>> bytes_feature
bytes_list {
  value: "x"
  value: "y"
  value: "z"
}

続いては、tf.train.Features を使って、複数の tf.train.Feature を束ねる。

>>> feature_mappings = {
...     'feature0': int64_feature,
...     'feature1': float_feature,
...     'feature2': bytes_feature,
... }
>>> features = tf.train.Features(feature=feature_mappings)
>>> features
feature {
  key: "feature0"
  value {
    int64_list {
      value: 1
      value: 2
      value: 3
    }
  }
}
feature {
  key: "feature1"
  value {
    float_list {
      value: 1.0
      value: 2.0
      value: 3.0
    }
  }
}
feature {
  key: "feature2"
  value {
    bytes_list {
      value: "x"
      value: "y"
      value: "z"
    }
  }
}

あとは tf.train.Example でラップするだけ。

>>> example = tf.train.Example(features=features)
>>> example
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature1"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "x"
        value: "y"
        value: "z"
      }
    }
  }
}

上記で完成した tf.train.Example オブジェクトがデータセットの中の特定のサンプルに対応することになる。 まあ、使っているのがダミーデータなのでちょっとイメージがつきにくいかもしれないけど。

tf.train.Example オブジェクトは SerializeToString() メソッドを使うことでバイト列にシリアライズできる。 つまり、.tfrecord の拡張子がついた TFRecord ファイルは、このシリアライズされたバイト列が書き込まれている。

>>> serialized_data = example.SerializeToString()
>>> serialized_data
b'\nL\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03'

ちなみに、これまでに登場したオブジェクトも、それぞれ単独で SerializeToString() を使えばシリアライズできる。

>>> int64_list.SerializeToString()
b'\n\x03\x01\x02\x03'
>>> int64_feature.SerializeToString()
b'\x1a\x05\n\x03\x01\x02\x03'
>>> features.SerializeToString()
b'\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@'

そして、シリアライズしたバイト列は、tf.train.Example.FromString() 関数を使ってデシリアライズできる。

>>> deserialized_object = tf.train.Example.FromString(serialized_data)
>>> deserialized_object
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature1"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "x"
        value: "y"
        value: "z"
      }
    }
  }
}

データセットを TFRecord ファイルに変換する

基本的な使い方がわかったところで、続いては実際にデータセットを TFRecord 形式のファイルに変換してみよう。

使う題材は特に何でも良いんだけど、今回は tensorflow-datasets 経由でロードした CIFAR10 を使うことにする。

>>> import tensorflow_datasets as tfds
>>> ds_train = tfds.load('cifar10', split='train')

このデータセットには (32, 32, 3) の形状を持った画像のテンソルと、それに対応したラベルが入っている。 画像のデータが一次元になっていないので、わざわざ Flatten する代わりに前述したシリアライズしてバイト列にする作戦でいこう。

>>> from pprint import pprint
>>> pprint(ds_train.element_spec)
{'id': TensorSpec(shape=(), dtype=tf.string, name=None),
 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None),
 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}

まず、特定のサンプルに対応したテンソルとラベルを前述した手順でシリアライズする関数を次のように定義する。

>>> def serialize_example(image, label):
...     """1 サンプルを Protocol Buffers で TFRecord フォーマットにシリアライズする関数"""
...     # 画像データをバイト列にシリアライズする
...     serialized_image = tf.io.serialize_tensor(image)
...     image_bytes_list = tf.train.BytesList(value=[serialized_image.numpy()])
...     # ラベルデータ
...     label_int64_list = tf.train.Int64List(value=[label.numpy()])
...     # 特徴量を Features にまとめる
...     feature_mappings = {
...         'image': tf.train.Feature(bytes_list=image_bytes_list),
...         'label': tf.train.Feature(int64_list=label_int64_list),
...     }
...     features = tf.train.Features(feature=feature_mappings)
...     # Example にまとめる
...     example_proto = tf.train.Example(features=features)
...     # バイト列にする
...     return example_proto.SerializeToString()
... 

続いて、データセットから取り出したサンプルに上記の関数を定義するヘルパー関数を次のように定義する。

>>> def tf_serialize_example(element):
...     """シリアライズ処理を tf.data.Dataset に適用するためのヘルパー関数"""
...     # イメージとラベルを取り出す
...     image = element['image']
...     label = element['label']
...     tf_string = tf.py_function(
...         serialize_example, 
...         (image, label),
...         tf.string,
...     )
...     return tf.reshape(tf_string, ())
... 

Dataset API を使って、上記の関数をデータセットに適用する。

>>> serialized_ds_train = ds_train.map(tf_serialize_example)

イテレータにしてサンプルをひとつ取り出してみよう。

>>> ite = iter(serialized_ds_train)
>>> next(ite)
<tf.Tensor: shape=(), dtype=string, numpy=b'\n\xb6\x18\n\x0e\n\x05label\x12\x05\x1a\x03\n\x01\x07\n\xa3\x18\n\x05image\x12\x99
...

ちゃんとシリアライズされたバイト列が確認できる。

あとは、シリアライズしたバイト列が取り出せる Dataset オブジェクトを引数にして tf.data.experimental.TFRecordWriter#write() を呼び出すだけ。

>>> filename = 'cifar10-train.tfrecord'
>>> writer = tf.data.experimental.TFRecordWriter(filename)
>>> writer.write(serialized_ds_train)

上記はデータセットを丸ごと 1 つのファイルにしてる。 公式ドキュメントを見ると、パフォーマンスを考えると 100 ~ 200MB 程度のサイズで複数に分割するのがおすすめらしい。 これは、おそらく GCS とかにアップロードして並列で読み出すときの話。

カレントディレクトリを確認すると、次のようにファイルが書き出されているはず。

$ du -m cifar10-train.tfrecord
161    cifar10-train.tfrecord
$ file cifar10-train.tfrecord 
cifar10-train.tfrecord: data

TFRecord ファイルからデータを読み出す

次は上記のファイルを読み込んでデシリアライズする。 まず、tf.data.TFRecordDataset に TFRecord ファイルのパスを指定する。 これで、シリアライズしたバイト列を読み出せる Dataset オブジェクトが得られる。

>>> loaded_ds_train = tf.data.TFRecordDataset(filename)

上記からは tf.Example に対応したバイト列が 1 つずつ読み出せる。 なので、それを元のテンソルに戻す関数を次のように定義する。

>>> def deserialize_example(example_proto):
...     """バイト列をデシリアライズしてオブジェクトに戻す関数"""
...     # バイト列のフォーマット
...     feature_description = {
...         'image': tf.io.FixedLenFeature([], tf.string),
...         'label': tf.io.FixedLenFeature([], tf.int64),
...     }
...     # Tensor オブジェクトの入った辞書に戻す
...     parsed_element = tf.io.parse_single_example(example_proto,
...                                                 feature_description)
...     # 画像はバイト列になっているのでテンソルに戻す
...     parsed_element['image'] = tf.io.parse_tensor(parsed_element['image'],
...                                                  out_type=tf.uint8)
...     return parsed_element
... 

上記を先ほどの Dataset オブジェクトに適用する。

>>> deserialized_ds_train = loaded_ds_train.map(deserialize_example)

試しに中身を取り出してみると、ちゃんと画像とラベルのテンソルが復元できていることがわかる。

>>> ite = iter(deserialized_ds_train)
>>> next(ite)
{'image': <tf.Tensor: shape=(32, 32, 3), dtype=uint8, numpy=
array([[[143,  96,  70],
        [141,  96,  72],
        [135,  93,  72],
        ...,
        [ 96,  37,  19],
        [105,  42,  18],
        [104,  38,  20]],

       [[128,  98,  92],
        [146, 118, 112],
        [170, 145, 138],
        ...,
        [108,  45,  26],
        [112,  44,  24],
        [112,  41,  22]],

       [[ 93,  69,  75],
        [118,  96, 101],
        [179, 160, 162],
        ...,
        [128,  68,  47],
        [125,  61,  42],
        [122,  59,  39]],

       ...,

       [[187, 150, 123],
        [184, 148, 123],
        [179, 142, 121],
        ...,
        [198, 163, 132],
        [201, 166, 135],
        [207, 174, 143]],

       [[187, 150, 117],
        [181, 143, 115],
        [175, 136, 113],
        ...,
        [201, 164, 132],
        [205, 168, 135],
        [207, 171, 139]],

       [[195, 161, 126],
        [187, 153, 123],
        [186, 151, 128],
        ...,
        [212, 177, 147],
        [219, 185, 155],
        [221, 187, 157]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=7>}

いじょう。

参考

www.tensorflow.org