ニューラルネットワークを実装するためのフレームワークの Keras は LightGBM と似たようなコールバックの機構を備えている。 そして、いくつか標準で用意されているコールバックがある。
そんな中に ModelCheckpoint
というコールバックがあって、これが意外と便利そうだった。
このコールバックは良いスコアを記録したモデルを自動的にディスクに永続化するためのもの。
そこで、今回は似たような機能のコールバックを LightGBM に移植してみることにした。
使った環境は次の通り。
$ sw_vers ProductName: Mac OS X ProductVersion: 10.14.6 BuildVersion: 18G95 $ python -V Python 3.7.4
下準備
まずは今回使うパッケージをインストールしておく。
$ pip install lightgbm scikit-learn
自動でモデルを永続化してくれるコールバック
自動でモデルを保存してくれるコールバックを実装したサンプルコードが次の通り。
コールバックは ModelCheckpointCallback
という名前のクラスとして実装している。
使い方はコメントを読むことで大体わかると思う。
チェックするメトリックのスコアが過去に記録されたものを上回ったときに、そのモデルをディスクに書き出す。
サンプルコードではコールバックによって保存された学習済みモデルを復元して、ホールドアウトしておいたデータを予測させている。
#!/usr/bin/env python # -*- coding: utf-8 -*- import pickle import numpy as np import lightgbm as lgb from sklearn.metrics import accuracy_score from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.model_selection import StratifiedKFold class ModelCheckpointCallback(object): """モデルをディスクに永続化するためのコールバック""" def __init__(self, save_file, monitor_metric, pickle_protocol=2): # モデルを保存するファイルパス self._save_file = save_file # スコアを確認するメトリックの名前 self._monitor_metric = monitor_metric # 永続化に使う pickle のプロトコル (デフォルトでは互換性優先) self._pickle_protocol = pickle_protocol self._best_score = None def _is_higher_score(self, metric_score, is_higher_better): if self._best_score is None: # 過去にスコアが記録されていなければ問答無用でベスト return True if is_higher_better: return metric_score < metric_score else: return metric_score > metric_score def _save_model(self, model): if isinstance(self._save_file, str): # 文字列ならファイルパスと仮定する with open(self._save_file, mode='wb') as fp: pickle.dump(model, fp, protocol=self._pickle_protocol) else: # それ以外は File-like object と仮定する pickle.dump(model, self._save_file, protocol=self._pickle_protocol) def __call__(self, env): evals = env.evaluation_result_list for _, name, score, is_higher_better, _ in evals: # チェックするメトリックを選別する if name != self._monitor_metric: continue # 対象のメトリックが見つかっても過去のスコアよりも性能が悪ければ何もしない if not self._is_higher_score(score, is_higher_better): return # ベストスコアならモデルを永続化する self._save_model(env.model) return # メトリックが見つからなかった raise ValueError('monitoring metric not found') def accuracy(preds, data): """精度 (Accuracy) を計算する関数""" y_true = data.get_label() y_pred = np.where(preds > 0.5, 1, 0) acc = accuracy_score(y_true, y_pred) # name, result, is_higher_better return 'accuracy', acc, True def main(): # Iris データセットを読み込む dataset = datasets.load_breast_cancer() X, y = dataset.data, dataset.target # デモ用にデータセットを分割する X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=42) # LightGBM 用のデータセット表現に直す lgb_train = lgb.Dataset(X_train, y_train) # XXX: lightgbm.engine._CVBooster を pickle で永続化できるようにする # lightgbm.train() で学習するときは不要な処理 def __getstate__(self): return self.__dict__.copy() setattr(lgb.engine._CVBooster, '__getstate__', __getstate__) def __setstate__(self, state): self.__dict__.update(state) setattr(lgb.engine._CVBooster, '__setstate__', __setstate__) # 学習済みモデルを取り出すためのコールバックを用意する model_filename = 'lgb-cvbooster-model.pkl' checkpoint_cb = ModelCheckpointCallback(save_file=model_filename, monitor_metric='binary_logloss') callbacks = [ checkpoint_cb, ] # データセットを 5-Fold CV で学習する lgbm_params = { 'objective': 'binary', 'metric': 'binary_logloss', } folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) cv_result = lgb.cv(lgbm_params, lgb_train, num_boost_round=1000, early_stopping_rounds=10, folds=folds, seed=42, feval=accuracy, callbacks=callbacks, verbose_eval=10, ) # CV の結果を出力する print('eval accuracy:', cv_result['accuracy-mean'][-1]) # 学習が終わったらモデルはディスクに永続化されている with open(model_filename, mode='rb') as fp: restored_model = pickle.load(fp) # 復元したモデルを使って Hold-out したデータを推論する y_pred_proba_list = restored_model.predict(X_test, restored_model.best_iteration) y_pred_probas = np.mean(y_pred_proba_list, axis=0) y_pred = np.where(y_pred_probas > 0.5, 1, 0) acc = accuracy_score(y_test, y_pred) print('test accuracy:', acc) if __name__ == '__main__': main()
上記を実行してみよう。
$ python lgbcpcb.py ...(snip)... eval accuracy: 0.9508305647840531 test accuracy: 0.958041958041958
ちゃんとテスト用にホールドアウトしておいたデータを予測できていることがわかる。
なお、ファイルシステムを確認するとカレントディレクトリにモデルが直列化されたファイルがあるはず。
$ file lgb-cvbooster-model.pkl lgb-cvbooster-model.pkl: data
めでたしめでたし。
スマートPythonプログラミング: Pythonのより良い書き方を学ぶ
- 作者: もみじあめ
- 発売日: 2016/03/12
- メディア: Kindle版
- この商品を含むブログ (1件) を見る