CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: ユニットテストを書いてみよう

ソフトウェアエンジニアにとって、不具合に対抗する最も一般的な方法は自動化されたテストを書くこと。 テストでは、書いたプログラムが誤った振る舞いをしないか確認する。 一口に自動テストといっても、扱うレイヤーによって色々なものがある。 今回は、その中でも最もプリミティブなテストであるユニットテストについて扱う。 ユニットテストでは、関数やクラス、メソッドといった単位の振る舞いについてテストを書いていく。

Python には標準ライブラリとして unittest というパッケージが用意されている。 これは、文字通り Python でユニットテストを書くためのパッケージとなっている。 このエントリでは、最初に unittest パッケージを使ってユニットテストを書く方法について紹介する。 その上で、さらに効率的にテストを記述するためにサードパーティ製のライブラリである pytest を使っていく。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V        
Python 3.7.3

はじめてのユニットテスト

まずは最も単純な例を使ってユニットテストの書き方を説明していく。

そもそもテストを書くからには、テストする対象が必要になる。 そこで、次のように greet() という関数を用意した。 この関数は、呼び出されると特定の文字列を返すようになっている。

# -*- coding: utf-8 -*-


def greet():
    """挨拶のメッセージを返す関数"""
    return 'Hello, World!'

上記の内容を helloworld.py という名前で保存する。

これで、上記を helloworld モジュールとしてインポートして使えるようになる。

$ python -c "import helloworld; print(helloworld.greet())"
Hello, World!

続いて、上記のモジュールに対応するテストを記述する。 まず、テストを書くには unittest.TestCase クラスを継承したクラスを定義する。 そのクラスの中にテストをメソッドとして記述していく。 テストのメソッドは、必ず名前の先頭が test から始まるようにする。 これは後述するテストランナーが名前を元にテストコードを探すため。 そして、テストではテスト対象から得られる値もしくは状態が期待する内容と一致するかを比較する。 比較する方法として unittest.TestCase クラスには assertEqual()assertTrue() といったメソッドが用意されている。

以下に unittest パッケージを使ったユニットテストのサンプルコードを示す。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

import helloworld  # テスト対象のモジュールをインポートする


class TestHelloWorld(unittest.TestCase):
    """helloworld モジュールのテストを記述するクラス"""

    def test_greet(self):
        """greet() 関数をテストするメソッド"""
        # テスト対象の関数を呼び出す
        message = helloworld.greet()
        # 関数の返り値が期待した内容と一致するか確認する
        self.assertEqual(message, 'Hello, World!')


if __name__ == '__main__':
    # スクリプトとして実行された場合の処理
    unittest.main(verbosity=2)

上記を test_helloworld.py という名前で保存しよう。 実はこの名前が重要で、後述するテストランナーは test から始まる名前を元にテストコードを探索する。

テストを実行する準備が整ったので、手始めに上記をスクリプトとして実行してみよう。 先ほどのテストコードは、スクリプトとして実行された場合にも unittest.main() 関数が呼ばれるようにしてある。

$ python test_helloworld.py
test_greet (__main__.TestHelloWorld) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

上記から、先ほど記述した test_greet() 関数が実行されて、テストが正しくパスしたことが分かる。

テストはスクリプトとして実行する以外にも、特定のディレクトリ以下から自動的に探して実行する方法もある。 それには、次のように Python のインタプリタで -m unittest として unittest モジュールが実行されるようにする。 その上で discover というコマンドを実行するとカレントディレクトリ以下のテストコードを名前を頼りに自動で探して実行できる。

$ python -m unittest discover -v
test_greet (test_helloworld.TestHelloWorld) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

なお、上記は Python の公式では「テストディスカバリ」という名前の機能として提供されている。 一般的には、テストを探して実行する機能は「テストランナー」と呼ばれる。

もう少し実用的な例を見てみる

先ほどはテストする対象が固定の文字列を返すだけだったので、あまりテストをする意味合いが感じられなかったかもしれない。 続いては、もう少しだけ実用的な例を見ていこう。

以下のサンプルコードでは、有名な FizzBuzz を実装している。 fizzbuzz() 関数では、渡された整数が 3 または 5 で割り切れるかを判定して返す値を切り替える。 3 と 5 の両方で割れるときは 'FizzBuzz' を、3 だけで割れるときは 'Fizz' を、5 だけで割れるときは 'Buzz' を返す。 なお、いずれでも割れないときは単に数字を文字列にして返すこととする。

# -*- coding: utf-8 -*-


def fizzbuzz(n):
    if n % 3 == 0 and n % 5 == 0:
        return 'FizzBuzz'

    if n % 3 == 0:
        return 'Fizz'

    if n % 5 == 0:
        return 'Buzz'

    return str(n)

上記を fizzbuzz.py という名前で保存する。

それでは、先ほどの FizzBuzz が正しく振る舞うかテストコードを書いて確かめてみることにしよう。 要領は先ほどと変わらない。 関数に入力される値と返り値に対して、期待される内容を比較していけば良い。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

import fizzbuzz


class TestFizzBuzz(unittest.TestCase):

    def test_fizzbuzz(self):
        # 代表的な入力と出力のパターンを列挙する
        expects = {
            1: '1',
            2: '2',
            3: 'Fizz',
            4: '4',
            5: 'Buzz',
            6: 'Fizz',
            7: '7',
            8: '8',
            9: 'Fizz',
            10: 'Buzz',
            11: '11',
            12: 'Fizz',
            13: '13',
            14: '14',
            15: 'FizzBuzz',
            16: '16',
        }
        for n, expect in expects.items():
            # 特定の入力に大して期待される値が返ってくるか確認する
            result = fizzbuzz.fizzbuzz(n)
            self.assertEqual(result, expect)


if __name__ == '__main__':
    unittest.main(verbosity=2)

上記を test_fizzbuzz.py という名前で保存しよう。

先ほどと同じようにテストディスカバリを実行してみよう。 新たに追加されたテストコードが正しくパスすれば上手くいっている。

$ python -m unittest discover -v
test_fizzbuzz (test_fizzbuzz.TestFizzBuzz) ... ok
test_greet (test_helloworld.TestHelloWorld) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.001s

OK

テストしにくい部分をモックと入れ替える

ユニットテストを書いていると、どうしてもテストしにくい部分が出てくる。 典型的な例としては、テスト対象が動作するのに何らかの依存関係があって、それがないと動かないような場合がある。 あるいは、厳密に処理するとあまりにも多くの時間がかかってしまうような場合も考えられる。 そういった場合には、依存している部分をモックと呼ばれる代用部品に入れ替えると良い。

テストしにくい例として以下のサンプルコードを用意した。 このコードの中では do_something() という関数が定義されている。 また、この関数は内部で _take_a_long_time_to_do() という時間のかかる処理を実行している。 なお、実際にやっている処理は最初の greet() 関数と同じで特定の文字列を返すだけとなっている。

# -*- coding: utf-8 -*-

import time


def do_something():
    _take_a_long_time_to_do()
    return 'Hello, World!'


def _take_a_long_time_to_do():
    time.sleep(10)

上記を foobar.py という名前で保存しておこう。

まずは愚直にテストコードを書いてみる

最初は、何も考えずに上記に対応するテストコードを書いてみよう。 やっていることは最初の例と何ら変わらない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

import foobar


class TestFooBar(unittest.TestCase):

    def test_do_something(self):
        # do_something() を呼ぶ
        message = foobar.do_something()
        # 返り値を比較する
        self.assertEqual(message, 'Hello, World!')


if __name__ == '__main__':
    unittest.main(verbosity=2)

上記を test_foobar.py という名前で保存しておく。

準備ができたら上記のテストコードを実行してみよう。 このテストが完了するには 10 秒を要する。

$ python test_foobar.py
test_do_something (__main__.TestFooBar) ... ok

----------------------------------------------------------------------
Ran 1 test in 10.001s

OK

内部的に読んでいる関数をモックに置き換えてみる

続いては、テストコードの実行時間を短縮するために内部的に呼んでいる関数をモックに置き換えてみよう。 モックへの置き換えはいくつかのやり方があるものの、今回は @patch デコレータを使うことにする。

以下のサンプルコードでは foobar モジュールの _take_a_long_time_to_do() 関数をモックに置き換えている。 モックへの置き換えは、テストコードに @patch() デコレータで置き換えたいオブジェクトのパスを指定する。 置き換えられたオブジェクトの振る舞いは、テストコードのメソッドに引数として渡されるモックオブジェクトでカスタマイズできる。 ただし、今回は置き換えるオブジェクトの動作にテスト対象の関数が特に依存していないので特にカスタマイズは必要ない。 本来であれば返り値やプロパティをいじることになる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest
from unittest.mock import patch

import foobar


