CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM の cv() 関数の実装について

今回は LightGBM の cv() 関数について書いてみる。 LightGBM の cv() 関数は、一般的にはモデルの性能を評価する交差検証に使われる。 一方で、この関数から取り出した学習済みモデルを推論にまで使うユーザもいる。 今回は、その理由やメリットとデメリットについて書いてみる。

cv() 関数から取り出した学習済みモデルを使う理由とメリット・デメリットについて

一部のユーザの間では有名だけど、LightGBM の cv() 関数は各 Fold の決定木の増やし方に特色がある。 まず、LightGBM では決定木の集まりを Booster というオブジェクトで管理している。 Booster が内包する決定木の本数は、ラウンド (イテレーション) 数として認識できる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1930

ちなみに、train() 関数を使って得られるのは、この Booster というオブジェクト。 一般的に、train() 関数を使って自前で交差検証をするときは、この Booster を Fold 毎にひとつずつ学習させることになる。

一方で、cv() 関数では全 Fold を並列で、複数の Booster を一度に学習させる。 具体的には、すべての Fold で歩調を合わせながら、それぞれの Booster のラウンド数をひとつずつ増やしている。 このとき、検証用のデータに対するメトリックも、ラウンド (イテレーション) 毎に「全 Fold の平均」で計算される。 つまり、全 Fold の平均的なメトリックが悪化するタイミングで Early Stopping がかかる。

言いかえると、cv() 関数から得られる学習済みモデルは Booster が内包する決定木の本数がすべて同じに揃う。 それに対して、train() 関数を使ってひとつずつ Booster を学習する方法では、ラウンド数が Fold によってバラつくことになる。 バラつきが小さいときは良いけど、ときには大きく偏ることもあって、その際は性能の見積もりや推論に悪影響があると考えられる。 この点から、cv() 関数では Fold 毎の偏りを考慮した、ようするに無難なモデルを得られることが期待できる。 なお、各 Fold から複数の Booster が得られるので、推論するときは Averaging などで対応する。

また、ターゲットの情報を使った特徴量抽出やスタッキングをするときも、この点は都合が良い。 これらのユースケースでは、一般的にはリークを防ぐために Out-of-Fold で処理することになる。 となると、データの全体を使って学習することが難しいので、各 Fold ごとに学習したモデルを使えると使い勝手が良い。

と、ここまでメリットばかり説明してきたけど、もちろんデメリットもある。 前述したとおり、cv() 関数では各 Fold の Booster を同時に並列で学習させていく。 そのため、学習に使うデータやモデルを一度にメモリに載せることになる。 つまり、train() 関数を使って Booster をひとつずつ学習するときよりも、相対的にメモリの制約は厳しくなると考えられる。 また、他の Fold を使って補える部分もあるとはいえ Out-of-Fold したデータは学習に使えない点もデメリットとして挙げられる。

cv() 関数の実装について

ここからは LightGBM のコードを軽く追いかけてみよう。

はじめに、LightGBM のコアといえる部分は C++ で書かれている。 Python では、それを ctypes モジュールを使った Binding として呼び出している。

自身の Python 実行環境で LightGBM のインストール先パスがわかっているときは LightGBM の共有ライブラリを探してみると良い。 上記でいうコアは Python 実行環境の中に「lib_lightgbm.so」として存在している。

$ python -c "import site; print (site.getsitepackages())"
['/Users/amedama/.virtualenvs/py38/lib/python3.8/site-packages']
$ file  ~/.virtualenvs/py38/lib/python3.8/site-packages/lightgbm/lib_lightgbm.so
/Users/amedama/.virtualenvs/py38/lib/python3.8/site-packages/lightgbm/lib_lightgbm.so: Mach-O 64-bit dynamically linked shared library x86_64

上記の共有ライブラリは Python Binding の Booster クラスから呼ばれている。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1930

ctypes モジュールで読み込んだライブラリを _LIB として呼び出している部分がそれ。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1988,L1991

そして、train() 関数や cv() 関数は、上記の Booster を学習させるためのラッパーになっている。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L18

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L394

cv() 関数に着目して読んでいくと、以下で全 Fold の Booster を同時に更新していることがわかる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L592

また、Early Stopping は全 Fold の平均的なメトリックを元に発火することが確認できる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L593,L609

いじょう。