CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: MLflow Tracking を使ってみる

MLflow は MLOps に関連した OSS のひとつ。 いくつかのコンポーネントに分かれていて、それぞれを必要に応じて独立して使うことができる。 今回は、その中でも実験の管理と可視化を司る MLflow Tracking を試してみることにした。

機械学習のプロジェクトでは試行錯誤することが多い。 その際には、パラメータやモデルの構成などを変えながら何度も実験を繰り返すことになる。 すると、回数が増えるごとに使ったパラメータや得られた結果、モデルなどの管理が煩雑になってくる。 MLflow Tracking を使うことで、その煩雑さが軽減できる可能性がある。

使った環境は次のとおり。

$ sw_vers          
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G5033
$ python -V       
Python 3.7.7
$ mlflow --version
mlflow, version 1.8.0

もくじ

登場する概念について

まず、MLflow Tracking では ExperimentRun というモノを作っていく。 これらは、特定の目的を持った実験と、それに 1:N で紐付いた各試行を表している。 つまり、試行錯誤の度に Run が増えることになる。 そして、それぞれの Run には次のような情報が、またもや 1:N で紐づく。

  • Parameter

    • データの前処理やモデルの学習に使ったパラメータ
  • Metric

    • 各ラウンドの損失や、交差検証のスコアといったメトリック
  • Artifact

    • 学習済みモデルや、特徴量の重要度など実験した結果として得られる成果物
  • Tag

    • 後から実験を探しやすくしたりするためのメタな情報

要するに Experiment > Run > Parameter, Metric, ... ということ。 これらの情報は、典型的には Python のスクリプトから記録される。 そして、記録された情報は後からスクリプトや WebUI 経由で確認できる。

データを保存する場所について

具体的な使い方を紹介する前に、前述した情報がどのように管理されるのか解説しておく。 まず、MLflow Tracking はクライアントとサーバに分かれてる。 そして、データは基本的にサーバに保存される。 クライアントは、データを保存するサーバの場所を Tracking URI と Artifact URI という 2 つの URI で指定する。

ただし、クライアントとサーバは 1 台のマシンで兼ねることもできる。 また、データを記録する方法も、追加でソフトウェアなどを必要としないローカルファイルで完結させることができる。 そのため、クライアント・サーバ方式といっても 1 人で使い始める分には環境構築などの作業は全く必要ない。 あくまで、複数人のチームで記録されたデータを共用・共有したいときに専用のサーバが必要となる。

Tracking URI について

Tracking URI は、概ね Run の中で Artifact 以外の情報を記録するところ。 記録される情報は、基本的には Key-Value 形式になっている。

利用できるバックエンドには次のようなものがある。

  • ローカルファイル

    • マウントされているブロックストレージにファイルとして記録される
      • 手っ取り早く使うならこれ (デフォルト)
    • 形式: file:<path>
  • リレーショナルデータベース

    • SQLAlchemy 経由で RDB にテーブルのエントリとして記録される
    • 形式: <dialect>+<driver>://<username>:<password>@<host>:<port>/<database>
  • REST API

    • MLflow Server というサーバを立ち上げて、そこに記録される
      • 後述するが、記録先の実体はローカルファイルだったりリレーショナルデータベースだったり選べる
    • 形式: http(s)://<host>:<port>

その他、詳細は以下に記載されている。

www.mlflow.org

Artifact URI について

Artifact URI は、その名の通り Artifact を記録するところ。 記録される情報はファイル (バイト列) になる。

利用できるバックエンドには次のようなものがある。

  • ローカルファイル

    • マウントされているブロックストレージにファイルとして保存される
      • 手っ取り早く使うならこれ (デフォルト)
    • 形式: file:<path>
  • 各種クラウド (オブジェクト) ストレージ

    • Amazon S3, Azure Blob Storage, Google Cloud Storage などにファイルとして保存される
    • 形式: s3://<bucket> など
  • (S)FTP サーバ

    • (S)FTP サーバにファイルとして保存される
    • 形式: (s)ftp://<user>@<host>/<path>
  • HDFS

    • HDFS (Hadoop Distributed File System) にファイルとして保存される
    • 形式: hdfs://<path>

