CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: pandas の DataFrame, Series, Index を拡張する

Python でデータ分析をするときに、ほぼ必ずといって良いほど使われるパッケージとして pandas がある。 そのままでも便利な pandas だけど、代表的なオブジェクトの DataFrame, Series, Index には実は独自の拡張を加えることもできる。 これがなかなか面白いので、今回はその機能について紹介してみる。

ただし、あらかじめ断っておくと注意点もある。 独自の拡張を加えると、本来は存在しないメソッドやプロパティがオブジェクトに生えることになる。 そのため、便利だからといってこの機能を使いすぎると、コードの可読性が低下する恐れもある。 使うなら、後から別の人がコードを読むときにも困らないようにしたい。 具体的には、使用するにしても最小限に留めたり、あるいはパッケージ化やドキュメント化をしておくことが挙げられる。

今回使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V
Python 3.6.5
$ pip list --format=columns | grep -i pandas
pandas          0.23.3

下準備

まずは pandas をインストールしておこう。

$ pip install pandas

ここからは Python の REPL を使って解説していく。

$ python

DataFrame を拡張する

まずは一番よく使うであろう DataFrame の拡張方法から。 あんまり実用的な例じゃないけど、ひとまず API がどんな感じになってるかを紹介したいので。

pandas のオブジェクトを拡張するときは、基本的に pandas.api.extensions 以下にある API を用いる。 例えば DataFrame を拡張するなら @pandas.api.extensions.register_dataframe_accessor() デコレータでクラスを修飾する。 次のサンプルコードでは DataFrame に helloworld という名前空間で greet() メソッドと length プロパティを追加している。

>>> import pandas as pd
>>> 
>>> # "helloworld" という名前空間で DataFrame を拡張する
... @pd.api.extensions.register_dataframe_accessor('helloworld')
... class HelloWorldDataFrameAccessor(object):
...     """DataFrameを拡張するためのクラス"""
...     def __init__(self, df):
...         self._df = df
...     # DataFrame#helloworld に greet() メソッドを追加する
...     def greet(self):
...         """標準出力にメッセージを出す"""
...         print('Hello, World!')
...     # DataFrame#helloworld に length プロパティを追加する
...     @property
...     def length(self):
...         """DataFrameの長さを返す"""
...         return len(self._df)
... 

これだけで DataFrame の拡張ができる。

実際に DataFrame のインスタンスを作って、上記の動作を確認してみよう。

>>> df = pd.DataFrame(list(range(1, 11)), columns=['n'])
>>> df
    n
0   1
1   2
2   3
3   4
4   5
5   6
6   7
7   8
8   9
9  10

特に意味はないけど DataFrame#helloworld#greet() メソッドを実行すると標準出力にメッセージが出るようになる。

>>> df.helloworld.greet()
Hello, World!

あとは DataFrame#helloworld#length プロパティを参照すると DataFrame の長さが得られるようになる。

>>> df.helloworld.length
10

たしかに DataFrame に自分で拡張したメソッドやプロパティを生やすことができた。

Series を拡張する

続いては Series の拡張方法を紹介する。 基本的にやることは先ほどと同じなので、次はもうちょっと実用的な例を紹介してみる。

例えば Series をマルチプロセスで並列に処理したい、というシチュエーションを考えてみよう。 使ってるマシンの CPU コアがたくさんあって、扱うデータセットが大きいときは結構やりたくなるんじゃないかな? 典型的には、次のような高階関数を用意するはず。

>>> import multiprocessing as mp
>>> import numpy as np
>>> 
>>> def parallelize(f, data, n_jobs=None):
...     """関数の適用をマルチプロセスで処理する"""
...     if n_jobs is None:
...         # 並列度の指定がなければ CPU のコア数を用いる
...         n_jobs = mp.cpu_count()
...     # データを並列度の数で分割する
...     split_data = np.array_split(data, n_jobs)
...     # プロセスプールを用意する
...     with mp.Pool(n_jobs) as pool:
...         # 各プロセスで関数を適用した結果を結合して返す
...         return pd.concat(pool.map(f, split_data))
... 

続いて、マルチプロセスで適用したい関数を適当に用意する。

>>> def square(x):
...     return x * x
... 

そして、こんな感じで使う。

>>> parallelize(square, df.n)
0      1
1      4
2      9
3     16
4     25
5     36
6     49
7     64
8     81
9    100
Name: n, dtype: int64

先ほどの使い方でも構わないんだけど pandas のオブジェクトから直接呼び出せると便利そうなので拡張してみよう。 次のようにして Series に parallel という名前空間で apply() メソッドを追加する。

>>> # "parallel" という名前空間で Series を拡張する
... @pd.api.extensions.register_series_accessor('parallel')
... class ParallelSeriesAccessor(object):
...     """Seriesを拡張するためのクラス"""
...     def __init__(self, s):
...         self._s = s
...     # Series#parallel に apply() というメソッドを定義する
...     def apply(self, f):
...         """Series に対して関数を並列で適用する"""
...         return parallelize(f, self._s)
... 

すると Series#parallel#apply() メソッドが使えるようになる。

>>> df.n.parallel.apply(square)
0      1
1      4
2      9
3     16
4     25
5     36
6     49
7     64
8     81
9    100
Name: n, dtype: int64

呼び出し方が違うだけで、やっていることは先ほどと変わらない。

Index を拡張する

続いては Index を拡張してみよう。

以下のサンプルコードでは Index が整数という前提で偶数・奇数だけを抜き出す機能を追加している。 また、あんまり実用性がない例になっちゃった。

>>> # "sampling" という名前空間で Index を拡張する
... @pd.api.extensions.register_index_accessor('sampling')
... class SamplingIndexAccessor(object):
...     """Indexを拡張するためのクラス"""
...     def __init__(self, idx):
...         self._idx = idx
...     # Index#sampling に even というプロパティを定義する
...     @property
...     def even(self):
...         return self._idx[self._idx % 2 == 0]
...     # Index#sampling に odd というプロパティを定義する
...     @property
...     def odd(self):
...         return self._idx[self._idx % 2 != 0]
... 

早速試してみよう。

