今回は表題の通り scikit-learn の手書き数字データセットをサポートベクターマシンを使って分類してみることにする。
下準備
あらかじめ必要な Python パッケージをインストールしておく。
$ pip install scikit-learn scipy matplotlib
分類対象を確認する
実際に分類する前にどんなデータなのかを確認しておく。 データセットの中からランダムにサンプリングしたものを matplotlib で可視化する。
#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import division from __future__ import unicode_literals from __future__ import print_function from matplotlib import pyplot as plt from matplotlib import cm import numpy as np from sklearn import datasets def main(): digits = datasets.load_digits() X = digits.data y = digits.target print('データセットの点数: {N}'.format(N=X.shape[0])) print('各データの次元数: {dimension}'.format(dimension=X.shape[1])) # データの中から 25 点を無作為に選び出す p = np.random.random_integers(0, len(X), 25) # 選んだデータとラベルを matplotlib で表示する samples = np.array(list(zip(X, y)))[p] for index, (data, label) in enumerate(samples): # 画像データを 5x5 の格子状に配置する plt.subplot(5, 5, index + 1) # 軸に関する表示はいらない plt.axis('off') # データを 8x8 のグレースケール画像として表示する plt.imshow(data.reshape(8, 8), cmap=cm.gray_r, interpolation='nearest') # 画像データのタイトルに正解ラベルを表示する plt.title(label, color='red') # グラフを表示する plt.show() if __name__ == '__main__': main()
上記を実行すると次のような結果が得られる。 ひとつひとつのデータは 64 (8x8) 次元のグレースケール画像になっている。 見た感じだいぶ粗め。
SVM で分類する
どういったデータなのかひとまず確認できたところで早速分類してみよう。 scikit-learn には SVM の分類器として sklearn.svm.SVC がある。
次のサンプルコードでは sklearn.svm.SVC に適切なパラメータを与えて手書き数字データセットを分類させている。 その際、K-fold 交差検証を使うことでその汎化性能 (未知のデータに対処する能力) を調べている。
#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import division from __future__ import unicode_literals from __future__ import print_function from sklearn import datasets from sklearn import cross_validation from sklearn import svm from sklearn import metrics def main(): digits = datasets.load_digits() X = digits.data y = digits.target scores = [] # K-fold 交差検証でアルゴリズムの汎化性能を調べる kfold = cross_validation.KFold(len(X), n_folds=10) for train, test in kfold: # デフォルトのカーネルは rbf になっている clf = svm.SVC(C=2**2, gamma=2**-11) # 訓練データで学習する clf.fit(X[train], y[train]) # テストデータの正答率を調べる score = metrics.accuracy_score(clf.predict(X[test]), y[test]) scores.append(score) # 最終的な正答率を出す accuracy = (sum(scores) / len(scores)) * 100 msg = '正答率: {accuracy:.2f}%'.format(accuracy=accuracy) print(msg) if __name__ == '__main__': main()
上記を実行すると、次のように 98.27% の正答率が得られた。
正答率: 98.27%
適切なパラメータの調べ方
先ほどのサンプルコードを見ると、パラメータの C と gamma がマジックナンバーになっている。 では、このマジックナンバーをどうやって調べたかというと、次のようにして汎化性能が高くなるものを総当りで調べた。
#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import division from __future__ import unicode_literals from __future__ import print_function import itertools import operator from sklearn import datasets from sklearn import cross_validation from sklearn import svm from sklearn import metrics def _print_result(percentage, C_n, gamma_n): """ 正答率とそれに使われたパラメータを出力する """ msg = '正答率 {percentage:.2f}% C=2^{C} gamma=2^{gamma}'.format( percentage=percentage, C=C_n, gamma=gamma_n, ) print(msg) def main(): # 数値画像のデータを読み込む digits = datasets.load_digits() X = digits.data y = digits.target # パラメータ C の候補 (2^-5 ~ 2^5) Cs = [(2 ** i, i) for i in range(-5, 5)] # パラメータ gamma の候補 (2^-12 ~ 2^-5) gammas = [(2 ** i, i) for i in range(-12, -5)] # 2^-12 ~ 2^-5 # 上記のパラメータが取りうる組み合わせ (デカルト積) を作る parameters = itertools.product(Cs, gammas) results = [] # 各組み合わせで正答率にどういった変化があるかを調べていく for (C, C_n), (gamma, gamma_n) in parameters: scores = [] # 正答率は K-fold 交差検定 (10 分割) で計算する kfold = cross_validation.KFold(len(X), n_folds=10) # 教師信号を学習用と検証用に分割する for train, test in kfold: # 前述したパラメータを使って SVM (RBF カーネル) を初期化する clf = svm.SVC(C=C, gamma=gamma) # 学習する clf.fit(X[train], y[train]) # 検証する score = metrics.accuracy_score(clf.predict(X[test]), y[test]) scores.append(score) # 正答率をパーセンテージにしてパラメータと共に表示する percentage = (sum(scores) / len(scores)) * 100 results.append((percentage, C_n, gamma_n)) _print_result(*results[-1]) # 正答率の最も高かったパラメータを出力する sorted_result = sorted(results, key=operator.itemgetter(0), reverse=True) print('--- 最適なパラメータ ---') _print_result(*sorted_result[0]) if __name__ == '__main__': main()
上記を実行すると、次のようにパラメータ毎の正答率が順に出力されたあとに最も正答率の高かった (= 汎化性能の良い) パラメータが表示される。
正答率 81.91% C=2^-5 gamma=2^-12 正答率 90.09% C=2^-5 gamma=2^-11 正答率 91.37% C=2^-5 gamma=2^-10 ...(省略)... 正答率 95.60% C=2^4 gamma=2^-8 正答率 81.53% C=2^4 gamma=2^-7 正答率 49.87% C=2^4 gamma=2^-6 --- 最適なパラメータ --- 正答率 98.27% C=2^2 gamma=2^-11
今後の課題
めでたしめでたし、とは残念ながらいかない。 上記の手法を 784 (28x28) 次元の MNIST データセットに適用してみたところ、今度は全然性能が得られない…。 次元と共に計算量も増えるし困ったもんだ。