CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: __getattr__() のあるオブジェクトを直列化しようとしてハマった話

今回は特殊メソッドの __getattr__() があるオブジェクトを pickle で直列化・非直列化 (SerDe) しようとしたらハマった話について。

まず、特殊メソッドの __getattr__() をクラスに実装してあると、そのインスタンスは未定義のアトリビュートにアクセスが生じたとき呼び出しがトラップされる。 そして、この __getattr__() を実装したクラスのインスタンスを pickle で SerDe しようとしたとき思わぬ挙動となった。 結論から先に述べると __getattr__() を実装してあると __getstate__()__setstate__() の呼び出しまでトラップされてしまう。 これらのメソッドは SerDe の振る舞いをオーバーライドするための特殊メソッドとなっている。 この問題の対策としては __getattr__() のある SerDe が必要なクラスには __getstate__()__setstate__() を実装しておくことが考えられる。

なお、pickle を使ったオブジェクトの SerDe の概要については、以下のエントリを参照のこと。

blog.amedama.jp

使った環境は次の通り。

$ sw_vers           
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V
Python 3.7.4

特殊メソッド __getattr__() がないときの振る舞いについて

まずは __getattr__() を実装していないクラスを直列化・非直列化 (SerDe) してみる。 以下のサンプルコードでは Example というクラスのインスタンスをバイト列にしてから元のオブジェクトに戻している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Example(object):
    """SerDe されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    # Example クラスをインスタンス化する
    o = Example('World')
    # メソッドを呼び出す
    o.greet()

    # オブジェクトをバイト列にシリアライズする
    # このときオブジェクトに __getstate__() があれば呼ばれる
    # このサンプルコードにはないためデフォルトの振る舞いになる
    s = pickle.dumps(o)

    # バイト列からオブジェクトをデシリアライズする
    # このときオブジェクトに __setstate__() があれば呼ばれる
    # このサンプルコードにはないためデフォルトの振る舞いになる
    restored_o = pickle.loads(s)

    # 復元したオブジェクトのメソッドを呼び出す
    restored_o.greet()


if __name__ == '__main__':
    main()

上記を実行した結果が次の通り。 ちゃんとオブジェクトをバイト列にして、また元のオブジェクトに戻せていることがわかる。

$ python serde1.py     
Hello, World!
Hello, World!

特殊メソッド __getattr__() があるときの振る舞いについて

続いては __getattr__() のあるオブジェクトを SerDe してみる。 ただ、先ほどの Example クラスに直接 __getattr__() を追加するのはユースケースとして考えにくいので、ちょっとアレンジを加えてある。 Example クラスはそのままに、そのラッパーとして動作する Wrapper クラスを用意して、そこに __getattr__() メソッドを実装した。 こういったプロキシのようなクラスは、プロキシする先のオブジェクトの呼び出しを中継するために __getattr__() を使うことが多い。 このような状況で Wrapper クラスのインスタンスを SerDe すると上手くいかない、というのが今回の本題となる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            # ラップするオブジェクトのアトリビュートを取得して呼び出す
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    # Wrapper でオブジェクトをラップする
    w = Wrapper(o)
    # ラッパー経由でメソッドを呼び出す
    w.greet()

    # XXX: __getstate__() が __getattr__() 経由で呼ばれようとする
    s = pickle.dumps(w)

    # XXX: __setstate__() が __getattr__() 経由で呼ばれようとする
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、次のように直列化するタイミングでエラーになる。 見ると __getstate__()Example のオブジェクトにない、という内容のようだ。

$ python serde2.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: __getstate__
Traceback (most recent call last):
  File "serde2.py", line 50, in <module>
    main()
  File "serde2.py", line 41, in main
    s = pickle.dumps(w)
  File "serde2.py", line 17, in _wrapper
    func = getattr(self.wrap_target, item)
AttributeError: 'Example' object has no attribute '__getstate__'

Example クラスに __*state__() を実装すれば解決...しない

では、エラーメッセージに習って Example クラスに __getstate__()__setstate__() を実装すれば解決するだろうか? 試してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))

    def __getstate__(self):
        """__getstate__() を明示的に定義する"""
        return self.__dict__.copy()

    def __setstate__(self, state):
        """__setstate__() を明示的に定義する"""
        self.__dict__.update(state)


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # XXX: __getstate__() が __getattr__() 経由で呼ばれようとする
    s = pickle.dumps(w)

    # XXX: __setstate__() が __getattr__() 経由で呼ばれようとする
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

残念ながら、今度は以下のようなエラーになる。 そもそも SerDe したいのは Wrapper クラスのインスタンスなので Example クラスに実装しても解決できない。

$ python serde3.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: __getstate__
trapped undefined access: __setstate__
Traceback (most recent call last):
  File "serde3.py", line 56, in <module>
    main()
  File "serde3.py", line 50, in main
    restored_w = pickle.loads(s)
  File "serde3.py", line 17, in _wrapper
    func = getattr(self.wrap_target, item)
AttributeError: 'function' object has no attribute '__setstate__'

このときのエラーメッセージがまた分かりにくくて、どうして function オブジェクトでエラーになるんだ、となる。

Wrapper クラスに __*state__() を実装してみる

ということで、今度は Wrapper クラスの方に __getstate__()__setstate__() を実装してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper

    def __getstate__(self):
        """__getstate__() を明示的に定義する"""
        return self.__dict__.copy()

    def __setstate__(self, state):
        """__setstate__() を明示的に定義する"""
        self.__dict__.update(state)


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # __getstate__() が明示的に定義されているため __getattr__() は呼ばれない
    s = pickle.dumps(w)

    # __setstate__() が明示的に定義されているため __getattr__() は呼ばれない
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

今度は次のようにエラーにならず SerDe できた。 Wrapper クラスに __getstate__()__setstate__() が定義されているため、呼び出しが __getattr__() にトラップされることがない。

$ python serde4.py
trapped undefined access: greet
Hello, World!
trapped undefined access: greet
Hello, World!

サードパーティーのライブラリで問題が発生しているとき

先ほどのようにクラスにメソッドを定義して救えるのは、自分で定義したクラスで問題が発生している場合に限られる。 もし、サードパーティ製のライブラリで同様の問題が生じた場合には、どのような解決策があるだろうか。 幸いなことに Python は既存のクラスにも動的にメソッドを追加できる。

以下のサンプルコードでは SerDe する直前で対象のクラスに __getstate__()__setstate__() を動的に追加している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # オブジェクトに __getstate__() を動的に追加する
    def __getstate__(self):
        return self.__dict__.copy()
    setattr(Wrapper, '__getstate__', __getstate__)

    # オブジェクトに __setstate__() を動的に追加する
    def __setstate__(self, state):
        self.__dict__.update(state)
    setattr(Wrapper, '__setstate__', __setstate__)

    s = pickle.dumps(w)

    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

上記を実行してみよう。 ちゃんと SerDe できていることがわかる。

$ python serde5.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: greet
Hello, World!

macOS で CH34x のシリアルコンソールを使う

Arduino などで使われていることがある CH34x のチップを macOS から使う方法について。

基本的には以下のリポジトリに詳細が載っている。

github.com

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84

インストール

もし過去に古いドライバを手動でインストールしたことがあるときは下部に記載したアンインストールを先に実行する。

Homebrew Cask を使ってドライバをインストールする。

$ brew cask reinstall wch-ch34x-usb-serial-driver 

マシンを再起動するか、あるいは以下のコマンドを実行してカーネルモジュールを読み込む。

$ sudo kextload /Library/Extensions/usbserial.kext

これで tty.wchusbserial から始まるデバイスファイルが見えるようになるはず。

$ ls /dev/tty.wchusbserial*
tty.wchusbserial141120

あとは一般的なシリアルデバイスとして screen なり pyserial などから使えば良い。

$ screen /dev/tty.wchusbserial141120 9600

手動で古いドライバを削除する

過去に古いドライバを手動でインストールしたことがあるときは、以下の手順にもとづいてアンインストールする。

まず、カーネルモジュールをアンロードする。

$ sudo kextunload /Library/Extensions/usbserial.kext
$ sudo kextunload /System/Library/Extensions/usb.kext

そして、カーネルモジュールのファイルを削除する。

$ sudo rm -rf /System/Library/Extensions/usb.kext
$ sudo rm -rf /Library/Extensions/usbserial.kext

いじょう。

Python: インポートするだけで Kivy が日本語を表示できるようになる japanize-kivy を作った

Python の GUI フレームワークである Kivy は、そのままだと日本語が表示できない。 そこで、インポートするだけで日本語を表示できるようにするパッケージ japanize-kivy を作った。

github.com

知っている人はピンと来るはずだけど名前や思想は以下のパッケージをインスパイアしている。

github.com

使った環境は次の通り。 パッケージがサポートする Python は 3.6 以上を想定している。

$ sw_vers  
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V
Python 3.7.4

もくじ

インストール

pip からインストールできる。

$ pip install japanize-kivy

試す

Python のインタプリタを起動する。

$ python

japanize_kivy パッケージをインポートする。

>>> import japanize_kivy

あとは日本語を含む Kivy のアプリケーションを用意する。

>>> from kivy.app import App
>>> from kivy.uix.boxlayout import BoxLayout
>>> from kivy.uix.label import Label
>>> class GreetingApp(App):
...     def build(self):
...         main_screen = BoxLayout()
...         label = Label(text='こんにちは、世界')
...         main_screen.add_widget(label)
...         return main_screen
... 
>>> GreetingApp().run()

以下のように日本語が表示できるようになる。

f:id:momijiame:20190730183416p:plain

インポートしないと、次のように日本語が豆腐になる。

f:id:momijiame:20190730183451p:plain

フォントのライセンスに関して

日本語を表示するためのフォントは IPAex ゴシックフォントを使わせてもらっている。 そのため、本パッケージを利用する上ではライセンスへの同意が必要となる。

次のようにするとライセンスが表示されるので、IPA への感謝と共に同意してほしい。

>>> japanize_kivy.show_license()

いじょう。

Python: Kivy と Matplotlib でデータセットの確認ツールを書いてみる

以前、このブログで Kivy で作った GUI に Matplotlib のグラフを埋め込む方法について書いた。

blog.amedama.jp

今回は、これを応用したツール作りをしてみる。 といっても、やっていることは単純で先の例にボタンを付けてインタラクティブにした程度にすぎない。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V          
Python 3.7.4

下準備

下準備として必要なパッケージをインストールしておく。

$ pip install kivy matplotlib scikit-learn

Digit データセットの内容を表示してみる

今回書いてみたサンプルコードが次の通り。 内容としては scikit-learn に同梱されている Digit データセットの内容を表示させてみることにした。 ボタンを使って表示するデータを前後に進めたり戻したりできる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn import datasets
from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.lang import Builder
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg


kv_def = '''
<RootWidget>:
    orientation: 'vertical'

    GraphView:
        id: graph_view
        size_hint_y: 0.8

    BoxLayout:
        size_hint_y: 0.2

        Button:
            id: prev_button
            text: '< Prev'
            on_press: root.ids.graph_view.prev()

        Button:
            id: next_button
            text: 'Next >'
            on_press: root.ids.graph_view.next()

<GraphView>:
'''
Builder.load_string(kv_def)


class GraphView(BoxLayout):
    """Matplotlib のグラフを表示するウィジェット"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # データセットを読み込んでおく
        self.dataset = datasets.load_digits()
        # 表示するデータのインデックス
        self.cursor = 0

        # 描画領域を用意する
        self.fig, self.ax = plt.subplots()

        # 描画を初期化する
        self._update_view()

        # グラフをウィジェットとして追加する
        widget = FigureCanvasKivyAgg(self.fig)
        self.add_widget(widget)

    def _update_view(self):
        """描画を更新するメソッド"""
        # 以前の内容を消去する
        self.ax.clear()
        self.ax.axis('off')

        # データを取得する
        img_data = self.dataset.data[self.cursor]
        label = self.dataset.target[self.cursor]

        # データを描画する
        self.ax.imshow(img_data.reshape(8, 8),
                       cmap=cm.gray_r,
                       interpolation='nearest')
        title_msg = 'index={idx}, label={label}'.format(idx=self.cursor,
                                                        label=label)
        self.ax.set_title(title_msg, color='red')

        # 再描画する
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

    def next(self):
        """次へボタンを押したときのコールバック"""
        if self.cursor < len(self.dataset.data) - 1:
            self.cursor += 1
        self._update_view()

    def prev(self):
        """戻るボタンを押したときのコールバック"""
        if self.cursor > 0:
            self.cursor -= 1
        self._update_view()


