CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM でカスタムメトリックを扱う

今回は LightGBM で、組み込みで用意されていない独自の評価指標 (カスタムメトリック) を扱う方法について。 ユースケースとしては、学習自体は別の評価指標を使って進めつつ、本来の目標としている評価指標を同時に確認するといったもの。 例えば、精度 (Accuracy) やマシューズ相関係数 (Matthews Correlation Coefficient) は、学習にそのまま用いることは難しい。 しかしながら、最終的な目標としている評価指標がそれらになっていることはよくある。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.4
BuildVersion:   18E226
$ python -V                  
Python 3.7.3
$ pip list | grep -i lightgbm
lightgbm        2.2.3  

下準備

$ brew install libomp
$ pip install lightgbm scikit-learn matplotlib

独自の評価指標を用いる

以下のサンプルコードでは、学習には LogLoss を使いつつ、同時に Accuracy を計算している。 独自の評価指標を計算するときは train() 関数や cv() 関数で feval というオプションを用いる。 指定するのは評価指標を計算する関数で、引数はモデルが予測した値と学習に使ったデータの二つ。 データからは get_label() という関数で真のラベルを取得できる。 注意点として、モデルが予測した値は多値分類問題であっても一次元の配列になっているため reshape する必要がある。 評価指標を計算する関数では、返り値として評価指標の名前、スコア、そしてスコアが大きい方が優れているのか否かを表す真偽値を返す。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import lightgbm as lgb
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split


def accuracy(preds, data):
    """精度 (Accuracy) を計算する関数"""
    # 正解ラベル
    y_true = data.get_label()
    # 推論の結果が 1 次元の配列になっているので直す
    N_LABELS = 3  # ラベルの数
    reshaped_preds = preds.reshape(N_LABELS, len(preds) // N_LABELS)
    # 最尤と判断したクラスを選ぶ 
    y_pred = np.argmax(reshaped_preds, axis=0)
    # メトリックを計算する
    acc = np.mean(y_true == y_pred)
    # name, result, is_higher_better
    return 'accuracy', acc, True


def main():
    # Iris データセットを読み込む
    iris = datasets.load_iris()
    X, y = iris.data, iris.target

    # 学習用データと検証用データに分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # LightGBM のデータセットの表現に直す
    lgb_train = lgb.Dataset(X_train, y_train)
    lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

    lgbm_params = {
        'objective': 'multiclass',
        'num_class': 3,
    }

    evals_result = {}
    lgb.train(lgbm_params,
              lgb_train,
              # メトリックを追跡する対象のデータセット
              valid_sets=[lgb_eval, lgb_train],
              # 上記の名前
              valid_names=['eval', 'train'],
              num_boost_round=1000,
              # メトリックの履歴を残すオブジェクト
              evals_result=evals_result,
              # 独自メトリックを計算する関数
              feval=accuracy,
              )

    # 組み込みのメトリック
    eval_metric_logloss = evals_result['eval']['multi_logloss']
    train_metric_logloss = evals_result['train']['multi_logloss']

    # カスタムメトリック
    eval_metric_acc = evals_result['eval']['accuracy']
    train_metric_acc = evals_result['train']['accuracy']

    # グラフにプロットする
    _, ax1 = plt.subplots(figsize=(8, 4))
    ax1.plot(eval_metric_logloss, label='eval logloss', c='r')
    ax1.plot(train_metric_logloss, label='train logloss', c='b')
    ax1.set_ylabel('logloss')
    ax1.set_xlabel('rounds')
    ax1.legend()

    ax2 = ax1.twinx()
    ax2.plot(eval_metric_acc, label='eval accuracy', c='g')
    ax2.plot(train_metric_acc, label='train accuracy', c='y')
    ax2.set_ylabel('accuracy')
    ax2.legend()

    plt.grid()
    plt.show()


if __name__ == '__main__':
    main()

サンプルコードでは、評価指標の推移を折れ線グラフとしてプロットしている。

上記を実行してみよう。 ログを確認すると、ちゃんと Accuracy に関する情報も出力される。

$ python lgbcm.py
...
[1000] train's multi_logloss: 8.23321e-05    train's accuracy: 1 eval's multi_logloss: 0.396726  eval's accuracy: 0.947368

上記を実行すると、以下のようなグラフが得られる。

f:id:momijiame:20190331004325p:plain

上記のグラフを見ると、学習に使った LogLoss に加えて Accuracy の推移も確認できる。 どうやら、イテレーション数が 100 を越えないあたりから検証データに対する LogLoss が増加に転じているようだ。 検証データに対する LogLoss は増加することなく現象し続けており、過学習を起こしていることがわかる。 検証データの LogLoss が増加するタイミングでは、学習データに対する Accuracy が増加すると共に検証データの Accuracy も低下している。

注意点

今回のように複数の評価指標を LightGBM に計算させるときは early_stopping_rounds との併用において注意が必要になる。 というのも early_stopping_rounds で early stopping の対象になるのが、いずれかの評価指標で条件を満たした場合になっているため。 つまり、本来は LogLoss の値が増加に転じた場合に止めたいのに、まだ現象を続けている段階で別の評価指標が条件を満たすと止まってしまう恐れがある。