CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM の学習率を動的に制御する

LightGBM の学習率は基本的に低い方が最終的に得られるモデルの汎化性能が高くなることが経験則として知られている。 しかしながら、学習率が低いとモデルの学習に多くのラウンド数、つまり計算量を必要とする。 そこで、今回は学習率を学習の過程において動的に制御するコールバックを実装してみた。

きっかけは以下のツイートを見たこと。

なるほど面白そう。

下準備

使用するライブラリをあらかじめインストールしておく。

$ pip install lightgbm seaborn scikit-learn

学習率を動的に制御するコールバック

早速だけど、以下が学習率を動的に制御するコールバックを実装したサンプルコードとなる。 コールバックの本体は LrSchedulingCallback というクラスで実装している。 このクラスをインスタンス化するときに、制御方法を記述した関数を渡す。 以下であれば sample_scheduler_func() という名前で定義した。 この関数は学習の履歴などを元に新たな学習率を決めて返すインターフェースとなっている。 今回はお試しとして 10 ラウンドごとに学習率を下限の 0.01 まで半減させ続けるという単純な戦略を記述してみた。 もちろん、これがベストというわけではなくて、あくまでサンプルとして簡単なものを書いてみたに過ぎない。 なお、EarlyStopping していないのは学習の過程を最後まで観察するため。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import lightgbm as lgb
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.model_selection import StratifiedKFold


def sample_scheduler_func(current_lr, eval_history, best_round, is_higher_better):
    """次のラウンドで用いる学習率を決定するための関数 (この中身を好きに改造する)

    :param current_lr: 現在の学習率 (指定されていない場合の初期値は None)
    :param eval_history: 検証用データに対する評価指標の履歴
    :param best_round: 現状で最も評価指標の良かったラウンド数
    :param is_higher_better: 高い方が性能指標として優れているか否か
    :return: 次のラウンドで用いる学習率

    NOTE: 学習を打ち切りたいときには callback.EarlyStopException を上げる
    """
    # 学習率が設定されていない場合のデフォルト
    current_lr = current_lr or 0.05

    # 試しに 20 ラウンド毎に学習率を半分にしてみる
    if len(eval_history) % 20 == 0:
        current_lr /= 2

    # 小さすぎるとほとんど学習が進まないので下限も用意する
    min_threshold = 0.001
    current_lr = max(min_threshold, current_lr)

    return current_lr


class LrSchedulingCallback(object):
    """ラウンドごとの学習率を動的に制御するためのコールバック"""

    def __init__(self, strategy_func):
        # 学習率を決定するための関数
        self.scheduler_func = strategy_func
        # 検証用データに対する評価指標の履歴
        self.eval_metric_history = []

    def __call__(self, env):
        # 現在の学習率を取得する
        current_lr = env.params.get('learning_rate')

        # 検証用データに対する評価結果を取り出す (先頭の評価指標)
        first_eval_result = env.evaluation_result_list[0]
        # スコア
        metric_score = first_eval_result[2]
        # 評価指標は大きい方が優れているか否か
        is_higher_better = first_eval_result[3]

        # 評価指標の履歴を更新する
        self.eval_metric_history.append(metric_score)
        # 現状で最も優れたラウンド数を計算する
        best_round_find_func = np.argmax if is_higher_better else np.argmin
        best_round = best_round_find_func(self.eval_metric_history)

        # 新しい学習率を計算する
        new_lr = self.scheduler_func(current_lr=current_lr,
                                     eval_history=self.eval_metric_history,
                                     best_round=best_round,
                                     is_higher_better=is_higher_better)

        # 次のラウンドで使う学習率を更新する
        update_params = {
            'learning_rate': new_lr,
        }
        env.model.reset_parameter(update_params)
        env.params.update(update_params)

    @property
    def before_iteration(self):
        # コールバックは各イテレーションの後に実行する
        return False