その他、詳細は以下に記載されている。

www.mlflow.org

色々とあるけど、結局のところ 1 人で使い始めるなら両方ともローカルファイルにすれば良い。 デフォルトでは、どちらもクライアントを実行した場所の mlruns というディレクトリが使われる。 これは、Traking URI と Artifact URI の両方に file:./mlruns を指定した状態ということ。

ちなみに、チームで使いたいけど自分で構築とか運用したくないってときは、開発の中心となっている Databricks がサーバ部分のマネージドサービスを提供している。

下準備

前置きが長くなったけど、ここからやっと実際に試していく。

はじめに、必要なパッケージをインストールしておこう。 なお、mlflow 以外は、後ほど登場するサンプルコードを動かすためだけに必要なもの。

$ pip install mlflow scikit-learn lightgbm matplotlib

インストールすると mlflow コマンドが使えるようになる。

$ mlflow --version
mlflow, version 1.8.0

基本的な使い方

とりあえず Python の REPL を使って、基本的な使い方を紹介する。 まずは REPL を起動しておく。

$ python

MLflow のパッケージをインポートする。

>>> import mlflow

実験の試行を開始する。 これには start_run() 関数を使う。 なお、デフォルトでは Experiment として Default という領域が使われる。

>>> mlflow.start_run()
<ActiveRun: >

実験に使ったパラメータを log_param() 関数で記録する。

>>> mlflow.log_param(key='foo', value='bar')

得られたメトリックなどの情報は log_metric() 関数で記録する。

>>> mlflow.log_metric(key='logloss', value=1.0)

実験には set_tag() 関数でタグが付与できる。

>>> mlflow.set_tag(key='hoge', value='fuga')

尚、これらは関数名を複数形にすると辞書型で複数の Key-Value を一度に記録できる。

アーティファクトについては、ちょっと面倒くさい。 アーティファクトはバイト列のファイルなので、まずはローカルにファイルを用意して、それをコピー (転送) することになる。 典型的には、一時ディレクトリを tempfile モジュールで作って、中身のファイルを作ったらそれをコピーすれば良い。 一時ディレクトリ以下のファイルは、処理が終わったら自動的に削除される。

>>> import tempfile
>>> import pathlib
>>> with tempfile.TemporaryDirectory() as d:
...     filename = 'test-artifact'
...     artifact_path = pathlib.Path(d) / filename
...     with open(artifact_path, 'w') as fp:
...         print('Hello, World!', file=fp)
...     mlflow.log_artifact(artifact_path)
... 

あとは実験の試行を終了するだけ。

>>> mlflow.end_run()

Python の REPL を終了して、カレントディレクトリを確認してみよう。 デフォルトの Tracking URI / Artifact URI として file:./mlruns が使われるため mlruns というディレクトリが作られている。

$ find mlruns 
mlruns
mlruns/0
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/metrics
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/metrics/logloss
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/artifacts
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/artifacts/test-artifact
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/tags
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/tags/hoge
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.user
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.source.name
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.source.type
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/params
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/params/foo
mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/meta.yaml
mlruns/0/meta.yaml
mlruns/.trash

最初の階層は、それぞれの Experiment を表している。 DefaultExperiment には ID として 0 が付与されていることがわかる。

$ cat mlruns/0/meta.yaml 
artifact_location: file:///Users/amedama/Documents/temporary/helloworld/mlruns/0
experiment_id: '0'
lifecycle_stage: active
name: Default

その下の階層は、それぞれの Run を表している。