class RootWidget(BoxLayout):
    pass


class ViewerApp(App):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Digit dataset viewer'

    def build(self):
        root_widget = RootWidget()
        return root_widget


def main():
    # アプリケーションを開始する
    app = ViewerApp()
    # ここでスレッドがブロックする
    app.run()


if __name__ == '__main__':
    main()

上記を実行してみる。

$ python digitviewer.py

すると、次のような GUI が表示される。

f:id:momijiame:20190725060813g:plain

応用すればアノテーションに使うツールなんかも作れるだろうね。 いじょう。

Python: LightGBM の学習率を動的に制御する

LightGBM の学習率は基本的に低い方が最終的に得られるモデルの汎化性能が高くなることが経験則として知られている。 しかしながら、学習率が低いとモデルの学習に多くのラウンド数、つまり計算量を必要とする。 そこで、今回は学習率を学習の過程において動的に制御するコールバックを実装してみた。

きっかけは以下のツイートを見たこと。

なるほど面白そう。

下準備

使用するライブラリをあらかじめインストールしておく。

$ pip install lightgbm seaborn scikit-learn

学習率を動的に制御するコールバック

早速だけど、以下が学習率を動的に制御するコールバックを実装したサンプルコードとなる。 コールバックの本体は LrSchedulingCallback というクラスで実装している。 このクラスをインスタンス化するときに、制御方法を記述した関数を渡す。 以下であれば sample_scheduler_func() という名前で定義した。 この関数は学習の履歴などを元に新たな学習率を決めて返すインターフェースとなっている。 今回はお試しとして 10 ラウンドごとに学習率を下限の 0.01 まで半減させ続けるという単純な戦略を記述してみた。 もちろん、これがベストというわけではなくて、あくまでサンプルとして簡単なものを書いてみたに過ぎない。 なお、EarlyStopping していないのは学習の過程を最後まで観察するため。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import lightgbm as lgb
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.model_selection import StratifiedKFold


