CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM の cv() 関数から取得した学習済みモデルを SerDe する

今回は、前回のエントリを書くきっかけになったネタについて。

blog.amedama.jp

上記は今回扱う LightGBM の cv() 関数から取得した _CVBooster のインスタンスで起きた問題だった。 このインスタンスは、そのままでは pickle で直列化・非直列化 (SerDe) できずエラーになってしまう。

ちなみに LightGBM の cv() 関数から学習済みモデルを取得する件については以下のエントリに書いてある。

blog.amedama.jp

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V            
Python 3.7.4

下準備

準備として LightGBM と Scikit-learn をインストールしておく。

$ pip install lightgbm scikit-learn

問題が生じるコード

まずは件の問題が生じるコードから。 以下のサンプルコードでは、取得した _CVBooster のインスタンスを pickle で直列化しようとしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle

import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split


class ModelExtractionCallback(object):
    """lightgbm.cv() 関数からモデルを取り出すコールバック"""

    def __init__(self):
        self._model = None

    def __call__(self, env):
        self._model = env.model

    def _assert_called_cb(self):
        if self._model is None:
            raise RuntimeError('callback has not called yet')

    @property
    def boosters_proxy(self):
        self._assert_called_cb()
        return self._model

    @property
    def raw_boosters(self):
        self._assert_called_cb()
        return self._model.boosters

    @property
    def best_iteration(self):
        self._assert_called_cb()
        return self._model.best_iteration


def main():
    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    # デモ用にデータセットを分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        test_size=0.2,
                                                        random_state=42)


    # LightGBM のデータセット表現にラップする
    lgb_train = lgb.Dataset(X_train, y_train)

    # モデルを学習する
    extraction_cb = ModelExtractionCallback()
    callbacks = [
        extraction_cb,
    ]
    lgb_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    result = lgb.cv(lgb_params,
                    lgb_train,
                    num_boost_round=1000,
                    early_stopping_rounds=10,
                    folds=skf,
                    seed=42,
                    callbacks=callbacks,
                    verbose_eval=10)

    print('cv logloss:', result['binary_logloss-mean'][-1])

    # モデルを取り出す
    proxy = extraction_cb.boosters_proxy

    # モデルを SerDe する
    serialized_model = pickle.dumps(proxy)
    restored_model = pickle.loads(serialized_model)

    # Deserialize したオブジェクト
    print(restored_model)

    # Hold-out しておいたデータを予測させる
    y_pred_probas = restored_model.predict(X_test)
    y_pred_proba = np.array(y_pred_probas).mean(axis=0)
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    # Accuracy について評価する
    acc = accuracy_score(y_test, y_pred)
    print('test accuracy:', acc)


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、次のように例外になる。

$ python lgbcvbserde.py
...
cv logloss: 0.12616399920831986
Traceback (most recent call last):
  File "lgbcvbserde.py", line 99, in <module>
    main()
  File "lgbcvbserde.py", line 84, in main
    restored_model = pickle.loads(serialized_model)
  File "/Users/amedama/.virtualenvs/py37/lib/python3.7/site-packages/lightgbm/engine.py", line 262, in handler_function
    for booster in self.boosters:
TypeError: 'function' object is not iterable

これは、先のエントリに記述した通り以下の条件が重なることで生じている。

  • ラッパーとなる _CVBooster__getattr__() が実装されており __getstate__()__setstate() をトラップする
  • ラップされるオブジェクトに __getstate__()__setstate__() が実装されておりラッパー経由で呼ばれている

問題を修正するコード

問題の修正方法は先のエントリに記述した通り。 ラッパーとして動作するオブジェクト、今回であれば _CVBooster のインスタンスに __getstate__()__setstate__() が必要になる。 ただし、_CVBooster は LightGBM のパッケージなので直接ソースコードを修正することは望ましくない。 そのためモンキーパッチを駆使して解決する。

以下のサンプルコードではクラスに動的にメソッドを追加することで問題を修正している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle

import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split


class ModelExtractionCallback(object):
    """lightgbm.cv() 関数からモデルを取り出すコールバック"""

    def __init__(self):
        self._model = None

    def __call__(self, env):
        self._model = env.model

    def _assert_called_cb(self):
        if self._model is None:
            raise RuntimeError('callback has not called yet')

    @property
    def boosters_proxy(self):
        self._assert_called_cb()
        return self._model

    @property
    def raw_boosters(self):
        self._assert_called_cb()
        return self._model.boosters

    @property
    def best_iteration(self):
        self._assert_called_cb()
        return self._model.best_iteration


def main():
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        test_size=0.2,
                                                        random_state=42)


    lgb_train = lgb.Dataset(X_train, y_train)

    extraction_cb = ModelExtractionCallback()
    callbacks = [
        extraction_cb,
    ]
    lgb_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    result = lgb.cv(lgb_params,
                    lgb_train,
                    num_boost_round=1000,
                    early_stopping_rounds=10,
                    folds=skf,
                    seed=42,
                    callbacks=callbacks,
                    verbose_eval=10)

    print('cv logloss:', result['binary_logloss-mean'][-1])

    proxy = extraction_cb.boosters_proxy

    # lightgbm.engine._CVBooster のクラスに
    # __getstate__() と __setstate__() を動的に追加する
    def __getstate__(self):
        return self.__dict__.copy()
    setattr(lgb.engine._CVBooster, '__getstate__', __getstate__)

    def __setstate__(self, state):
        self.__dict__.update(state)
    setattr(lgb.engine._CVBooster, '__setstate__', __setstate__)

    serialized_model = pickle.dumps(proxy)
    restored_model = pickle.loads(serialized_model)

    print(restored_model)

    y_pred_probas = restored_model.predict(X_test)
    y_pred_proba = np.array(y_pred_probas).mean(axis=0)
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    acc = accuracy_score(y_test, y_pred)
    print('test accuracy:', acc)


if __name__ == '__main__':
    main()

上記を実行してみよう。 SerDe の部分は全く修正していないけど、今度は例外にならず実行できている。

$ python lgbcvbserde.py
...
cv logloss: 0.12616399920831986
<lightgbm.engine._CVBooster object at 0x114704090>
test accuracy: 0.9736842105263158

ちなみに、上記のように _CVBooster ごと直列化しようとするから今回のような問題になるのであって、中身の Booster を格納したリストを直列化するという選択肢もある。