def main():
    # Titanic データセットを読み込む
    dataset = sns.load_dataset('titanic')

    # 重複など不要な特徴量は落とす
    X = dataset.drop(['survived',
                      'class',
                      'who',
                      'embark_town',
                      'alive'], axis=1)
    y = dataset.survived

    # カテゴリカル変数を指定する
    categorical_columns = ['pclass',
                           'sex',
                           'embarked',
                           'adult_male',
                           'deck',
                           'alone']
    X = X.astype({c: 'category'
                  for c in categorical_columns})

    # LightGBM のデータセット表現に直す
    lgb_train = lgb.Dataset(X, y)

    # コールバックを用意する
    lr_scheduler_cb = LrSchedulingCallback(strategy_func=sample_scheduler_func)
    callbacks = [
        lr_scheduler_cb,
    ]

    # 二値分類を LogLoss で評価する
    lgb_params = {
        'objective': 'binary',
        'metrics': 'binary_logloss',
        'min_data_in_leaf': 10,
    }
    # 5-Fold CV
    skf = StratifiedKFold(n_splits=5,
                          shuffle=True,
                          random_state=42)

    # 動的に学習率を制御した場合
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        callbacks=callbacks,
                        )
    dynamic_lr = cv_results['binary_logloss-mean']

    # 学習率を 0.1 に固定した場合
    lgb_params.update({'learning_rate': 0.1})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_1 = cv_results['binary_logloss-mean']

    # 学習率を 0.05 に固定した場合
    lgb_params.update({'learning_rate': 0.05})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_05 = cv_results['binary_logloss-mean']

    # 学習率を 0.01 に固定した場合
    lgb_params.update({'learning_rate': 0.01})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_01 = cv_results['binary_logloss-mean']

    # 最小の損失を比較する
    print('min loss value (lr=dynamic):', min(dynamic_lr))
    print('min loss value (lr=0.1):', min(static_lr_0_1))
    print('min loss value (lr=0.05):', min(static_lr_0_05))
    print('min loss value (lr=0.01):', min(static_lr_0_01))

    # 最小の損失が得られたラウンド数を比較する
    print('min loss round (lr=dynamic):', np.argmin(dynamic_lr))
    print('min loss round (lr=0.1):', np.argmin(static_lr_0_1))
    print('min loss round (lr=0.05):', np.argmin(static_lr_0_05))
    print('min loss round (lr=0.01):', np.argmin(static_lr_0_01))

    # グラフにプロットする
    sns.lineplot(np.arange(len(dynamic_lr)),
                 dynamic_lr,
                 label='LR=dynamic')
    sns.lineplot(np.arange(len(static_lr_0_1)),
                 static_lr_0_1,
                 label='LR=0.1')
    sns.lineplot(np.arange(len(static_lr_0_05)),
                 static_lr_0_05,
                 label='LR=0.05')
    sns.lineplot(np.arange(len(static_lr_0_01)),
                 static_lr_0_01,
                 label='LR=0.01')
    plt.title('learning rate control comparison')
    plt.xlabel('rounds')
    plt.ylabel('logloss')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみる。 最も性能が良かったモデルのロジスティック損失とラウンド数が表示される。 最適なモデルの性能、学習に要するラウンド数ともに 0.1 固定と 0.01 固定の間にあることが分かる。

$ python dynamiclr.py
...(snip)...
min loss value (lr=dynamic): 0.421097448234137
min loss value (lr=0.1): 0.42265029913071334
min loss value (lr=0.05): 0.4221001657363532
min loss value (lr=0.01): 0.42100303104081405
min loss round (lr=dynamic): 84
min loss round (lr=0.1): 18
min loss round (lr=0.05): 38
min loss round (lr=0.01): 196

そして、各条件における検証用データに対する評価指標の推移をプロットしたグラフが次の通り。 学習率を動的に制御しているパターンは、0.1 固定ほどではないにせよ早く性能が収束していることが分かる。 まあ、とはいえこれくらいなら lr=0.01 ~ 0.05 の間に似たような特性の学習率がいるかもしれない。

f:id:momijiame:20190720222635p:plain

いじょう。 こんな上手くいくスケジューラが書けた、みたいな話があったら教えてほしいな。

GNU Coreutils の shred でストレージのデータを削除する

HDD や SSD といったストレージを廃棄あるいは売却するとき、単に保存されているファイルを削除しただけでは復元のリスクが高い。 これは、本のメタファーでいえば索引の部分を消しただけで本文は丸々残っている、といった状況になっているため。 そこで、何回かデータを実際に上書きすることで復元のリスクを低減させることが望ましい。 今回は GNU Coreutils の shred を使ってランダムなデータをストレージに書き込んでみる。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132

まずは Homebrew で GNU Coreutils をインストールしておく。

$ brew install coreutils

続いて、データを消去したいストレージのデバイス名を diskutil list で確認しておく。 今回であれば /dev/disk2 にあった。

$ diskutil list
...(snip)...

/dev/disk2 (external, physical):
   #:                       TYPE NAME                    SIZE       IDENTIFIER
   0:     FDisk_partition_scheme                        *80.0 GB    disk2

あとは、このデバイスを -v オプションで指定して shred コマンドを使うだけ。 デフォルトではランダムなデータを 3 回に渡って書き込むことになる。

