MLflow は MLOps に関連した OSS のひとつ。
いくつかのコンポーネントに分かれていて、それぞれを必要に応じて独立して使うことができる。
今回は、その中でも実験の管理と可視化を司る MLflow Tracking を試してみることにした。
機械学習のプロジェクトでは試行錯誤することが多い。
その際には、パラメータやモデルの構成などを変えながら何度も実験を繰り返すことになる。
すると、回数が増えるごとに使ったパラメータや得られた結果、モデルなどの管理が煩雑になってくる。
MLflow Tracking を使うことで、その煩雑さが軽減できる可能性がある。
使った環境は次のとおり。
$ sw_vers
ProductName: Mac OS X
ProductVersion: 10.14 .6
BuildVersion: 18G5033
$ python -V
Python 3.7 .7
$ mlflow --version
mlflow, version 1.8 .0
もくじ
登場する概念について
まず、MLflow Tracking では Experiment
と Run
というモノを作っていく。
これらは、特定の目的を持った実験と、それに 1:N
で紐付いた各試行を表している。
つまり、試行錯誤の度に Run
が増えることになる。
そして、それぞれの Run
には次のような情報が、またもや 1:N
で紐づく。
Parameter
Metric
各ラウンドの損失や、交差検証のスコアといったメトリック
Artifact
学習済みモデルや、特徴量の重要度など実験した結果として得られる成果物
Tag
要するに Experiment > Run > Parameter, Metric, ...
ということ。
これらの情報は、典型的には Python のスクリプトから記録される。
そして、記録された情報は後からスクリプトや WebUI 経由で確認できる。
データを保存する場所について
具体的な使い方を紹介する前に、前述した情報がどのように管理されるのか解説しておく。
まず、MLflow Tracking はクライアントとサーバに分かれてる。
そして、データは基本的にサーバに保存される。
クライアントは、データを保存するサーバの場所を Tracking URI と Artifact URI という 2 つの URI で指定する。
ただし、クライアントとサーバは 1 台のマシンで兼ねることもできる。
また、データを記録する方法も、追加でソフトウェアなどを必要としないローカルファイルで完結させることができる。
そのため、クライアント・サーバ方式といっても 1 人で使い始める分には環境構築などの作業は全く必要ない。
あくまで、複数人のチームで記録されたデータを共用・共有したいときに専用のサーバが必要となる。
Tracking URI について
Tracking URI は、概ね Run
の中で Artifact
以外の情報を記録するところ。
記録される情報は、基本的には Key-Value 形式になっている。
利用できるバックエンドには次のようなものがある。
ローカルファイル
マウントされているブロックストレージにファイルとして記録される
形式: file:<path>
リレーショナルデータベース
SQLAlchemy 経由で RDB にテーブルのエントリとして記録される
形式: <dialect>+<driver>://<username>:<password>@<host>:<port>/<database>
REST API
MLflow Server というサーバを立ち上げて、そこに記録される
後述するが、記録先の実体はローカルファイルだったりリレーショナルデータベースだったり選べる
形式: http(s)://<host>:<port>
その他、詳細は以下に記載されている。
www.mlflow.org
Artifact URI について
Artifact URI は、その名の通り Artifact
を記録するところ。
記録される情報はファイル (バイト列) になる。
利用できるバックエンドには次のようなものがある。
ローカルファイル
マウントされているブロックストレージにファイルとして保存される
形式: file:<path>
各種クラウド (オブジェクト) ストレージ
Amazon S3, Azure Blob Storage, Google Cloud Storage などにファイルとして保存される
形式: s3://<bucket>
など
(S)FTP サーバ
(S)FTP サーバにファイルとして保存される
形式: (s)ftp://<user>@<host>/<path>
HDFS
HDFS (Hadoop Distributed File System) にファイルとして保存される
形式: hdfs://<path>
その他、詳細は以下に記載されている。
www.mlflow.org
色々とあるけど、結局のところ 1 人で使い始めるなら両方ともローカルファイルにすれば良い。
デフォルトでは、どちらもクライアントを実行した場所の mlruns
というディレクトリが使われる。
これは、Traking URI と Artifact URI の両方に file:./mlruns
を指定した状態ということ。
ちなみに、チームで使いたいけど自分で構築とか運用したくないってときは、開発の中心となっている Databricks がサーバ部分のマネージドサービスを提供している。
下準備
前置きが長くなったけど、ここからやっと実際に試していく。
はじめに、必要なパッケージをインストールしておこう。
なお、mlflow
以外は、後ほど登場するサンプルコードを動かすためだけに必要なもの。
$ pip install mlflow scikit-learn lightgbm matplotlib
インストールすると mlflow
コマンドが使えるようになる。
$ mlflow --version
mlflow, version 1.8 .0
基本的な使い方
とりあえず Python の REPL を使って、基本的な使い方を紹介する。
まずは REPL を起動しておく。
$ python
MLflow のパッケージをインポートする。
>>> import mlflow
実験の試行を開始する。
これには start_run()
関数を使う。
なお、デフォルトでは Experiment
として Default
という領域が使われる。
>>> mlflow.start_run()
<ActiveRun: >
実験に使ったパラメータを log_param()
関数で記録する。
>>> mlflow.log_param(key='foo' , value='bar' )
得られたメトリックなどの情報は log_metric()
関数で記録する。
>>> mlflow.log_metric(key='logloss' , value=1.0 )
実験には set_tag()
関数でタグが付与できる。
>>> mlflow.set_tag(key='hoge' , value='fuga' )
尚、これらは関数名を複数形にすると辞書型で複数の Key-Value を一度に記録できる。
アーティファクトについては、ちょっと面倒くさい。
アーティファクトはバイト列のファイルなので、まずはローカルにファイルを用意して、それをコピー (転送) することになる。
典型的には、一時ディレクトリを tempfile
モジュールで作って、中身のファイルを作ったらそれをコピーすれば良い。
一時ディレクトリ以下のファイルは、処理が終わったら自動的に削除される。
>>> import tempfile
>>> import pathlib
>>> with tempfile.TemporaryDirectory() as d:
... filename = 'test-artifact'
... artifact_path = pathlib.Path(d) / filename
... with open (artifact_path, 'w' ) as fp:
... print ('Hello, World!' , file =fp)
... mlflow.log_artifact(artifact_path)
...
あとは実験の試行を終了するだけ。
>>> mlflow.end_run()
Python の REPL を終了して、カレントディレクトリを確認してみよう。
デフォルトの Tracking URI / Artifact URI として file:./mlruns
が使われるため mlruns
というディレクトリが作られている。
$ find mlruns
mlruns
mlruns/0
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/metrics
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/metrics/logloss
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/artifacts
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/artifacts/test-artifact
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/tags
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/tags/hoge
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.user
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.source .name
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/tags/mlflow.source .type
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/params
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/params/foo
mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/meta.yaml
mlruns/0 /meta.yaml
mlruns/.trash
最初の階層は、それぞれの Experiment
を表している。
Default
の Experiment
には ID として 0
が付与されていることがわかる。
$ cat mlruns/0 /meta.yaml
artifact_location: file:///Users/amedama/Documents/temporary/helloworld/mlruns/0
experiment_id: '0'
lifecycle_stage: active
name: Default
その下の階層は、それぞれの Run
を表している。
$ head mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/meta.yaml
artifact_uri: file:///Users/amedama/Documents/temporary/helloworld/mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/artifacts
end_time: 1591267604170
entry_point_name: ''
experiment_id: '0'
lifecycle_stage: active
name: ''
run_id: 3de33a8f39294c4f8bd404a5d5bccf39
run_uuid: 3de33a8f39294c4f8bd404a5d5bccf39
source_name: ''
source_type: 4
ここにパラメータやメトリックなどが記録される。
メトリックに記録されている各列は、実行時刻、値、ステップ数を表している。
ステップ数というのは、たとえばニューラルネットワークのエポックだったり、ブースティングマシンのラウンドだったりする。
先ほど実行した log_metric()
関数の引数として、実は指定できた。
$ head mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/params/foo
bar
$ head mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/metrics/logloss
1591267315737 1.0 0
$ cat mlruns/0 /3de33a8f39294c4f8bd404a5d5bccf39/artifacts/test-artifact
Hello, World!
時刻には 1,000 倍した UNIX time が記録されている。
$ date -r $(( 1591267315737 / 1000 ))
2020 年 6 月 4 日 木曜日 19 時41 分55 秒 JST
記録された情報を WebUI で確認する
今のところ「ふーん」という感じだと思うので、記録された情報を WebUI からも確認してみよう。
確認用の WebUI を立ち上げるために、mlruns
ディレクトリのある場所で mlflow ui
コマンドを実行する。
$ mlflow ui
そして、ブラウザで localhost:5000
を閲覧する。
$ open http://localhost:5000
すると、こんな感じで過去に記録された情報が確認できる。
メトリックで試行を並べ替えたり、ステップ毎の値を可視化する機能も備わっている。
MLflow Tracking WebUI
ちなみに WebUI を使う以外にも、ちょっと面倒だけどスクリプトから確認することもできる。
$ python
たとえば、メトリックの logloss
が最も良いものを絞り込んでみる。
まあ、まだ 1 回しか実行してないんだけどね。
>>> import mlflow
>>> tracking_uri = 'file:./mlruns'
>>> client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri)
>>> experiment = client.get_experiment_by_name('Default' )
>>> run = client.search_runs(experiment.experiment_id,
... order_by=['metric.logloss asc' ],
... max_results=1 )[0 ]
>>> run.data.params
{'foo' : 'bar' }
>>> run.data.metrics
{'logloss' : 1.0 }
複数人のチームで使いたいとき
複数人のチームで使うときは、いくつかの選択肢があるけど基本的には MLflow Server を用意して、そこに皆でアクセスする。
MLflow Server というのは Tracking URI として使える REST API と、先ほど確認した WebUI が一緒になったもの。
注意点として、MLflow Server はあくまで Tracking URI と WebUI のエンドポイントを提供するもの。
なので、Artifact URI の実体については別に用意しなければいけない。
たとえばクラウドストレージが使えるならそれを使っても良し、自分たちで FTP サーバを立ち上げたり HDFS のクラスタを組むことも考えられるだろう。
ここでは MLflow Server について軽く解説しておく。
まず、MLflow Server は mlflow server
コマンドで起動できる。
起動するときに、Tracking URI (--backend-store-uri
) と Artifact URI (--default-artifact-root
) を指定する。
$ mlflow server \
--backend-store-uri sqlite:///tracking.db \
--default-artifact-root file:/tmp/artifacts \
--host 0.0 .0.0
--backend-store-uri
は、MLflow サーバがクライアントから REST API 経由で受け取ったデータを記録する場所。
一方で、--default-artifact-root
はサーバに接続してきたクライアントに「Artifact はここに保存してね」と伝えられる場所に過ぎない。
つまり、MLflow Server がプロキシしてくれるわけではないのでクライアントから接続性のある URI を指定する必要がある。
ここでは横着して /tmp
以下を指定してしまっている。
しかし、こんな風にするなら本来は NFS などでクライアントがすべて /tmp
以下に共有ディレクトリをマウントする必要がある。
起動したサーバを使って実際にデータを記録してみよう。
まずは別のターミナルで Python の REPL を起動する。
$ python
set_tracking_uri()
関数で Tracking URI として MLflow Server のエンドポイントを指定する。
>>> import mlflow
>>> tracking_uri = 'http://localhost:5000'
>>> mlflow.set_tracking_uri(tracking_uri)
すると、データの記録先が次のように設定される。
>>> mlflow.get_tracking_uri()
'http://localhost:5000'
>>> mlflow.get_artifact_uri()
'file:///tmp/artifacts/0/26d9e4204c20401eb7d2807a93be8b75/artifacts'
何か適当にデータを記録してみよう。
>>> mlflow.start_run()
>>> mlflow.log_param(key='foo' , value='bar' )
>>> mlflow.log_metric(key='logloss' , value=0.5 )
>>> import tempfile
>>> import pathlib
>>> with tempfile.TemporaryDirectory() as d:
... filename = 'test-artifact'
... artifact_path = pathlib.Path(d) / filename
... with open(artifact_path, 'w' ) as fp:
... print ('Hello, World!' , file=fp)
... mlflow.log_artifact(artifact_path)
...
>>> mlflow.end_run()
これで、アーティファクトについては前述したとおり /tmp
以下に保存される。
$ cat /tmp/artifacts/0 /26d9e4204c20401eb7d2807a93be8b75/artifacts/test-artifact
Hello, World!
アーティファクト以外の情報は MLflow Server の方に記録される。
今回であれば MLflow Server を実行したディレクトリにある SQLite3 のデータベースに入っている。
$ sqlite3 tracking.db 'SELECT * FROM experiments'
0 | Default| file:///tmp/artifacts/0 | active
$ sqlite3 tracking.db 'SELECT * FROM runs'
26d9e4204c20401eb7d2807a93be8b75|| UNKNOWN||| amedama| FINISHED| 1591272053874 | 1591272141536 || active| file:///tmp/artifacts/0 /26d9e4204c20401eb7d2807a93be8b75/artifacts| 0
ちなみに mlflow server
コマンドの実装は mlflow.server:app
にある Flask の WSGI アプリケーションを gunicorn でホストするコマンドをキックしてるだけ。
なので、自分で WSGI サーバを立ててアプリケーションをホストしても構わないだろう。
認証が必要なら前段にリバースプロキシを置いて好きにやれば良いと思う。
MLflow Server の解説については、ここまでで一旦おわり。
スクリプトに組み込んでみる
続いては、実際に MLflow Tracking を機械学習を扱うコードに組み込んでみよう。
そんなに良い例でもないけど乳がんデータセットを RandomForest で分類する過程を MLflow Tracking で記録してみる。
ここでは set_tracking_uri()
や set_experiment()
していないのでデフォルト値 (mlruns
/ Default
) が使われる。
ちなみに、コードに書かなくても、環境変数の MLFLOW_TRACKING_URI
や MLFLOW_EXPERIMENT_NAME
を使う方法もある。
import pickle
import json
import tempfile
import pathlib
import mlflow
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
def main ():
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
feature_names = dataset.feature_names
clf = RandomForestClassifier(n_estimators=100 ,
random_state=42 ,
verbose=1 )
folds = StratifiedKFold(shuffle=True , random_state=42 )
result = cross_validate(clf, X, y,
cv=folds,
return_estimator=True ,
n_jobs=-1 )
estimators = result.pop('estimator' )
with mlflow.start_run():
mlflow.log_params({
'n_estimators' : clf.n_estimators,
'random_state' : clf.random_state,
})
for key, value in result.items():
mlflow.log_metric(key=key,
value=value.mean())
for index, estimator in enumerate (estimators):
with tempfile.TemporaryDirectory() as d:
clf_filename = f'sklearn.ensemble.RandomForestClassifier.{index}'
clf_artifact_path = pathlib.Path(d) / clf_filename
with open (clf_artifact_path, 'wb' ) as fp:
pickle.dump(estimator, fp)
imp_filename = f'sklearn.ensemble.RandomForestClassifier.{index}.feature_importances_.json'
imp_artifact_path = pathlib.Path(d) / imp_filename
with open (imp_artifact_path, 'w' ) as fp:
importances = dict (zip (feature_names, estimator.feature_importances_))
json.dump(importances, fp, indent=2 )
mlflow.log_artifacts(d)
if __name__ == '__main__' :
main()
上記を実行する。
$ python bcrf.py
これで、各モデルの学習に使ったパラメータやメトリック、特徴量の重要度などが記録される。
$ cat mlruns/0 /15d6ba2c32e240a382ff1efab8814e47/params/n_estimators
100
$ cat mlruns/0 /15d6ba2c32e240a382ff1efab8814e47/metrics/test_score
1591273430658 0.9560937742586555 0
$ head mlruns/0 /15d6ba2c32e240a382ff1efab8814e47/artifacts/sklearn.ensemble.RandomForestClassifier.0 .feature_importances_.json
{
"mean radius" : 0.04924964271136709 ,
"mean texture" : 0.01737644949893605 ,
"mean perimeter" : 0.07531323166620862 ,
"mean area" : 0.05398223096565245 ,
"mean smoothness" : 0.007976941962365291 ,
"mean compactness" : 0.01205471060787445 ,
"mean concavity" : 0.04868260289112485 ,
"mean concave points" : 0.08748636077564236 ,
"mean symmetry" : 0.0037492071919759205 ,
各種フレームワークの自動ロギング
先ほどのコードを見て分かる通り、MLflow Tracking のコードを組み込むのは結構めんどくさい。
そこで、MLflow Tracking は各種フレームワークの学習を自動で記録するインテグレーションを提供している。
ただし、この機能は今のところ Experimental な点に注意が必要。
以下のサンプルコードは LightGBM の学習を自動で記録するもの。
MLflow Tracking を動作させている部分は mlflow.lightgbm.autolog()
を呼び出している一行だけ。
from mlflow import lightgbm as mlflow_lgb
import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def main ():
mlflow_lgb.autolog()
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
feature_names = list (dataset.feature_names)
X_train, X_test, y_train, y_test = train_test_split(X, y,
shuffle=True ,
random_state=42 )
lgb_train = lgb.Dataset(X_train, y_train,
feature_name=feature_names)
lgb_eval = lgb.Dataset(X_test, y_test,
reference=lgb_train,
feature_name=feature_names)
lgbm_params = {
'objective' : 'binary' ,
'metric' : 'binary_logloss' ,
'verbosity' : -1 ,
}
booster = lgb.train(lgbm_params,
lgb_train,
valid_sets=lgb_eval,
num_boost_round=1000 ,
early_stopping_rounds=100 ,
verbose_eval=10 ,
)
y_pred_proba = booster.predict(X_test,
num_iteration=booster.best_iteration)
y_pred = np.where(y_pred_proba > 0.5 , 1 , 0 )
test_score = accuracy_score(y_test, y_pred)
print (test_score)
if __name__ == '__main__' :
main()
上記を実行してみよう。
$ python bclgb.py
...
[180 ] valid_0's binary_logloss: 0.14104
[190] valid_0' s binary_logloss: 0.143199
Early stopping, best iteration is:
[97 ] valid_0's binary_logloss: 0.10515
0.958041958041958
すると、train()
関数を呼ぶときに使われたパラメータやメトリックが記録される。
$ find mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/metrics
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/metrics/valid_0-binary_logloss
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/metrics/stopped_iteration
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/metrics/best_iteration
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_gain.json
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_gain.png
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_split.json
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/MLmodel
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/conda.yaml
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/model/model.lgb
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/artifacts/feature_importance_split.png
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/tags
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.user
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.source .name
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.log-model.history
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/tags/mlflow.source .type
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/categorical_feature
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/feature_name
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/keep_training_booster
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/num_boost_round
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/early_stopping_rounds
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/objective
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/verbosity
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/verbose_eval
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/params/metric
mlruns/0 /aae1e01ffa04406db4eaa6031d1c1ffb/meta.yaml
ただし、当たり前だけど自分で計算した内容については自分で記録しなければ残らない。
今回であれば自分でホールドアウト検証で計算している test_score
は記録されていないことがわかる。
自分で計算したメトリックも記録するように修正してみよう。
import mlflow
from mlflow import lightgbm as mlflow_lgb
import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def main ():
mlflow_lgb.autolog()
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
feature_names = list (dataset.feature_names)
X_train, X_test, y_train, y_test = train_test_split(X, y,
shuffle=True ,
random_state=42 )
lgb_train = lgb.Dataset(X_train, y_train,
feature_name=feature_names)
lgb_eval = lgb.Dataset(X_test, y_test,
reference=lgb_train,
feature_name=feature_names)
lgbm_params = {
'objective' : 'binary' ,
'metric' : 'binary_logloss' ,
'verbosity' : -1 ,
}
active_run = mlflow.start_run()
booster = lgb.train(lgbm_params,
lgb_train,
valid_sets=lgb_eval,
num_boost_round=1000 ,
early_stopping_rounds=100 ,
verbose_eval=10 ,
)
y_pred_proba = booster.predict(X_test,
num_iteration=booster.best_iteration)
y_pred = np.where(y_pred_proba > 0.5 , 1 , 0 )
test_score = accuracy_score(y_test, y_pred)
print (test_score)
mlflow.log_metric(key='test_score' , value=test_score)
mlflow.end_run()
if __name__ == '__main__' :
main()
上記をもう一度実行する。
$ python bclgb.py
すると、今度は test_score
も記録されていることがわかる。
$ cat mlruns/0 //fd0e67c47819488089431813e2028986/metrics/test_score
1591274826005 0.958041958041958 0
自作の autolog() 相当を作ってみる
ところで、先ほどの autolog()
関数がどのように実現されているのか気にならないだろうか。
これは、モンキーパッチを使うことで対象モジュールのコードを動的に書きかえている。
先ほどの LightGBM であれば lightgbm.train()
が MLflow Tracking を使うものに書きかえられた。
そこで、試しに自分でも autolog()
相当のものを書いてみることにした。
以下は scikit-learn の RandomForestClassifier#fit()
を書きかえたもの。
MLflow Tracking のモンキーパッチには gorilla というフレームワークが使われている。
import json
import tempfile
import pathlib
import numpy as np
import gorilla
import mlflow
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def autolog_sklearn_random_forest ():
"""scikit-learn の RandomForestClassifier#fit() をモンキーパッチする
モンキーパッチ後は、モデルの学習に関する情報が自動で MLflow Tracking に残る"""
@ gorilla.patch (RandomForestClassifier)
def fit (self, *args, **kwargs):
if not mlflow.active_run():
mlflow.start_run()
auto_end_run = True
else :
auto_end_run = False
attr_names = ['n_estimators' ,
'max_depth' ,
'min_samples_split' ,
'min_samples_leaf' ,
'random_state' ,
]
for attr_name in attr_names:
attr_value = getattr (self, attr_name)
mlflow.log_param(key=f'sklearn.ensemble.RandomForestClassifier.{attr_name}' ,
value=attr_value)
original = gorilla.get_original_attribute(RandomForestClassifier, 'fit' )
result = original(self, *args, **kwargs)
NOTE
XXX
tmpdir = tempfile.mkdtemp()
filename = 'sklearn.ensemble.RandomForestClassifier.feature_importances_.json'
artifact_path = pathlib.Path(tmpdir) / filename
with open (artifact_path, 'w' ) as fp:
json.dump(list (self.feature_importances_), fp, indent=2 )
mlflow.log_artifact(artifact_path)
if not auto_end_run:
mlflow.end_run()
return result
settings = gorilla.Settings(allow_hit=True , store_hit=True )
monkey_patch = gorilla.Patch(RandomForestClassifier, 'fit' , fit, settings=settings)
gorilla.apply(monkey_patch)
def main ():
autolog_sklearn_random_forest()
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
X_train, X_test, y_train, y_test = train_test_split(X, y,
shuffle=True ,
random_state=42 )
clf = RandomForestClassifier(n_estimators=100 ,
random_state=42 ,
verbose=1 )
clf.fit(X_train, y_train)
y_pred_proba = clf.predict(X_test)
y_pred = np.where(y_pred_proba > 0.5 , 1 , 0 )
test_score = accuracy_score(y_test, y_pred)
print (test_score)
if __name__ == '__main__' :
main()
上記を実行してみよう。
$ myautolog.py
すると、次のようにパラメータなどが記録される。
$ find mlruns/0 /6c2abd9a92344534a45965005f7dfcc6
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/metrics
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/artifacts
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/artifacts/sklearn.ensemble.RandomForestClassifier.feature_importances_.json
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/tags
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.user
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.source .name
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/tags/mlflow.source .type
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.max_depth
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.random_state
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.n_estimators
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.min_samples_split
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.min_samples_leaf
mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/meta.yaml
$ cat mlruns/0 /6c2abd9a92344534a45965005f7dfcc6/params/sklearn.ensemble.RandomForestClassifier.n_estimators
100
モンキーパッチを使うと、共通で何度も使われるようなコードをトラッキングする手間がだいぶ減りそうだ。
ただ、対象のモジュールのインターフェースに追従しなければならない副作用もあるため、使い所は吟味しなければいけないだろう。
まとめ
今回は MLflow の Tracking というコンポーネントを試してみた。
総じて、使い勝手や設計などに筋の良さを感じた。
一人でもチームでも同じコードが使える点や、インテグレーション以外で他に何らかの機械学習フレームワークへの依存がないのはとても良い。
一方で、多少は仕方ないにせよ自分のコードに組み込むときに、モンキーパッチを書かない限りは似たようなコードを何度も書くハメになってつらそう。
あと、チームで使うときにサーバの運用とか Artifact の置き場所どうしよってところは悩みそうだと思った。
そんな感じで。