$ head mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/meta.yaml
artifact_uri: file:///Users/amedama/Documents/temporary/helloworld/mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/artifacts
end_time: 1591267604170
entry_point_name: ''
experiment_id: '0'
lifecycle_stage: active
name: ''
run_id: 3de33a8f39294c4f8bd404a5d5bccf39
run_uuid: 3de33a8f39294c4f8bd404a5d5bccf39
source_name: ''
source_type: 4

ここにパラメータやメトリックなどが記録される。 メトリックに記録されている各列は、実行時刻、値、ステップ数を表している。 ステップ数というのは、たとえばニューラルネットワークのエポックだったり、ブースティングマシンのラウンドだったりする。 先ほど実行した log_metric() 関数の引数として、実は指定できた。

$ head mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/params/foo 
bar
$ head mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/metrics/logloss 
1591267315737 1.0 0
$ cat mlruns/0/3de33a8f39294c4f8bd404a5d5bccf39/artifacts/test-artifact 
Hello, World!

時刻には 1,000 倍した UNIX time が記録されている。

$ date -r $((1591267315737 / 1000))
202064日 木曜日 194155秒 JST

記録された情報を WebUI で確認する

今のところ「ふーん」という感じだと思うので、記録された情報を WebUI からも確認してみよう。 確認用の WebUI を立ち上げるために、mlruns ディレクトリのある場所で mlflow ui コマンドを実行する。

$ mlflow ui

そして、ブラウザで localhost:5000 を閲覧する。

$ open http://localhost:5000

すると、こんな感じで過去に記録された情報が確認できる。 メトリックで試行を並べ替えたり、ステップ毎の値を可視化する機能も備わっている。

f:id:momijiame:20200604200638p:plain
MLflow Tracking WebUI

ちなみに WebUI を使う以外にも、ちょっと面倒だけどスクリプトから確認することもできる。

$ python

たとえば、メトリックの logloss が最も良いものを絞り込んでみる。 まあ、まだ 1 回しか実行してないんだけどね。

>>> import mlflow
>>> tracking_uri = 'file:./mlruns'
>>> client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri)
>>> experiment = client.get_experiment_by_name('Default')
>>> run = client.search_runs(experiment.experiment_id,
...                          order_by=['metric.logloss asc'],
...                          max_results=1)[0]
>>> run.data.params
{'foo': 'bar'}
>>> run.data.metrics
{'logloss': 1.0}

複数人のチームで使いたいとき

複数人のチームで使うときは、いくつかの選択肢があるけど基本的には MLflow Server を用意して、そこに皆でアクセスする。 MLflow Server というのは Tracking URI として使える REST API と、先ほど確認した WebUI が一緒になったもの。

注意点として、MLflow Server はあくまで Tracking URI と WebUI のエンドポイントを提供するもの。 なので、Artifact URI の実体については別に用意しなければいけない。 たとえばクラウドストレージが使えるならそれを使っても良し、自分たちで FTP サーバを立ち上げたり HDFS のクラスタを組むことも考えられるだろう。

ここでは MLflow Server について軽く解説しておく。 まず、MLflow Server は mlflow server コマンドで起動できる。 起動するときに、Tracking URI (--backend-store-uri) と Artifact URI (--default-artifact-root) を指定する。

$ mlflow server \
    --backend-store-uri sqlite:///tracking.db \
    --default-artifact-root file:/tmp/artifacts \
    --host 0.0.0.0

--backend-store-uri は、MLflow サーバがクライアントから REST API 経由で受け取ったデータを記録する場所。 一方で、--default-artifact-root はサーバに接続してきたクライアントに「Artifact はここに保存してね」と伝えられる場所に過ぎない。 つまり、MLflow Server がプロキシしてくれるわけではないのでクライアントから接続性のある URI を指定する必要がある。 ここでは横着して /tmp 以下を指定してしまっている。 しかし、こんな風にするなら本来は NFS などでクライアントがすべて /tmp 以下に共有ディレクトリをマウントする必要がある。