def sample_scheduler_func(current_lr, eval_history, best_round, is_higher_better):
    """次のラウンドで用いる学習率を決定するための関数 (この中身を好きに改造する)

    :param current_lr: 現在の学習率 (指定されていない場合の初期値は None)
    :param eval_history: 検証用データに対する評価指標の履歴
    :param best_round: 現状で最も評価指標の良かったラウンド数
    :param is_higher_better: 高い方が性能指標として優れているか否か
    :return: 次のラウンドで用いる学習率

    NOTE: 学習を打ち切りたいときには callback.EarlyStopException を上げる
    """
    # 学習率が設定されていない場合のデフォルト
    current_lr = current_lr or 0.05

    # 試しに 20 ラウンド毎に学習率を半分にしてみる
    if len(eval_history) % 20 == 0:
        current_lr /= 2

    # 小さすぎるとほとんど学習が進まないので下限も用意する
    min_threshold = 0.001
    current_lr = max(min_threshold, current_lr)

    return current_lr


class LrSchedulingCallback(object):
    """ラウンドごとの学習率を動的に制御するためのコールバック"""

    def __init__(self, strategy_func):
        # 学習率を決定するための関数
        self.scheduler_func = strategy_func
        # 検証用データに対する評価指標の履歴
        self.eval_metric_history = []

    def __call__(self, env):
        # 現在の学習率を取得する
        current_lr = env.params.get('learning_rate')

        # 検証用データに対する評価結果を取り出す (先頭の評価指標)
        first_eval_result = env.evaluation_result_list[0]
        # スコア
        metric_score = first_eval_result[2]
        # 評価指標は大きい方が優れているか否か
        is_higher_better = first_eval_result[3]

        # 評価指標の履歴を更新する
        self.eval_metric_history.append(metric_score)
        # 現状で最も優れたラウンド数を計算する
        best_round_find_func = np.argmax if is_higher_better else np.argmin
        best_round = best_round_find_func(self.eval_metric_history)

        # 新しい学習率を計算する
        new_lr = self.scheduler_func(current_lr=current_lr,
                                     eval_history=self.eval_metric_history,
                                     best_round=best_round,
                                     is_higher_better=is_higher_better)

        # 次のラウンドで使う学習率を更新する
        update_params = {
            'learning_rate': new_lr,
        }
        env.model.reset_parameter(update_params)
        env.params.update(update_params)

    @property
    def before_iteration(self):
        # コールバックは各イテレーションの後に実行する
        return False