$ sudo shred -v /dev/disk2
shred: /dev/disk2: pass 1/3 (random)...
shred: /dev/disk2: pass 1/3 (random)...69MiB
shred: /dev/disk2: pass 1/3 (random)...144MiB
shred: /dev/disk2: pass 1/3 (random)...219MiB
shred: /dev/disk2: pass 1/3 (random)...293MiB
...(snip)

いじょう。

dd コマンドの進捗を確認する

dd コマンドの進捗を確認したいときは macOS であれば SIGINFO を、Linux (GNU Coreutils) であれば SIGUSR1 を送れば良い。 また、GNU Coreutils の dd には status=progress というオプションもある。

macOS

まずは macOS から。

使った環境は次の通り。

$ sw_vers                          
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132

適当にランダムな値でファイルを作らせる。

$ dd if=/dev/urandom of=example bs=1m count=1024

別のターミナルを開いたら killall を使って dd に SIGINFO を送りつける。

$ sudo killall -INFO dd

すると、次のように現状が表示される。

$ dd if=/dev/urandom of=example bs=1m count=1024
298+0 records in
298+0 records out
312475648 bytes transferred in 11.867853 secs (26329585 bytes/sec)

定期的に表示させたいときは watch コマンドと組み合わせると良い。

$ brew install watch

以下のようにすると 1 秒ごとに SIGINFO を送ることができる。

$ watch -n 1 sudo killall -INFO dd

結果として、次のように 1 秒ごとに進捗が表示される。

$ dd if=/dev/urandom of=example bs=1m count=1024
...(snip)...
501+0 records in
500+0 records out
524288000 bytes transferred in 19.181651 secs (27332788 bytes/sec)
528+0 records in
528+0 records out
553648128 bytes transferred in 20.254623 secs (27334408 bytes/sec)
557+0 records in
557+0 records out
584056832 bytes transferred in 21.357118 secs (27347175 bytes/sec)
...(snip)...

Linux (GNU Coreutils)

続いて Linux を。

使った環境は次の通り。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.2 LTS"
$ uname -r
4.15.0-54-generic
$ dd --version
dd (coreutils) 8.28
Copyright (C) 2017 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>.
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.

Written by Paul Rubin, David MacKenzie, and Stuart Kemp.

先ほどと同じように適当なファイルを作らせておく。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024

GNU Coreutils の dd であれば SIGUSR1 を送る。

$ sudo killall -USR1 dd

次のように進捗が表示される。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024
95+0 records in
95+0 records out
99614720 bytes (100 MB, 95 MiB) copied, 1.4874 s, 67.0 MB/s

定期的に表示させたいときは、先ほどと同じように watch と組み合わせれば良い。

$ sudo apt-get -y install procps
$ watch -n 1 sudo killall -USR1 dd

次のように定期的に進捗が表示されるようになる。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024
...(snip)...
348+3 records in
347+3 records out
365930880 bytes (366 MB, 349 MiB) copied, 6.29171 s, 58.2 MB/s
408+3 records in
407+3 records out
428845440 bytes (429 MB, 409 MiB) copied, 7.32246 s, 58.6 MB/s
467+3 records in
467+3 records out
491760000 bytes (492 MB, 469 MiB) copied, 8.34007 s, 59.0 MB/s
...(snip)...

あるいは、もっと単純に status=progress というオプションを付けても良い。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024 status=progress
605028352 bytes (605 MB, 577 MiB) copied, 10 s, 60.4 MB/s

いじょう。

Python: pandas-should というパッケージを作ってみた

pandas を使ってデータ分析などをしていると、自分が意図した通りのデータになっているか、たまに確認することになると思う。 確認する方法としてはグラフにプロットしてみたり、あるいは assert 文を使って shape などを確認することが考えられる。

今回紹介する pandas-should は後者の「assert 文を使った内容の確認」を、なるべく簡単に分かりやすく記述するために作ってみた。

github.com

使った環境は次の通り。 なお、パッケージ自体は Python 3.6 以降で動作する。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V
Python 3.7.4

インストール

インストールは pip からできる。

$ pip install pandas-should

使い方

パッケージをインストールできたら、とりあえず Python のインタプリタを起動しておく。

$ python

あとは pandas_should をインポートするだけ。

>>> import pandas_should

インポートすると pandas の DataFrame と Series のインスタンスに should というアトリビュートがひそかに生えてくる。

>>> import pandas as pd
>>> df = pd.DataFrame([1, 2, 3], columns=['id'])
>>> df.should
<pandas_should.dataframe.ShouldDataFrameAccessor object at 0x1083d6160>
>>> s = pd.Series([1, 2, 3])
>>> s.should
<pandas_should.series.ShouldSeriesAccessor object at 0x1196a36a0>