>>> df.index.sampling.even
Int64Index([0, 2, 4, 6, 8], dtype='int64')
>>> df.index.sampling.odd
Int64Index([1, 3, 5, 7, 9], dtype='int64')

ちゃんと偶数・奇数だけ取り出すことができた。

めでたしめでたし。

参考

Extending Pandas — pandas 0.23.3 documentation

Pythonによるデータ分析入門 第2版 ―NumPy、pandasを使ったデータ処理

Pythonによるデータ分析入門 第2版 ―NumPy、pandasを使ったデータ処理

  • 作者: Wes McKinney,瀬戸山雅人,小林儀匡,滝口開資
  • 出版社/メーカー: オライリージャパン
  • 発売日: 2018/07/26
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

Python: gzip モジュールを使ってデータを圧縮・解凍する

今回は Python の標準ライブラリの gzip モジュールの使い方について。 上手く使えば Python から大きなデータを扱うときにディスクの節約になるかな。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V
Python 3.6.5

まずは Python の REPL を起動しておく。

$ python

基本的な使い方

gzip モジュールの基本的な使い方としては、組み込み関数の open() っぽい使い勝手の gzip.open() 関数を使う。 この関数を通して得られたファイルライクオブジェクトに書き込むと、自動的に書き込んだデータが GZip で圧縮される。

試しに、実際にデータを書き込んでみよう。

>>> import gzip
>>> with gzip.open('example.txt.gz', mode='wt') as fp:
...     fp.write('Hello, World!\n')
... 
14

注意点としては、テキストデータ (ユニコード文字列) を扱うときは mode 引数に t を明示的に指定しなきゃいけない。 これは gzip.open() 関数がデフォルトではバイナリデータを扱うように作られているため。 明示的に t を指定しないとバイナリモードになる。 Python 3 における組み込みの open() 関数はテキストモードがデフォルトなので、ここは気をつける必要がある。

別のターミナルからファイルを確認すると、ちゃんと GZip 圧縮されたファイルができている。

$ file example.txt.gz           
example.txt.gz: gzip compressed data, was "example.txt", last modified: Wed Aug  1 13:23:58 2018, max compression

gzcat コマンドで内容を確認すると、ちゃんと書き込んだ内容が見える。

$ gzcat example.txt.gz 
Hello, World!

元の Python インタプリタに戻って、今度は読み込みをしてみよう。

>>> with gzip.open('example.txt.gz', mode='rt') as fp:
...     fp.read()
... 
'Hello, World!\n'

ちゃんと元の文字列が解凍できた。

gzip コマンドで圧縮したファイルからデータを読み出してみる

続いては Python 以外のアーカイバを使って圧縮したファイルを解凍できるか試してみよう。

gzip コマンドを使って圧縮したファイルを用意しておく。

$ echo "Hello, GZip" > greet.txt
$ gzip greet.txt
$ file greet.txt.gz  
greet.txt.gz: gzip compressed data, was "greet.txt", last modified: Wed Aug  1 13:27:00 2018, from Unix

先ほどと同じようにファイルからデータを読み込んでみよう。

>>> with gzip.open('greet.txt.gz', mode='rt') as fp:
...     fp.read()
... 
'Hello, GZip\n'

ちゃんと読み出せた。

ただし、公式ドキュメントを読むとサポートしていない形式もあるようだ。

13.2. gzip — gzip ファイルのサポート — Python 3.6.5 ドキュメント

バイナリデータを扱ってみる

登場機会として多そうなのは pickle モジュールとの組み合わせかな。 これも試してみよう。

pickle モジュールについては以下の記事で取り扱った。

blog.amedama.jp

以下の辞書データを GZip ファイルとして保存したい。

>>> d = {'message': 'Hello, World!'}

そこで、まずは pickle モジュールを使って、上記をバイト列に変換する。

>>> import pickle
>>> data = pickle.dumps(d)

こんな感じになった。

>>> data
b'\x80\x03}q\x00X\x07\x00\x00\x00messageq\x01X\r\x00\x00\x00Hello, World!q\x02s.'

上記のバイト列を gzip モジュール経由でファイルに書き込む。

>>> with gzip.open('dict.pickle.gz', mode='wb') as fp:
...     fp.write(data)
... 
41

別のターミナルから確認すると、ちゃんと GZip ファイルができている。

$ file dict.pickle.gz 
dict.pickle.gz: gzip compressed data, was "dict.pickle", last modified: Wed Aug  1 13:28:15 2018, max compression

書き込みはできたので、今度は読み込みを。

>>> with gzip.open('dict.pickle.gz', mode='rb') as fp:
...     data = fp.read()
... 
>>> data
b'\x80\x03}q\x00X\x07\x00\x00\x00messageq\x01X\r\x00\x00\x00Hello, World!q\x02s.'

さっきと同じバイト列が得られた。

pickle モジュールに読み込ませると、ちゃんと辞書データが元に戻せた。

>>> pickle.loads(data)
{'message': 'Hello, World!'}

いじょう。

unzip で "need PK compat. v5.1 (can do v4.5)" と言われて解凍できない件

ある日、パスワードつきの ZIP ファイルを macOS 組み込みの unzip コマンドで解凍しようとしたところ、タイトルのようなエラーになった。 今回は、その対処方法と、そもそもどういったときに起こるのかについて。

結論から先に要約してしまうと、次の通り。

  • ZIP ファイルのフォーマットにはバージョンがある
  • unzip コマンドがサポートしているバージョンが不足すると、このエラーになる
  • エラーメッセージにある "PK" とはオリジナルの ZIP アーカイバの名称からきている
  • 必要なバージョンをサポートしている ZIP アーカイバを代わりに使えば解決する

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ unzip | head -n 2
UnZip 6.00 of 20 April 2009, by Info-ZIP.  Maintained by C. Spieler.  Send
bug reports using http://www.info-zip.org/zip-bug.html; see README for details.
$ brew info p7zip
p7zip: stable 16.02 (bottled)
7-Zip (high compression file archiver) implementation
https://p7zip.sourceforge.io/
/usr/local/Cellar/p7zip/16.02_1 (103 files, 4.7MB) *
  Poured from bottle on 2018-07-31 at 23:22:00