def main():
    # Titanic データセットを読み込む
    dataset = sns.load_dataset('titanic')

    # 重複など不要な特徴量は落とす
    X = dataset.drop(['survived',
                      'class',
                      'who',
                      'embark_town',
                      'alive'], axis=1)
    y = dataset.survived

    # カテゴリカル変数を指定する
    categorical_columns = ['pclass',
                           'sex',
                           'embarked',
                           'adult_male',
                           'deck',
                           'alone']
    X = X.astype({c: 'category'
                  for c in categorical_columns})

    # LightGBM のデータセット表現に直す
    lgb_train = lgb.Dataset(X, y)

    # コールバックを用意する
    lr_scheduler_cb = LrSchedulingCallback(strategy_func=sample_scheduler_func)
    callbacks = [
        lr_scheduler_cb,
    ]

    # 二値分類を LogLoss で評価する
    lgb_params = {
        'objective': 'binary',
        'metrics': 'binary_logloss',
        'min_data_in_leaf': 10,
    }
    # 5-Fold CV
    skf = StratifiedKFold(n_splits=5,
                          shuffle=True,
                          random_state=42)

    # 動的に学習率を制御した場合
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        callbacks=callbacks,
                        )
    dynamic_lr = cv_results['binary_logloss-mean']

    # 学習率を 0.1 に固定した場合
    lgb_params.update({'learning_rate': 0.1})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_1 = cv_results['binary_logloss-mean']

    # 学習率を 0.05 に固定した場合
    lgb_params.update({'learning_rate': 0.05})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_05 = cv_results['binary_logloss-mean']

    # 学習率を 0.01 に固定した場合
    lgb_params.update({'learning_rate': 0.01})
    cv_results = lgb.cv(lgb_params, lgb_train,
                        num_boost_round=500,
                        verbose_eval=1,
                        folds=skf, seed=42,
                        )
    static_lr_0_01 = cv_results['binary_logloss-mean']

    # 最小の損失を比較する
    print('min loss value (lr=dynamic):', min(dynamic_lr))
    print('min loss value (lr=0.1):', min(static_lr_0_1))
    print('min loss value (lr=0.05):', min(static_lr_0_05))
    print('min loss value (lr=0.01):', min(static_lr_0_01))

    # 最小の損失が得られたラウンド数を比較する
    print('min loss round (lr=dynamic):', np.argmin(dynamic_lr))
    print('min loss round (lr=0.1):', np.argmin(static_lr_0_1))
    print('min loss round (lr=0.05):', np.argmin(static_lr_0_05))
    print('min loss round (lr=0.01):', np.argmin(static_lr_0_01))

    # グラフにプロットする
    sns.lineplot(np.arange(len(dynamic_lr)),
                 dynamic_lr,
                 label='LR=dynamic')
    sns.lineplot(np.arange(len(static_lr_0_1)),
                 static_lr_0_1,
                 label='LR=0.1')
    sns.lineplot(np.arange(len(static_lr_0_05)),
                 static_lr_0_05,
                 label='LR=0.05')
    sns.lineplot(np.arange(len(static_lr_0_01)),
                 static_lr_0_01,
                 label='LR=0.01')
    plt.title('learning rate control comparison')
    plt.xlabel('rounds')
    plt.ylabel('logloss')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみる。 最も性能が良かったモデルの損失とラウンド数が表示される。 最適なモデルの性能、学習に要するラウンド数ともに 0.1 固定と 0.01 固定の間にあることが分かる。

