CUBE SUGAR CONTAINER

技術系のこと書きます。

リモートサーバの Jupyter Notebook を SSH Port Forwarding 経由で使う

一般的に Jupyter Notebook はローカルの環境にインストールして使うことが多い。 ただ、ローカルの環境は計算資源が乏しい場合もある。 そんなときは IaaS などリモートにあるサーバで Jupyter Notebook を使いたい場面が存在する。 ただ、セキュリティのことを考えると Jupyter Notebook の Web UI をインターネットに晒したくはない。

そこで、今回は SSH Port Forwarding を使って Web UI をインターネットに晒すことなく使う方法について書く。 このやり方ならリモートサーバに SSH でログインしたユーザだけが Jupyter Notebook を使えるようになる。 また、Web UI との通信も SSH 経由になるので HTTP over SSL/TLS (HTTPS) を使わなくても盗聴のリスクを下げられる。

リモートサーバを想定した環境は次の通り。 話を単純にするために環境は Vagrant で作ってある。

vagrant $ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.1 LTS"
vagrant $ uname -r
4.15.0-29-generic

そこに接続するクライアントの環境は次の通り。

client $ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.13.6
BuildVersion:   17G65
client $ ssh -V
OpenSSH_7.6p1, LibreSSL 2.6.2

必要なパッケージをインストールする

ここからは、すでにリモートの Ubuntu マシンに SSH でログインしている前提で話を進める。

まずは必要なパッケージをインストールする。 ログインするたびに Jupyter Notebook を起動するコマンドを入力するのも面倒なので、最終的に Supervisord でデーモン化することにした。

vagrant $ sudo apt-get update
vagrant $ sudo apt-get -y install jupyter-notebook supervisor

今回は OS のパッケージ管理システム経由でインストールしてるけど pip を使うとかはお好みで。

まずはサクッと試す

ひとまず手っ取り早く今回やることの本質を示す。

最初にリモートサーバ上で Jupyter Notebook を起動する。 これで TCP/8888 で Jupyter Notebook の Web UI が動く。

vagrant $ jupyter notebook

ターミナルに Web UI のアクセストークンが表示されるのでメモしておこう。

続いて、クライアントの別のターミナルを開いて、改めてリモートサーバに SSH でログインする。 このとき SSH Port Forwarding を使って、リモートサーバの TCP/8888 をローカルホストのポートにマッピングする。

client $ ssh -L 8888:localhost:8888 <username>@<remotehost>

今回は Vagrant の環境を使っているのでこんな感じ。 恒久的に設定を入れたいなら Vagrantfile を編集する。

client $ vagrant ssh-config > ssh.config
client $ ssh -L 8888:localhost:8888 -F ssh.config default

あとは、クライアントのブラウザでローカルホストにマッピングしたポート番号を開くだけ。

client $ open http://localhost:8888

すると、Jupyter Notebook の Web UI でアクセストークンを入力する画面が表示される。 先ほど Jupyter Notebook を起動するときにターミナルに表示されたトークンを入力しよう。

f:id:momijiame:20181015080902p:plain

これで、いつもの見慣れた Web UI が表示されるはず。 あとは使うだけ。

f:id:momijiame:20181014022447p:plain

以上で、今回やることの本質は示せた。

ただ、上記の操作は毎回やるには結構めんどくさいしセキュリティをあまり考慮していない。 そこで、ここからは運用をできるだけ楽に、そしてセキュアな環境を手に入れるべく手順を記載していく。

以降の手順を試すときは、一旦先ほど起動した Jupyter Notebook は停止しておこう。

アクセス制御をかける

リモートサーバを想定しているので、念のため必要なポート以外はファイアウォールを使って閉じておく。

SSH に使うポートだけを残して、それ以外は全て閉じる。 SSH に使うポート番号を 22 以外にしているときは、適宜読み替える感じで。

vagrant $ sudo ufw allow 22
vagrant $ sudo ufw default DENY
vagrant $ yes | sudo ufw enable
vagrant $ sudo ufw status
Status: active

To                         Action      From
--                         ------      ----
22                         ALLOW       Anywhere                  
22 (v6)                    ALLOW       Anywhere (v6)             

ファイアウォールの設定を変更するときはリモートサーバから追い出されないように注意しよう。

Jupyter Notebook を起動するユーザを追加する

若干好みの問題にも近いけど、念のため Jupyter Notebook を起動する専用のユーザを追加しておく。

vagrant $ sudo useradd -m -s $SHELL jupyter

Jupyter Notebook を設定する

ここからは Jupyter Notebook を設定していく。

まずは先ほど作ったユーザにログインする。

vagrant $ sudo su - jupyter

続いて、設定ファイルを生成する。

jupyter $ jupyter notebook --generate-config
Writing default config to: /home/jupyter/.jupyter/jupyter_notebook_config.py

Jupyter Notebook の作業ディレクトリを用意する。

jupyter $ mkdir -p /home/$(whoami)/jupyter-working

設定ファイルを編集する。

jupyter $ sed -i.back \
  -e "s:^#c.NotebookApp.token = .*$:c.NotebookApp.token = u'':" \
  -e "s:^#c.NotebookApp.ip = .*$:c.NotebookApp.ip = 'localhost':" \
  -e "s:^#c.NotebookApp.open_browser = .*$:c.NotebookApp.open_browser = False:" \
  -e "s:^#c.NotebookApp.notebook_dir = .*$:c.NotebookApp.notebook_dir = '/home/$(whoami)/jupyter-working':" \
  /home/$(whoami)/.jupyter/jupyter_notebook_config.py
jupyter $ cat ~/.jupyter/jupyter_notebook_config.py | sed -e "/^#/d" -e "/^$/d"
c.NotebookApp.ip = 'localhost'
c.NotebookApp.notebook_dir = '/home/jupyter/jupyter-working'
c.NotebookApp.open_browser = False
c.NotebookApp.token = u''

それぞれの設定の内容や意図としては以下のような感じ。

  • c.NotebookApp.ip = 'localhost'
    • Jupyter Notebook が Listen するアドレスをループバックアドレスにする
    • もしファイアウォールがなくてもインターネットからは Jupyter Notebook の WebUI に疎通がなくなる
  • c.NotebookApp.notebook_dir = '/home/jupyter/jupyter-working'
    • Jupyter Notebook の作業ディレクトリを専用ユーザのディレクトリにする
    • 仮に Web UI が不正アクセスを受けたときにも影響範囲を小さくとどめる (気休め程度)
  • c.NotebookApp.open_browser = False
    • 起動時にブラウザを開く動作を抑制する
    • ローカル環境ではないので起動するときにブラウザを起動する必要はない
  • c.NotebookApp.token = u''
    • Jupyter Notebook の Web UI にビルトインで備わっている認証を使わない
    • 認証は SSH によるログインで担保する場合の設定 (心配なときは後述する共通パスワードなどを設定する)

(オプション) Jupyter Notebook の Web UI に共通パスワードをかける

SSH のログイン以外にも認証をかけたいときは、例えばシンプルなものだと共通パスワードが設定できる。

Jupyter Notebook の Web UI に共通パスワードをかけるには jupyter notebook password コマンドを実行する。

jupyter $ jupyter notebook password
Enter password: 
Verify password: 
[NotebookPasswordApp] Wrote hashed password to /home/jupyter/.jupyter/jupyter_notebook_config.json

すると、ソルト付きの暗号化されたパスワードが設定ファイルとしてできる。

jupyter $ cat ~/.jupyter/jupyter_notebook_config.json 
{
  "NotebookApp": {
    "password": "sha1:217911554b0b:f2fa9cd9f336951c335bdaa06a6c16eb6286c192"
  }
}

上記のやり方だとハッシュのアルゴリズムが SHA1 固定っぽい。 もし、より頑丈なものが使いたいときは次のように Python のインタプリタ経由で生成する。