From: https://github.com/Homebrew/homebrew-core/blob/master/Formula/p7zip.rb

概要

事象としては、次のように解凍できないファイルがある。 エラーメッセージでは PK のバージョンが足りない的なことを言っている。

$ unzip greet.txt.zip 
Archive:  greet.txt.zip
   skipping: greet.txt               need PK compat. v5.1 (can do v4.5)

上記について調べたところ ZIP ファイルのフォーマットにもバージョンがあることを知った。 また、ZIP アーカイバのオリジナル実装の名前は PKZIP という名前で、上記の "PK" はそこからきているらしい。 ようするに macOS 組み込みの unzip コマンドが新しい ZIP フォーマット (v5.1) に対応していないってことのようだ。

解決策

解決策としては、組み込みの unzip の代わりに p7zip をインストールして使えば良い。

$ brew install p7zip

まあ p7zip に限らず v5.1 フォーマットをサポートしているアーカイバなら何でも良いはず。

7za コマンドに e (extract) オプションをつけて解凍する。 パスワードを確認されるのでターミナルに入力する。

$ 7za e greet.txt.zip 

7-Zip (a) [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=utf8,Utf16=on,HugeFiles=on,64 bits,4 CPUs x64)

Scanning the drive for archives:
1 file, 215 bytes (1 KiB)

Extracting archive: greet.txt.zip
--
Path = greet.txt.zip
Type = zip
Physical Size = 215

Enter password (will not be echoed):
Everything is Ok

Size:       13
Compressed: 215

今度はちゃんと解凍できた。

$ cat greet.txt
Hello, World

unzip (6.00) で解凍できないファイルの作り方

p7zip を使って unzip コマンドで解凍できないファイルの作り方についても書いておく。

とりあえず圧縮前のファイルを適当に用意する。

$ echo "Hello, World" > greet.txt

あとは最近の暗号化形式を指定してパスワードつき ZIP ファイルを作る。

$ 7za a -tzip -ppassword -mem=AES256 greet.txt.zip greet.txt

これで v5.1 フォーマットの ZIP ファイルができた。

$ file greet.txt.zip 
greet.txt.zip: Zip archive data, at least v5.1 to extract
$ unzip greet.txt.zip 
Archive:  greet.txt.zip
   skipping: greet.txt               need PK compat. v5.1 (can do v4.5)

ちなみに暗号化形式を指定しないで作ると v2.0 フォーマットの ZIP ファイルになった。

$ rm greet.txt.zip
$ 7za a -tzip -ppassword greet.txt.zip greet.txt
$ file greet.txt.zip                            
greet.txt.zip: Zip archive data, at least v2.0 to extract

これなら macOS の unzip コマンドでも解凍できる。

$ unzip greet.txt.zip
Archive:  greet.txt.zip
[greet.txt.zip] greet.txt password: 
 extracting: greet.txt

いじょう。

Python: Selenium + Headless Chrome で Web ページ全体のスクリーンショットを撮る

スクレイピングした Web サイトからページ全体のスクリーンショットを撮影したい機会があった。 そこで Selenium の Python バインディングと Headless Chrome を使ったところ実現できたのでメモしておく。 ちなみに、ページ全体でなければ Headless Chrome 単体でも撮れる。 その方法についても末尾に補足として記載しておいた。

使った環境は次の通り。

$ sw_vers             
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V   
Python 3.6.5
$ pip list --format=columns | grep -i selenium
selenium   3.13.0 
$ chromedriver --version
ChromeDriver 2.40.565386 (45a059dc425e08165f9a10324bd1380cc13ca363)

下準備

まずは Chrome Driver をインストールしておく。

$ brew cask install chromedriver

続いて Selenium の Python バインディングをインストールする。

$ pip install selenium

インストールできたら Python の REPL を起動する。

$ python

これで下準備ができた。

スクリーンショットを撮る

まずは必要なパッケージをインポートする。

>>> from selenium import webdriver
>>> from selenium.webdriver.chrome.options import Options

続いて Chrome Driver を Headless モードで起動する。

>>> options = Options()
>>> options.add_argument('--headless')
>>> driver = webdriver.Chrome(options=options)

スクリーンショットを撮りたいページの内容を取得する。

>>> driver.get('https://www.python.org')

続いてページ全体の横幅と高さを取得したら、それをウィンドウサイズとしてセットする。 これによって Web ページ全体のスクリーンショットを一つの画像として撮影できる。 ちなみに Headless モードにしていないと、ここで設定できる値がモニターの解像度に依存してしまう。

>>> page_width = driver.execute_script('return document.body.scrollWidth')
>>> page_height = driver.execute_script('return document.body.scrollHeight')
>>> driver.set_window_size(page_width, page_height)

ウィンドウサイズを変更したら、ページ内の要素が読み込まれるまで少し待った方が良い。 読み込みを待ちたい要素があらかじめ分かっているなら Wait をかけることもできそう。

5. Waits — Selenium Python Bindings 2 documentation

あとはスクリーンショットを撮るだけ。 上手くいけば返り値として True が返ってくる。

>>> driver.save_screenshot('screenshot.png')
True

上手くいったらドライバと REPL を終了する。

>>> driver.quit()
>>> exit()

撮影した画像を確認してみよう。

$ open screenshot.png

次のような結果が得られた。

f:id:momijiame:20180728183758p:plain

めでたしめでたし。

補足

ちなみに Headless Chrome 単体でもスクリーンショット自体は撮れる。 ただし、このやり方では残念ながらページ全体を撮ることはできない。

$ alias chrome="/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome"
$ chrome --headless --disable-gpu --screenshot https://www.python.org

こんな感じになっちゃう。

f:id:momijiame:20180728184056p:plain

参考

developers.google.com

Python: パラメータ選択を伴う機械学習モデルの交差検証について