$ python dynamiclr.py
...(snip)...
min loss value (lr=dynamic): 0.421097448234137
min loss value (lr=0.1): 0.42265029913071334
min loss value (lr=0.05): 0.4221001657363532
min loss value (lr=0.01): 0.42100303104081405
min loss round (lr=dynamic): 84
min loss round (lr=0.1): 18
min loss round (lr=0.05): 38
min loss round (lr=0.01): 196

そして、各条件における検証用データに対する評価指標の推移をプロットしたグラフが次の通り。 学習率を動的に制御しているパターンは、0.1 固定ほどではないにせよ早く性能が収束していることが分かる。 まあ、とはいえこれくらいなら lr=0.01 ~ 0.05 の間に似たような特性の学習率がいるかもしれない。

いじょう。 こんな上手くいくスケジューラが書けた、みたいな話があったら教えてほしいな。

GNU Coreutils の shred でストレージのデータを削除する

HDD や SSD といったストレージを廃棄あるいは売却するとき、単に保存されているファイルを削除しただけでは復元のリスクが高い。 これは、本のメタファーでいえば索引の部分を消しただけで本文は丸々残っている、といった状況になっているため。 そこで、何回かデータを実際に上書きすることで復元のリスクを低減させることが望ましい。 今回は GNU Coreutils の shred を使ってランダムなデータをストレージに書き込んでみる。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132