class TestFooBar(unittest.TestCase):

    # foobar._take_a_long_time_to_do() をモックに置き換える
    @patch('foobar._take_a_long_time_to_do')
    def test_do_something(self, patched_object):
        # モックに置き換えられた状態で do_something() 関数をテストする
        message = foobar.do_something()
        self.assertEqual(message, 'Hello, World!')
        # モックが呼び出されたことを確認する
        self.assertTrue(patched_object.called)


if __name__ == '__main__':
    unittest.main(verbosity=2)

先ほどと同じようにテストコードを実行してみよう。 今度は 10 秒もかからずにテストが完了する。

$ python test_foobar.py 
test_do_something (__main__.TestFooBar) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

本当に必要な部分をモックに置き換える

先ほどの例では、テスト対象が内部的に呼び出している関数をモックに置き換えることで実行時間を短縮できた。 しかし、実は先ほどのやり方には問題がある。 というのも、モックに置き換えたのがアンダースコアから始まる隠し関数だったため。 これの何が問題かというと、テストコードが実装に依存することを意味している。

テストコードが実装に依存することの最大の問題点は、メンテナンスコストが高くつくこと。 例えば、リファクタリングなどをしただけでもテストが正しくパスしなくなる恐れがある。 そのため、テストコードは外部に公開しているインターフェースに対して記述するのが基本となる。 もし、内部的にしか呼び出されていない関数にテストコードを書いていると感じたなら、それは要注意な状態といえる。 テストを書くのであれば、まずインターフェースは何処なのか、それはどう振る舞うべきなのかを考えた上で書くようにしよう。

例えば先ほどの例であれば、内部的に呼んでいる _take_a_long_time_to_do() 関数よりも time.sleep() 関数をモックに置き換えてしまった方が良いかもしれない。 この先 time.sleep() の挙動が変わるような事態は、ちょっとやそっとでは起こらないだろう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest
from unittest.mock import patch

import foobar


class TestFooBar(unittest.TestCase):

    # time.sleep() をモックに置き換えてみる
    @patch('time.sleep')
    def test_do_something(self, patched_object):
        message = foobar.do_something()
        self.assertEqual(message, 'Hello, World!')
        self.assertTrue(patched_object.called)


if __name__ == '__main__':
    unittest.main(verbosity=2)

実行結果は先ほどと変わらないけど、変更に対する耐性は先ほどとは段違いなはず。

$ python test_foobar.py 
test_do_something (__main__.TestFooBar) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

標準ライブラリの unittest を使った例は、ここまでで一旦おわりにする。

pytest で、より効率的にテストを書く

ここからは、サードパーティ製のテストフレームワークである pytest について見ていこう。 実際のところ、巷のライブラリなどで標準ライブラリの unittest をそのまま使ってテストを書いている例は少ない。 多くの場合、サードパーティ製のテストフレームワークとして pytest や nose などを使う場合が多い。 その中でも、最近は pytest がデファクトになりつつある。

pytest はサードパーティ製のライブラリなので pip を使ってインストールする必要がある。

$ pip install pytest

実は pytest は unittest と上位互換性がある。 そのため、既存の unittest を使ったプロジェクトにも後から導入しやすい。 試しに pytest のテストランナーで、これまでに書いた unittest のテストコードを実行してみよう。

$ pytest -v         
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python3.7
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
collected 3 items                                                                                                                                                  

test_fizzbuzz.py::TestFizzBuzz::test_fizzbuzz PASSED                                                                                                         [ 33%]
test_foobar.py::TestFooBar::test_do_something PASSED                                                                                                         [ 66%]
test_helloworld.py::TestHelloWorld::test_greet PASSED                                                                                                        [100%]

===================================================================== 3 passed in 0.08 seconds =====================================================================

ちゃんとテストが実行できてパスしたことが分かる。

最初の例を pytest 流に書き直してみる

先ほどは unittest で書いたテストコードも pytest から実行できることを示した。 とはいえ、pytest には pytest 流のテストコードの書き方がある。 試しに、最初に書いたテストコードを pytest 流に書き直してみよう。

書き直したサンプルコードが次の通り。 最初の例よりも、だいぶこざっぱりしている。 例えばテストを書くのにクラスを定義する必要はなく、単なる関数で構わない。 また、値を比較するにも専用の関数やメソッドは必要なくて単なる assert 文を使っている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pytest

import helloworld


def test_greet():
    """テストコードは単なる関数で良い"""
    message = helloworld.greet()
    # 比較は assert 文を使うだけで良い
    assert message == 'Hello, World!'


if __name__ == '__main__':
    pytest.main(['-v', __file__])

上記を実行してみよう。 こざっぱりした内容でも、ちゃんとテストとして機能していることが分かる。

$ python test_helloworld.py 
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
collected 1 item                                                                                                                                                   

test_helloworld.py::test_greet PASSED                                                                                                                        [100%]

===================================================================== 1 passed in 0.03 seconds =====================================================================

FizzBuzz のテストも書き直してみる

続いては FizzBuzz のテストも書き直してみよう。 こちらは、pytest の parametrize という機能を使うとキレイに書くことができる。

以下が parametrize を使った FizzBuzz のテストコードになる。 この機能ではデコレータを使うことでテストの外側に入力と期待される出力の組を定義できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pytest

import fizzbuzz


# 入力と期待される出力を parametrize で定義する
@pytest.mark.parametrize('n, expect', [
    (1, '1'),
    (2, '2'),
    (3, 'Fizz'),
    (4, '4'),
    (5, 'Buzz'),
    (6, 'Fizz'),
    (7, '7'),
    (8, '8'),
    (9, 'Fizz'),
    (10, 'Buzz'),
    (11, '11'),
    (12, 'Fizz'),
    (13, '13'),
    (14, '14'),
    (15, 'FizzBuzz'),
    (16, '16'),
])
def test_fizzbuzz(n, expect):
    # テストコードがシンプルに保たれる
    assert fizzbuzz.fizzbuzz(n) == expect


if __name__ == '__main__':
    pytest.main(['-v', __file__])

上記を実行すると、各パラメータの組み合わせに応じてテストが走ることが確認できる。

$ python test_fizzbuzz.py 
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
collected 16 items                                                                                                                                                 

test_fizzbuzz.py::test_fizzbuzz[1-1] PASSED                                                                                                                  [  6%]
test_fizzbuzz.py::test_fizzbuzz[2-2] PASSED                                                                                                                  [ 12%]
test_fizzbuzz.py::test_fizzbuzz[3-Fizz] PASSED                                                                                                               [ 18%]
test_fizzbuzz.py::test_fizzbuzz[4-4] PASSED                                                                                                                  [ 25%]
test_fizzbuzz.py::test_fizzbuzz[5-Buzz] PASSED                                                                                                               [ 31%]
test_fizzbuzz.py::test_fizzbuzz[6-Fizz] PASSED                                                                                                               [ 37%]
test_fizzbuzz.py::test_fizzbuzz[7-7] PASSED                                                                                                                  [ 43%]
test_fizzbuzz.py::test_fizzbuzz[8-8] PASSED                                                                                                                  [ 50%]
test_fizzbuzz.py::test_fizzbuzz[9-Fizz] PASSED                                                                                                               [ 56%]
test_fizzbuzz.py::test_fizzbuzz[10-Buzz] PASSED                                                                                                              [ 62%]
test_fizzbuzz.py::test_fizzbuzz[11-11] PASSED                                                                                                                [ 68%]
test_fizzbuzz.py::test_fizzbuzz[12-Fizz] PASSED                                                                                                              [ 75%]
test_fizzbuzz.py::test_fizzbuzz[13-13] PASSED                                                                                                                [ 81%]
test_fizzbuzz.py::test_fizzbuzz[14-14] PASSED                                                                                                                [ 87%]
test_fizzbuzz.py::test_fizzbuzz[15-FizzBuzz] PASSED                                                                                                          [ 93%]
test_fizzbuzz.py::test_fizzbuzz[16-16] PASSED                                                                                                                [100%]

==================================================================== 16 passed in 0.07 seconds =====================================================================

多彩なプラグインを使いこなす

pytest には色々な機能を持ったプラグインが存在することも魅力の一つといえる。

例えばテストを実行するのと一緒に flake8 を実行できる pytest-flake8 は使われることが多い。

$ pip install pytest-flake8

このプラグインを使うと、テストランナーに --flake8 オプションを渡すことで flake8 を実行できるようになる。

$ pytest -v --flake8
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python3.7
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
plugins: flake8-1.0.4
collected 24 items                                                                                                                                                 