今回は、ハイパーパラメータ選びを含む機械学習モデルの交差検証について書いてみる。 このとき、交差検証のやり方がまずいと汎化性能を本来よりも高く見積もってしまう恐れがある。 汎化性能というのは、未知のデータに対処する能力のことを指す。 ようするに、いざモデルを実環境に投入してみたら想定よりも性能が出ない (Underperform) ということが起こる。 これを防ぐには、交差検証の中でも Nested Cross Validation (Nested CV) あるいは Double Cross Validation と呼ばれる手法を使う。

ハイパーパラメータの選び方としては、色々な組み合わせをとにかく試すグリッドサーチという方法を例にする。 また、モデルのアルゴリズムにはサポートベクターマシンを使った。 これは、サポートベクターマシンはハイパーパラメータの変更に対して敏感な印象があるため。

その他、使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V
Python 3.6.5

下準備

まずは、今回のサンプルコードで使うパッケージをあらかじめインストールしておく。

$ pip install scikit-learn numpy scipy tqdm matplotlib

続いて Python の REPL を起動する。

$ python

起動したら scikit-learn に組み込みで用意されている乳がんデータセットを読み込んでおこう。 これには、しこりに関する特徴量とそれが良性か悪性かの情報が含まれる。

>>> from sklearn import datasets
>>> 
>>> dataset = datasets.load_breast_cancer()
>>> X = dataset.data
>>> y = dataset.target

これで準備が整った。

ここからは、前提となる知識として交差検証に至る機械学習モデルを評価するやり方を一つずつ紹介していく。 おそらく、知っている内容も多いと思うので必要に応じて読み飛ばしてもらえると。

学習に用いたデータでモデルを評価する

まずは、最もダメなパターンから。 これは、モデルの学習に用いたデータを使って、そのモデルを評価するというもの。 これをやってしまうと、汎化性能は全く測れない。 なにせ全然未知ではなく、モデルが既に見たことのあるデータなのだから。

概念図はこんな感じ。

f:id:momijiame:20180723033816p:plain

とはいえ、ダメなパターンについても見ておくことは重要なので以下にサンプルコードを示す。 まずはサポートベクターマシンの分類器を用意して、データセットを全て使って学習させる。

>>> from sklearn.svm import SVC
>>> 
>>> svm = SVC(kernel='rbf')
>>> svm.fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

続いて、学習したモデルを使ってデータセット全てに対して予測する。

>>> y_pred = svm.predict(X)

これは、学習に使ったデータを、そのまま予測しているということ。

予測結果の精度 (Accuracy) を計算してみよう。

>>> from sklearn.metrics import accuracy_score
>>> 
>>> accuracy_score(y, y_pred)
1.0

なんと 100% の精度が得られた!ヤッター! …といっても、これはモデルが一度見たことのあるデータを予測しているだけなので、何ら驚くには値しない。 もちろん、これではモデルの汎化性能は測れない。

ホールドアウト検証 (Hold-out Validation)

続いては、汎化性能をそれなりに評価するための方法としてホールドアウト検証を紹介する。 これは、データセットを学習用とテスト用に分割する。 そして、学習用のデータをモデルに学習させた上で、テスト用のデータを使ってモデルの性能を評価するというもの。 テスト用のデータはモデルにとって見たことのない未知のデータなので、これは汎化性能を示す指標となりうる。

概念図はこんな感じ。

f:id:momijiame:20180723032847p:plain

先ほどと同様に、サンプルコードを示す。 まずは、データセットを学習用とテスト用に分割する。

>>> from sklearn.model_selection import train_test_split
>>> 
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, shuffle=True, random_state=42)

分割方法に再現性を持たせたい場合には random_state オプションを指定した方が良い。 この数値を指定して、他のオプションについても値が同じである限りデータが同じように分割される。 また、データの分割に偏りを作らないためには shuffle オプションを有効にしてランダムにデータを選択した方が良い。 もちろん、ただランダムに分割するだけでは偏りを取り除けない場合には、それ以外のやり方で分割する必要がある。

データを分割したら、学習用データの方を使ってモデルを学習する。

>>> svm = SVC(kernel='rbf')
>>> svm.fit(X_train, y_train)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

今度は 64.3% の精度が得られた。 こちらの方が、先ほどよりもモデルの性能を現実的に示している。

>>> y_pred = svm.predict(X_test)
>>> accuracy_score(y_test, y_pred)
0.6436170212765957

交差検証 (Cross Validation)

ホールドアウト検証法はデータを分割しているとはいっても一回だけの試行なので偏りが含まれる余地が比較的ある。 この偏りを減らすには交差検証というやり方を用いる。 これは、複数回に渡って異なる分割をしたデータに対し、それぞれでホールドアウト検証をして結果を合算するというもの。

概念図はこんな感じ。

f:id:momijiame:20180723032916p:plain

scikit-learn では KFold を使うと交差検証が楽にできる。 以下のサンプルコードでは分割数 (試行回数) として 4 を指定した。

>>> from sklearn.model_selection import KFold
>>> 
>>> kf = KFold(n_splits=4, shuffle=True, random_state=42)

交差検証のスコアは cross_val_score を使うと楽に計算できる。 といっても、これは先ほどのホールドアウト検証を KFold を使ってループしながら実行しているだけ。 自分で書いても全然問題はない。

>>> from sklearn.model_selection import cross_val_score
>>> 
>>> svm = SVC(kernel='rbf')
>>> scores = cross_val_score(svm, X=X, y=y, cv=kf)

結果としては、各ホールドアウト検証における性能が得られる。

>>> scores
array([0.62237762, 0.69014085, 0.61971831, 0.57746479])

一般的には、上記を単純に算術平均すると思う。

>>> average_score = scores.mean()
>>> average_score
0.6274253915098986

ちなみに分割数をデータ点数まで増やした場合は Leave-One-Out 検証法と呼ばれる。 機械学習系の文章の中では、よく LOO と省略されていることがある。

ハイパーパラメータの選択を含む交差検証

ここからが本題。 先ほどのサンプルコードでは、基本的にサポートベクターマシンのモデルをデフォルトのハイパーパラメータで扱っていた。 ただ、実際に使うときはハイパーパラメータの調整が必要になる。 このとき、ただ単純に交差検証をするだけだとモデルの性能を高く見積もってしまう恐れがある。