この should 経由で色々とできて、例えば行数が一致することを確認したいなら have_length() を使う。

>>> df.should.have_length(3)
True

基本的にメソッドは真偽値を返すので、アサーションに使うならこうする。

>>> assert df.should.have_length(3)

ここからは使いそうな API を幾つか紹介していく。

DataFrame

まずは DataFrame から。

要素に Null (NaN or NaT) が含まれるか調べたい

普通に書くと、こんな感じになると思う。

>>> not df.isnull().any(axis=None)
True

pandas-should を使うと、こう書ける。

>>> df.should.have_not_null()
True

あるいは Null が含まれることを期待するのであれば、こう。

>>> df.should.have_null()
False

要素のレンジを調べたい

各要素が特定のレンジ (値の範囲) に収まっているか知りたいときは、こう書く。 値の範囲には、指定した最小値と最大値も含まれる。

>>> df.should.fall_within_range(1, 3)
True

下限だけ指定したいときは greater_than() を使う。

>>> df.should.greater_than(0)
True

greater_than() では指定した値は含まれないので、含みたいときは greater_than_or_equal() を使う。

>>> df.should.greater_than_or_equal(1)
True

長いのでエイリアスとして gt()gte() も使える。

>>> df.should.gt(1)
False
>>> df.should.gte(1)
True

上限についても同様。 こちらもエイリアスとして lt()lte() が使える。

>>> df.should.less_than(3)
False
>>> df.should.less_than_or_equal(3)
True

形状 (Shape) を調べたい

続いて DataFrame の形状を調べる方法について。

比較対象が必要なので新たに DataFrame を用意しておく。

>>> data1 = [
...     ('apple', 98, True),
...     ('banana', 128, True),
... ]
>>> df1 = pd.DataFrame(data1, columns=['name', 'price', 'fruit'])
>>> data2 = [
...     ('carrot', 198, False),
...     ('dates', 498, True),
... ]
>>> df2 = pd.DataFrame(data2, columns=['name', 'price', 'fruit'])

同じ行数や列数であることを確認したいときは have_same_length()have_same_width() を使う。

>>> df1.should.have_same_length(df2)
True
>>> df1.should.have_same_width(df2)
True

前述した通り、整数で指定したいときは have_width()have_length() が使える。

>>> df1.should.have_width(2)
True
>>> df1.should.have_length(2)
True

ちなみに have_same_*() は複数の DataFrame との比較もできる。

>>> data3 = [
...     ('eggplant', 128, False),
... ]
>>> df3 = pd.DataFrame(data3, columns=['name', 'price', 'fruit'])
>>> data4 = [
...     ('fig', 298, True),
... ]
>>> df4 = pd.DataFrame(data4, columns=['name', 'price', 'fruit'])

例えば二つの DataFrame の行数を加算したものと同じになるか調べたいときは以下のようにする。 具体的なユースケースとしては、結合前と結合後の DataFrame の行数が一致しているか調べるときとか。

>>> df1.should.have_same_length(df3, df4)
True

行数についても同様。

>>> df1.should.have_same_width(df3, df4)
True

行と列を別々に比較するのがめんどいときは be_shaped_like() で一気に比較できる。

>>> df1.should.be_shaped_like(df2)
True

このメソッドにはタプルとか整数も渡せる。

>>> df1.should.be_shaped_like(df2.shape)  # tuple
True
>>> df1.should.be_shaped_like(df2.shape[0], df2.shape[1])  # int, int
True

Series

続いては Series について。

要素に Null (NaN or NaT) が含まれるか調べたい

要素に Null が含まれるか調べたいときは DataFrame と同じやり方が使える。

>>> s.should.have_not_null()
True
>>> s.should.have_null()
False

要素のレンジを調べたい

Series に関しても DataFrame と同じように、要素のレンジ (値の範囲) を調べられる。 追加で説明することは特にないかな。

>>> s.should.fall_within_range(1, 3)
True
>>> s.should.gt(1)
False
>>> s.should.gte(1)
True
>>> s.should.lt(3)
False
>>> s.should.lte(3)
True

形状 (Shape) を調べたい

Series に関しては列数という概念がないけど、次のように行数に関しては DataFrame と同じやり方が使える。

>>> s2 = pd.Series([4, 5, 6])
>>> s.should.have_same_length(s2)
True
>>> s.should.have_length(3)
True

そんなかんじで。 こういう API があると便利で欲しいみたいなのがあれば教えてほしい。

Python: Kivy で Matplotlib のグラフをプロットする