fizzbuzz.py::FLAKE8 PASSED                                                                                                                                   [  4%]
foobar.py::FLAKE8 PASSED                                                                                                                                     [  8%]
helloworld.py::FLAKE8 PASSED                                                                                                                                 [ 12%]
test_fizzbuzz.py::FLAKE8 PASSED                                                                                                                              [ 16%]
test_fizzbuzz.py::test_fizzbuzz[1-1] PASSED                                                                                                                  [ 20%]
test_fizzbuzz.py::test_fizzbuzz[2-2] PASSED                                                                                                                  [ 25%]
test_fizzbuzz.py::test_fizzbuzz[3-Fizz] PASSED                                                                                                               [ 29%]
test_fizzbuzz.py::test_fizzbuzz[4-4] PASSED                                                                                                                  [ 33%]
test_fizzbuzz.py::test_fizzbuzz[5-Buzz] PASSED                                                                                                               [ 37%]
test_fizzbuzz.py::test_fizzbuzz[6-Fizz] PASSED                                                                                                               [ 41%]
test_fizzbuzz.py::test_fizzbuzz[7-7] PASSED                                                                                                                  [ 45%]
test_fizzbuzz.py::test_fizzbuzz[8-8] PASSED                                                                                                                  [ 50%]
test_fizzbuzz.py::test_fizzbuzz[9-Fizz] PASSED                                                                                                               [ 54%]
test_fizzbuzz.py::test_fizzbuzz[10-Buzz] PASSED                                                                                                              [ 58%]
test_fizzbuzz.py::test_fizzbuzz[11-11] PASSED                                                                                                                [ 62%]
test_fizzbuzz.py::test_fizzbuzz[12-Fizz] PASSED                                                                                                              [ 66%]
test_fizzbuzz.py::test_fizzbuzz[13-13] PASSED                                                                                                                [ 70%]
test_fizzbuzz.py::test_fizzbuzz[14-14] PASSED                                                                                                                [ 75%]
test_fizzbuzz.py::test_fizzbuzz[15-FizzBuzz] PASSED                                                                                                          [ 79%]
test_fizzbuzz.py::test_fizzbuzz[16-16] PASSED                                                                                                                [ 83%]
test_foobar.py::FLAKE8 PASSED                                                                                                                                [ 87%]
test_foobar.py::TestFooBar::test_do_something PASSED                                                                                                         [ 91%]
test_helloworld.py::FLAKE8 PASSED                                                                                                                            [ 95%]
test_helloworld.py::test_greet PASSED                                                                                                                        [100%]

==================================================================== 24 passed in 0.28 seconds =====================================================================

あるいはテストカバレッジを計測するための pytest-cov というプラグインも有名。 これは ptyest と coverage のインテグレーションを提供している。

$ pip install pytest-cov

このプラグインは --cov というオプションをつけることでテストカバレッジの計測ができるようになる。

$ pytest -v --cov=.
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python3.7
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
plugins: cov-2.7.1, flake8-1.0.4
collected 18 items                                                                                                                                                 

test_fizzbuzz.py::test_fizzbuzz[1-1] PASSED                                                                                                                  [  5%]
test_fizzbuzz.py::test_fizzbuzz[2-2] PASSED                                                                                                                  [ 11%]
test_fizzbuzz.py::test_fizzbuzz[3-Fizz] PASSED                                                                                                               [ 16%]
test_fizzbuzz.py::test_fizzbuzz[4-4] PASSED                                                                                                                  [ 22%]
test_fizzbuzz.py::test_fizzbuzz[5-Buzz] PASSED                                                                                                               [ 27%]
test_fizzbuzz.py::test_fizzbuzz[6-Fizz] PASSED                                                                                                               [ 33%]
test_fizzbuzz.py::test_fizzbuzz[7-7] PASSED                                                                                                                  [ 38%]
test_fizzbuzz.py::test_fizzbuzz[8-8] PASSED                                                                                                                  [ 44%]
test_fizzbuzz.py::test_fizzbuzz[9-Fizz] PASSED                                                                                                               [ 50%]
test_fizzbuzz.py::test_fizzbuzz[10-Buzz] PASSED                                                                                                              [ 55%]
test_fizzbuzz.py::test_fizzbuzz[11-11] PASSED                                                                                                                [ 61%]
test_fizzbuzz.py::test_fizzbuzz[12-Fizz] PASSED                                                                                                              [ 66%]
test_fizzbuzz.py::test_fizzbuzz[13-13] PASSED                                                                                                                [ 72%]
test_fizzbuzz.py::test_fizzbuzz[14-14] PASSED                                                                                                                [ 77%]
test_fizzbuzz.py::test_fizzbuzz[15-FizzBuzz] PASSED                                                                                                          [ 83%]
test_fizzbuzz.py::test_fizzbuzz[16-16] PASSED                                                                                                                [ 88%]
test_foobar.py::TestFooBar::test_do_something PASSED                                                                                                         [ 94%]
test_helloworld.py::test_greet PASSED                                                                                                                        [100%]

---------- coverage: platform darwin, python 3.7.3-final-0 -----------
Name                 Stmts   Miss  Cover
----------------------------------------
fizzbuzz.py              8      0   100%
foobar.py                6      0   100%
helloworld.py            2      0   100%
test_fizzbuzz.py         6      1    83%
test_foobar.py          10      1    90%
test_helloworld.py       7      1    86%
----------------------------------------
TOTAL                   39      3    92%


==================================================================== 18 passed in 0.13 seconds =====================================================================

一般的な pytest のディレクトリ構成

ところでここまでテスト対象とテストコードを一つのディレクトリに雑然と放り込んできた。 巷のライブラリなどを見ると pytest を使ったプロジェクトでは、次のように tests というディレクトリを専用に用意することが多いように思う。

$ mkdir tests
$ mv test_*.py tests
$ touch tests/__init__.py

テスト対象のモジュール・パッケージについては、tests と同じ階層の別ディレクトリに入れられる場合が多い。

$ mkdir example
$ mv *.py example
$ touch example/__init__.py

ようするに、こんな感じ。

$ tree
.
├── example
│   ├── __init__.py
│   ├── fizzbuzz.py
│   ├── foobar.py
│   └── helloworld.py
└── tests
    ├── __init__.py
    ├── test_fizzbuzz.py
    ├── test_foobar.py
    └── test_helloworld.py

2 directories, 8 files

ただ、上記のような変更を加えると、先ほど書いたテストコードは少しだけ修正が必要になる。 というのも helloworldfizzbuzz モジュールが example パッケージ配下に移動しているため。 そこで、次のようにインポート文を修正する。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pytest

# example 配下にある helloworld モジュールをインポートする
from example import helloworld


def test_greet():
    message = helloworld.greet()
    assert message == 'Hello, World!'


if __name__ == '__main__':
    pytest.main(['-v', __file__])

インポート文を変更したらテストランナーを実行してみよう。 次のように、ちゃんとテストがパスすれば上手くいっている。

$ pytest -v        
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.7.3, pytest-4.6.3, py-1.8.0, pluggy-0.12.0 -- /Users/amedama/.virtualenvs/py37/bin/python3.7
cachedir: .pytest_cache
rootdir: /Users/amedama/Documents/temporary/ut
plugins: cov-2.7.1, flake8-1.0.4
collected 18 items                                                                                                                                                 

tests/test_fizzbuzz.py::test_fizzbuzz[1-1] PASSED                                                                                                            [  5%]
tests/test_fizzbuzz.py::test_fizzbuzz[2-2] PASSED                                                                                                            [ 11%]
tests/test_fizzbuzz.py::test_fizzbuzz[3-Fizz] PASSED                                                                                                         [ 16%]
tests/test_fizzbuzz.py::test_fizzbuzz[4-4] PASSED                                                                                                            [ 22%]
tests/test_fizzbuzz.py::test_fizzbuzz[5-Buzz] PASSED                                                                                                         [ 27%]
tests/test_fizzbuzz.py::test_fizzbuzz[6-Fizz] PASSED                                                                                                         [ 33%]
tests/test_fizzbuzz.py::test_fizzbuzz[7-7] PASSED                                                                                                            [ 38%]
tests/test_fizzbuzz.py::test_fizzbuzz[8-8] PASSED                                                                                                            [ 44%]
tests/test_fizzbuzz.py::test_fizzbuzz[9-Fizz] PASSED                                                                                                         [ 50%]
tests/test_fizzbuzz.py::test_fizzbuzz[10-Buzz] PASSED                                                                                                        [ 55%]
tests/test_fizzbuzz.py::test_fizzbuzz[11-11] PASSED                                                                                                          [ 61%]
tests/test_fizzbuzz.py::test_fizzbuzz[12-Fizz] PASSED                                                                                                        [ 66%]
tests/test_fizzbuzz.py::test_fizzbuzz[13-13] PASSED                                                                                                          [ 72%]
tests/test_fizzbuzz.py::test_fizzbuzz[14-14] PASSED                                                                                                          [ 77%]
tests/test_fizzbuzz.py::test_fizzbuzz[15-FizzBuzz] PASSED                                                                                                    [ 83%]
tests/test_fizzbuzz.py::test_fizzbuzz[16-16] PASSED                                                                                                          [ 88%]
tests/test_foobar.py::TestFooBar::test_do_something PASSED                                                                                                   [ 94%]
tests/test_helloworld.py::test_greet PASSED                                                                                                                  [100%]

