今回は機械学習において学習済みのモデルを取り回す方法の一つとして pickle を扱う方法を取り上げてみる。 尚、使うフレームワークによっては pickle 以外の方法があらかじめ提供されている場合もある。 例えば学習済みモデルのパラメータを文字列などの形でダンプできるようになっているものとか。
ちなみに pickle という機能自体は機械学習に限らず色々な用途に応用が効く。 より汎用な解説については以前に別の記事でまとめたことがある。
使う環境は以下の通り。
$ sw_vers ProductName: Mac OS X ProductVersion: 10.13.4 BuildVersion: 17E202 $ python -V Python 3.6.5
下準備
今回はサンプルとして scikit-learn を使うため、まずはインストールしておく。
$ pip install scikit-learn
インストールが終わったら Python のインタプリタを起動する。
$ python
学習済みモデルを用意する
まずは学習済みモデルを用意する。 今回はサンプルとして Iris データセットを k-NN 分類器で学習させることにした。
まずは Iris データセットを読み込む。 この際、一番最後のデータについては学習から除外しておく。 これは、後ほど保存・復元したモデルに識別させるため。
>>> from sklearn import datasets >>> iris = datasets.load_iris() >>> X, y = iris.data[:-1], iris.target[:-1]
続いては k-NN 分類器のモデルを用意する。
>>> from sklearn.neighbors import KNeighborsClassifier >>> clf = KNeighborsClassifier()
Iris データセットを使ってモデルを学習させる。
>>> clf.fit(X, y) KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform')
これで学習済みモデルが用意できた。
学習済みモデルを保存する
続いては用意できた学習済みモデルを pickle を使って保存する。
以下のコードで model.pickle
というファイル名でモデルを保存できる。
>>> import pickle >>> with open('model.pickle', mode='wb') as fp: ... pickle.dump(clf, fp) ...
別のバージョンの Python からも読み込むことが考えられるときは protocol
オプションに 2
を指定しておくと良い。
こうすれば Python 2.3 以降であれば保存したモデルを復元できるようになる。
>>> with open('model.pickle', mode='wb') as fp: ... pickle.dump(clf, fp, protocol=2) ...
これでモデルがファイルに保存できたので一旦インタプリタを終了しておく。
>>> exit()
学習済みモデルを復元する
続いては、先ほどファイルに保存した学習済みモデルを別のインタプリタから復元できることを確認しよう。
$ python
次のコードで、先ほど保存した model.pickle
というファイルから学習済みモデルを復元できる。
>>> import pickle >>> with open('model.pickle', mode='rb') as fp: ... clf = pickle.load(fp) ...
たしかに k-NN 分類器のモデルが得られた。
>>> clf KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform')
先ほど学習に使わなかった一番最後のデータを識別させてみよう。 まずはデータセットからデータを取り出す。
>>> from sklearn import datasets >>> iris = datasets.load_iris() >>> X_last, y_last = iris.data[-1], iris.target[-1]
学習済みモデルで識別させてみる。
>>> clf.predict([X_last])
array([2])
モデルはデータをラベル 2
と判断した。
正解ラベルも 2
なので正しく識別できている。
>>> y_last
2
めでたしめでたし。
スマートPythonプログラミング: Pythonのより良い書き方を学ぶ
- 作者: もみじあめ
- 発売日: 2016/03/12
- メディア: Kindle版
- この商品を含むブログ (1件) を見る