Kivy は最近人気のある Python のクロスプラットフォームな GUI のフレームワーク。 今回はそんな Kivy で作った GUI 上に Matplotlib のグラフをプロットしてみる。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V                    
Python 3.7.3

下準備

まずは Kivy と Matplotlib をインストールしておく。

$ pip install kivy matplotlib numpy

続いて Kivy Garden を使って Matplotlib 用のプラグイン (garden.matplotlib) をインストールする。

$ garden install matplotlib

これで Kivy で Matplotlib を使う準備ができた。

Kivy で Matplotlib のグラフをプロットする

以下の Kivy で Matplotlib のグラフをプロットするサンプルコードを示す。 garden.matplotlib を使うと Figure#canvas のインスタンスをウィジェットとして追加できるようになる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.uix.label import Label
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
# Kivy 上で Matplotlib を使うために必要な準備
matplotlib.use('module://kivy.garden.matplotlib.backend_kivy')


class GraphApp(App):
    """Matplotlib のグラフを表示するアプリケーション"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Matplotlib graph on Kivy'

    def build(self):
        # メインの画面
        main_screen = BoxLayout()
        main_screen.orientation = 'vertical'

        # 上部にラベルを追加しておく
        label_text = 'The following is a graph of Matplotlib'
        label = Label(text=label_text)
        label.size_hint_y = 0.2
        main_screen.add_widget(label)

        # サイン波のデータを用意する
        x = np.linspace(-np.pi, np.pi, 100)
        y = np.sin(x)
        # 描画する領域を用意する
        fig, ax = plt.subplots()
        # プロットする
        ax.plot(x, y)
        # Figure#canvas をウィジェットとして追加する
        main_screen.add_widget(fig.canvas)

        return main_screen


def main():
    # アプリケーションを開始する
    app = GraphApp()
    app.run()


if __name__ == '__main__':
    main()

上記を実行する。

$ python kvplot.py

すると、次のような結果が得られる。

f:id:momijiame:20190710212812p:plain

ちゃんと描画できてるね。

FigureCanvasKivyAgg を使う場合

別のやり方として Figure オブジェクトを FigureCanvasKivyAgg でラップするやり方もある。

サンプルコードは次の通り。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg
from kivy.uix.label import Label
import numpy as np
import matplotlib.pyplot as plt


class GraphApp(App):
    """Matplotlib のグラフを表示するアプリケーション"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Matplotlib graph on Kivy'

    def build(self):
        main_screen = BoxLayout()
        main_screen.orientation = 'vertical'

        label_text = 'The following is a graph of Matplotlib'
        label = Label(text=label_text)
        label.size_hint_y = 0.2
        main_screen.add_widget(label)

        x = np.linspace(-np.pi, np.pi, 100)
        y = np.sin(x)
        fig, ax = plt.subplots()
        ax.plot(x, y)

        # Figure をラップする
        widget = FigureCanvasKivyAgg(fig)
        # ウィジェットとして追加する
        main_screen.add_widget(widget)

        return main_screen


def main():
    app = GraphApp()
    app.run()


if __name__ == '__main__':
    main()

表示される内容は変わらない。

ウィジェットのクラスとして定義する

さらに、ウィジェットをクラスとして定義した上で、その中にグラフを埋め込む場合には、次のようにする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.uix.label import Label
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg
import numpy as np
import matplotlib.pyplot as plt


class GraphView(BoxLayout):
    """Matplotlib のグラフを表示するためのウィジェット"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        x = np.linspace(-np.pi, np.pi, 100)
        y = np.sin(x)
        fig, ax = plt.subplots()
        ax.plot(x, y)

        widget = FigureCanvasKivyAgg(fig)
        self.add_widget(widget)


class GraphApp(App):
    """Matplotlib のグラフを表示するアプリケーション"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Matplotlib graph on Kivy'

    def build(self):
        main_screen = BoxLayout()
        main_screen.orientation = 'vertical'

        label_text = 'The following is a graph of Matplotlib'
        label = Label(text=label_text)
        label.size_hint_y = 0.2
        main_screen.add_widget(label)

        # ウィジェットを生成して追加する
        graph = GraphView()
        main_screen.add_widget(graph)

        return main_screen


def main():
    # アプリケーションを開始する
    app = GraphApp()
    app.run()


if __name__ == '__main__':
    main()

こちらも表示される内容は変わらない。

KV 言語を使う

Kivy は KV 言語という DSL でレイアウトを制御できる。 もし、KV 言語も併用したいという場合であれば次のようにする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.uix.label import Label
from kivy.lang import Builder
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg
import numpy as np
import matplotlib.pyplot as plt