上記をサンプルコードと共に確認する。 まずは、先ほどと同じようにサポートベクターマシンのモデルと交差検証用のオブジェクトを用意する。

>>> svm = SVC(kernel='rbf')
>>> kf = KFold(n_splits=4, shuffle=True, random_state=42)

ハイパーパラメータの候補は、次のように辞書とリストを組み合わせて用意する。

>>> candidate_params = {
...     'C': [1, 10, 100],
...     'gamma': [0.01, 0.1, 1],
... }

GridSearchCV にモデルとハイパーパラメータの候補を渡して、データを学習させる。 GridSearchCV は名前に CV と入っている通り、内部的に交差検証を使いながら性能の良いハイパーパラメータの組み合わせを探してくれる。

>>> from sklearn.model_selection import GridSearchCV
>>> from multiprocessing import cpu_count
>>> 
>>> gs = GridSearchCV(estimator=svm, param_grid=candidate_params, cv=kf, n_jobs=cpu_count())
>>> gs.fit(X, y)
GridSearchCV(cv=KFold(n_splits=4, random_state=42, shuffle=True),
       error_score='raise',
       estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False),
       fit_params=None, iid=True, n_jobs=4,
       param_grid={'C': [1, 10, 100], 'gamma': [0.01, 0.1, 1]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)

学習が終わると、最も性能の良かったハイパーパラメータで学習したモデルが GridSearchCV#best_estimator_ で得られる。

