CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM の学習曲線をコールバックで動的にプロットする

LightGBM の学習が進む様子は、学習させるときにオプションとして verbose_eval などを指定することでコンソールから確認できる。 ただ、もっと視覚的にリアルタイムで確認したいなーと思ったので、今回はコールバックと Matplotlib を使って学習曲線を動的にグラフとしてプロットしてみることにした。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V            
Python 3.7.3

下準備

下準備として LightGBM と Matplotlib をインストールしておく。 Seaborn は本来は必要ないんだけどデータセットの読み込みにだけ使っている。

$ pip install lightgbm matplotlib seaborn

学習曲線を動的にプロットする

今回書いてみたサンプルコードは次の通り。 Seaborn から Titanic データセットを読み込んで LightGBM のモデルが学習していく過程を可視化している。 グラフのプロットは LearningVisualizationCallback というコールバックを実装することで実現している。 そのままだとグラフが寂しいので、カスタムメトリックとして Accuracy も追加してみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import defaultdict

import numpy as np
import lightgbm as lgb
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from matplotlib import pyplot as plt


class LearningVisualizationCallback(object):
    """学習の過程を動的にプロットするコールバック"""

    def __init__(self, fig=None, ax=None):
        self._metrics = defaultdict(list)
        self._lines = {}

        # 初期化する
        self._fig = fig
        self._ax = ax
        if self._fig is None and self._ax is None:
            self._fig, self._ax = plt.subplots()
        self._fig.canvas.draw()
        self._fig.show()

    def __call__(self, env):
        # メトリックを保存する
        evals = env.evaluation_result_list
        for _, name, mean, _, _ in evals:
            self._metrics[name].append(mean)

        # 可視化する
        for name, values in self._metrics.items():

            # 初回だけ描画用オブジェクトを取得して保存しておく
            if name not in self._lines:
                line, = self._ax.plot(np.arange(len(values)),
                                      values)
                self._lines[name] = line
                line.set_label(name)

            # グラフデータを更新する
            line = self._lines[name]
            line.set_data(np.arange(len(values)), values)

        # グラフの見栄えを調整する
        self._ax.grid()
        self._ax.legend()
        self._ax.relim()
        self._ax.autoscale_view()

        # 再描画する
        self._fig.canvas.draw()
        self._fig.canvas.flush_events()

    def show_until_close(self):
        """ウィンドウを閉じるまで表示し続ける"""
        plt.show()


def accuracy(preds, data):
    """精度 (Accuracy) を計算する関数
    NOTE: 表示が eval set の LogLoss だけだと寂しいので"""
    y_true = data.get_label()
    y_pred = np.where(preds > 0.5, 1, 0)
    acc = np.mean(y_true == y_pred)
    return 'accuracy', acc, True


def main():
    # Titanic データセットを読み込む
    dataset = sns.load_dataset('titanic')

    # 重複など不要な特徴量は落とす
    X = dataset.drop(['survived',
                      'class',
                      'who',
                      'embark_town',
                      'alive'], axis=1)
    y = dataset.survived

    # カテゴリカル変数を指定する
    categorical_columns = ['pclass',
                           'sex',
                           'embarked',
                           'adult_male',
                           'deck',
                           'alone']
    X = X.astype({c: 'category'
                  for c in categorical_columns})

    # LightGBM のデータセット表現に直す
    lgb_train = lgb.Dataset(X, y)

    # 学習の過程を可視化するコールバックを用意する
    visualize_cb = LearningVisualizationCallback()
    callbacks = [
        visualize_cb,
    ]

    # 二値分類を LogLoss で評価する
    lgb_params = {
        'objective': 'binary',
        'metrics': 'binary_logloss',
    }
    # 5-Fold CV
    skf = StratifiedKFold(n_splits=5,
                          shuffle=True,
                          random_state=42)
    lgb.cv(lgb_params, lgb_train,
           num_boost_round=1000,
           early_stopping_rounds=100,
           verbose_eval=10,
           folds=skf, seed=42,
           # Accuracy も確認する
           feval=accuracy,
           # コールバックを登録する
           callbacks=callbacks)

    # ウィンドウを閉じるまで表示し続ける
    visualize_cb.show_until_close()


if __name__ == '__main__':
    main()

上記に適当な名前をつけて実行してみよう。

$ python lgblearnviz.py

すると、モデルの学習に伴って次のようなアニメーションが表示される。

f:id:momijiame:20190606223725g:plain

いいかんじ。

なお、表示されているのは Validation Set に対するメトリックとなる。 Training Set も確認したかったんだけど、どうやら次のリリース (2.2.4?) でオプションに eval_train_metric が入るのを待つ必要がありそう。

あと、Jupyter Notebook で使うときは %matplotlib notebook マジックコマンドを使うと良い。