CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM で学習済みモデルを自動で永続化するコールバックを書いてみた

ニューラルネットワークを実装するためのフレームワークの Keras は LightGBM と似たようなコールバックの機構を備えている。 そして、いくつか標準で用意されているコールバックがある。

keras.io

そんな中に ModelCheckpoint というコールバックがあって、これが意外と便利そうだった。 このコールバックは良いスコアを記録したモデルを自動的にディスクに永続化するためのもの。 そこで、今回は似たような機能のコールバックを LightGBM に移植してみることにした。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G95
$ python -V
Python 3.7.4

下準備

まずは今回使うパッケージをインストールしておく。

$ pip install lightgbm scikit-learn

自動でモデルを永続化してくれるコールバック

自動でモデルを保存してくれるコールバックを実装したサンプルコードが次の通り。 コールバックは ModelCheckpointCallback という名前のクラスとして実装している。 使い方はコメントを読むことで大体わかると思う。 チェックするメトリックのスコアが過去に記録されたものを上回ったときに、そのモデルをディスクに書き出す。 サンプルコードではコールバックによって保存された学習済みモデルを復元して、ホールドアウトしておいたデータを予測させている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle

import numpy as np
import lightgbm as lgb
from sklearn.metrics import accuracy_score
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold


class ModelCheckpointCallback(object):
    """モデルをディスクに永続化するためのコールバック"""

    def __init__(self, save_file, monitor_metric, pickle_protocol=2):
        # モデルを保存するファイルパス
        self._save_file = save_file
        # スコアを確認するメトリックの名前
        self._monitor_metric = monitor_metric
        # 永続化に使う pickle のプロトコル (デフォルトでは互換性優先)
        self._pickle_protocol = pickle_protocol
        self._best_score = None

    def _is_higher_score(self, metric_score, is_higher_better):
        if self._best_score is None:
            # 過去にスコアが記録されていなければ問答無用でベスト
            return True

        if is_higher_better:
            return metric_score < metric_score
        else:
            return metric_score > metric_score

    def _save_model(self, model):
        if isinstance(self._save_file, str):
            # 文字列ならファイルパスと仮定する
            with open(self._save_file, mode='wb') as fp:
                pickle.dump(model, fp,
                            protocol=self._pickle_protocol)
        else:
            # それ以外は File-like object と仮定する
            pickle.dump(model, self._save_file,
                        protocol=self._pickle_protocol)

    def __call__(self, env):
        evals = env.evaluation_result_list
        for _, name, score, is_higher_better, _ in evals:
            # チェックするメトリックを選別する
            if name != self._monitor_metric:
                continue
            # 対象のメトリックが見つかっても過去のスコアよりも性能が悪ければ何もしない
            if not self._is_higher_score(score, is_higher_better):
                return
            # ベストスコアならモデルを永続化する
            self._save_model(env.model)
            return
        # メトリックが見つからなかった
        raise ValueError('monitoring metric not found')


def accuracy(preds, data):
    """精度 (Accuracy) を計算する関数"""
    y_true = data.get_label()
    y_pred = np.where(preds > 0.5, 1, 0)
    acc = accuracy_score(y_true, y_pred)
    # name, result, is_higher_better
    return 'accuracy', acc, True


def main():
    # Iris データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    # デモ用にデータセットを分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # LightGBM 用のデータセット表現に直す
    lgb_train = lgb.Dataset(X_train, y_train)

    # XXX: lightgbm.engine._CVBooster を pickle で永続化できるようにする
    #      lightgbm.train() で学習するときは不要な処理
    def __getstate__(self):
        return self.__dict__.copy()
    setattr(lgb.engine._CVBooster, '__getstate__', __getstate__)
    def __setstate__(self, state):
        self.__dict__.update(state)
    setattr(lgb.engine._CVBooster, '__setstate__', __setstate__)

    # 学習済みモデルを取り出すためのコールバックを用意する
    model_filename = 'lgb-cvbooster-model.pkl'
    checkpoint_cb = ModelCheckpointCallback(save_file=model_filename,
                                            monitor_metric='binary_logloss')
    callbacks = [
        checkpoint_cb,
    ]

    # データセットを 5-Fold CV で学習する
    lgbm_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
    }
    folds = StratifiedKFold(n_splits=5,
                            shuffle=True,
                            random_state=42)
    cv_result = lgb.cv(lgbm_params,
                       lgb_train,
                       num_boost_round=1000,
                       early_stopping_rounds=10,
                       folds=folds,
                       seed=42,
                       feval=accuracy,
                       callbacks=callbacks,
                       verbose_eval=10,
                       )

    # CV の結果を出力する
    print('eval accuracy:', cv_result['accuracy-mean'][-1])

    # 学習が終わったらモデルはディスクに永続化されている
    with open(model_filename, mode='rb') as fp:
        restored_model = pickle.load(fp)

    # 復元したモデルを使って Hold-out したデータを推論する
    y_pred_proba_list = restored_model.predict(X_test,
                                               restored_model.best_iteration)
    y_pred_probas = np.mean(y_pred_proba_list, axis=0)
    y_pred = np.where(y_pred_probas > 0.5, 1, 0)
    acc = accuracy_score(y_test, y_pred)
    print('test accuracy:', acc)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python lgbcpcb.py
...(snip)...
eval accuracy: 0.9508305647840531
test accuracy: 0.958041958041958

ちゃんとテスト用にホールドアウトしておいたデータを予測できていることがわかる。

なお、ファイルシステムを確認するとカレントディレクトリにモデルが直列化されたファイルがあるはず。

$ file lgb-cvbooster-model.pkl
lgb-cvbooster-model.pkl: data

めでたしめでたし。