kv_def = '''
<RootWidget>:
    orientation: 'vertical'

    Label:
        text: 'The following is a graph of Matplotlib'
        size_hint_y: 0.2

    GraphView:

<GraphView>:
'''
Builder.load_string(kv_def)


class GraphView(BoxLayout):
    """Matplotlib のグラフを表示するためのウィジェット"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        x = np.linspace(-np.pi, np.pi, 100)
        y = np.sin(x)
        fig, ax = plt.subplots()
        ax.plot(x, y)

        widget = FigureCanvasKivyAgg(fig)
        self.add_widget(widget)


class RootWidget(BoxLayout):
    """子を追加していくためのウィジェットを用意しておく"""


class GraphApp(App):
    """Matplotlib のグラフを表示するアプリケーション"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Matplotlib graph on Kivy'

    def build(self):
        return RootWidget()


def main():
    # アプリケーションを開始する
    app = GraphApp()
    app.run()


if __name__ == '__main__':
    main()

こちらも表示される内容は変わらない。

グラフで表示される内容を動的に更新したい

グラフに表示される内容を動的に更新したい場合のサンプルも以下に示す。 基本的には普通に Matplotlib を使って動的なグラフを描くのと変わらない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.clock import Clock
from kivy.lang import Builder
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg
import numpy as np
import matplotlib.pyplot as plt


kv_def = '''
<RootWidget>:
    orientation: 'vertical'

    Label:
        text: 'The following is a graph of Matplotlib'
        size_hint_y: 0.2

    GraphView:

<GraphView>:
'''
Builder.load_string(kv_def)


class GraphView(BoxLayout):
    """Matplotlib のグラフを表示するためのウィジェット"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # 初期化に用いるデータ
        x = np.linspace(-np.pi, np.pi, 100)
        y = np.sin(x)
        # 描画状態を保存するためのカウンタ
        self.counter = 0

        # Figure, Axis を保存しておく
        self.fig, self.ax = plt.subplots()
        # 最初に描画したときの Line も保存しておく
        self.line, = self.ax.plot(x, y)

        # ウィジェットとしてグラフを追加する
        widget = FigureCanvasKivyAgg(self.fig)
        self.add_widget(widget)

        # 1 秒ごとに表示を更新するタイマーを仕掛ける
        Clock.schedule_interval(self.update_view, 0.01)

    def update_view(self, *args, **kwargs):
        # データを更新する
        self.counter += np.pi / 100  # 10 分の pi ずらす
        # ずらした値を使ってデータを作る
        x = np.linspace(-np.pi + self.counter,
                        np.pi + self.counter,
                        100)
        y = np.sin(x)
        # Line にデータを設定する
        self.line.set_data(x, y)
        # グラフの見栄えを調整する
        self.ax.relim()
        self.ax.autoscale_view()
        # 再描画する
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()


class RootWidget(BoxLayout):
    """子を追加していくためのウィジェットを用意しておく"""


class GraphApp(App):
    """Matplotlib のグラフを表示するアプリケーション"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Matplotlib graph on Kivy'

    def build(self):
        return RootWidget()


def main():
    # アプリケーションを開始する
    app = GraphApp()
    app.run()


if __name__ == '__main__':
    main()

実行すると、こんな表示が得られる。 うにょうにょ。

f:id:momijiame:20190714123106g:plain

いじょう。

gRPC の通信を Wireshark でキャプチャしてみる

今回は、最近よく使われている gRPC の通信を Wireshark でキャプチャしてみる。 ちなみに、現行の Wireshark だと gRPC をちゃんと解釈できるみたい。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V                    
Python 3.7.3
$ wireshark --version | head -n 1
Wireshark 3.0.2 (v3.0.2-0-g621ed351d5c9)

Python で gRPC サーバ・クライアントを書く

通信をキャプチャするためには、まず gRPC のサーバとクライアントを用意する必要がある。

まずは gRPC を使う上で必要なツールキットをインストールしておく。

$ pip install grpcio-tools

gRPC では様々なプログラミング言語を用いてサーバとクライアントを記述できる。 そのために、どの言語を使う場合にも共通のスキーマを定義した上で、それを各言語用にコンパイルする。 共通のスキーマを定義するには Protocol Buffers というフォーマットを用いる。

以下は Protocol Buffers のスキーマを定義するファイルとなっている。 この中では HelloWorld というサービス上で greet() という RPC が定義されている。 greet() でやり取りするのは HelloRequestHelloReply というメッセージ。

syntax = "proto3";

service HelloWorld {
  rpc greet (HelloRequest) returns (HelloReply) {}
}