==================================================================== 18 passed in 0.13 seconds =====================================================================

そんなかんじで。

インターネットに疎通のないマシンに SSH Remote Port Forwarding + Squid で Web にアクセスさせる

インターネットに直接つながっていないマシンというのは意外とよくある。 とはいえ、そういったマシンでも当然のことながらセットアップ等の作業は必要になる。 その際、作業に必要なファイルは大抵の場合に SCP などで転送することになると思う。 とはいえ、もしマシンから直接 Web につながれば楽ができるはず。 今回は、そういった状況で SSH Remote Port Forwarding と Squid (Web プロキシのソフトウェア) を使って Web への疎通を提供する方法を試してみる。 なお、提供できるのはあくまで HTTP/HTTPS 等に限られ、ICMP や UDP といったプロトコルは通らない。

構成について

最初に、ざっくりとした構成図を以下に示す。 インターネットにつながらないマシンは SSH Server を想定する。 そして、SSH Client がインターネットにつながるマシンで、SSH Server に接続しに行く。 SSH Client のマシンでは Squid を 8080 ポートで起動している。 今回の主眼は、SSH Client のマシンの 8080 ポートを Remote Port Forwarding で SSH Server にも見えるようにするところ。 SSH Server は自身の 8080 ポートを Web プロキシとして利用することで、Web への疎通が手に入る。

f:id:momijiame:20190616102958p:plain

使った環境について

続いては、今回の検証に使った環境について説明する。

インターネットにつながっているマシン (SSH Client) としては、以下の通り macOS のマシンを用意した。

$ sw_vers    
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132

このマシンは、次のようにインターネットに疎通がある。

$ ping -c 3 8.8.8.8
PING 8.8.8.8 (8.8.8.8): 56 data bytes
64 bytes from 8.8.8.8: icmp_seq=0 ttl=56 time=9.156 ms
64 bytes from 8.8.8.8: icmp_seq=1 ttl=56 time=30.316 ms
64 bytes from 8.8.8.8: icmp_seq=2 ttl=56 time=18.479 ms

--- 8.8.8.8 ping statistics ---
3 packets transmitted, 3 packets received, 0.0% packet loss
round-trip min/avg/max/stddev = 9.156/19.317/30.316/8.659 ms

そして、インターネットにつながっていないマシン (SSH Server) は、以下の通り Ubuntu 18.04 LTS のマシンを用意した。 このマシンは VirtualBox を使って先ほどの macOS 上で仮想マシンとして稼働させている。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.2 LTS"
$ uname -r
4.15.0-51-generic

こちらは以下のようにインターネットに疎通がない。 かろうじて 192.168.33.0/24 のネットワークを経由して前述した macOS のマシンと疎通がある。

$ ip addr show
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
    inet6 ::1/128 scope host 
       valid_lft forever preferred_lft forever
2: enp0s8: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc fq_codel state UP group default qlen 1000
    link/ether 08:00:27:42:06:7b brd ff:ff:ff:ff:ff:ff
    inet 192.168.33.10/24 brd 192.168.33.255 scope global enp0s8
       valid_lft forever preferred_lft forever
    inet6 fe80::a00:27ff:fe42:67b/64 scope link 
       valid_lft forever preferred_lft forever
$ ip route show
192.168.33.0/24 dev enp0s8 proto kernel scope link src 192.168.33.10 
$ ping -c 3 8.8.8.8
connect: Network is unreachable

なお、このマシンは Vagrant + VirtualBox を使って仮想マシンとして用意した。 前述した macOS 上で稼働している。 ただし SSH については vagrant コマンドを使う代わりに、次のようにしてログインしている。

$ ssh -i .vagrant/machines/default/virtualbox/private_key \
  -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
  -l vagrant 192.168.33.10

Web プロキシをインストールする

ここからは、実際に SSH Server のマシンに Web への疎通を提供するまでに必要な作業について記載していく。 まずは SSH Client 側のマシンに Squid (Web プロキシ) をインストールする。

使っているのが macOS なので Homebrew を使うと楽にインストールできる。

$ brew install squid

コンフィグを編集する場合には、初期状態のものをコピーしてバックアップしておく。

$ cp $(brew --prefix)/etc/squid.conf{,.bak}

今回はコンフィグの編集例として、使うポートをよく利用される 8080 に変更してみる。

$ sed -i -e "
  s:^http_port.*$:http_port 8080:
" $(brew --prefix)/etc/squid.conf
$ grep http_port $(brew --prefix)/etc/squid.conf      
http_port 8080

あと、念のため次のようにしてどのような ACL が入っているのか、あらかじめ確認しておいた方が良いと思う。 オープンプロキシになってると危険なので。 とはいえ、おそらくデフォルトでプライベートアドレスからのアクセスしか認めないようになっているはず。

$ grep acl $(brew --prefix)/etc/squid.conf
$ grep deny $(brew --prefix)/etc/squid.conf

Squid のサービスを開始する。

$ brew services start squid

次のように 8080 ポートを Listen していれば良い。

$ lsof -i:8080 | grep -i squid
squid     58763 amedama   14u  IPv6 0x4e9127c629e616d5      0t0  TCP *:http-alt (LISTEN)

Remote Port Forwarding で Web プロキシを利用する

続いては SSH Server 側の作業に入る。

やることは単純で、次のように SSH するときに -R オプションで Remote Port Forwarding する。 以下では自身の 8080 ポートをリモートの localhost:8080 で見られるようにしている。

$ ssh -i .vagrant/machines/default/virtualbox/private_key \
      -R 8080:localhost:8080 \
      -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
      -l vagrant 192.168.33.10

ログインしたらポートを localhost:8080 で Listen していることを確認しておこう。

$ ss -tlnp | grep 8080
LISTEN   0         128               127.0.0.1:8080             0.0.0.0:*       
LISTEN   0         128                   [::1]:8080                [::]:*       

これで localhost:8080 経由で Web プロキシが使えるようになった。 あとは一般的なプロキシを使うのと同じ。 例えば環境変数に設定を入れておこう。

$ export http_proxy=http://localhost:8080
$ export https_proxy=$http_proxy

試しに wget を使って Ubuntu のイメージファイルのハッシュファイルを取得してみよう。

$ wget http://ftp.riken.jp/Linux/ubuntu-releases/bionic/SHA256SUMS

次のように、ちゃんと取得できれば上手くいっている。

$ cat SHA256SUMS 
22580b9f3b186cc66818e60f44c46f795d708a1ad86b9225c458413b638459c4 *ubuntu-18.04.2-desktop-amd64.iso
ea6ccb5b57813908c006f42f7ac8eaa4fc603883a2d07876cf9ed74610ba2f53 *ubuntu-18.04.2-live-server-amd64.iso

いじょう。

Squid Proxy Server 3.1: Beginner's Guide (English Edition)

Squid Proxy Server 3.1: Beginner's Guide (English Edition)

Squid: The Definitive Guide: The Definitive Guide (Definitive Guides) (English Edition)

Squid: The Definitive Guide: The Definitive Guide (Definitive Guides) (English Edition)

Python: 定期実行のアルゴリズムについて

今回は割と小ネタで、特定の処理を定期実行するようなプログラムを書く場合について考えてみる。 ただし、前提としてあくまで定期実行は Python の中で処理して cron 的なものには頼らないものとする。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V       
Python 3.7.3

ダメなパターン: 定期実行の時間だけ単純に sleep する

最初に考えられるのは、定期実行したい間隔で time.sleep() のような関数を使ってインターバルを入れるというもの。 ただし、このパターンでは肝心の定期実行したい処理にかかる時間が考慮できていない。

以下のサンプルコードでは 3 秒ごとに定期実行しているつもりでいる。 しかし、肝心の定期実行したい処理には 2 秒かかっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 2 秒くらいかかる
    time.sleep(2)


def schedule(interval_sec, callable_task, args=None, kwargs=None):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    while True:
        callable_task(*args, **kwargs)  # ここで時間を食う
        time.sleep(interval_sec)  # さらにスリープしてしまう


def main():
    # 3 秒ごとに実行している...つもり
    schedule(interval_sec=3, callable_task=interval_task)


if __name__ == '__main__':
    main()

上記を実行してみる。 表示を見て分かる通り 3 + 2 = 5 秒の間隔で時刻が表示されてしまっている。

$ python sched1.py 
11:24:16.698956
11:24:21.709522
11:24:26.710433
11:24:31.718227
...

ダメなパターン: 処理にかかる時間を開始・終了の前後で毎回計測して計算する

続いて考えられるのが、定期実行したい処理の前後で開始・終了時刻を計測してスリープする時間を補正するというもの。 ただし、この場合は定期実行したい処理以外にかかる処理時間が考慮できていない。