まずは Homebrew で GNU Coreutils をインストールしておく。

$ brew install coreutils

続いて、データを消去したいストレージのデバイス名を diskutil list で確認しておく。 今回であれば /dev/disk2 にあった。

$ diskutil list
...(snip)...

/dev/disk2 (external, physical):
   #:                       TYPE NAME                    SIZE       IDENTIFIER
   0:     FDisk_partition_scheme                        *80.0 GB    disk2

あとは、このデバイスを -v オプションで指定して shred コマンドを使うだけ。 デフォルトではランダムなデータを 3 回に渡って書き込むことになる。

$ sudo shred -v /dev/disk2
shred: /dev/disk2: pass 1/3 (random)...
shred: /dev/disk2: pass 1/3 (random)...69MiB
shred: /dev/disk2: pass 1/3 (random)...144MiB
shred: /dev/disk2: pass 1/3 (random)...219MiB
shred: /dev/disk2: pass 1/3 (random)...293MiB
...(snip)

いじょう。

dd コマンドの進捗を確認する

dd コマンドの進捗を確認したいときは macOS であれば SIGINFO を、Linux (GNU Coreutils) であれば SIGUSR1 を送れば良い。 また、GNU Coreutils の dd には status=progress というオプションもある。

macOS

まずは macOS から。

使った環境は次の通り。

$ sw_vers                          
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132

適当にランダムな値でファイルを作らせる。

$ dd if=/dev/urandom of=example bs=1m count=1024

別のターミナルを開いたら killall を使って dd に SIGINFO を送りつける。

$ sudo killall -INFO dd

すると、次のように現状が表示される。

$ dd if=/dev/urandom of=example bs=1m count=1024
298+0 records in
298+0 records out
312475648 bytes transferred in 11.867853 secs (26329585 bytes/sec)

定期的に表示させたいときは watch コマンドと組み合わせると良い。

$ brew install watch

以下のようにすると 1 秒ごとに SIGINFO を送ることができる。

$ watch -n 1 sudo killall -INFO dd

結果として、次のように 1 秒ごとに進捗が表示される。

$ dd if=/dev/urandom of=example bs=1m count=1024
...(snip)...
501+0 records in
500+0 records out
524288000 bytes transferred in 19.181651 secs (27332788 bytes/sec)
528+0 records in
528+0 records out
553648128 bytes transferred in 20.254623 secs (27334408 bytes/sec)
557+0 records in
557+0 records out
584056832 bytes transferred in 21.357118 secs (27347175 bytes/sec)
...(snip)...

Linux (GNU Coreutils)

続いて Linux を。

使った環境は次の通り。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.2 LTS"
$ uname -r
4.15.0-54-generic
$ dd --version
dd (coreutils) 8.28
Copyright (C) 2017 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>.
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.

Written by Paul Rubin, David MacKenzie, and Stuart Kemp.

先ほどと同じように適当なファイルを作らせておく。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024

GNU Coreutils の dd であれば SIGUSR1 を送る。

$ sudo killall -USR1 dd

次のように進捗が表示される。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024
95+0 records in
95+0 records out
99614720 bytes (100 MB, 95 MiB) copied, 1.4874 s, 67.0 MB/s

定期的に表示させたいときは、先ほどと同じように watch と組み合わせれば良い。

$ sudo apt-get -y install procps
$ watch -n 1 sudo killall -USR1 dd

次のように定期的に進捗が表示されるようになる。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024
...(snip)...
348+3 records in
347+3 records out
365930880 bytes (366 MB, 349 MiB) copied, 6.29171 s, 58.2 MB/s
408+3 records in
407+3 records out
428845440 bytes (429 MB, 409 MiB) copied, 7.32246 s, 58.6 MB/s
467+3 records in
467+3 records out
491760000 bytes (492 MB, 469 MiB) copied, 8.34007 s, 59.0 MB/s
...(snip)...

あるいは、もっと単純に status=progress というオプションを付けても良い。

$ sudo dd if=/dev/urandom of=example bs=1M count=1024 status=progress
605028352 bytes (605 MB, 577 MiB) copied, 10 s, 60.4 MB/s

いじょう。