jupyter $ python3
Python 3.6.6 (default, Sep 12 2018, 18:26:19) 
[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from notebook.auth import passwd
>>> passwd('jupyter-server-password', algorithm='sha512')
'sha512:d197670d2987:19bb2eedfc6fde56f1a9fc04d403999c3f03a99af368e528f45ee9a68f01a7c5f07e375bd34ec176d1c66a0f2e8ef7615ebcf9e524a23ace5ab6dd5a930398d4'

生成した暗号化済みパスワードを、次のような形で Jupyter Notebook の設定ファイルに入力すれば良い。

c.NotebookApp.password_required = True
c.NotebookApp.password = u'sha512:d197670d2987:19bb2eedfc6fde56f1a9fc04d403999c3f03a99af368e528f45ee9a68f01a7c5f07e375bd34ec176d1c66a0f2e8ef7615ebcf9e524a23ace5ab6dd5a930398d4'

上記の共通パスワード方式を含む Jupyter Notebook の認証周りについては以下の公式ドキュメントを参照のこと。

Running a notebook server — Jupyter Notebook 5.7.0 documentation

Jupyter Notebook を Supervisord 経由で起動する

続いては Jupyter Notebook をデーモン化する設定に入る。

一旦、元の管理者権限をもったユーザに戻る。

jupyter $ exit
logout

Supervisord の設定ファイルを用意する。

vagrant $ cat << 'EOF' | sudo tee /etc/supervisor/conf.d/jupyter.conf > /dev/null
[program:jupyter]
command=jupyter notebook
user=jupyter
stdout_logfile=/var/log/supervisor/jupyter.log
redirect_stderr=true
autostart=true
autorestart=true
EOF

Supervisord を起動する。

vagrant $ sudo systemctl enable supervisor
vagrant $ sudo systemctl reload supervisor

ちゃんと Jupyter Notebook が起動しているかを確認する。

vagrant $ ps auxww | grep [j]upyter
jupyter   4689 27.0  5.4 183560 55088 ?        S    16:31   0:01 /usr/bin/python3 /usr/bin/jupyter-notebook
vagrant $ ss -tlnp | grep :8888
LISTEN   0         128               127.0.0.1:8888             0.0.0.0:*       
LISTEN   0         128                   [::1]:8888                [::]:*       

もし、上手く立ち上がっていないときはログから原因を調べよう。

vagrant $ sudo tail /var/log/supervisor/supervisord.log 
vagrant $ sudo tail /var/log/supervisor/jupyter.log 

(オプション) ログインシェルを無効化する

もし Jupyter Notebook 専用に作ったユーザをシェル経由で操作するつもりがなければ、ログインシェルを無効化しておく。

vagrant $ sudo usermod -s /usr/sbin/nologin jupyter

こうするとシェル経由でユーザにログインできなくなる。

vagrant $ grep jupyter /etc/passwd
jupyter:x:1001:1001::/home/jupyter:/usr/sbin/nologin
vagrant $ sudo su - jupyter
This account is currently not available.

デーモンプログラムを起動するユーザは、不正アクセスを受けた場合の影響を小さくする意図でこうすることが多い。

SSH Port Forwarding 経由で Jupyter Notebook の Web UI にアクセスする

ここまでで、リモートサーバ上の Jupyter Notebook の設定は終わった。

一旦リモートサーバから SSH でログアウトする。

vagrant $ exit

改めて SSH Port Forwarding を有効にしてリモートサーバにログインする。 このときリモートサーバの TCP/8888 ポートを、ローカルホストのポートにマッピングする。 ユーザ名やホスト名は適宜読み替える。

client $ ssh -L 8888:localhost:8888 <username>@<remotehost>

今回は Vagrant の環境を使っているので、こんな感じで。

client $ vagrant ssh-config > ssh.config
client $ ssh -L 8888:localhost:8888 -F ssh.config default

あとは、クライアントのブラウザでローカルホストにマッピングしたポート番号を開く。

client $ open http://localhost:8888

すると、見覚えのある Web UI が表示される。 オプションの共通パスワード認証を使っていないのであれば、いきなりいつもの画面になるはず。

f:id:momijiame:20181014022447p:plain

あとは、もしポータビリティとかを考えるのであればお好みで Docker イメージとかにする感じで。

めでたしめでたし。

Python: デコレータについて

Python の特徴的な構文の一つにデコレータがある。 便利な機能なんだけど、最初はとっつきにくいかもしれない。 そこで、今回はデコレータについて一通り色々と書いてみる。 先に断っておくと、とても長い。

これを読むと、以下が分かる。

  • デコレータの本質
    • デコレータはシンタックスシュガー (糖衣構文) に過ぎない
  • デコレータの作り方
    • 引数を取るデコレータと取らないデコレータ
  • デコレータの用途
    • 用途はラッピングとマーキングの二つに大別できる
  • デコレータの種類
    • デコレータは関数、メソッド、インスタンスで作れる
  • デコレータの対象
    • デコレートできるのは関数、メソッド以外にクラスもある

今回使った環境は次の通り。 尚、紹介するコードの中には、一部に Python 3 以降でないと動作しないものが含まれている。

$ python -V
Python 3.6.6

デコレータについて

まずはデコレータのおさらかいから。 デコレータは、その名の通りオブジェクトをデコレーション (装飾) するための機能。 構文としては、デコレートしたいオブジェクトの前で @ を先頭につけて使う。 デコレートできるオブジェクトの種類は、関数、メソッド、クラスがサポートされている。

標準モジュールにも、組み込みでいくつかのデコレータがある。 その中の一つを見てみよう。 以下のサンプルコードでは functools モジュールの lru_cache というデコレータを使っている。 このデコレータを使うと、デコレートした関数を簡単にメモ化できる。 メモ化というのは、ようするに関数の戻り値をキャッシュすること。 サンプルコードでは足し算をする add() という関数をメモ化している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import lru_cache


# 関数をメモ化するデコレータ
@lru_cache()
def add(a, b):
    # 実際に関数が処理された場合を区別するための出力
    print('calculate')
    return a + b


def main():
    # 同じ引数で 2 回呼び出す
    print(add(1, 2))
    print(add(1, 2))


if __name__ == '__main__':
    main()

ポイントは add() 関数の中で calculate という文字列を出力しているところ。 これで、実際に関数が呼び出されたのか、それともキャッシュされた値が返ったのか区別できる。

それでは、上記を保存して実行してみよう。 サンプルコードでは同じ引数 (1, 2) を使って add() 関数を 2 回呼び出している。

$ python cache.py
calculate
3
3

実行しても calculate が 1 回しか出力されない。 つまり 2 回目の呼び出しではキャッシュされた値が返っていることが分かる。 見事に @lru_cache デコレータが機能しているようだ。

デコレータの本質

おさらいが終わったところで、早速本題に入る。 デコレータという機能は、実はシンタックスシュガー (糖衣構文) に過ぎない。 シンタックスシュガーというのは、プログラミング言語において、ある書き方に対して別の書き方ができるようにしたもの。 デコレータがシンタックスシュガーということは、つまり同じ内容はデコレータを使わなくても書けるということ。

先ほどのサンプルコードを、デコレータを使わない形に直してみよう。 つまり、足し算をする add() 関数をデコレータを使わずに functools.lru_cache でメモ化している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import lru_cache


# デコレータ構文は使っていない
def add(a, b):
    print('calculate')
    return a + b


# デコレータの代わりになる書き方
# デコレータ構文は、以下を書きやすくしたシンタックスシュガーにすぎない
add = lru_cache()(add)


def main():
    # 同じ引数で 2 回呼び出す
    print(add(1, 2))
    print(add(1, 2))


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。

$ python cache.py
calculate
3
3

ちゃんとメモ化が動作していることが分かる。

先ほどのサンプルコードでは functools.lru_cache をデコレータとして使っていない。 代わりに、次のようなコードが登場している。 これは lru_cache() を通して add() 関数を代入し直している。 ようするに add() 関数の内容が lru_cache() の返り値で上書きされることになる。

add = lru_cache()(add)

つまり、最初のコードで登場した以下と上記は本質的に等価ということ。

@lru_cache()
def add(a, b):

これは理解する上で重要なポイントで、デコレータを使って書かれたコードは、必ず使わずに書くこともできる。

デコレータの作り方

続いてはデコレータの作り方を見ていく。 前述したように、デコレータは単なるシンタックスシュガーで、やっていることは単なる返り値を使った上書きだった。 それさえ分かっていればデコレータの作り方は理解しやすい。

例えば、関数をデコレートするデコレータについて考えてみよう。 これまで理解した内容から考えれば「関数を受け取って、代わりとなる関数を返す」ものを作れば良い。

以下のサンプルコードでは deco という名前でデコレータを作っている。 見て分かる通り、普通の関数と見た目は何ら変わらない。 つまり deco はデコレータとして動作する関数、ということになる。 デコレータとして動作するために、引数 func という名前で関数を受け取って、代わりとなる wrapper() という関数の参照を返している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def deco(func):
    """デコレートした関数の前後に処理を挟み込む自作デコレータ"""

    def wrapper(*args, **kwargs):
        """本来の関数の代わりに返される関数"""
        print('before')  # 本来の関数が呼び出される前に実行される処理
        result = func(*args, **kwargs)  # 本来の関数の呼び出し
        print('after')  # 本来の関数が呼び出された後に実行される処理
        return result  # 本来の関数の返り値を返す

    # 引数で関数を受け取って、代わりに別の関数を返す
    return wrapper


# @deco デコレータで greet() 関数をデコレートしている
@deco
def greet():
    """文字列を書き出すだけの関数"""
    print('Hello, World!')


def main():
    # デコレータでデコレートされた関数を呼び出す
    greet()


if __name__ == '__main__':
    main()

このデコレータは、本来の関数の呼び出しの前後に文字列の出力を挟み込むものになっている。 @deco を使ってデコレートする対象は greet() という関数で、内容は文字列を出力するだけ。

上記を保存して実行してみよう。 greet() 関数が出力する文字列の前後に @deco で追加した処理が挟み込まれていることが分かる。

$ python deco.py
before
Hello, World!
after

念のため、デコレータを使わないパターンも見ておこう。 繰り返しになるけど、デコレータはただのシンタックスシュガーなので、必ず使わない形にも直せる。 デコレータを使わない形にすれば、やっていることがよく分かる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def deco(func):
    """デコレートした関数の前後に処理を挟み込む自作デコレータ"""

    def wrapper(*args, **kwargs):
        """本来の関数の代わりに呼び出される関数"""
        print('before')  # 本来の関数が呼び出される前に実行される処理
        result = func(*args, **kwargs)  # 本来の関数の呼び出し
        print('after')  # 本来の関数が呼び出された後に実行される処理
        return result  # 本来の関すが返した結果を返す

    # 引数で関数を受け取って、別の関数を返している
    return wrapper


def greet():
    """文字列を書き出すだけの関数"""
    print('Hello, World!')


# デコレータは単なるシンタックスシュガーに過ぎないため、
# 必ず以下のような代入文に置き換えることができる
greet = deco(greet)


def main():
    # デコレータでデコレートされた関数を呼び出す
    greet()


if __name__ == '__main__':
    main()

ようするに deco() 関数が greet() 関数の参照を受け取って、代わりに wrapper() 関数の参照を返しているだけ。

上記を保存して実行してみよう。

$ python deco.py  
before
Hello, World!
after

ちゃんと動作している。

引数を受け取るデコレータ

先ほどのサンプルコードで登場した deco デコレータは lru_cache デコレータと違うところが一つあった。 それは、デコレータとして使うとき後ろにカッコがあるかないか。

lru_cache の例を思い出すと、後ろにカッコがついていた。

@lru_cache()
def add(a, b):

それに対して deco の例では、後ろにカッコがない。

@deco
def greet():

上記の違いは、デコレータが引数を受け取るか受け取らないか。 例えば lru_cache であれば、キャッシュする数の上限を設定するために maxsize というオプションがあったりするため。 つまり、こんな感じで書ける。

@lru_cache(maxsize=32)
def add(a, b):

先ほどの deco を引数を受け取れるように書き換えてみよう。 次のサンプルコードでは deco デコレータが本来の処理の前後に挿入するメッセージを引数で指定できるようにしている。 コード上の変化としては、先ほどよりも deco のネストが増していることが分かる。 引数の受け取らないパターンで deco という名前だった関数が今度は wrapper という名前になって、新しい deco がそれを返している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def deco(before_msg='before', after_msg='after'):
    """引数を受け取るデコレータ (ネストが一段増える)"""

    def wrapper(func):
        def _wrapper(*args, **kwargs):
            print(before_msg)
            result = func(*args, **kwargs)
            print(after_msg)
            return result
        return _wrapper

    return wrapper


# デコレータにカッコがあって引数を受け取っている
@deco('mae', 'ato')
def greet():
    """文字列を書き出すだけの関数"""
    print('Hello, World!')


def main():
    # デコレータでデコレートされた関数を呼び出す
    greet()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 今度はデコレータを使うときに指定した引数にもとづいて前後の出力が変化している。

$ python decoargs.py 
mae
Hello, World!
ato

引数を取るパターンでは、取らないパターンよりも何をやっているのかが分かりにくいかもしれない。 これも、デコレータを使わない形に書き直すと理解しやすくなる。 以下のサンプルコードは、同じ内容をデコレータを使わない形に直してある。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def deco(before_msg='before', after_msg='after'):
    """引数を受け取るデコレータ (ネストが一段増える)"""

    def wrapper(func):
        def _wrapper(*args, **kwargs):
            print(before_msg)
            result = func(*args, **kwargs)
            print(after_msg)
            return result
        return _wrapper

    return wrapper


def greet():
    """文字列を書き出すだけの関数"""
    print('Hello, World!')


# デコレータ構文を使わずに書いたパターン
greet = deco('mae', 'ato')(greet)
# より冗長に、分かりやすく書くと以下のようになる
# wrap_func = deco('mae', 'ato')
# greet = wrap_func(greet)


def main():
    greet()


if __name__ == '__main__':
    main()

上記を見ると、関数を上書きする工程が二段階に分かれていることが見て取れる。 冗長に分かりやすく書いたパターンでは、まず deco() が関数を上書きするのに使う関数 (変数 wrap_func) を返している。 そして、その関数を使って対象の関数 greet() を上書きしている。 これが引数を受け取るデコレータの動作原理ということ。

以降は、デコレータを使わずに書いたパターンを示すことは基本的には省略する。 しかし、デコレータが単なるシンタックスシュガーで、使わないパターンに必ず書き直せるという点は意識しながら読むと理解が深まると思う。

デコレータの用途

デコレータの基本が分かったところで、次は用途について考えてみる。 デコレータの用途は、大きく分けて「ラッピング」と「マーキング」の二つがある。 これまで紹介してきた内容は、用途が全て前者の「ラッピング」だった。

ラッピング

それでは、まずラッピングの用途から見ていこう。 これは、これまでにも紹介してきた通り元の関数などをデコレータを通して上書きするというもの。 以下のサンプルコードでは関数の返り値に 2 倍をかけて返すデコレータ double を定義している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def double(func):
    """デコレートした関数の返り値を 2 倍にするデコレータ"""
    def wrapper(*args, **kwargs):
        # 本来の関数の返り値に 2 をかけて返す
        return func(*args, **kwargs) * 2
    return wrapper


# 返り値を倍にするデコレータをつける
@double
def add(a, b):
    """足し算をする関数"""
    return a + b


def main():
    # 1 + 2 を計算すると...?
    print('1 + 2 =', add(1, 2))  # 1 + 2 = 3 ... 6!!


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 @double デコレータによってデコレートされた add() 関数は、計算結果を倍にして返すように上書きされる。

$ python wrapping.py 
1 + 2 = 6

ラッピング用途での注意点

ちなみに、ラッピング用途でデコレータを使うときは一つ注意点がある。 それは、ラッピング用途のデコレータが、デコレートしたオブジェクトを代わりの何かで上書きするという性質に由来している。

以下のサンプルコードを見てほしい。 このコードでは、デコレートされた関数 add() の名前を __name__ プロパティから取得して出力している。 もちろん、本来の意図としては add という文字列が出力されてほしいはず。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def double(func):
    """デコレートした関数の返り値を 2 倍にするデコレータ"""
    def wrapper(*args, **kwargs):
        # 本来の関数の返り値に 2 をかけて返す
        return func(*args, **kwargs) * 2
    return wrapper


# 返り値を倍にするデコレータをつける
@double
def add(a, b):
    """足し算をする関数"""
    return a + b


def main():
    # add() 関数の名前は?
    print('add()\'s name:', add.__name__)


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。

$ python name.py 
add()'s name: wrapper

なんと、残念ながら wrapper という出力になってしまった。

ここまで読んできていれば、理由は何となく想像がつくと思う。 ようするにデコレータを通して add() 関数は wrapper() 関数に置き換えられてしまっている。 そのため add() 関数のつもりで扱うと、実際には置き換えられた関数だった、ということが起こる。

この問題は functools.wraps を使うと解決できる。 以下のサンプルコードでは、デコレータが返す代わりの関数を functools.wraps でデコレートしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import wraps


def double(func):
    """デコレートした関数の返り値を 2 倍にするデコレータ"""
    # デコレータが返す関数を functools.wraps でデコレートする
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 本来の関数の返り値に 2 をかけて返す
        return func(*args, **kwargs) * 2
    return wrapper


# 返り値を倍にするデコレータをつける
@double
def add(a, b):
    """足し算をする関数"""
    return a + b


def main():
    # add() 関数の名前は?
    print('add()\'s name:', add.__name__)


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 今度はちゃんと add という名前が出力された。

$ python name.py 
add()'s name: add

このように functools.wraps を使うと、置き換える関数が元の関数の性質を引き継げる。

マーキング

もう一つの用途としてマーキングを見てみよう。 この用途では、デコレータは受け取ったオブジェクトをそのまま返す。 ただし、受け取ったオブジェクトを何処かに記録しておいて、それを後から利用することになる。 以下のサンプルコードでは @register デコレータでデコレートした関数は _MARKED_FUNCTIONS というリストに保存される。 そして、保存されたリストから関数を呼び出している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


# デコレータでデコレートした関数を入れるリスト
_MARKED_FUNCTIONS = []


def register(func):
    """関数を登録するデコレータ"""

    # デコレートした関数をリストに追加する
    _MARKED_FUNCTIONS.append(func)

    # 受け取った関数をそのまま返す
    return func


# デコレータを使って、それぞれの関数をマーキングしていく
@register
def greet_morning():
    print('Good morning!')


@register
def greet_afternoon():
    print('Good afternoon!')


@register
def greet_evening():
    print('Good evening!')


def main():
    # リストに追加された関数を確認する
    print(_MARKED_FUNCTIONS)
    # 先頭の一つを呼び出してみる
    _MARKED_FUNCTIONS[0]()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 デコレートした関数がリストに保存されて、それを後から呼び出すことができている。

$ python marking.py 
[<function greet_morning at 0x10b4d1598>, <function greet_afternoon at 0x10b59b378>, <function greet_evening at 0x10b59b400>]
Good morning!

マーキング用途のデコレータは、典型的にはイベントハンドラで用いられる。 例えば Web アプリケーションフレームワークの Flask は、マーキングした関数がクライアントからのアクセスを捌くハンドラになる。

もちろん、上記のコードもデコレータを使わない形に直せる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


# デコレータでデコレートした関数を入れるリスト
_MARKED_FUNCTIONS = []


def register(func):
    """関数を登録するデコレータ"""

    # デコレートした関数をリストに追加する
    _MARKED_FUNCTIONS.append(func)

    # 受け取った関数をそのまま返す
    return func


# デコレータを使って、それぞれの関数をマーキングしていく
def greet_morning():
    print('Good morning!')


def greet_afternoon():
    print('Good afternoon!')


def greet_evening():
    print('Good evening!')


# デコレータを使わずに書き換えたパターン
greet_morning = register(greet_morning)
greet_afternoon = register(greet_afternoon)
greet_evening = register(greet_evening)


def main():
    # リストに追加された関数を確認する
    print(_MARKED_FUNCTIONS)
    # 先頭の一つを呼び出してみる
    _MARKED_FUNCTIONS[0]()


if __name__ == '__main__':
    main()

上記を見て分かる通り、デコレータはモジュールが読み込まれるタイミングで解釈される。 そのため、あらかじめデコレートされた関数の情報を収集するようなこともできるというわけ。

関数以外で作るデコレータ

ここまで紹介してきたデコレータは、全て関数を使って実装されていた。 しかし、デコレータはそれ以外を使った作り方もある。

メソッドで作るデコレータ

例えば、以下のサンプルコードを見てほしい。 ここでは Decorator クラスの deco() というインスタンスメソッドでデコレータを実装している。 内容は最初に自作した処理の前後に出力を挿入するものだ。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Decorator(object):

    def deco(self, func):
        """デコレータとして機能するメソッド"""

        def wrapper(*args, **kwargs):
            print('before')
            result = func(*args, **kwargs)
            print('after')
            return result

        return wrapper


# クラスをインスタンス化する
instance = Decorator()


# インスタンスメソッドで作ったデコレータ
@instance.deco
def greet():
    print('Hello, World!')


def main():
    greet()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 ちゃんとデコレータとして機能していることが分かる。

$ python instance.py 
before
Hello, World!
after

インスタンスメソッドとしてデコレータを実装すると嬉しいのは、インスタンスごとにコンテキストを持たせられるところ。 以下のサンプルコードにおいて japaneseenglish という二つのインスタンスは、それぞれ異なる引数で初期化されている。 そして、それぞれが別の関数をデコレートしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Decorator(object):

    def __init__(self, before_msg='before', after_msg='after'):
        # 前後に挿入するメッセージ
        self.before_msg = before_msg
        self.after_msg = after_msg

    def deco(self, func):
        """デコレータとして機能するメソッド"""

        def wrapper(*args, **kwargs):
            print(self.before_msg)
            result = func(*args, **kwargs)
            print(self.after_msg)
            return result

        return wrapper


# インスタンスごとにコンテキストが持てるのがポイント
japanese = Decorator('mae', 'ato')
english = Decorator('before', 'after')


@japanese.deco
def greet_morning():
    print('Good morning')


@english.deco
def greet_afternoon():
    print('Good afternoon')


def main():
    greet_morning()
    greet_afternoon()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 すると、初期化したときの引数によってデコレートされた結果が異なっていることが分かる。

$ python context.py 
mae
Good morning
ato
before
Good afternoon
after

このようにインスタンスメソッドでデコレータを作ると、インスタンスにコンテキストをもたせられるというメリットがある。

呼び出し可能オブジェクトで作るデコレータ

メソッドで作るデコレータの変わり種として、呼び出し可能オブジェクトを使うパターンも考えられる。 これはクラスに特殊メソッド __call__() を実装するというもの。 この特殊メソッドを実装すると、インスタンス自体を関数みたいに実行できるようになる。 で、その特殊メソッド __call__() がデコレータとして動作するとしたら?という。

以下のサンプルコードでは特殊メソッド __call__() がデコレータとして動作する。 そのためインスタンス化したオブジェクトの instance が、そのままデコレータとして使えている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Decorator(object):

    def __call__(self, func):
        """呼び出し可能オブジェクトを作るための特殊メソッド

        デコレータとして動作する"""

        def wrapper(*args, **kwargs):
            print('before')
            result = func(*args, **kwargs)
            print('after')
            return result

        return wrapper


# クラスをインスタンス化する
instance = Decorator()


# 呼び出し可能オブジェクトなので
# インスタンスそのものがデコレータとして使える
@instance
def greet():
    print('Hello, World!')


def main():
    greet()


if __name__ == '__main__':
    main()

実行結果はこれまでと変わらないので省略する。

デコレートする対象

ここまでの例では、デコレートする対象は全て関数だった。 しかし、デコレータは関数以外もデコレートすることができる。

メソッドをデコレートする

以下のサンプルコードでは、おなじみの @deco デコレータがインスタンスメソッドをデコレートしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def deco(func):
    """デコレートした関数の前後に処理を挟み込む自作デコレータ"""

    def wrapper(*args, **kwargs):
        """本来の関数の代わりに呼び出される関数"""
        print('before')  # 本来の関数が呼び出される前に実行される処理
        result = func(*args, **kwargs)  # 本来の関数の呼び出し
        print('after')  # 本来の関数が呼び出された後に実行される処理
        return result  # 本来の関すが返した結果を返す

    # 引数で関数を受け取って、別の関数を返している
    return wrapper


class MyClass(object):

    # インスタンスメソッドをデコレートする
    @deco
    def greet(self):
        print('Hello, World!')


def main():
    # クラスをインスタンス化する
    obj = MyClass()
    # デコレートされたメソッドを呼び出す
    obj.greet()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 ちゃんと動くことが分かる。

$ python method.py 
before
Hello, World!
after

ちなみにメソッドをデコレートするときの注意点が一つある。 それは、置き換える関数の第一引数にインスタンスオブジェクトを受け取れるようにすること。 Python のメソッドは、典型的には self という名前で第一引数にインスタンスを受け取る。 置き換える関数が、この一つ余分な引数を受け取れるようになっていないと動作しない。 先ほどのサンプルコードでは引数を (*args, **kwargs) という任意の形で受け取れるようにしていたので、特に気にすることはなかった。

クラスをデコレートする

デコレータはクラスをデコレートすることもできる。

以下のサンプルコードでは @deco デコレータが MyClass をデコレートしている。 @deco デコレータでは、クラスが持っているメソッドを上書きして回っている。 上書きされたメソッドは、呼び出されたタイミングでその旨が出力されるようになる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import inspect


def deco(cls):
    """クラスオブジェクトを引数に取るデコレータ

    XXX: Python 3 でしか動作しない"""

    # クラスからメソッド一覧を取得する
    methods = inspect.getmembers(cls, predicate=inspect.isfunction)

    # クラスのメソッドを上書きして回る
    for method_name, method_object in methods:
        wrapped_method = logging_wrapper(method_object)
        setattr(cls, method_name, wrapped_method)

    # 受け取ったクラスはそのまま返す
    return cls


def logging_wrapper(func):
    """関数の呼び出しを記録するラッパー"""
    def _wrapper(*args, **kwargs):
        # 本当は logging モジュールを使うべき
        print('call:', func.__name__)
        result = func(*args, **kwargs)
        return result
    return _wrapper


# クラスをデコレータでデコレートする
@deco
class MyClass(object):

    def greet_morning(self):
        print('Good morning!')

    def greet_afternoon(self):
        print('Good afternoon!')

    def greet_evening(self):
        print('Good evening!')


def main():
    # デコレータでデコレートされたクラスをインスタンス化する
    o = MyClass()
    # いくつかのメソッドを呼び出す (実はデコレータによって上書き済み)
    o.greet_morning()
    o.greet_afternoon()
    o.greet_evening()


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 上書きされたメソッドによって、呼び出しが記録されていることが分かる。

$ python clsdeco.py 
call: greet_morning
Good morning!
call: greet_afternoon
Good afternoon!
call: greet_evening
Good evening!

まとめ

今回扱った内容は以下の通り。

  • デコレータの本質
    • デコレータはシンタックスシュガー (糖衣構文) に過ぎない
  • デコレータの作り方
    • 引数を取るデコレータと取らないデコレータ
  • デコレータの用途
    • 用途はラッピングとマーキングの二つに大別できる
  • デコレータの種類
    • デコレータは関数、メソッド、インスタンスで作れる
  • デコレータの対象
    • デコレートできるのは関数、メソッド以外にクラスもある

上記さえ理解していれば、あとは目的に応じてどのようなデコレータを作れば良いかが自動的に決まる。

参考

www.python.org

www.python.org

Python: メモ化した内容を percache で永続化する

プログラムを高速化する手法の一つとしてメモ化がある。 これは、関数の返り値をキャッシュしておくことで、同じ呼び出しがあったときにそれを使い回すというもの。

今回は、メモ化でキャッシュした内容を補助記憶装置に永続化できる Python のパッケージとして percache を使ってみる。 キャッシュを補助記憶装置に永続化すると、その分だけ読み書きにはオーバーヘッドがかかる。 しかしながら、計算に多量の時間がかかる場合にはそれでもメリットがありそう。

ただし、先に断っておくと世間的にはほとんど使われていないパッケージなので実際に使うときは十分に検討した方が良い。 キャッシュの機構は、慎重にならないと不具合や脆弱性を生みやすいところなので、特に気をつけた方が良いと思う。 今回の動機としては、元々は似たようなパッケージを自分で書こうか悩んでいて、探したら API がいけてたので試してみたという感じ。

使った環境は次の通り。

$ sw_vers             
ProductName:    Mac OS X
ProductVersion: 10.13.6
BuildVersion:   17G65
$ python -V
Python 3.6.6

下準備

まずは percache をインストールしておく。

$ pip install percache

標準モジュールの functools.lru_cache を使ったメモ化

percache の説明に入る前に、一般的なメモ化について扱っておく。 まず、補助記憶装置への永続化のないメモ化については Python の標準モジュールに実装がある。 具体的には functools.lru_cache を使うと簡単に実現できる。

サンプルコードを以下に示す。 このコードの中では add() 関数を @lru_cache() デコレータを使ってメモ化している。 add() 関数では ab という二つの引数を足し算をして、その結果を返す。 それをメモ化しているということは、つまり引数の ab が以前に呼び出した値と同じだったら、そのときの戻り値を使い回す。 また、関数の中では、実際に処理されたときと返り値が使い回されたときを区別できるように、デバッグ用の文字列を出力している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import lru_cache


# add() 関数をメモ化する
@lru_cache()
def add(a, b):
    print('calculate')  # 実際に処理されたことを確認するための出力
    return a + b


def main():
    # メモ化された関数を二回呼び出す
    print(add(1, 2))
    print(add(1, 2))


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 すると、関数は二回呼び出されているにも関わらず calculate という文字列は一回しか出力されていない。 これは、二回目の呼び出しは引数が同じなので戻り値が使い回されたことを示している。

$ python memoize.py
calculate
3
3

ただし functools.lru_cache を使ったメモ化では、キャッシュが補助記憶装置に永続化されない。 そのため Python のプロセスが終了すると、主記憶装置にキャッシュしていた内容も揮発してしまう。 その証拠に、先ほどのプログラムを再度実行すると、同じ出力になる。 このとき、もしキャッシュしていた内容が補助記憶装置に永続化されているなら calculate という文字列は出力されないはず。

$ python memoize.py
calculate
3
3

percache でキャッシュを永続化してみる

続いては今回の本題として percache を使ってキャッシュを補助記憶装置に永続化してみる。 使い方は percache.Cache のインスタンスを作って、そのインスタンスをデコレータとして使うというもの。 Flask 的な API になっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import percache

# my-cache というファイル名でキャッシュを永続化する
cache = percache.Cache('my-cache')


# add() 関数をメモ化する
@cache
def add(a, b):
    print('calculate')
    return a + b


def main():
    import sys
    # 引数を整数として解釈して add() 関数を呼び出す
    result = add(int(sys.argv[1]), int(sys.argv[2]))
    print(result)


if __name__ == '__main__':
    main()

上記を保存したら、同じ引数を使ってプログラムを何度か実行してみよう。 すると、同じ引数を使うと二度目以降は calculate という文字列が表示されない。 つまり、補助記憶装置に永続化されたキャッシュが使われていることが分かる。

$ python pmemoize.py 1 2
calculate
3
$ python pmemoize.py 1 2
3
$ python pmemoize.py 1 3
calculate
4
$ python pmemoize.py 1 3
4

percache の実装

基本的な使い方が分かったところで、続いては percache の実装について見ていこう。 このパッケージは一つのモジュールに収まるほどシンプルな作りになっている。

github.com

どのようにキャッシュの永続化が実現されているか確認しよう。 まず、先ほどの例を実行すると my-cache というファイルがカレントディレクトリにできているはず。 これが、永続化されたキャッシュ内容を保存するためのファイルになる。

$ ls my-cache 
my-cache
$ file my-cache 
my-cache: GNU dbm 1.x or ndbm database, little endian, 64-bit

中身を確認するために Python の REPL を起動しよう。

$ python

キャッシュの永続化は標準モジュールの shelve を使って実装されている。 次のコードを実行すると、キャッシュされた内容が確認できる。

>>> import shelve
>>> with shelve.open('my-cache') as s:
...     for key, value in s.items():
...         print(key, value)
... 
936247610f625403ba55b32ab4dddfc6abd7c2ee 4
de71ece6a221c54c692400a6294839b2c02fd4f2:atime 1535584568.42677
936247610f625403ba55b32ab4dddfc6abd7c2ee:atime 1535584570.2073479
de71ece6a221c54c692400a6294839b2c02fd4f2 3

shelve というモジュールは Python の辞書ライクなオブジェクトを補助記憶装置に永続化するための機能になっている。

12.3. shelve --- Python オブジェクトの永続化 — Python 3.6.6 ドキュメント

先ほど確認した内容から、いくつか分かることがある。 まず、辞書のキーとしてはハッシュと思われる値が使われており、それに対応する値にはメモ化した関数の計算結果が保存されている。 そして、それとは別に永続化した時刻と思われる内容についても保存されているようだ。

ちなみに辞書のキーとなる値については、ソースコードを確認したところ次のようなアルゴリズムで生成されていた。 まず、関数名と repr() した引数の内容を文字列として連結して UTF-8 でバイト列にエンコードする。 そして、その内容を SHA1 でハッシュ化する。

試しに add() 関数に 12 を渡した際のハッシュを手作業で生成してみよう。

>>> args = ''.join(['add', repr(1), repr(2)]).encode('utf-8')
>>> hashlib.sha1(args).hexdigest()
'de71ece6a221c54c692400a6294839b2c02fd4f2'

上記が、先ほど確認した辞書のキーで、結果が 3 になっているものと一致していることが分かる。

バックエンドをオリジナルの実装に入れ替える

percache はキャッシュを永続化する部分をオリジナルの実装に入れ替えることもできる。 ちなみに、キャッシュを永続化する部分の実装を percache ではバックエンドと呼んでいる。 例えば、やろうと思えばクラウド上のストレージに永続化するバックエンドを書くこともできるはず。

以下にサンプルコードとして、保存されるキャッシュの件数を制限できるバックエンドを書いてみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import OrderedDict
import shelve

import percache


class LimitedSizeBackend(OrderedDict):
    """永続化する結果の数を制限したバックエンド"""

    def __init__(self, filename, limit_size=-1):
        self._filename = filename
        self._limit_size = limit_size
        self._shelve_dict = shelve.open(self._filename)

        self._load()
        self._check_size()

    def _load(self):
        """shelve から永続化されているデータを読み込む"""
        self.update(self._shelve_dict)

    def _check_size(self):
        if self._limit_size < 0:
            # サイズ上限が負なら何もしない
            return

        # サイズ上限に収まるように一番古い要素を削除する
        # NOTE: percache は 1 つのメモ化に 2 つ要素を使う
        while len(self) > self._limit_size * 2:
            # FIFO (Queue)
            self.popitem(last=False)

    def __setitem__(self, key, value):
        """要素の追加があったとき呼ばれる特殊メソッド"""
        super(LimitedSizeBackend, self).__setitem__(key, value)
        self._check_size()

    def _save(self):
        """shelve にデータを書き込んで永続化する"""
        # 一旦既存のデータをクリアする
        self._shelve_dict.clear()
        # 現在のデータを書き戻す
        self._shelve_dict.update(self)
        # 永続化する
        self._shelve_dict.sync()

    def sync(self):
        """shelve として振る舞う (ダックタイピング) ために必要"""
        self._save()

    def close(self):
        """shelve として振る舞う (ダックタイピング) ために必要"""
        self._save()


# サイズ上限が 1 のバックエンドを用意する
backend = LimitedSizeBackend('limited-cache', 1)
# バックエンドを設定してキャッシュオブジェクトを作る
cache = percache.Cache(backend)


# add() 関数をメモ化する
@cache
def add(a, b):
    print('calculate')
    return a + b


def main():
    import sys
    # プログラムの第一引数と第二引数を整数として add() 関数に渡す
    result = add(int(sys.argv[1]), int(sys.argv[2]))
    print(result)


if __name__ == '__main__':
    main()

サンプルコードではキャッシュのサイズを 1 に制限している。 これはつまり、直近一件の返り値だけがキャッシュされるということ。

上記を保存したら色々な値を入れて動作を確認してみよう。 同じ値を入れる限りはキャッシュの結果が使われるものの、別の値を入力すれば忘却することが分かる。

$ python limited.py 1 2
calculate
3
$ python limited.py 1 2
3
$ python limited.py 1 3
calculate
4
$ python limited.py 1 3
4
$ python limited.py 1 2
calculate
3
$ python limited.py 1 2
3

ちなみに、キャッシュのアルゴリズムでは、際限なくサイズが膨れ上がらないような仕組みを入れることが非常に重要となる。 具体的には、キャッシュする件数を制限したり、あるいは一定時間使われないものを消去するといった内容が考えられる。 そういった仕組みがないと、キャッシュによってシステムのリソースを使い尽くす可能性がある。

また、ユーザからの入力を元にキャッシュしているときにそうした仕組みがないと、意図的にリソースを枯渇させることも可能になってしまう。 これは DoS (Denial of Service) 攻撃への脆弱性になる。 ここらへんの制約は、主記憶装置よりも補助記憶装置の方がゆるい。 ただし、キャッシュの保存先が補助記憶装置だとしても、実際に使うときは主記憶装置の上に展開されることを忘れてはいけない。

読み書きがある毎に永続化する

percache は、デフォルトだと明示的に Cache#sync() メソッドや Cache#close() メソッドを呼ばないとバックエンドへの読み書きが発生しない。 これは、コストの高い補助記憶装置へのアクセスを最小限に留めるためと考えられる。 ただし、オプションの livesyncTrue を指定すれば、値の更新が生じた時点でバックエンドに読み書きが生じる。

以下のサンプルコードでは、デバッグ用に sync() メソッドと close() メソッドが呼ばれると文字列を出力するバックエンドを定義している。 それを用いた上で livesync オプションに True を指定している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import percache


class DebugBackend(dict):
    """デバッグ用のバックエンド (ディスクへの永続化はしない)"""

    def sync(self):
        """sync() メソッドが呼ばれたときに標準出力に書く"""
        print('sync')

    def close(self):
        """close() メソッドが呼ばれたときに標準出力に書く"""
        print('close')

# 値の更新がある毎に永続化する
debug_backend = DebugBackend()
cache = percache.Cache(debug_backend, livesync=True)


@cache
def add(a, b):
    print('calculate')
    return a + b


def main():
    print(add(1, 2))
    print(add(1, 2))


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 計算自体は一回しか実施されていないが、戻り値が返るごとにバックエンドへの読み書きが生じていることが分かる。

$ python livesync.py
calculate
sync
3
sync
3
close

オブジェクトの文字列表現をカスタマイズする

前述した通り percache では関数の引数を組み込み関数の repr() に渡して、その結果を元に辞書のキーに使うハッシュを作る。 つまり、関数の引数に渡すオブジェクトは repr() で適切な文字列を返すようになっていないといけない。

ここで補足しておくと 、組み込み関数の repr() というのは渡されたオブジェクトの文字列表現を取り出すために用いられる。 この repr() 関数にオブジェクトが渡されたときの振る舞いは、特殊メソッドの __repr__() を定義することでオーバーライドできる。

実際に試してみよう。 まずは User という自作クラスを用意する。 このクラスはインスタンス化するときに名前と年齢をメンバ変数に格納する。 そして、このクラスには特殊メソッドの __repr__() が定義されていない。

>>> class User(object):
...     def __init__(self, name, age):
...         self._name = name
...         self._age = age
... 

このクラスをインスタンス化して repr() に渡すと、デフォルトの文字列表現が返される。 これは、インスタンスの元となったクラス名やメモリ上の配置位置を示している。

>>> o = User('alice', 20)
>>> repr(o)
'<__main__.User object at 0x103519b00>'

上記のメモリ上の配置位置はオブジェクトを作る毎に変化する。 本来であれば、同じ名前と年齢を持ったオブジェクトからは同じ repr() がほしい。 そうなっていないと、同じ値を持っているにも関わらず生成したハッシュが異なってしまう。 それではメモ化するときのキーとしては使えない。

そこで、次のように特殊メソッドの __repr__() を定義する。 ここにはクラス名と名前と年齢が文字列に埋め込まれている。

>>> class User(object):
...     def __init__(self, name, age):
...         self._name = name
...         self._age = age
...     def __repr__(self):
...         """repr() で呼ばれる特殊メソッド"""
...         params = {
...             'name': self._name,
...             'age': self._age,
...         }
...         repr_msg = '<User name:{name} age:{age}>'.format(**params)
...         return repr_msg
... 

実際にインスタンス化したオブジェクトを repr() 関数に渡してみよう。

>>> o = User('alice', 20)
>>> repr(o)
'<User name:alice age:20>'

ちゃんとクラス名と名前と年齢を使ってオブジェクトの文字列表現が返るようになった。 これならオブジェクトが異なっても、同じ値さえ持っていれば同じハッシュが生成できる。

このように、自作のクラスについては上記のように特殊メソッドの __repr__() を実装してやればいい。 ただ、実際には自分で作っていないオブジェクトをメモ化した関数に渡すことも考えられる。

percache では、この問題の解決方法も用意してある。 具体的には組み込み関数 repr() の代わりにオブジェクトの文字列表現を取り出すための関数が登録できる。

以下のサンプルコードでは _repr() という関数でオブジェクトの文字列表現を取り出すための関数を定義している。 そして、それを Cache クラスのコンストラクタに渡している。 _repr() 関数の中では、全てのオブジェクトの文字列表現を生成する。 ただし、カスタマイズしたいオブジェクト以外については単純に repr() 関数の出力を返すだけで良い。 このサンプルコードでは User クラスのときだけ特別扱いして、名前と年齢を元に文字列表現を作っている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import percache


class User(object):
    def __init__(self, name, age):
        self._name = name
        self._age = age


def _repr(args):
    if isinstance(args, User):
        # 引数に User のインスタンスが渡されたときの処理
        params = {
            'name': args._name,
            'age': args._age,
        }
        repr_msg = '<User name:{name} age:{age}>'.format(**params)
        return repr_msg

    # それ以外のオブジェクト
    return repr(args)


cache = percache.Cache('user-cache', repr=_repr)


@cache
def process(user):
    print('calculate')
    return user._name, user._age


def main():
    o1 = User('alice', 20)
    print(process(o1))
    o2 = User('alice', 20)
    print(process(o2))


if __name__ == '__main__':
    main()

変数の o1o2 は異なるオブジェクトだけど、持っているメンバ変数の内容は同じなので等価と見なせる。

上記を実行してみよう。 calculate という文字列は一度しか出力されていないことから、ちゃんと戻り値が使い回されていることが分かる。

$ python myrepr.py
calculate
('alice', 20)
('alice', 20)

めでたしめでたし。

Python: グローバルスコープにあるオブジェクトの __del__() でインポートしたときの挙動について

今回は Python の __del__() メソッドでちょっと不思議な挙動を目にしてから色々と調べてみた話。 具体的には、グローバルスコープにあるオブジェクトの __del__() で別のモジュールをインポートしてるとき、そのオブジェクトがプロセス終了時に破棄されると場合によっては例外になる。 ただし、これは Python の仕様かというとかなり微妙で CPython の 3.x 系でしか同じ問題は観測できていない。 少なくとも、同じ CPython でも 2.x 系や、同じ Python 3.x 系の実装でも PyPy3 では発生しない。 おそらく実装上の都合によるものだと思う。

使った環境は次の通り。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.1 LTS"
$ uname -r
4.15.0-20-generic
$ python3 -V
Python 3.6.5
$ python -V
Python 2.7.15rc1

オブジェクトの __del__() メソッドについて

まずは前提となる知識から。 Python のオブジェクトは __del__() という特殊メソッドを定義することで、自身が破棄されるときの挙動をオーバーライドできる。 この __del__() メソッドは、デストラクタやファイナライザとも呼ばれる。

例えば、以下のサンプルコードでは Example クラスに __del__() メソッドを定義している。 それを main() 関数の中でインスタンス化した後に del 文を使って明示的に破棄している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Example(object):

    def __init__(self):
        """オブジェクトが作られるとき呼び出される"""
        print('born:', id(self))

    def __del__(self):
        """オブジェクトが破棄されるとき呼び出される"""
        print('died:', id(self))


def main():
    print('making')
    o = Example()  # オブジェクトを作る
    print('deleting')
    del o  # オブジェクトを明示的に破棄する
    print('done')


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python3 explicitdel.py
making
born: 139716100508416
deleting
died: 139716100508416
done

del 文が発行されたタイミングで __del__() メソッドが呼び出されていることが分かる。

また、明示的に del 文を発行しなくても、オブジェクトの寿命と共に呼び出される。 次のサンプルコードでは、先ほどとは異なり明示的に del 文を発行していない。 ただし、関数スコープの終了と共にオブジェクトは破棄されることが期待できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Example(object):

    def __init__(self):
        """オブジェクトが作られるとき呼び出される"""
        print('born:', id(self))

    def __del__(self):
        """オブジェクトが破棄されるとき呼び出される"""
        print('died:', id(self))


def main():
    print('making')
    o = Example()
    # 明示的にオブジェクトを破棄しない
    print('done')


if __name__ == '__main__':
    print('start')
    main()
    print('end')

上記を実際に実行してみよう。

$ python3 implicitdel.py 
start
making
born: 140412486589464
done
died: 140412486589464
end

たしかに main() 関数が終了するタイミングで __del__() メソッドが呼び出されていることが分かる。

本題

ここからが本題なんだけど、例えば以下のようなコードがあったとする。 特徴としては二つある。 まずひとつ目は __del__() メソッドの中で別のモジュールをインポートしているところ。 そしてふたつ目が、そのオブジェクトがグローバルスコープにあるところ。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


class Example(object):

    def __init__(self):
        """オブジェクトが作られるとき呼び出される"""
        print('born:', id(self))

    def __del__(self):
        """オブジェクトが破棄されるとき呼び出される"""
        import sys  # __del__() の中で他のモジュールをインポートする
        print('died:', id(self))


# オブジェクトをグローバルスコープに設置する
print('making')
o = Example()
print('done')
# プロセスが終了するタイミングでオブジェクトが破棄される

で、これを 3.x 系の CPython で実行すると、次のような例外になる。

$ python3 gimplicitdel.py 
making
born: 139889227380104
done
Exception ignored in: <bound method Example.__del__ of <__main__.Example object at 0x7f3a7fb4b588>>
Traceback (most recent call last):
  File "globaldel.py", line 13, in __del__
ImportError: sys.meta_path is None, Python is likely shutting down

なんかもう Python が終了しようとしてるからインポートするの無理っすみたいなエラーになってる。

ちなみに、同じ CPython でも 2.x 系であれば例外にならない。

$ python gimplicitdel.py
making
('born:', 140683376878608)
done
('died:', 140683376878608)

回避策 (ワークアラウンド)

この問題が起こる理由は一旦置いといて、とりあえず回避する方法としては以下の三つがある。 尚 Python 2.x を使うという選択肢は、もちろんない。

  1. __del__() メソッド内でインポートするのをやめる
  2. オブジェクトを del 文で明示的に破棄する
  3. CPython 以外の実装を使う

ひとつ目の回避策は一見するともっともで、そもそもファイルの先頭以外でのインポートは PEP8 に準拠していない。 とはいえ、現実にはそうもいかない場合があって。 例えば標準パッケージでもファイルの先頭以外でインポートしている例は見つかる。 このようなコードを間接的にでも呼び出すと、同じ問題が起こる。

github.com

上記でインポートしているモジュール dbm は、環境によっては存在しない。 そこで、あるときだけ使うためにこのようなコードになっているんだと思う。

ふたつ目の選択肢は、ひとつ目がダメなときは現実的になってくるかもしれない。 Python は atexit というモジュールを使ってインタプリタの終了ハンドラが登録できる。 その中で del 文を発行すれば良いと思う。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import atexit


class Example(object):

    def __init__(self):
        """オブジェクトが作られるとき呼び出される"""
        print('born:', id(self))

    def __del__(self):
        """オブジェクトが破棄されるとき呼び出される"""
        import sys  # __del__() の中で他のモジュールをインポートする
        print('died:', id(self))


# オブジェクトをグローバルスコープに設置する
print('making')
o = Example()
print('making done')

def _atexit():
    """オブジェクトを後始末する"""
    print('atexit')
    global o  # グローバルスコープの変数を変更する
    print('deleting')
    del o  # 終了ハンドラの中で明示的にオブジェクトを破棄する
    print('deleting done')


def main():
    # 終了ハンドラに登録する
    atexit.register(_atexit)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python3 delatexit.py 
making
born: 140689479625640
making done
atexit
deleting
died: 140689479625640
deleting done

たしかに、今度はエラーにならない。

ただ、ここまで書いておいてなんだけど、そもそも本当に回避する必要はあるのか?という点も議論の余地があるかもしれない。 問題は Python のプロセスが終了するタイミングの話なので、放っておいても結局のところオブジェクトとか関係なくメモリは開放される。 とはいえ、例外のせいで終了時のステータスコードは非ゼロにセットされるし、対処しないと気持ち悪いのはたしか。

ソースコードから問題について調べる

ここからは、この問題について CPython のソースコードを軽く追ってみた話。

まず、前述した例外を出しているのは以下のようだ。

https://github.com/python/cpython/blob/v3.6.6/Lib/importlib/_bootstrap.py#L873,L876

コメントには PyImport_Cleanup() 関数が呼ばれている最中か、あるいは既に呼ばれたことで起こると書いてある。

ドキュメントを確認すると、この関数は用途が内部利用なもののインポート関連のリソースを開放する目的があるらしい。

モジュールのインポート — Python 3.6.6 ドキュメント

ソースコードでいうと、以下に該当する。

https://github.com/python/cpython/blob/v3.6.6/Python/import.c#L335

上記の関数が呼ばれるのは、以下の二箇所かな。

https://github.com/python/cpython/blob/v3.6.6/Python/pylifecycle.c#L608

https://github.com/python/cpython/blob/v3.6.6/Python/pylifecycle.c#L881

上記の PyImport_Cleanup() という関数がオブジェクトが破棄されるよりも前に呼び出されているとアウトっぽいことが分かった。

続いては、実際にタイミングを動的解析で調べてみよう。 まずは GDB と Python3 のデバッグ用パッケージをインストールする。

$ sudo apt-get install gdb python3-dbg

次に gdb コマンド経由で Python を起動する。

$ gdb --args python3 gimplicitdel.py

お目当ての関数にブレークポイントを打つ。

(gdb) b PyImport_Cleanup
Breakpoint 1 at 0x573b80: file ../Python/import.c, line 336.

プログラムを走らせるとブレークポイントに引っかかった。

(gdb) run
Starting program: /usr/bin/python3 gimplicitdel.py
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
making
born: 140737352545616
done

Breakpoint 1, PyImport_Cleanup () at ../Python/import.c:336
336 ../Python/import.c: No such file or directory.

バックトレースはこんな感じ。 さっき確認した場所から呼ばれてる。

(gdb) bt
#0  PyImport_Cleanup () at ../Python/import.c:336
#1  0x0000000000426906 in Py_FinalizeEx () at ../Python/pylifecycle.c:608
#2  0x0000000000426b15 in Py_FinalizeEx () at ../Python/pylifecycle.c:740
#3  0x0000000000441c22 in Py_Main (argc=argc@entry=2, argv=argv@entry=0xa8f260)
    at ../Modules/main.c:830
#4  0x0000000000421ff4 in main (argc=2, argv=<optimized out>)
    at ../Programs/python.c:69

プログラムを進めると、今度は前述した例外が上がった。

(gdb) c
Continuing.
Exception ignored in: <bound method Example.__del__ of <__main__.Example object at 0x7ffff7e7b550>>
Traceback (most recent call last):
  File "gimplicitdel.py", line 13, in __del__
ImportError: sys.meta_path is None, Python is likely shutting down
[Inferior 1 (process 11210) exited normally]

たしかにオブジェクトの __del__() メソッドが呼ばれるより前に PyImport_Cleanup() 関数が呼ばれているようだ。

いじょう。

Python: scikit-learn のロジスティック回帰を使ってみる

最近、意外とロジスティック回帰が使われていることに気づいた。 もちろん世間にはもっと表現力のある分類器がたくさんあるけど、問題によってどれくらい複雑なモデルが適しているかは異なる。 それに、各特徴量がどのように働くか重みから確認したり、単純なモデルなのでスコアをベンチマークとして利用する、といった用途もあるらしい。 今回は、そんなロジスティック回帰を scikit-learn の実装で試してみる。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.13.6
BuildVersion:   17G65
$ python -V               
Python 3.6.6
$ pip list --format=columns | grep -i scikit-learn
scikit-learn    0.19.2 

下準備

まずは scikit-learn をインストールしておく。

$ pip install scikit-learn

乳がんデータセットをロジスティック回帰で分類してみる

以下にロジスティック回帰を使って乳がんデータセットを分類するサンプルコードを示す。 とはいえ scikit-learn は API が統一されているので、分類器がロジスティック回帰になってる以外に特筆すべき点はないかも。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate


def main():
    # 乳がんデータセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    # ロジスティック回帰
    clf = LogisticRegression()
    # Stratified K-Fold CV で性能を評価する
    skf = StratifiedKFold(shuffle=True)
    scoring = {
        'acc': 'accuracy',
        'auc': 'roc_auc',
    }
    scores = cross_validate(clf, X, y, cv=skf, scoring=scoring)

    print('Accuracy (mean):', scores['test_acc'].mean())
    print('AUC (mean):', scores['test_auc'].mean())


if __name__ == '__main__':
    main()

実行してみよう。 だいたい精度 (Accuracy) の平均で 0.947 前後のスコアが得られた。

$ python logistic.py
Accuracy (mean): 0.9472570314675578
AUC (mean): 0.991659762496548

正直そんなに高くないけど、ロジスティック回帰くらい単純なモデルではこれくらいなんだなっていう指標にはなると思う。

ロジスティック回帰の利点

ロジスティック回帰の良いところは、モデルが単純で解釈も容易なところ。 例えば、基本的に線形モデルの眷属なので、各特徴量の重み (傾き) が確認できる。

実際に、それを体験してみよう。 今回例に挙げた乳がんデータセットは、腫瘍の特徴を記録した 30 次元の教師データだった。

>>> from sklearn import datasets
>>> dataset = datasets.load_breast_cancer()
>>> dataset.feature_names
array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',
       'mean smoothness', 'mean compactness', 'mean concavity',
       'mean concave points', 'mean symmetry', 'mean fractal dimension',
       'radius error', 'texture error', 'perimeter error', 'area error',
       'smoothness error', 'compactness error', 'concavity error',
       'concave points error', 'symmetry error',
       'fractal dimension error', 'worst radius', 'worst texture',
       'worst perimeter', 'worst area', 'worst smoothness',
       'worst compactness', 'worst concavity', 'worst concave points',
       'worst symmetry', 'worst fractal dimension'], dtype='<U23')