>>> gs.best_estimator_
SVC(C=10, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma=0.01, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

同様に GridSearchCV#best_params_ で最も性能の良かったハイパーパラメータの組み合わせが得られる。

>>> gs.best_params_
{'C': 10, 'gamma': 0.01}

また、GridSearchCV#best_score_ から上記のモデルが記録したスコアも得られる。

>>> gs.best_score_
0.6344463971880492

では、上記のハイパーパラメータを使ったモデルなら未知のデータに対して 63.4% の精度が得られるはずかというと、そうでもないらしい。 これは、一回の交差検証だけだと精度が偏って得られることも考えられるため。 ようするに、大きく外してはいないはずだけど見積もりとしては楽観的なものになる。 この、一回の交差検証だけで評価するやり方を Non-nested Cross Validation (Non-nested CV) という。

概念図としてはこんな感じ。 さっきの単純な交差検証をハイパーパラメータの組み合わせごとにやっているだけ。

f:id:momijiame:20180723033334p:plain

Nested Cross Validation (Nested CV)

前述した問題をどうやって解決するかというと、交差検証を二重にする。 この方法は Nested CV と呼ばれる。

概念図はこんな感じ。

f:id:momijiame:20180723033015p:plain

Nested CV では、交差検証を内側 (Inner CV) と外側 (Outer CV) の二重に分けている。 内側ではハイパーパラメータの選択に注力し、外側はできたモデルの評価に注力する。 ポイントとしては、それぞれで重複するデータをモデルに触らせていないところ。 一度でもモデルに見せたデータはその時点で汚れてしまうため、評価する上で二度と使うことはできない。

サンプルコードで Nested CV を見ていこう。 まずは、先ほどと同じようにグリッドサーチ用のオブジェクトまで作っておく。

>>> svm = SVC(kernel='rbf')
>>> kf = KFold(n_splits=4, shuffle=True, random_state=42)
>>> gs = GridSearchCV(estimator=svm, param_grid=candidate_params, cv=kf, n_jobs=cpu_count())

続いて、上記の GridSearchCV のインスタンスをさらに cross_val_score() 関数に突っ込む。

>>> scores = cross_val_score(gs, X=X, y=y, cv=kf)

上記は正直なかなか分かりにくいので、順を追って解説する。 まず、cross_val_score() 関数が前述した外側の交差検証になっている。 外側で分割した学習用データのみが GridSearchCV のインスタンスに渡される。 GridSearchCV のインスタンスは、渡された学習用データをさらに分割して学習用データとハイパーパラメータ調整用データにする。 そして、そのデータを使ってハイパーパラメータを選択する。 これが前述した内側の交差検証になる。 ハイパーパラメータの選択が終わったら、できあがったモデルが外側の交差検証で評価される。 ようするに、外側の交差検証の分割数×内側の交差検証の分割数×ハイパーパラメータの組み合わせの数だけホールドアウト検証を繰り返すことになる。

上記で得られたスコアが以下の通り。 ようするに、これは各内側の交差検証で性能の良かったモデルたちが外側の交差検証で記録した性能ということになる。

>>> scores
array([0.62937063, 0.69014085, 0.61971831, 0.58450704])

上記の算術平均は次の通り。 先ほどの Non-nested CV よりも、ほんの少しではあるが下がっている。

>>> average_score = scores.mean()
>>> average_score
0.6309342066384319

Non-nested CV に比べると、この Nested CV で記録した値の方が現実に則した汎化性能を表している、とされる。

Non-nested CV と Nested CV が記録するスコアを比較する

先ほどの例では Non-nested CV よりも Nested CV の方が低めのスコアが出た。 一回だけならたまたまということも考えられるので、念のため何度か試行してグラフにプロットしてみる。

次のサンプルコードでは Non-nested CV と Nested CV を 50 回繰り返して、それぞれが記録するスコアをプロットする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from multiprocessing import cpu_count

from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV

import numpy as np

from matplotlib import pyplot as plt

from tqdm import tqdm


def main():
    NUM_TRIALS = 50

    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X = dataset.data
    y = dataset.target

    # 候補となるハイパーパラメータ
    candidate_params = {
        'C': [1, 10, 100],
        'gamma': [0.01, 0.1, 1, 'auto'],
    }

    # 計測したスコアを保存するためのリスト
    scores_non_nested_cv = np.zeros(NUM_TRIALS)
    scores_nested_cv = np.zeros(NUM_TRIALS)

    # 何回か試してみる
    for i in tqdm(range(NUM_TRIALS)):

        # Non Nested CV
        svm = SVC(kernel='rbf')
        kf = KFold(n_splits=4, shuffle=True, random_state=i)
        gscv = GridSearchCV(estimator=svm, param_grid=candidate_params, cv=kf, n_jobs=cpu_count())
        gscv.fit(X, y)
        scores_non_nested_cv[i] = gscv.best_score_

        # Nested CV
        svm = SVC(kernel='rbf')
        kf = KFold(n_splits=4, shuffle=True, random_state=i)
        gs = GridSearchCV(estimator=svm, param_grid=candidate_params, cv=kf, n_jobs=cpu_count())
        scores = cross_val_score(gs, X=X, y=y, cv=kf)
        scores_nested_cv[i] = scores.mean()

    # スコア平均と標準偏差
    print('non nested cv: mean={:.5f} std={:.5f}'.format(scores_non_nested_cv.mean(), scores_non_nested_cv.std()))
    print('nested cv: mean={:.5f} std={:.5f}'.format(scores_nested_cv.mean(), scores_nested_cv.std()))

    # グラフを描画する
    plt.figure(figsize=(10, 6))
    plt.plot(scores_non_nested_cv, color='g', label='non nested cv')
    plt.plot(scores_nested_cv, color='b', label='nested cv')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    main()

適当な名前をつけて上記を保存したら実行してみよう。

$ python cv.py 
 36%|█████████████████▎                              | 18/50 [00:47<01:24,  2.64s/it]

計算には時間が結構かかるので tqdm を使って進捗を表示させている。

tqdm については、以前に以下の記事で紹介している。

blog.amedama.jp

50 回試した上での精度の平均は次の通り。 Nested CV の方が 0.3% ほど精度を低く見積もっていることが分かる。

non nested cv: mean=0.63195 std=0.00261
nested cv: mean=0.62878 std=0.00255

得られたグラフは次の通り。 一部に例外はあるものの、基本的には Nested CV の方が Non-nested CV よりも精度を低く見積もっている。

f:id:momijiame:20180722140025p:plain

疑問と悩み

めでたしめでたし。 と、言いたいところなんだけど、いくつか自分でもまだ完全には腑に落ちていないところがある。

学習に使うデータが減る問題

Nested CV の方が Non-nested CV よりも精度の見積もりは低く出た。 とはいえ Nested CV では Non-nested CV よりもモデルの学習に使うデータ自体も減っている。 これは、データの分割が Non-nested CV では二つなのに対して Nested CV では三つになっているため。 具体的には Nested CV ではデータを学習用、ハイパーパラメータ調整用、検証用に分割することになる。 対して Non-nested CV では学習用と検証用にしか分割していない。

学習に使うデータが減れば、バイアス以外にもそれだけで精度が低くなる余地があるように感じる。 かといって、ハイパーパラメータ調整用や検証用のデータを減らすと、精度の分散が大きくなってモデル選択が難しくなるような気がする。 まあ、もちろんそれで Nested CV をやらない理由にはならないだろうけど。

一体どのハイパーパラメータを選べば良いのよ問題

Nested CV では、内側の交差検証で選ばれてくるモデルたちのハイパーパラメータがどれも同一とは限らないはず。 同一でない場合には、結局のところどのハイパーパラメータの組み合わせを選べば良いのよ?となる。 まあ、本当にバラバラならどれを使っても似たようなものなんだ、という理解にはつながるかもしれないけど。

もし結果をそのまま使いたいなら、内側の交差検証で選ばれた各モデルを使ってアンサンブル (Voting) すると良いのかな? 実際のところ、ハイパーパラメータの目星がついたからといって、改めてモデルに未分割の全データを学習させて同じ汎化性能が得られるとは限らない。 交差検証をしていないモデルからは、どんな結果が得られてもおかしくはないのだから。

参考

Nested versus non-nested cross-validation — scikit-learn 0.19.2 documentation

Nested Cross Validation: When Cross Validation Isn’t Enough

いじょう。

Python: tqdm で処理の進捗状況をプログレスバーとして表示する

最近は Python がデータ分析や機械学習の分野でも使われるようになってきた。 その影響もあって REPL や Jupyter Notebook 上でインタラクティブに作業することも増えたように感じる。 そんなとき、重い処理を走らせると一体いつ終わるのか分からず途方に暮れることもある。 今回紹介する tqdm は、走らせた処理の進捗状況をプログレスバーとして表示するためのパッケージ。 このパッケージ自体はかなり昔からあるんだけど、前述した通り利用環境の変化や連携するパッケージの増加によって便利さが増してきてる感じ。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V    
Python 3.6.5

もくじ

下準備

まずは tqdm をインストールしておく。

$ pip install tqdm
$ pip list --format=columns | grep tqdm 
tqdm       4.23.4

終わったら Python の REPL を起動する。

$ python

基本的な使い方

ここからは tqdm の基本的な使い方を紹介する。

その前に、まずは tqdm がない場合から考えてみる。 次のサンプルコードでは 100 回のループを 100 ミリ秒の間隔を空けて実行している。 実行すると、何もレスポンスがない状態で 10 秒間待たされることになる。

>>> import time
>>> 
>>> for _ in range(100):
...     time.sleep(0.1)
... 

実際に実行してみると 10 秒間とはいえ長く感じる。

それでは、続いて上記の処理に tqdm を導入してみる。 変更点は一箇所だけで、上記の range() 関数の結果を tqdm() 関数に渡すだけ。 これだけで tqdm は渡された内容を読み取って全体の処理と現在の進捗をプログレスバーとして表示してくれる。

>>> from tqdm import tqdm
>>> 
>>> for _ in tqdm(range(100)):
...     time.sleep(0.1)
... 
 63%|█████████████████████████████▌                 | 63/100 [00:06<00:03,  9.69it/s]

プログレスバーがあるだけで同じ待ち時間でも感じ方はだいぶ変わるはず。

上記を見ると、どういう仕組みなのか結構気になる。 そもそも tqdm に渡せるオブジェクトは一体なんなのか? 答えから言ってしまうと tqdm にはイテラブルなオブジェクトなら何でも渡せる。 イテラブルなオブジェクトというのは、具体的には iter() 関数を使ってイテレータが返ってくるもの。

そもそもイテレータって何?っていう話については以下の記事に書いた。

blog.amedama.jp

なので、もちろんリストを渡すこともできるし。

>>> for _ in tqdm(list([1, 2, 3, 4])):
...     time.sleep(1)
... 
100%|██████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.00s/it]