以下のサンプルコードでは定期実行したい処理の前後で時刻を計測してスリープする時間を補正している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 2 秒くらいかかる
    time.sleep(2)


def schedule(interval_sec, callable_task, args=None, kwargs=None):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    while True:
        # 処理開始時間を取得する
        start_timing = datetime.now()

        callable_task(*args, **kwargs)

        # 処理完了時間を取得する
        end_timing = datetime.now()
        # 実行間隔との差分を取る
        time_delta_sec = (end_timing - start_timing).total_seconds()
        # スリープすべき時間を計算する
        sleep_sec = interval_sec - time_delta_sec

        time.sleep(max(sleep_sec, 0))


def main():
    # 3 秒ごとに実行される...と良いなあ
    schedule(interval_sec=3, callable_task=interval_task)


if __name__ == '__main__':
    main()

上記を実行してみると、だいたい 3 秒ごとに時刻が表示されるため一見すると上手くいっているように見える。 しかし、よく見ると 1 ~ 10 ミリ秒の単位は単調に時刻が増加していることが分かる。 これは、定期実行以外の処理にかかる時間が考慮できていないため、実際には間隔がわずかに長くなってしまっている。

$ python sched2.py 
11:25:41.820993
11:25:44.826223
11:25:47.831453
11:25:50.836683
11:25:53.840000
11:25:56.844355
11:25:59.849602
...

また、間隔が少し長くなる以外にも、もう一つの問題がある。 定期実行したい処理が実行間隔よりもかかってしまうと、スリープする時間がマイナスになってしまう。 そのため、スリープを全く入れなくても実際より実行間隔が長くなってしまう。 いわゆるバッチの突き抜けみたいな状態。

以下のサンプルコードでは実行間隔として 3 秒を意図しているにもかかわらず、実際には定期実行の処理には 4 秒かかる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 4 秒くらいかかる
    time.sleep(4)


def schedule(interval_sec, callable_task, args=None, kwargs=None):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    while True:
        start_timing = datetime.now()

        # ここが意図した実行間隔よりも長くかかる (バッチの突き抜け)
        callable_task(*args, **kwargs)

        end_timing = datetime.now()
        time_delta_sec = (end_timing - start_timing).total_seconds()
        sleep_sec = interval_sec - time_delta_sec

        # sleep_sec が負の値になる
        time.sleep(max(sleep_sec, 0))


def main():
    schedule(interval_sec=3, callable_task=interval_task)


if __name__ == '__main__':
    main()

上記を実行すると、次のように本来意図した 3 秒ではなく 4 秒間隔で実行される。

$ python sched3.py 
11:26:52.933087
11:26:56.938491
11:27:00.943913
11:27:04.948757
11:27:08.949384

解決策: 特定の基準時刻を元にスリープする時間を補正しつつ別のスレッドで実行する

先ほどの問題点を解決するために二つの施策が必要となる。 まずひとつ目は一つの基準時刻を設けて、それを元にスリープする時刻を補正する。 これで、実行間隔がやや長くなってしまう問題が解決できる。 もうひとつは定期実行の処理を別のスレッドで実行することで、バッチの突き抜けを防止できる。

以下のサンプルコードでは特定の基準時刻を元にスリープする時間を補正している。 具体的には、処理の最初で取得した時刻から剰余演算でスリープすべき時間を計算する。 その上で定期実行の処理は別のスレッドを起動している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time
import threading


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 4 秒くらいかかる
    time.sleep(4)


def schedule(interval_sec, callable_task,
             args=None, kwargs=None):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    # 基準時刻を作る
    base_timing = datetime.now()
    while True:
        # 処理を別スレッドで実行する
        t = threading.Thread(target=callable_task,
                             args=args, kwargs=kwargs)
        t.start()

        # 基準時刻と現在時刻の剰余を元に、次の実行までの時間を計算する
        current_timing = datetime.now()
        elapsed_sec = (current_timing - base_timing).total_seconds()
        sleep_sec = interval_sec - (elapsed_sec % interval_sec)

        time.sleep(max(sleep_sec, 0))


def main():
    schedule(interval_sec=3, callable_task=interval_task)


if __name__ == '__main__':
    main()

上記の実行結果は次の通り。 定期実行の処理に 4 秒かかったとしても、正しく 3 秒間隔で処理が実行できていることが分かる。 また、ミリ秒単位についても単調増加していない。

$ python sched4.py 
11:27:38.941382
11:27:41.946648
11:27:44.942504
11:27:47.946613
11:27:50.942494
...

オプション: スレッドプールを使う

先ほどの処理では単純に定期実行の度にスレッドを起動していた。 しかし、スレッドの生成には時間的・空間的な計算量がかかる。 もし、オーバーヘッドを小さくしたり、メモリを過剰に使われたくないときはスレッドプールを利用することが検討できるはず。

以下のサンプルコードではワーカーが 10 のスレッドプールを使って先ほどと同じ処理を実行している。 これで、10 を越えるスレッドが同時に生成されることがなくなる。 また、一度作られたスレッドは再利用されるため時間的な計算量でもわずかながら有利になるはず。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time
from concurrent.futures import ThreadPoolExecutor


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 4 秒くらいかかる
    time.sleep(4)


def schedule(interval_sec, callable_task,
             args=None, kwargs=None,
             workers_n=10):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    base_timing = datetime.now()

    # 必要以上にスレッドが生成されないようにスレッドプールを使う
    with ThreadPoolExecutor(max_workers=workers_n) as executor:
        while True:
            future = executor.submit(callable_task,
                                     *args, **kwargs)

            current_timing = datetime.now()
            elapsed_sec = (current_timing - base_timing).total_seconds()
            sleep_sec = interval_sec - (elapsed_sec % interval_sec)

            time.sleep(max(sleep_sec, 0))


def main():
    schedule(interval_sec=3, callable_task=interval_task)


if __name__ == '__main__':
    main()

上記の実行結果は次の通り。

$ python sched5.py 
11:30:30.000741
11:30:33.005632
11:30:36.002308
11:30:39.003651
11:30:42.002279
...

ただし、スレッドプールにはプールの上限に達したときにバッチの突き抜けが起こるという問題がある。 メモリが枯渇して OOM Killer に殺されるか、殺されないけど突き抜けるかは状況によって選ぶのが良いと思う。 とはいえメモリが枯渇するほどスケジュール実行のスレッドが生成される状況って、暴走しているような場合くらいな気もする?

以下のサンプルコードではワーカーの数を 1 に制限することで、意図的に突き抜けを起こるようにしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import time
from concurrent.futures import ThreadPoolExecutor


def interval_task():
    """定期的に実行したい何か時間のかかる処理"""
    now = datetime.now()
    print(now.strftime('%H:%M:%S.%f'))
    # 実行に 4 秒くらいかかる
    time.sleep(4)


def schedule(interval_sec, callable_task,
             args=None, kwargs=None,
             workers_n=10):
    """何らかの処理を定期的に実行する関数"""
    args = args or []
    kwargs = kwargs or {}
    base_timing = datetime.now()

    with ThreadPoolExecutor(max_workers=workers_n) as executor:
        while True:
            future = executor.submit(callable_task,
                                     *args, **kwargs)

            current_timing = datetime.now()
            elapsed_sec = (current_timing - base_timing).total_seconds()
            sleep_sec = interval_sec - (elapsed_sec % interval_sec)

            time.sleep(max(sleep_sec, 0))


def main():
    # 並列度を 1 にして時間のかかる処理を 1 秒ごとに実行した場合
    # スレッドが空くまで待たされる
    schedule(interval_sec=1, callable_task=interval_task,
             workers_n=1)


if __name__ == '__main__':
    main()

上記を実行してみよう。 たしかに突き抜けて処理に 4 秒かかっていることが分かる。

$ python sched7.py 
11:32:48.013018
11:32:52.014658
11:32:56.018407
11:33:00.024064
11:33:04.029571
...

なお、巷にはスケジュール実行するためのライブラリも色々とあって、それらを使うことで色々と楽ができる。 ただし、意図通りに動作させるためには上記のような考慮点についてあらかじめ検討しておく必要がある。 また、リアルタイム OS でない限り今回用いたようなコードで正しく定期実行されるという保証は実のところないはず。

いじょう。

Python: Keras で imdb データセットを読もうとするとエラーになる問題と回避策について

今回は、表題の通り Keras の API を使ってダウンロードできる imdb データセットを読もうとするとエラーになる問題について。

これは数ヶ月前から既知の問題で、以下のチケットが切られている。 内容については細かく読まなくても、詳しくは後述する。

github.com

問題を修正するコードは Git リポジトリの HEAD にはマージされている。 しかし、現時点 (2019-06-14) ではまだ修正済みのバージョンがリリースされていない。

github.com

そして、この問題について検索すると、以下の二つの回避策の提案が見つかる。

  • NumPy のバージョンを 1.16.2 以下にダウングレードする
  • インストール済みの Keras のソースコードを手動で書き換える