>>> len(dataset.feature_names)
30

目的変数は 0 が良性で 1 が悪性となっている。

>>> import numpy as np
>>> np.unique(y)
array([0, 1])

ホールドアウト検証を使ってデータを分割したら、モデルを学習させよう。

>>> from sklearn.model_selection import train_test_split
>>> X, y = dataset.data, dataset.target
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=42)
>>> from sklearn.linear_model import LogisticRegression
>>> clf = LogisticRegression()
>>> clf.fit(X_train, y_train)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)

すると、学習したモデルで切片 (LogisticRegression#intercept_) と重み (LogisticRegression#coef_) が確認できる。 重みは各特徴量ごとにある。

>>> clf.intercept_
array([0.40407439])
>>> clf.coef_
array([[ 2.18931343e+00,  1.51512837e-01, -1.57814199e-01,
        -1.03404299e-03, -1.29170075e-01, -4.23805008e-01,
        -6.47620520e-01, -3.37002545e-01, -1.97619418e-01,
        -3.23607668e-02, -6.88409834e-02,  1.48012177e+00,
         4.81243097e-02, -1.05177866e-01, -1.40690243e-02,
        -3.50323361e-02, -7.06715773e-02, -3.93587747e-02,
        -4.81468850e-02, -2.01238862e-03,  1.20675464e+00,
        -3.93262696e-01, -4.96613892e-02, -2.45385329e-02,
        -2.43248181e-01, -1.21314110e+00, -1.60969567e+00,
        -6.01906976e-01, -7.28573372e-01, -1.21974174e-01]])
>>> len(clf.coef_[0])
30

元のデータセットを標準化していないので、重みの値の大小については単純な比較が難しい。 ただ、特徴量がプラスであればその特徴量は悪性の方向に、反対にマイナスなら良性の方向に働く。

学習済みモデルの切片と重みから推論内容を確認する

ここからは学習済みモデルの切片と重みから計算した内容が推論と一致することを確認してみる。 ロジスティック回帰は線形回帰の式をシグモイド関数で 0 ~ 1 に変換したものになっている。

以下の式は線形回帰の式で、重みが  w で切片が  b に対応する。 ようするに特徴量と重みをかけて切片を足すだけ。

 y = wX + b

上記をシグモイド関数に放り込むと値が 0 ~ 1 の範囲に収まる。

 z = \frac{1}{1 + e^{-y}}

これがロジスティック回帰の出力となる。

実際に上記を学習済みモデルで確認してみよう。 例えば検証用データの先頭は悪性と判定されている。

>>> clf.predict([X_test[0]])
array([1])
>>> clf.predict_proba([X_test[0]])
array([[0.19165157, 0.80834843]])

悪性の確率は 0.80834843 になる。

正解を確認すると、たしかに悪性のようだ。

>>> y_test[0]
1

それでは学習済みモデルから計算した内容と上記が一致するかを確認してみよう。 まずは線形回帰の式を作る。 これが上記の  y = wX + b に対応する。

>>> import numpy as np
>>> y = np.sum(clf.coef_ * X_test[0]) + clf.intercept_
>>> y
array([1.4393142])

続いてシグモイド関数を定義しておく。

>>> def sigmoid(x):
...     return 1 / (1 + np.exp(-x))
... 

あとは、さきほど得られた結果をシグモイド関数に放り込むだけ。

>>> z = sigmoid(y)
>>> z
array([0.80834843])

結果は 0.80834843 となって、見事に先ほど得られた内容と一致している。

ばっちり。

Python: scikit-learn の Pipeline 機能をデバッグする

今回はだいぶ小ネタ。 以前にこのブログでも記事にしたことがある scikit-learn の Pipeline 機能について。

blog.amedama.jp

scikit-learn の Pipeline 機能は機械学習に必要となる複数の工程を一つのパイプラインで表現できる。 ただ、パイプラインを組んでしまうと途中のフェーズで出力がどうなっているか、とかが確認しにくい問題がある。 この問題について調べると以下の StackOverflow が見つかるんだけど、なかなかシンプルな解決方法だった。

stackoverflow.com

先に概要を述べると、特に何もしないフェーズを用意して、そこでデバッグ用の出力をするというもの。

下準備

まずは必要になるパッケージをインストールしておく。

$ pip install pandas scikit-learn scipy numpy

Pipeline に組み込むデバッグ用のオブジェクト

早速だけど以下にサンプルコードを示す。 このコードでは Debug という名前でデバッグ用のクラスを用意している。 このクラスは scikit-learn の Pipeline に組み込むことができる。 実体としては Pipeline が使うメソッドでデバッグ用の出力をするだけの内容になっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-


from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin

import pandas as pd


class Debug(BaseEstimator, TransformerMixin):
    """Pipelineをデバッグするためのクラス"""

    def transform(self, X):
        # 受け取った X を DataFrame に変換して先頭部分だけを出力する
        print(pd.DataFrame(X).head())
        # データはそのまま横流しする
        return X

    def fit(self, X, y=None, **fit_params):
        # 特に何もしない
        return self


def main():
    # Iris データセットを読み込む
    dataset = datasets.load_iris()
    X, y = dataset.data, dataset.target

    # パイプラインを構成する
    steps = [
        ('pca', PCA()),
        ('debug', Debug()),  # PCA の出力結果を確認する
        ('rf', RandomForestClassifier()),
    ]
    pipeline = Pipeline(steps=steps)

    # 学習
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    pipeline.fit(X_train, y_train)

    # 推論
    y_pred = pipeline.predict(X_test)
    score = accuracy_score(y_test, y_pred)
    print(score)


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、Iris データセットを PCA (主成分分析) した結果の先頭部分が標準出力にプリントされる。 二回出力されているのは、モデルの学習 (fit() メソッド) と評価 (predict()) の二回で呼び出されているため。

$ python debug.py                            
          0         1         2         3
0  0.321625 -0.235144  0.057917  0.125637
1  3.355396  0.578542 -0.331641  0.076760
2  0.606069 -0.315582  0.300426  0.187366
3 -2.727847  0.438338  0.013295  0.002542
4  3.455577  0.501194 -0.562918  0.098940
          0         1         2         3
0  0.868341 -0.114257 -0.250542  0.271719
1 -2.233869  0.987378 -0.045914 -0.029639
2  3.746741  0.287862 -0.513685 -0.094163
3  0.760309 -0.111519  0.023542  0.020324
4  1.283430  0.320953 -0.507830 -0.063090
0.96

ばっちり。

Python: 層化抽出法を使ったK-分割交差検証 (Stratified K-Fold CV)

K-分割交差検証 (K-Fold CV) を用いた機械学習モデルの評価では、元のデータセットを K 個のサブセットに分割する。 そして、分割したサブセットの一つを検証用に、残りの K - 1 個を学習用に用いる。

上記の作業で、元のデータセットを K 個のサブセットに分割する工程に着目してみよう。 果たして、どのようなルールにもとづいて分割するのが良いのだろうか? このとき、誤ったやり方で分割すると、モデルの学習が上手くいかなかったり、汎化性能を正しく評価できない恐れがある。

今回は、分割方法として層化抽出法を用いたK-分割交差検証 (Stratified K-Fold CV) について書いてみる。 この方法を使うと、学習用データと検証用データで目的変数の偏りが少なくなる。 実装には scikit-learn の sklearn.model_selection.StratifiedKFold を用いた。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.13.6
BuildVersion:   17G65
$ python -V
Python 3.6.6
$ pip list --format=columns | grep -i scikit-learn
scikit-learn 0.19.2

下準備

まずは今回使う Python のパッケージをインストールしておこう。

$ pip install scikit-learn numpy scipy

本当は怖い KFold CV

セクションのタイトルはちょっと煽り気味になっちゃったけど、実際のところ知っていないと怖い。 例えば scikit-learn が実装している KFold は、データの分割になかなか大きな落とし穴がある。

次のサンプルコードでは sklearn.model_selection.KFold を使ってデータセットを分割している。 問題を単純化するために、データセットには 4 つの要素しか入っていない。 そして、目的変数に相当する変数 y には 01 が 2 つずつ入っている。 このデータセットを 2 つに分割 (2-Fold) したとき、どのような結果が得られるだろうか?

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from sklearn.model_selection import KFold


def main():
    # 目的変数のつもり
    y = np.array([0, 0, 1, 1])
    # 説明変数のつもり
    X = np.arange(len(y))

    # デフォルトでは先頭からの並びで分割される
    # 目的変数の並びに規則性があると確実に偏りが生じる
    kf = KFold(n_splits=2)
    for train_index, test_index in kf.split(X, y):
        # どう分割されたか確認する
        print('TRAIN:', y[train_index], 'TEST:', y[test_index])


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 学習用とテスト用のサブセットで、目的変数が偏ってしまっている。

$ python kfold.py                                 
TRAIN: [1 1] TEST: [0 0]
TRAIN: [0 0] TEST: [1 1]

仮に、上記のような偏ったデータを機械学習モデルに学習させて評価させた場合を考えてみよう。 最初の試行では学習データの目的変数に 1 しかないので 0 のパターンをモデルは覚えることができない。 そして、覚えていない 0 だけのデータでモデルが評価されることになる。 もちろん、次の試行でも同様に学習データと検証用データが偏ることになる。 これでは正しくモデルを学習させて評価することはできない。

どうしてこんなことが起こるかというと、デフォルトで K-Fold はデータの並び順にもとづいて分割するため。 例えば、先ほどの例と同じようにデータの並び順に規則性のある Iris データセットを使って検証してみよう。

>>> from sklearn import datasets
>>> dataset = datasets.load_iris()
>>> dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

上記のように Iris データセットはあやめの各品種 (目的変数) ごとに規則性を持ってデータが並んでいる。

このように並び順に規則性を持ったデータセットを scikit-learn の KFold のデフォルトパラメータで分割してみよう。

>>> from sklearn.model_selection import KFold
>>> kf = KFold(n_splits=2)
>>> ite = kf.split(dataset.data, dataset.target)
>>> train_index, test_index = next(ite)
>>> dataset.target[train_index]
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2])
>>> dataset.target[test_index]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1])