何なら文字列だって渡すことができる。

>>> for _ in tqdm('Hello, World!'):
...     time.sleep(0.1)
... 
100%|████████████████████████████████████████████████| 13/13 [00:01<00:00,  9.73it/s]

イテレータ自体には終わりを設ける必要はない。 なので、次のように無限に値を返し続けるイテレータを渡しても良い。 ただし、この場合はイテレータからオブジェクトを取り出した回数や、経過時間やスループットだけが表示される。

>>> from itertools import count
>>> 
>>> for _ in tqdm(count()):
...     time.sleep(0.01)
... 
417it [00:04, 85.10it/s]

ひとしきり満足したら Ctrl-C で止めよう。

pandas と連携させる

tqdm は pandas と連携させることもできる。

まずは pandas をインストールしよう。

$ pip install pandas
$ pip list --format=columns | grep pandas
pandas            0.23.3 

サンプルとなる DataFrame オブジェクトを用意しておく。

>>> import pandas as pd
>>> df = pd.DataFrame(list(range(10000)))

この状態では DataFrame には pregress_apply() というメソッドは存在しない。

>>> df.progress_apply
Traceback (most recent call last):
...(省略)...
AttributeError: 'DataFrame' object has no attribute 'progress_apply'

そこで、おもむろに tqdm をインポートしたら pandas() 関数を呼び出してみよう。

>>> from tqdm import tqdm
>>> tqdm.pandas()

すると DataFrame オブジェクトに progress_apply() メソッドが生えてきて使えるようになる。 これは単純に DataFrame#apply() メソッドの進捗表示ありバージョンと考えれば良い。

>>> df.progress_apply(lambda x: x ** 2, axis=1)
 96%|███████████████████████████████████████▎ | 9577/10000 [00:01<00:00, 5944.60it/s]

DataFrame#apply() 関数は結構重い処理をすることも多い (特に axis=1 のとき) ので、これは意外とありがたい。 ただし、それ以外のメソッドについてはこれまで通り何も表示されない。

Jupyter Notebook と連携させる

また、Jupyter Notebook と連携させることもできる。

まずは Jupyter Notebook 本体と ipyqidgets をインストールしておこう。

$ pip install notebook ipywidgets
$ pip list --format=columns | grep notebook
notebook            5.6.0

ノートブックのサーバを起動する。

$ jupyter notebook

適当なノートブックを新たに作ったら、次のコードをセルに入力して実行してみよう。 ターミナルとの違いはインポートするものが tqdm.tqdm から tqdm.tqdm_notebook に変わるだけ。

from tqdm import tqdm_notebook as tqdm
import time

for _ in tqdm(range(100)):
    time.sleep(0.1)

すると、次のようにプログレスバーが表示される。

f:id:momijiame:20180721130000p:plain

ちなみに、別に普通の tqdm.tqdm が使えないというわけではない。 試しに、最初に示した例を入力して実行してみよう。

from tqdm import tqdm
import time

for _ in tqdm(range(100)):
    time.sleep(0.1)

上記ほどしっかりとした表示ではないものの、次のようにちゃんと表示してくれる。

f:id:momijiame:20180721130420p:plain

めでたしめでたし。

Python: matplotlib で動的にグラフを生成する

今回は matplotlib を使って動的にグラフを生成する方法について。 ここでいう動的というのは、データを逐次的に作って、それを随時グラフに反映していくという意味を指す。 例えば機械学習のモデルを学習させるときに、その過程 (損失の減り方とか) を眺める用途で便利だと思う。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.13.5
BuildVersion:   17F77
$ python -V
Python 3.6.5
$ pip list --format=columns | egrep -i "(matplotlib|pillow)"
matplotlib      2.2.2  
Pillow          5.2.0  

もくじ

下準備

まずは今回使うパッケージをインストールしておく。

$ pip install matplotlib pillow

静的にグラフを生成する

動的な生成について説明する前に、まずは静的なグラフの生成から説明する。 といっても、これは一般的な matplotlib のグラフの作り方そのもの。 あらかじめ必要なデータを全て用意しておいて、それをグラフとしてプロットする。

この場合、当たり前だけどプロットする前に全てのデータが揃っていないといけない。 例えば機械学習なら、モデルの学習を終えて各エポックなりラウンドごとの損失が出揃っている状態まで待つ必要がある。

次のサンプルコードではサイン波のデータをあらかじめ作った上で、それを折れ線グラフにしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math

import numpy as np
from matplotlib import pyplot as plt


def main():
    # 描画領域
    fig = plt.figure(figsize=(10, 6))
    # 描画するデータ
    x = np.arange(0, 10, 0.1)
    y = [math.sin(i) for i in x]

    # グラフを描画する
    plt.plot(x, y) 

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を適当な名前でファイルに保存して実行してみよう。

$ python sin.py

すると、次のようなグラフが表示される。

f:id:momijiame:20180712235912p:plain

これが静的なグラフ生成の場合。

動的にグラフを生成する

続いて動的にグラフを生成する方法について。 これには matplotlib.animation パッケージを使う。 特に FuncAnimation を使うと作りやすい。

matplotlib.animation — Matplotlib 3.1.1 documentation

次のサンプルコードでは、先ほどの例と同じサイン波を動的に生成している。 ポイントは、グラフの再描画を担当する関数をコールバックとして FuncAnimation に登録すること。 そうすれば、あとは FuncAnimation が一定間隔でその関数を呼び出してくれる。 呼び出されるコールバック関数の中でデータを生成したりグラフを再描画する。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation


def _update(frame, x, y):
    """グラフを更新するための関数"""
    # 現在のグラフを消去する
    plt.cla()
    # データを更新 (追加) する
    x.append(frame)
    y.append(math.sin(frame))
    # 折れ線グラフを再描画する
    plt.plot(x, y)