message HelloRequest {
  string name = 1;
}

message HelloReply {
  string message = 1;
}

上記を Python 用にコンパイルする。

$ python -m grpc_tools.protoc \
  --proto_path=. \
  --grpc_python_out=. \
  --python_out=. \
  helloworld.proto

すると、次のように pb2 という名前を含むファイルが二つ生成される。

$ ls | grep pb2
helloworld_pb2.py
helloworld_pb2_grpc.py

これらは、先ほどのスキーマ定義から生成された Python のモジュールになっている。

$ file helloworld_pb2.py 
helloworld_pb2.py: Python script text executable, ASCII text
$ file helloworld_pb2_grpc.py 
helloworld_pb2_grpc.py: Python script text executable, ASCII text

生成されたモジュールを使って gRPC のサーバを書いてみよう。 Protocol Buffers の定義にはインターフェースしか記述されていない。 そのため、内部で何をやるか実装を書いてやる必要がある。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from concurrent import futures
import time

import grpc

import helloworld_pb2
import helloworld_pb2_grpc


class HelloWorldService(helloworld_pb2_grpc.HelloWorldServicer):
    """サービスを定義する"""

    def greet(self, request, context):
        """RPC の中身"""
        # 受け取った内容を使って返信するメッセージを組み立てる
        message = 'Hello, {name}'.format(name=request.name)
        reply = helloworld_pb2.HelloReply(message=message)
        return reply


def main():
    # gRPC のサーバを用意する
    executor = futures.ThreadPoolExecutor(max_workers=10)
    server = grpc.server(executor)
    service = HelloWorldService()
    helloworld_pb2_grpc.add_HelloWorldServicer_to_server(service, server)

    # サーバを 37564 ポートで動作させる
    server.add_insecure_port('localhost:37564')
    server.start()

    try:
        while True:
            time.sleep(1)
    finally:
        server.stop(0)


if __name__ == '__main__':
    main()

サーバを起動してみよう。

$ python server.py

別のターミナルを開いてポートの状態を確認してみよう。 TCP で localhost:37564 を Listen していれば上手くいっている。

$ lsof -i | grep 37564
python3.7 11805 amedama    5u  IPv6 0xbd776a50dc8d08b1      0t0  TCP localhost:37564 (LISTEN)
python3.7 11805 amedama    6u  IPv6 0xbd776a50dc8d4231      0t0  TCP localhost:37564 (LISTEN)

続いてクライアントを記述しよう。 サーバが稼働している localhost:37564 ポートに接続させる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import grpc

import helloworld_pb2
import helloworld_pb2_grpc


def main():
    # 'localhost:37564' に接続する
    with grpc.insecure_channel('localhost:37564') as channel:
        # メッセージと共にリモートプロシジャを呼び出す
        stub = helloworld_pb2_grpc.HelloWorldStub(channel)
        reply = stub.greet(helloworld_pb2.HelloRequest(name='World'))
        # 返ってきた内容を表示する
        print('Reply:', reply.message)


if __name__ == '__main__':
    main()

上記を実行してみよう。 次のようにメッセージが表示されれば上手くいった。

$ python client.py 
Reply: Hello, World

通信を Wireshark でキャプチャする

さて、これだけだとあまり面白くないので、続いては上記の通信をキャプチャしてみよう。

パケットキャプチャをするために Wireshark をインストールする。

$ brew cask install wireshark

インストールできたら Wireshark を起動する。

$ wireshark

起動したら Loopback インターフェースのキャプチャを開始する。 また、ディスプレイフィルタのバーに tcp.port == 37564 を指定する。 これで余計な内容が表示されなくて済む。

f:id:momijiame:20190630152902p:plain

準備ができたら、先ほどの gRPC サーバを起動する。

$ python server.py

そして、gRPC クライアントを実行する。

$ python client.py

すると、次のように TCP の通信内容がキャプチャされる。

f:id:momijiame:20190630153107p:plain

適当にフレームを選択して「Follow > TCP Stream」すると一連の TCP の通信内容が確認できる。

f:id:momijiame:20190630153245p:plain

HTTP/2.0 という文字列から読み取れる通り、gRPC は通信部分のレイヤーに HTTP2 を採用している。

生の TCP だと読みにくいので、通信を HTTP/2 として解釈させてみよう。 先ほどと同じようにフレームを選択したら「Decode As...」を選択する。

f:id:momijiame:20190630153445p:plain

上記のような画面が開いたら右端のカラムを「HTTP2」にする。 これで、通信プロトコルが HTTP/2 として解釈される。