データが順番通りに真っ二つにされていることが上記から確認できる。

無作為抽出法を用いたK-分割交差検証

上記のようなデータの偏りを減らす方法として無作為抽出法 (Random Sampling) を使うやり方がある。 これは、順番に依存せず無作為にデータを選んでサブセットを作るというもの。

例えば scikit-learn の KFold であれば、オプションに shuffle=True を渡すと無作為抽出になる。 次のサンプルコードで試してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from sklearn.model_selection import KFold


def main():
    # 目的変数のつもり
    y = np.array([0, 0, 1, 1])
    # 説明変数のつもり
    X = np.arange(len(y))

    # 無作為抽出法を使って分割する (実行結果は試行によって異なる)
    kf = KFold(n_splits=2, shuffle=True)
    for train_index, test_index in kf.split(X, y):
        # どう分割されたか確認する
        print('TRAIN:', y[train_index], 'TEST:', y[test_index])


if __name__ == '__main__':
    main()

上記を実行してみよう。 試行にもよるけど、ちゃんとデータが偏らずに分割されるパターンもある。

$ python rndkfold.py
TRAIN: [0 1] TEST: [0 1]
TRAIN: [0 1] TEST: [0 1]
$ python rndkfold.py
TRAIN: [0 0] TEST: [1 1]
TRAIN: [1 1] TEST: [0 0]