def main():
    # 描画領域
    fig = plt.figure(figsize=(10, 6))
    # 描画するデータ (最初は空っぽ)
    x = []
    y = []

    params = {
        'fig': fig,
        'func': _update,  # グラフを更新する関数
        'fargs': (x, y),  # 関数の引数 (フレーム番号を除く)
        'interval': 10,  # 更新間隔 (ミリ秒)
        'frames': np.arange(0, 10, 0.1),  # フレーム番号を生成するイテレータ
        'repeat': False,  # 繰り返さない
    }
    anime = animation.FuncAnimation(**params)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

先ほどと同じようにファイルに保存したら実行する。

$ python sin.py

すると、次のように動的にグラフが描画される。

f:id:momijiame:20180712235931g:plain

ちなみに、上記のような GIF 画像や動画は次のようにすると保存できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation


def _update(frame, x, y):
    """グラフを更新するための関数"""
    # 現在のグラフを消去する
    plt.cla()
    # データを更新 (追加) する
    x.append(frame)
    y.append(math.sin(frame))
    # 折れ線グラフを再描画する
    plt.plot(x, y)


def main():
    # 描画領域
    fig = plt.figure(figsize=(10, 6))
    # 描画するデータ (最初は空っぽ)
    x = []
    y = []

    params = {
        'fig': fig,
        'func': _update,  # グラフを更新する関数
        'fargs': (x, y),  # 関数の引数 (フレーム番号を除く)
        'interval': 10,  # 更新間隔 (ミリ秒)
        'frames': np.arange(0, 10, 0.1),  # フレーム番号を生成するイテレータ
        'repeat': False,  # 繰り返さない
    }
    anime = animation.FuncAnimation(**params)

    # グラフを保存する
    anime.save('sin.gif', writer='pillow')


if __name__ == '__main__':
    main()

グラフを延々と描画し続ける

先ほどの例では frames オプションに渡すイテレータに終わりがあった。 具体的には 0 ~ 10 の範囲を 0.1 区切りで分割した 100 のデータに対してグラフを生成した。 また repeat オプションに False を指定することで繰り返し描画することも抑制している。

続いては、先ほどとは異なり frames オプションに終わりのないイテレータを渡してみよう。 こうすると、手動で止めるかメモリなどのリソースを食いつぶすまでは延々とデータを生成してグラフを描画することになる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import itertools
import math

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation


def _update(frame, x, y):
    """グラフを更新するための関数"""
    # 現在のグラフを消去する
    plt.cla()
    # データを更新 (追加) する
    x.append(frame)
    y.append(math.sin(frame))
    # 折れ線グラフを再描画する
    plt.plot(x, y)


def main():
    # 描画領域
    fig = plt.figure(figsize=(10, 6))
    # 描画するデータ
    x = []
    y = []

    params = {
        'fig': fig,
        'func': _update,  # グラフを更新する関数
        'fargs': (x, y),  # 関数の引数 (フレーム番号を除く)
        'interval': 10,  # 更新間隔 (ミリ秒)
        'frames': itertools.count(0, 0.1),  # フレーム番号を無限に生成するイテレータ
    }
    anime = animation.FuncAnimation(**params)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記をファイルに保存して実行してみよう。

$ python sin.py

ずーーーっとグラフが生成され続けるはず。

f:id:momijiame:20180715153531p:plain

Jupyter Notebook 上で動的にグラフを生成する

Jupyter Notebook 上で動的なグラフ生成をするときは、次のように %matplotlib nbagg マジックコマンドを使う。 また、意図的に pyplot.show() を呼び出す必要はない。

%matplotlib nbagg

import itertools
import math

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation


def _update(frame, x, y):
    """グラフを更新するための関数"""
    # 現在のグラフを消去する
    plt.cla()
    # データを更新 (追加) する
    x.append(frame)
    y.append(math.sin(frame))
    # 折れ線グラフを再描画する
    plt.plot(x, y)


# 描画領域
fig = plt.figure(figsize=(10, 6))
# 描画するデータ
x = []
y = []

params = {
    'fig': fig,
    'func': _update,  # グラフを更新する関数
    'fargs': (x, y),  # 関数の引数 (フレーム番号を除く)
    'interval': 10,  # 更新間隔 (ミリ秒)
    'frames': np.arange(0, 10, 0.1),  # フレーム番号を生成するイテレータ
    'repeat': False,  # 繰り返さない
}
anime = animation.FuncAnimation(**params)

データの更新間隔とグラフの再描画間隔をずらす

グラフの再描画はそこまで軽い処理でもないし、データの更新間隔とずらしたいときもあるかも。 そんなときはデータを更新するスレッドと、グラフを再描画するスレッドを分ける。

以下のサンプルコードではデータの更新用に新しくスレッドを起動している。 データの更新間隔が 100ms 間隔なのに対してグラフの再描画は 250ms 間隔にしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import threading
import math
import itertools

from matplotlib import pyplot as plt
from matplotlib import animation


def _redraw(_, x, y):
    """グラフを再描画するための関数"""
    # 現在のグラフを消去する
    plt.cla()
    # 折れ線グラフを再描画する
    plt.plot(x, y)


def main():
    # 描画領域
    fig = plt.figure(figsize=(10, 6))
    # 描画するデータ (最初は空っぽ)
    x = []
    y = []

    def _update():
        """データを一定間隔で追加するスレッドの処理"""
        for frame in itertools.count(0, 0.1):
            x.append(frame)
            y.append(math.sin(frame))
            # データを追加する間隔 (100ms)
            time.sleep(0.1)

    def _init():
        """データを一定間隔で追加するためのスレッドを起動する"""
        t = threading.Thread(target=_update)
        t.daemon = True
        t.start()

    params = {
        'fig': fig,
        'func': _redraw,  # グラフを更新する関数
        'init_func': _init,  # グラフ初期化用の関数 (今回はデータ更新用スレッドの起動)
        'fargs': (x, y),  # 関数の引数 (フレーム番号を除く)
        'interval': 250,  # グラフを更新する間隔 (ミリ秒)
    }
    anime = animation.FuncAnimation(**params)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記をファイルに保存して実行してみよう。

$ python sin.py

これまでに比べるとグラフの再描画間隔が長いので、ちょっとカクカクした感じでグラフが更新される。

めでたしめでたし。