最初のやり方は、実は潜在的に脆弱性のある NumPy のバージョンを使うことを意味している。 また、二番目のやり方は正直あまりやりたくない類のオペレーションのはず。 そこで、上記とは異なる第三の回避策としてモンキーパッチを使う方法を提案してみる。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V                          
Python 3.7.3
$ pip list | egrep -i "(keras|numpy)"
Keras               2.2.4  
Keras-Applications  1.0.8  
Keras-Preprocessing 1.1.0  
numpy               1.16.4 

再現環境を作る

ひとまず再現環境を作るための準備として Keras とバックエンドの TensorFlow をインストールしておく。

$ pip install keras tensorflow

準備ができたら Python のインタプリタを起動する。

$ python

問題を再現する

この問題を再現するのは非常に簡単で、ただ imdb データセットを読み込もうとすれば良い。

まずは imdb モジュールをインポートする。

>>> from keras.datasets import imdb
Using TensorFlow backend.

そして、load_data() 関数を呼ぶだけ。 すると、以下のように例外になってしまう。

>>> imdb.load_data()
Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 3s 0us/step
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/keras/datasets/imdb.py", line 59, in load_data
    x_train, labels_train = f['x_train'], f['y_train']
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/numpy/lib/npyio.py", line 262, in __getitem__
    pickle_kwargs=self.pickle_kwargs)
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/numpy/lib/format.py", line 696, in read_array
    raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False

問題の詳細

この問題は NumPy の脆弱性に対する対応と、imdb が Pickle 形式の npz フォーマットで配布されていることに起因している。

まず、発端は以下の脆弱性 CVE-2019-6446 に始まる。 この脆弱性は、誤って信頼できない (細工された) Pickle を NumPy で読み込んでしまうと任意のコード実行が生じるというもの。

nvd.nist.gov

上記の脆弱性に対する対応として、NumPy はバージョン 1.16.3 以降で以下のようにコードを修正した。 具体的には、意図的にフラグ (allow_pickle=True) を有効にしない限り Pickle フォーマットのデータを読めないようにしている。

github.com

その煽りを受けたのが Keras の imdb データセットだった。 Pickle 形式のデータセットを NumPy デフォルトのオプションで読み込んでいた。 そのため、前述したように NumPy 1.16.3 以降を使うと例外になってしまう。

上記のような事情があるため、前述した通り Web を探すと以下のような回避策が提案されている。

  • NumPy のバージョンを 1.16.2 以下にダウングレードする
  • インストール済みの Keras のソースコードを書き換える

とはいえ、どちらもあまりやりたくないのは前述した通り。

第三の選択肢 (モンキーパッチ)

そこで提案するのが、モンキーパッチを使うやり方。 これは、データセットを読み込むタイミングだけ、一時的にピンポイントでコードを動的に書き換えてしまうというもの。 問題は NumPy の load() 関数がデフォルトのオプションのまま呼び出される点にある。 だとすると、関数が呼ばれるタイミングだけオプションを一時的に上書きしてしまえば良い。

具体的には、次のように関数のパラメータを部分適用して上書きする。

>>> from functools import partial
>>> import numpy as np
>>> np.load = partial(np.load, allow_pickle=True)  # monkey patch

この状態なら、エラーにならずにデータセットを読み込むことができる。

>>> from keras.datasets import imdb
Using TensorFlow backend.
>>> imdb.load_data()  # エラーにならずデータが得られる

もし、そのままになっているのが気持ち悪いのであれば、読み込みが終わった後でまた元のパラメータに戻してやれば良い。

>>> np.load = partial(np.load, allow_pickle=False)

まあ次の Keras のリリース版が出るまでの短い間だけ必要な回避策だけど、スクリプト言語ならこんなやり方もありますよということで。

PythonとKerasによるディープラーニング

PythonとKerasによるディープラーニング

パスワード付き ZIP ファイルを hashcat + JtR + GPU で総当たりしてみる

少し前に以下のツイートが話題になっていた。 hashcat というツールと GTX 2080 Ti を 4 台積んだマシンで ZIP ファイルのパスワードを探索するというもの。 このツイートでは 15 桁までわずか 15 時間 (!) で探索できたとしている。 その探索速度はなんと 22.7 ZH/s (Z = ゼッタ = Giga<Tera<Peta<Exa<Zetta) に及ぶらしい。

ただし、これは PKWARE 社の暗号化方式に存在する脆弱性を利用して計算量を削減した場合の結果らしい。 一般的に用いられている形式については、ここまで高速には探索できないとのこと。 今回のエントリは、上記を見て一般的なものはどれくらいのスピードで探索できるのか気になって実際に試してみた。

使った環境は次の通り。 GPU には Tesla V100 を 1 台使っている。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.2 LTS"
$ uname -r
4.15.0-1033-gcp
$ nvidia-smi
Sat Jun  8 05:14:29 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    37W / 300W |      0MiB / 16130MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

下準備

まずは必要となるパッケージを最初に一通りインストールしておく。

$ sudo apt-get update
$ sudo apt-get -y install clinfo wget git zip p7zip-full build-essential libssl-dev zlib1g-dev 

OpenCL (NVIDIA CUDA Runtime) をインストールする

hashcat の動作には OpenCL のランタイムが必要になる。 そこで CUDA のランタイムをインストールする。

まず、以下の Web サイトから CUDA のインストール用リポジトリの入った deb ファイルを取得する。

developer.nvidia.com

wget などを使ってダウンロードしてくれば良い。

$ wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-repo-ubuntu1804_10.1.168-1_amd64.deb

リポジトリを登録して CUDA をインストールする。

$ sudo dpkg -i cuda-repo-ubuntu1804_10.1.168-1_amd64.deb
$ sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
$ sudo apt-get update
$ sudo apt-get -y install cuda

インストールが終わったら、一旦マシンを再起動しておく。

$ sudo shutdown -r now

すると、次のように NVIDIA のグラフィックドライバと CUDA がインストールされた。

$ nvidia-smi
Sat Jun  8 05:14:29 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    37W / 300W |      0MiB / 16130MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

次のように OpenCL のランタイムが認識されている。

$ clinfo | grep -A 3 "Number of platforms"
Number of platforms                               1
  Platform Name                                   NVIDIA CUDA
  Platform Vendor                                 NVIDIA Corporation
  Platform Version                                OpenCL 1.2 CUDA 10.1.152

hashcat をインストールする

hashcat は各種ハッシュ値を探索するためのツール。

現時点でリリース済みのバージョン (v5.1.0) から HEAD はだいぶ差分があるようなので Git のリポジトリからインストールする。

$ git clone https://github.com/hashcat/hashcat.git
$ cd hashcat
$ make && sudo make install

hashcat -I コマンドで以下のように GPU が認識されていれば上手くいっている。

$ hashcat -I
hashcat (v5.1.0-1138-g581839d4) starting...

CUDA Info:
==========

CUDA.Version.: 10.1

Backend Device ID #1 (Alias: #2)
  Name...........: Tesla V100-SXM2-16GB
  Processor(s)...: 80
  Clock..........: 1530
  Memory.........: 16130 MB

OpenCL Info:
============

OpenCL Platform ID #1
  Vendor..: NVIDIA Corporation
  Name....: NVIDIA CUDA
  Version.: OpenCL 1.2 CUDA 10.1.152

  Backend Device ID #2 (Alias: #1)
    Type...........: GPU
    Vendor.ID......: 32
    Vendor.........: NVIDIA Corporation
    Name...........: Tesla V100-SXM2-16GB
    Version........: OpenCL 1.2 CUDA
    Processor(s)...: 80
    Clock..........: 1530
    Memory.........: 4032/16130 MB allocatable
    OpenCL.Version.: OpenCL C 1.2 
    Driver.Version.: 418.67

JtR (JohnTheRipper) をインストールする

hashcat はハッシュの探索に特化したツールなので、肝心のハッシュ値は別のツールを使って調べる必要がある。 ZIP ファイルに関しては JohnTheRipper というツールを使うのが一般的なようだ。

こちらも Git のリポジトリからインストールしておく。

$ git clone https://github.com/magnumripper/JohnTheRipper.git
$ cd JohnTheRipper/src
$ ./configure
$ make -s clean && make -sj$(grep processor /proc/cpuinfo | wc -l)
$ sudo make install
$ sudo ln -s $(pwd)/../run/zip2john /usr/local/bin/

以下のように zip2john コマンドが使えるようになっていれば良い。