起動したサーバを使って実際にデータを記録してみよう。 まずは別のターミナルで Python の REPL を起動する。

$ python

set_tracking_uri() 関数で Tracking URI として MLflow Server のエンドポイントを指定する。

>>> import mlflow
>>> tracking_uri = 'http://localhost:5000'
>>> mlflow.set_tracking_uri(tracking_uri)

すると、データの記録先が次のように設定される。

>>> mlflow.get_tracking_uri()
'http://localhost:5000'
>>> mlflow.get_artifact_uri()
'file:///tmp/artifacts/0/26d9e4204c20401eb7d2807a93be8b75/artifacts'

何か適当にデータを記録してみよう。

>>> mlflow.start_run()
>>> mlflow.log_param(key='foo', value='bar')
>>> mlflow.log_metric(key='logloss', value=0.5)
>>> import tempfile
>>> import pathlib
>>> with tempfile.TemporaryDirectory() as d:
...     filename = 'test-artifact'
...     artifact_path = pathlib.Path(d) / filename
...     with open(artifact_path, 'w') as fp:
...         print('Hello, World!', file=fp)
...     mlflow.log_artifact(artifact_path)
... 
>>> mlflow.end_run()

これで、アーティファクトについては前述したとおり /tmp 以下に保存される。

$ cat /tmp/artifacts/0/26d9e4204c20401eb7d2807a93be8b75/artifacts/test-artifact 
Hello, World!

アーティファクト以外の情報は MLflow Server の方に記録される。 今回であれば MLflow Server を実行したディレクトリにある SQLite3 のデータベースに入っている。

$ sqlite3 tracking.db 'SELECT * FROM experiments'
0|Default|file:///tmp/artifacts/0|active
$ sqlite3 tracking.db 'SELECT * FROM runs'       
26d9e4204c20401eb7d2807a93be8b75||UNKNOWN|||amedama|FINISHED|1591272053874|1591272141536||active|file:///tmp/artifacts/0/26d9e4204c20401eb7d2807a93be8b75/artifacts|0

ちなみに mlflow server コマンドの実装は mlflow.server:app にある Flask の WSGI アプリケーションを gunicorn でホストするコマンドをキックしてるだけ。 なので、自分で WSGI サーバを立ててアプリケーションをホストしても構わないだろう。 認証が必要なら前段にリバースプロキシを置いて好きにやれば良いと思う。 MLflow Server の解説については、ここまでで一旦おわり。

スクリプトに組み込んでみる

続いては、実際に MLflow Tracking を機械学習を扱うコードに組み込んでみよう。 そんなに良い例でもないけど乳がんデータセットを RandomForest で分類する過程を MLflow Tracking で記録してみる。 ここでは set_tracking_uri()set_experiment() していないのでデフォルト値 (mlruns / Default) が使われる。 ちなみに、コードに書かなくても、環境変数の MLFLOW_TRACKING_URIMLFLOW_EXPERIMENT_NAME を使う方法もある。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle
import json
import tempfile
import pathlib

import mlflow
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate


def main():
    # データセットの読み込み
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target
    feature_names = dataset.feature_names

    # 5-Fold Stratified CV でスコアを確認する
    clf = RandomForestClassifier(n_estimators=100,
                                 random_state=42,
                                 verbose=1)
    folds = StratifiedKFold(shuffle=True, random_state=42)
    # モデルの性能を示すメトリック
    result = cross_validate(clf, X, y,
                            cv=folds,
                            return_estimator=True,
                            n_jobs=-1)
    # 学習済みモデル
    estimators = result.pop('estimator')

    # 結果を記録する Run を始める
    with mlflow.start_run():

        # 実験に使った Parameter を記録する
        mlflow.log_params({
            'n_estimators': clf.n_estimators,
            'random_state': clf.random_state,
        })

        # 実験で得られた Metric を記録する
        for key, value in result.items():
            # 数値なら int/float な必要があるので平均値に直して書き込む
            mlflow.log_metric(key=key,
                              value=value.mean())

        # 実験で得られた Artifact を記録する
        for index, estimator in enumerate(estimators):

            # 一時ディレクトリの中に成果物を書き出す
            with tempfile.TemporaryDirectory() as d:

                # 学習済みモデル
                clf_filename = f'sklearn.ensemble.RandomForestClassifier.{index}'
                clf_artifact_path = pathlib.Path(d) / clf_filename
                with open(clf_artifact_path, 'wb') as fp:
                    pickle.dump(estimator, fp)

                # 特徴量の重要度
                imp_filename = f'sklearn.ensemble.RandomForestClassifier.{index}.feature_importances_.json'  # noqa
                imp_artifact_path = pathlib.Path(d) / imp_filename
                with open(imp_artifact_path, 'w') as fp:
                    importances = dict(zip(feature_names, estimator.feature_importances_))
                    json.dump(importances, fp, indent=2)

                # ディレクトリにあるファイルを Artifact として登録する
                mlflow.log_artifacts(d)


if __name__ == '__main__':
    main()

上記を実行する。

$ python bcrf.py

これで、各モデルの学習に使ったパラメータやメトリック、特徴量の重要度などが記録される。

$ cat mlruns/0/15d6ba2c32e240a382ff1efab8814e47/params/n_estimators 
100
$ cat mlruns/0/15d6ba2c32e240a382ff1efab8814e47/metrics/test_score 
1591273430658 0.9560937742586555 0
$ head mlruns/0/15d6ba2c32e240a382ff1efab8814e47/artifacts/sklearn.ensemble.RandomForestClassifier.0.feature_importances_.json 
{
  "mean radius": 0.04924964271136709,
  "mean texture": 0.01737644949893605,
  "mean perimeter": 0.07531323166620862,
  "mean area": 0.05398223096565245,
  "mean smoothness": 0.007976941962365291,
  "mean compactness": 0.01205471060787445,
  "mean concavity": 0.04868260289112485,
  "mean concave points": 0.08748636077564236,
  "mean symmetry": 0.0037492071919759205,

各種フレームワークの自動ロギング

先ほどのコードを見て分かる通り、MLflow Tracking のコードを組み込むのは結構めんどくさい。 そこで、MLflow Tracking は各種フレームワークの学習を自動で記録するインテグレーションを提供している。 ただし、この機能は今のところ Experimental な点に注意が必要。

以下のサンプルコードは LightGBM の学習を自動で記録するもの。 MLflow Tracking を動作させている部分は mlflow.lightgbm.autolog() を呼び出している一行だけ。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from mlflow import lightgbm as mlflow_lgb
import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


def main():
    # LightGBM の学習を自動でトラッキングする
    mlflow_lgb.autolog()

    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target
    feature_names = list(dataset.feature_names)

    # 訓練データと検証データに分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # データセットを生成する
    lgb_train = lgb.Dataset(X_train, y_train,
                            feature_name=feature_names)
    lgb_eval = lgb.Dataset(X_test, y_test,
                           reference=lgb_train,
                           feature_name=feature_names)

    lgbm_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
        'verbosity': -1,
    }
    # ここでは MLflow Tracking がパッチした train() 関数が呼ばれる
    booster = lgb.train(lgbm_params,
                        lgb_train,
                        valid_sets=lgb_eval,
                        num_boost_round=1000,
                        early_stopping_rounds=100,
                        verbose_eval=10,
                        )

    # 学習済みモデルを使って検証データを予測する
    y_pred_proba = booster.predict(X_test,
                                   num_iteration=booster.best_iteration)

    # 検証データのスコアを確認する
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    test_score = accuracy_score(y_test, y_pred)
    print(test_score)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python bclgb.py

...