層化抽出法を用いたK-分割交差検証

先ほどの無作為抽出法では試行によってはサブセットに偏りができる場合もあった。 もちろん、データセットが大きければ大きいほど大数の法則に従って偏りはできにくくなる。 とはいえゼロではないので、そこで登場するのが今回紹介する層化抽出法 (Stratified Sampling) を用いる方法となる。

層化抽出法を使うと、サブセットを作るときに目的変数の比率がなるべく元のままになるように分割できる。 次のサンプルコードでは、実装に StratifiedKFold を使うことで層化抽出法を使った分割を実現している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from sklearn.model_selection import StratifiedKFold


def main():
    # 目的変数のつもり
    y = np.array([0, 0, 1, 1])
    # 説明変数のつもり
    X = np.arange(len(y))

    # 層化抽出法を使って分割する
    kf = StratifiedKFold(n_splits=2, shuffle=True)
    for train_index, test_index in kf.split(X, y):
        # どう分割されたか確認する
        print('TRAIN:', y[train_index], 'TEST:', y[test_index])


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python skfold.py  
TRAIN: [0 1] TEST: [0 1]
TRAIN: [0 1] TEST: [0 1]
$ python skfold.py
TRAIN: [0 1] TEST: [0 1]
TRAIN: [0 1] TEST: [0 1]

何度実行しても偏りができないように分割されることが分かる。

試しに Iris データセットを使ったパターンも確認しておこう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from sklearn.model_selection import StratifiedKFold
from sklearn import datasets


def main():
    # Iris データセットを読み込む
    dataset = datasets.load_iris()
    X, y = dataset.data, dataset.target
    
    # 層化抽出法を使って分割する
    kf = StratifiedKFold(n_splits=2, shuffle=True)
    for train_index, test_index in kf.split(X, y):
        # どう分割されたか確認する
        print('TRAIN:', y[train_index], 'TEST:', y[test_index])


if __name__ == '__main__':
    main()

上記を実行すると分割した結果に偏りがないことが分かる。

$ python skfoldiris.py
TRAIN: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2] TEST: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2]
TRAIN: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2] TEST: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2]

めでたしめでたし。