$ zip2john
Usage: zip2john [options] [zip file(s)]
Options for 'old' PKZIP encrypted files only:
 -a <filename>   This is a 'known' ASCII file. This can be faster, IF all
    files are larger, and you KNOW that at least one of them starts out as
    'pure' ASCII data.
 -o <filename>   Only use this file from the .zip file.
 -c This will create a 'checksum only' hash.  If there are many encrypted
    files in the .zip file, then this may be an option, and there will be
    enough data that false possitives will not be seen.  If the .zip is 2
    byte checksums, and there are 3 or more of them, then we have 48 bits
    knowledge, which 'may' be enough to crack the password, without having
    to force the user to have the .zip file present.
 -m Use "file magic" as known-plain if applicable. This can be faster but
    not 100% safe in all situations.
 -2 Force 2 byte checksum computation.

NOTE: By default it is assumed that all files in each archive have the same
password. If that's not the case, the produced hash may be uncrackable.
To avoid this, use -o option to pick a file at a time.

パスワード付き ZIP ファイルを用意する

以下のようにしてパスワードが password の ZIP ファイルを作る。 辞書攻撃であれば一瞬で解ける脆弱なパスワードだけど、今回は総当たりなのでサンプルとしては構わないかな。

$ echo "Hello, World" > greet.txt
$ zip -e --password=password greet.txt.zip greet.txt
  adding: greet.txt (stored 0%)
$ file greet.txt.zip
greet.txt.zip: Zip archive data, at least v1.0 to extract

ハッシュ値を取得する

zip2john コマンドを使って次のようにハッシュ値を記録したファイルを作成する。

$ zip2john greet.txt.zip | cut -d ":" -f 2 > greet.txt.zip.hash
ver 1.0 efh 5455 efh 7875 greet.txt.zip/greet.txt PKZIP Encr: 2b chk, TS_chk, cmplen=25, decmplen=13, crc=40F63A90

取得できたハッシュ値は以下。

$ cat greet.txt.zip.hash
$pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*e6c233aef1ba5a982f025c9bcdcdc86c4fa27c949c7871dc01*$/pkzip2$

hashcat でハッシュ値を探索する

hashcat は計算対象のハッシュ値を自動では認識してくれない。 なので、先ほどのハッシュ値の内容と以下のページの内容を見比べて適切なハッシュ形式を探す。 今回であれば 172xx のいずれかだろう、となる。

hashcat.net

ここまでできたら、あとは hashcat コマンドを使って探索するだけ。 ハッシュ形式は -m オプションで指定する。 -a オプションはアタックモードで、総当たり (Brute-force) なら 3 を指定する。 -w オプションはワークロードプロファイルで、全力で探索するときは 4 を指定する。

$ hashcat -m 17210 -a 3 -w 4 \
    --session helloworld \
    -o result.txt \
    greet.txt.zip.hash

実行すると、次のように探索が始まって状況が表示される。 探索速度は 28043.6 MH/s なので、約 28 GH/s となる。

$ hashcat -m 17210 -a 3 -w 4 \
    --session helloworld \
    -o result.txt \
    greet.txt.zip.hash
...

Session..........: helloworld
Status...........: Running
Hash.Name........: PKZIP (Uncompressed)
Hash.Target......: $pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*e...kzip2$
Time.Started.....: Sat Jun  8 05:24:42 2019 (15 secs)
Time.Estimated...: Sat Jun  8 05:27:58 2019 (3 mins, 1 sec)
Guess.Mask.......: ?1?2?2?2?2?2?2?3 [8]
Guess.Charset....: -1 ?l?d?u, -2 ?l?d, -3 ?l?d*!$@_, -4 Undefined 
Guess.Queue......: 8/15 (53.33%)
Speed.#1.........: 28043.6 MH/s (93.32ms) @ Accel:32 Loops:1024 Thr:1024 Vec:1
Recovered........: 0/1 (0.00%) Digests, 0/1 (0.00%) Salts
Progress.........: 437382021120/5533380698112 (7.90%)
Rejected.........: 0/437382021120 (0.00%)
Restore.Point....: 5242880/68864256 (7.61%)
Restore.Sub.#1...: Salt:0 Amplifier:6144-7168 Iteration:0-1024
Candidates.#1....: 0uc3aen1 -> Ciskoe86
Hardware.Mon.#1..: Temp: 54c Util: 99% Core:1530MHz Mem: 877MHz Bus:16

...

デフォルトではアルファベットと数字だけが探索対象なので、8 桁でも数分もあれば見つかる。 結果は -o オプションでファイルに書き出しているので、表示させてみよう。

$ cat result.txt 
$pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*e6c233aef1ba5a982f025c9bcdcdc86c4fa27c949c7871dc01*$/pkzip2$:password

ちなみに探索を中止してもセッションに名前をつけてあれば、以下のように再開できる。

$ hashcat --session helloworld --restore

セッションの情報は以下のように保存されている。

$ ls ~/.hashcat/sessions/
hashcat.log  hashcat.restore  hellosymbol.log  hellosymbol.restore  helloworld.log  helloworld.restore

探索する文字種別に記号を含めてみる

探索が必要な空間は、文字種別と桁数で指数関数的に増える。 そのため探索する文字種別に記号を含めると、単位時間あたりに探せる桁数はぐっと落ちることになる。 そこで、次はパスワードに記号を含めて試してみよう。

$ zip -e --password='pswd_+' greet.txt.zip greet.txt
updating: greet.txt (stored 0%)

ハッシュを取り直す。

$ zip2john greet.txt.zip | cut -d ":" -f 2 > greet.txt.zip.hash
ver 1.0 efh 5455 efh 7875 greet.txt.zip/greet.txt PKZIP Encr: 2b chk, TS_chk, cmplen=25, decmplen=13, crc=40F63A90
$ cat greet.txt.zip.hash 
$pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*7f04312cbe0aab6e4a19fad645aecda94ea6fee0c2b3710fb0*$/pkzip2$

使用する文字種別は -1 ~ -4 オプションで登録できる。 それをマスク (以下の ?1?1... となっている部分) として利用する。 以下では記号を含む一通りの文字種別で 10 桁までインクリメンタルに探索する設定となる。

$ hashcat -m 17210 -a 3 -w 4 \
    --session hellosymbol \
    -o result.txt \
    -1 ?a \
    --increment \
    greet.txt.zip.hash \
    ?1?1?1?1?1?1?1?1?1?1

ちなみに、今回のケースではわずか 7 桁でも探索に 40 分ほどかかることが分かった。

...

Session..........: hellosymbol                     
Status...........: Exhausted
Hash.Name........: PKZIP (Uncompressed)
Hash.Target......: $pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*7...kzip2$
Time.Started.....: Sat Jun  8 05:41:19 2019 (28 secs)
Time.Estimated...: Sat Jun  8 05:41:47 2019 (0 secs)
Guess.Mask.......: ?1?1?1?1?1?1 [6]
Guess.Charset....: -1 ?a, -2 Undefined, -3 Undefined, -4 Undefined 
Guess.Queue......: 6/10 (60.00%)
Speed.#1.........: 26579.9 MH/s (86.36ms) @ Accel:32 Loops:1024 Thr:1024 Vec:1
Recovered........: 0/1 (0.00%) Digests, 0/1 (0.00%) Salts
Progress.........: 735091890625/735091890625 (100.00%)
Rejected.........: 0/735091890625 (0.00%)
Restore.Point....: 81450625/81450625 (100.00%)
Restore.Sub.#1...: Salt:0 Amplifier:8192-9025 Iteration:0-1024
Candidates.#1....: 3<5x$~ ->  ~ ~}z
Hardware.Mon.#1..: Temp: 53c Util: 99% Core:1530MHz Mem: 877MHz Bus:16

とはいえ、実際に使ったパスワードの長さは 6 桁なので、数十秒あれば見つかる。

$ cat result.txt 
$pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*e6c233aef1ba5a982f025c9bcdcdc86c4fa27c949c7871dc01*$/pkzip2$:password
$pkzip2$1*2*2*0*19*d*40f63a90*0*43*0*19*40f6*2adb*7f04312cbe0aab6e4a19fad645aecda94ea6fee0c2b3710fb0*$/pkzip2$:pswd_+

より強固な暗号化方式を用いる

例えば、もっと強固な方式を用いると探索速度はどのように変化するだろうか? 例として AES256 を使って暗号化してみることにした。

$ 7za a -tzip -ppassword -mem=AES256 greet.txt.zip greet.txt

ハッシュを取り直す。

$ zip2john greet.txt.zip | cut -d ":" -f 2 > greet.txt.zip.hash
$ cat greet.txt.zip.hash 
$zip2$*0*3*0*cf86e2828c42995d3a631f9dfe159ce8*2fa8*d*0a5619e079e4f6389e7d8da029*55fd0328aae560645e58*$/zip2$

実行してみよう。

$ hashcat -m 13600 -a 3 -w 4 \
    --session zipaes \
    -o result.txt \
    -1 ?a \
    --increment \
    greet.txt.zip.hash \
    ?1?1?1?1?1?1?1?1?1?1

すると、このパターンでは探索速度がわずか 2194 kH/s (2 MH/s) しか出ていない。 先ほどと比べると、およそ 10,000 分の 1 となった。