[180]  valid_0's binary_logloss: 0.14104
[190] valid_0's binary_logloss: 0.143199
Early stopping, best iteration is:
[97]   valid_0's binary_logloss: 0.10515
0.958041958041958

すると、train() 関数を呼ぶときに使われたパラメータやメトリックが記録される。

$ find mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb 
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/metrics
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/metrics/valid_0-binary_logloss
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/metrics/stopped_iteration
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/metrics/best_iteration
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_gain.json
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_gain.png
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_split.json
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/MLmodel
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/conda.yaml
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/model.lgb
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_split.png
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/tags
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.user
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.source.name
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.log-model.history
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.source.type
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/categorical_feature
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/feature_name
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/keep_training_booster
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/num_boost_round
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/early_stopping_rounds
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/objective
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/verbosity
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/verbose_eval
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/params/metric
mlruns/0/aae1e01ffa04406db4eaa6031d1c1ffb/meta.yaml

ただし、当たり前だけど自分で計算した内容については自分で記録しなければ残らない。 今回であれば自分でホールドアウト検証で計算している test_score は記録されていないことがわかる。

自分で計算したメトリックも記録するように修正してみよう。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import mlflow
from mlflow import lightgbm as mlflow_lgb
import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


def main():
    # LightGBM の学習を自動でトラッキングする
    mlflow_lgb.autolog()

    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target
    feature_names = list(dataset.feature_names)

    # 訓練データと検証データに分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # データセットを生成する
    lgb_train = lgb.Dataset(X_train, y_train,
                            feature_name=feature_names)
    lgb_eval = lgb.Dataset(X_test, y_test,
                           reference=lgb_train,
                           feature_name=feature_names)

    lgbm_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
        'verbosity': -1,
    }

    # Run を開始する
    active_run = mlflow.start_run()

    # ここでは MLflow Tracking がパッチした train() 関数が呼ばれる
    booster = lgb.train(lgbm_params,
                        lgb_train,
                        valid_sets=lgb_eval,
                        num_boost_round=1000,
                        early_stopping_rounds=100,
                        verbose_eval=10,
                        )

    # 学習済みモデルを使って検証データを予測する
    y_pred_proba = booster.predict(X_test,
                                   num_iteration=booster.best_iteration)

    # 検証データのスコアを確認する
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    test_score = accuracy_score(y_test, y_pred)
    print(test_score)

    # 自分で計算したメトリックを記録する
    mlflow.log_metric(key='test_score', value=test_score)

    # Run を終了する
    mlflow.end_run()


if __name__ == '__main__':
    main()

上記をもう一度実行する。

$ python bclgb.py

すると、今度は test_score も記録されていることがわかる。

$ cat mlruns/0//fd0e67c47819488089431813e2028986/metrics/test_score
1591274826005 0.958041958041958 0

自作の autolog() 相当を作ってみる

ところで、先ほどの autolog() 関数がどのように実現されているのか気にならないだろうか。 これは、モンキーパッチを使うことで対象モジュールのコードを動的に書きかえている。 先ほどの LightGBM であれば lightgbm.train() が MLflow Tracking を使うものに書きかえられた。

そこで、試しに自分でも autolog() 相当のものを書いてみることにした。 以下は scikit-learn の RandomForestClassifier#fit() を書きかえたもの。 MLflow Tracking のモンキーパッチには gorilla というフレームワークが使われている。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import tempfile
import pathlib

import numpy as np
import gorilla
import mlflow
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