HTTP/2 において HTTP のリクエストとレスポンスは HeadersData のメッセージでやり取りされる。 以下のように、まずリクエストが出ている。 メッセージの内容に World が含まれることが分かる。 また、呼び出す対象のリモートプロシジャは HTTP のパス部分に格納されるようだ。

f:id:momijiame:20190630153533p:plain

同様に、上記に対するレスポンスが以下になる。 メッセージに Hello, World という内容が含まれることが分かる。

f:id:momijiame:20190630153821p:plain

なかなか分かりやすいね。

Python: py4j で Java の API を Python から使う

今回は py4j を使って Java の API を Python から利用してみる。

py4j のアーキテクチャはサーバ・クライアントモデルになっている。 つまり、まず Java の API を Python から叩けるように、Java でゲートウェイサーバとなるプログラムを書く。 そして、Python からはネットワーク経由でそのゲートウェイサーバにアクセスする。 これは、RPC (Remote Procedure Call) の考え方に近い。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V
Python 3.7.3
$ java -version 
openjdk version "12.0.1" 2019-04-16
OpenJDK Runtime Environment (build 12.0.1+12)
OpenJDK 64-Bit Server VM (build 12.0.1+12, mixed mode, sharing)

下準備

下準備として Java (JDK) と py4j をインストールしておく。

$ brew cask install java
$ pip install py4j

ゲートウェイサーバを記述する

最初に、Python から利用したい Java の API に対してゲートウェイサーバを書く。

以下のサンプルコードでは HelloWorld クラスの API についてゲートウェイサーバを用意している。 基本的にはクラスのインスタンスを GatewayServer クラスに渡してやるだけ。 提供されるのは割り算の機能を持った div() メソッドになる。

import py4j.GatewayServer;

public class HelloWorld {

  public int div(int a, int b) {
    // 割り算の機能を提供するメソッド
    return a / b;
  }

  public static void main(String[] args) {
    // GatewayServer 経由で機能を提供する
    HelloWorld application = new HelloWorld();
    GatewayServer gateway = new GatewayServer(application);
    gateway.start();
    System.out.println("Starting server...");
  }
}

続いて、上記を Java バイトコードにコンパイルする。 ただし、それには py4j の jar ファイルが必要になる。

jar ファイルは py4j のインストール先にある。 なので、まずは py4j のインストールされている場所を確認する。

$ python -c "import py4j; print(py4j.__path__)"
['/Users/amedama/.virtualenvs/py37/lib/python3.7/site-packages/py4j']

次のように share ディレクトリ以下に jar ファイルがあった。

$ ls ~/.virtualenvs/py37/share/py4j
py4j0.10.8.1.jar

この jar ファイルにクラスパスを通しながら、先ほどのプログラムをコンパイルする。

$ javac -cp ~/.virtualenvs/py37/share/py4j/py4j0.10.8.1.jar HelloWorld.java

次のように class ファイルが完成すれば上手くいっている。

$ file HelloWorld.class 
HelloWorld.class: compiled Java class data, version 56.0

コンパイルできたら、ゲートウェイサーバのプログラムを起動する。 この際にも、jar ファイルにクラスパスを通す必要がある。

$ java -cp ~/.virtualenvs/py37/share/py4j/py4j0.10.8.1.jar:. HelloWorld
Starting server...

これで、デフォルトでは TCP:25333 で py4j のサービスが起動する。

$ lsof -i | grep 25333
java      10364 amedama    6u  IPv6 0xbd776a50dc8d3c71      0t0  TCP localhost:25333 (LISTEN)

Python から利用する

これで準備ができたらので、Python から利用してみよう。

まずは Python のインタプリタを起動する。

$ python

py4j 経由で Java の API を呼び出すために JavaGateway クラスのインスタンスを用意する。

>>> from py4j.java_gateway import JavaGateway
>>> java_gateway = JavaGateway()
>>> java_app = java_gateway.entry_point

あとは Java の API を呼び出すだけ。

>>> java_app.div(20, 10)
2

ちゃんと割り算ができている。

試しに Java のプログラム上で例外を発生させてみよう。 すると、次のように py4j.protocol.Py4JJavaError 例外となる。 例外の中には、Java 上で発生した例外の情報が入っている。

>>> java_app.div(1, 0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/amedama/.virtualenvs/py37/lib/python3.7/site-packages/py4j/java_gateway.py", line 1286, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/Users/amedama/.virtualenvs/py37/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: An error occurred while calling t.div.
: java.lang.ArithmeticException: / by zero
    at HelloWorld.div(HelloWorld.java:7)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:567)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.base/java.lang.Thread.run(Thread.java:835)

いじょう。