...

Session..........: zipaes
Status...........: Running
Hash.Name........: WinZip
Hash.Target......: $zip2$*0*3*0*cf86e2828c42995d3a631f9dfe159ce8*2fa8*.../zip2$
Time.Started.....: Sat Jun  8 05:45:57 2019 (11 secs)
Time.Estimated...: Sat Jun  8 05:46:34 2019 (26 secs)
Guess.Mask.......: ?1?1?1?1 [4]
Guess.Charset....: -1 ?a, -2 Undefined, -3 Undefined, -4 Undefined 
Guess.Queue......: 4/10 (40.00%)
Speed.#1.........:  2194.1 kH/s (74.89ms) @ Accel:32 Loops:249 Thr:1024 Vec:1
Recovered........: 0/1 (0.00%) Digests, 0/1 (0.00%) Salts
Progress.........: 23149125/81450625 (28.42%)
Rejected.........: 0/23149125 (0.00%)
Restore.Point....: 0/857375 (0.00%)
Restore.Sub.#1...: Salt:0 Amplifier:27-28 Iteration:249-498
Candidates.#1....: oari -> o ~}
Hardware.Mon.#1..: Temp: 50c Util: 99% Core:1530MHz Mem: 877MHz Bus:16

まとめ

  • パスワード付きの ZIP ファイルを hashcat + JtR + GPU で総当たりしてみた
    • さほど恐ろしさを覚える探索速度は出なかった
    • 仮に並列度を上げても総当たりなら定数倍の改善にとどまるはず
    • また、より強固な暗号化方式を使うとさらに総当たりが難しくなる
    • ただしパスワードは十分に長く記号を含んだものを使うことが前提となる
  • 今回使った環境での長さに関する相場感
    • 数字のみ: 12 桁の探索に 1 分
    • 数字 + アルファベット: 8 桁の探索に 3 分
    • 数字 + アルファベット + 記号: 7 桁の探索に 40 分

いじょう。

Python: LightGBM を Git のソースコードからインストールする

今回は LightGBM の Python パッケージを Git のソースコードからインストールする方法について。 まだリリースされていない最新の機能を使いたい、あるいは自分で改造したパッケージを使いたい、といった場合に。

なお、インストール方法は以下に記載されている。

github.com

使った環境は次の通り。

$ sw_vers                
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V
Python 3.7.3

下準備

まずはビルドに必要なパッケージをインストールしておく。

$ brew install cmake libomp

LightGBM のリポジトリをチェックアウトして python-package ディレクトリに移動しておく。

$ git clone https://github.com/microsoft/LightGBM.git
$ cd LightGBM/python-package 

インストール

公式のマニュアルを見ると、次のようにインストールすると書いてある。 ただ、これだと依存パッケージが一緒に入らない。

$ python setup.py install

なので、まずはソースコード配布物などのパッケージをまずはビルドした上で、それを使ってインストールするのが楽だと思う。

$ python setup.py sdist

これなら依存パッケージが同時に入る。

$ pip install dist/lightgbm-2.2.4.tar.gz

現時点 (2019-06-07) で未リリースのバージョンがインストールされた。 もちろん、これは Git の HEAD を使った開発版なので正式なバージョンがついているわけではない。

$ pip list | grep -i lightgbm                     
lightgbm     2.2.4  

作業ディレクトリを移動して lightgbm パッケージがインポートできることを確認する。

$ pushd /tmp && python -c "import lightgbm as lgb"

いじょう。

統計的学習の基礎 ―データマイニング・推論・予測―

統計的学習の基礎 ―データマイニング・推論・予測―

  • 作者: Trevor Hastie,Robert Tibshirani,Jerome Friedman,杉山将,井手剛,神嶌敏弘,栗田多喜夫,前田英作,井尻善久,岩田具治,金森敬文,兼村厚範,烏山昌幸,河原吉伸,木村昭悟,小西嘉典,酒井智弥,鈴木大慈,竹内一郎,玉木徹,出口大輔,冨岡亮太,波部斉,前田新一,持橋大地,山田誠
  • 出版社/メーカー: 共立出版
  • 発売日: 2014/06/25
  • メディア: 単行本
  • この商品を含むブログ (6件) を見る

Python: LightGBM の学習曲線をコールバックで動的にプロットする

LightGBM の学習が進む様子は、学習させるときにオプションとして verbose_eval などを指定することでコンソールから確認できる。 ただ、もっと視覚的にリアルタイムで確認したいなーと思ったので、今回はコールバックと Matplotlib を使って学習曲線を動的にグラフとしてプロットしてみることにした。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V            
Python 3.7.3

下準備

下準備として LightGBM と Matplotlib をインストールしておく。 Seaborn は本来は必要ないんだけどデータセットの読み込みにだけ使っている。

$ pip install lightgbm matplotlib seaborn

学習曲線を動的にプロットする

今回書いてみたサンプルコードは次の通り。 Seaborn から Titanic データセットを読み込んで LightGBM のモデルが学習していく過程を可視化している。 グラフのプロットは LearningVisualizationCallback というコールバックを実装することで実現している。 そのままだとグラフが寂しいので、カスタムメトリックとして Accuracy も追加してみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import defaultdict

import numpy as np
import lightgbm as lgb
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from matplotlib import pyplot as plt


class LearningVisualizationCallback(object):
    """学習の過程を動的にプロットするコールバック"""

    def __init__(self, fig=None, ax=None):
        self._metrics = defaultdict(list)
        self._lines = {}

        # 初期化する
        self._fig = fig
        self._ax = ax
        if self._fig is None and self._ax is None:
            self._fig, self._ax = plt.subplots()
        self._fig.canvas.draw()
        self._fig.show()

    def __call__(self, env):
        # メトリックを保存する
        evals = env.evaluation_result_list
        for _, name, mean, _, _ in evals:
            self._metrics[name].append(mean)

        # 可視化する
        for name, values in self._metrics.items():

            # 初回だけ描画用オブジェクトを取得して保存しておく
            if name not in self._lines:
                line, = self._ax.plot(np.arange(len(values)),
                                      values)
                self._lines[name] = line
                line.set_label(name)

            # グラフデータを更新する
            line = self._lines[name]
            line.set_data(np.arange(len(values)), values)

        # グラフの見栄えを調整する
        self._ax.grid()
        self._ax.legend()
        self._ax.relim()
        self._ax.autoscale_view()

        # 再描画する
        self._fig.canvas.draw()
        self._fig.canvas.flush_events()

    def show_until_close(self):
        """ウィンドウを閉じるまで表示し続ける"""
        plt.show()


def accuracy(preds, data):
    """精度 (Accuracy) を計算する関数
    NOTE: 表示が eval set の LogLoss だけだと寂しいので"""
    y_true = data.get_label()
    y_pred = np.where(preds > 0.5, 1, 0)
    acc = np.mean(y_true == y_pred)
    return 'accuracy', acc, True


def main():
    # Titanic データセットを読み込む
    dataset = sns.load_dataset('titanic')

    # 重複など不要な特徴量は落とす
    X = dataset.drop(['survived',
                      'class',
                      'who',
                      'embark_town',
                      'alive'], axis=1)
    y = dataset.survived

    # カテゴリカル変数を指定する
    categorical_columns = ['pclass',
                           'sex',
                           'embarked',
                           'adult_male',
                           'deck',
                           'alone']
    X = X.astype({c: 'category'
                  for c in categorical_columns})

    # LightGBM のデータセット表現に直す
    lgb_train = lgb.Dataset(X, y)

    # 学習の過程を可視化するコールバックを用意する
    visualize_cb = LearningVisualizationCallback()
    callbacks = [
        visualize_cb,
    ]

    # 二値分類を LogLoss で評価する
    lgb_params = {
        'objective': 'binary',
        'metrics': 'binary_logloss',
    }
    # 5-Fold CV
    skf = StratifiedKFold(n_splits=5,
                          shuffle=True,
                          random_state=42)
    lgb.cv(lgb_params, lgb_train,
           num_boost_round=1000,
           early_stopping_rounds=100,
           verbose_eval=10,
           folds=skf, seed=42,
           # Accuracy も確認する
           feval=accuracy,
           # コールバックを登録する
           callbacks=callbacks)

    # ウィンドウを閉じるまで表示し続ける
    visualize_cb.show_until_close()


if __name__ == '__main__':
    main()

上記に適当な名前をつけて実行してみよう。

$ python lgblearnviz.py

すると、モデルの学習に伴って次のようなアニメーションが表示される。

f:id:momijiame:20190606223725g:plain

いいかんじ。

なお、表示されているのは Validation Set に対するメトリックとなる。 Training Set も確認したかったんだけど、どうやら次のリリース (2.2.4?) でオプションに eval_train_metric が入るのを待つ必要がありそう。

あと、Jupyter Notebook で使うときは %matplotlib notebook マジックコマンドを使うと良い。