def autolog_sklearn_random_forest():
    """scikit-learn の RandomForestClassifier#fit() をモンキーパッチする

    モンキーパッチ後は、モデルの学習に関する情報が自動で MLflow Tracking に残る"""

    @gorilla.patch(RandomForestClassifier)
    def fit(self, *args, **kwargs):
        # 実行中の Run が存在するか確認する
        if not mlflow.active_run():
            # 無ければ新しく Run を作る
            mlflow.start_run()
            # 関数内で end_run() を呼ぶ必要があるか
            auto_end_run = True
        else:
            auto_end_run = False

        # 学習に使われたパラメータを記録する
        attr_names = ['n_estimators',
                      'max_depth',
                      'min_samples_split',
                      'min_samples_leaf',
                      'random_state',
                      ]
        for attr_name in attr_names:
            # インスタンスからアトリビュートの値を取り出して記録する
            attr_value = getattr(self, attr_name)
            mlflow.log_param(key=f'sklearn.ensemble.RandomForestClassifier.{attr_name}',
                             value=attr_value)

        # パッチ前のオブジェクトを取得する
        original = gorilla.get_original_attribute(RandomForestClassifier, 'fit')

        # パッチ前のオブジェクトの呼び出し
        result = original(self, *args, **kwargs)

        # メトリックを記録する (ここでは特に何もない)
        # NOTE: validation set の損失などを取得する手段があれば残す

        # アーティファクトを記録する
        # Gini Importance を記録する (XXX: feature names を渡す良い方法が思いつかない...)
        tmpdir = tempfile.mkdtemp()
        filename = 'sklearn.ensemble.RandomForestClassifier.feature_importances_.json'
        artifact_path = pathlib.Path(tmpdir) / filename
        with open(artifact_path, 'w') as fp:
            json.dump(list(self.feature_importances_), fp, indent=2)
        mlflow.log_artifact(artifact_path)

        # 関数内で Run を作っていたら終了する
        if not auto_end_run:
            mlflow.end_run()

        # 結果を返す
        return result

    # 既にパッチされているときは上書きしない
    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    # RandomForestClassifier#fit() をモンキーパッチする
    monkey_patch = gorilla.Patch(RandomForestClassifier, 'fit', fit, settings=settings)
    gorilla.apply(monkey_patch)


def main():
    # RandomForestClassifier をパッチする
    autolog_sklearn_random_forest()

    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    # 訓練データと検証データに分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # RandomForest 分類器を用意する
    clf = RandomForestClassifier(n_estimators=100,
                                 random_state=42,
                                 verbose=1)
    # ここで呼び出されるのはパッチされたオブジェクトになる
    clf.fit(X_train, y_train)

    # 学習済みモデルを使って検証データを予測する
    y_pred_proba = clf.predict(X_test)

    # 検証データのスコアを確認する
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    test_score = accuracy_score(y_test, y_pred)
    print(test_score)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ myautolog.py

すると、次のようにパラメータなどが記録される。

$ find mlruns/0/6c2abd9a92344534a45965005f7dfcc6  
mlruns/0/6c2abd9a92344534a45965005f7dfcc6
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/metrics
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/artifacts
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/artifacts/sklearn.ensemble.RandomForestClassifier.feature_importances_.json
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/tags
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.user
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.source.name
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.source.type
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.max_depth
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.random_state
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.n_estimators
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.min_samples_split
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.min_samples_leaf
mlruns/0/6c2abd9a92344534a45965005f7dfcc6/meta.yaml
$ cat mlruns/0/6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.n_estimators 
100

モンキーパッチを使うと、共通で何度も使われるようなコードをトラッキングする手間がだいぶ減りそうだ。 ただ、対象のモジュールのインターフェースに追従しなければならない副作用もあるため、使い所は吟味しなければいけないだろう。

まとめ

今回は MLflow の Tracking というコンポーネントを試してみた。

総じて、使い勝手や設計などに筋の良さを感じた。 一人でもチームでも同じコードが使える点や、インテグレーション以外で他に何らかの機械学習フレームワークへの依存がないのはとても良い。 一方で、多少は仕方ないにせよ自分のコードに組み込むときに、モンキーパッチを書かない限りは似たようなコードを何度も書くハメになってつらそう。 あと、チームで使うときにサーバの運用とか Artifact の置き場所どうしよってところは悩みそうだと思った。

そんな感じで。