CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: Keras で imdb データセットを読もうとするとエラーになる問題と回避策について

今回は、表題の通り Keras の API を使ってダウンロードできる imdb データセットを読もうとするとエラーになる問題について。

これは数ヶ月前から既知の問題で、以下のチケットが切られている。 内容については細かく読まなくても、詳しくは後述する。

github.com

問題を修正するコードは Git リポジトリの HEAD にはマージされている。 しかし、現時点 (2019-06-14) ではまだ修正済みのバージョンがリリースされていない。

github.com

そして、この問題について検索すると、以下の二つの回避策の提案が見つかる。

  • NumPy のバージョンを 1.16.2 以下にダウングレードする
  • インストール済みの Keras のソースコードを手動で書き換える

最初のやり方は、実は潜在的に脆弱性のある NumPy のバージョンを使うことを意味している。 また、二番目のやり方は正直あまりやりたくない類のオペレーションのはず。 そこで、上記とは異なる第三の回避策としてモンキーパッチを使う方法を提案してみる。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V                          
Python 3.7.3
$ pip list | egrep -i "(keras|numpy)"
Keras               2.2.4  
Keras-Applications  1.0.8  
Keras-Preprocessing 1.1.0  
numpy               1.16.4 

再現環境を作る

ひとまず再現環境を作るための準備として Keras とバックエンドの TensorFlow をインストールしておく。

$ pip install keras tensorflow

準備ができたら Python のインタプリタを起動する。

$ python

問題を再現する

この問題を再現するのは非常に簡単で、ただ imdb データセットを読み込もうとすれば良い。

まずは imdb モジュールをインポートする。

>>> from keras.datasets import imdb
Using TensorFlow backend.

そして、load_data() 関数を呼ぶだけ。 すると、以下のように例外になってしまう。

>>> imdb.load_data()
Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 3s 0us/step
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/keras/datasets/imdb.py", line 59, in load_data
    x_train, labels_train = f['x_train'], f['y_train']
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/numpy/lib/npyio.py", line 262, in __getitem__
    pickle_kwargs=self.pickle_kwargs)
  File "/Users/amedama/.virtualenvs/example/lib/python3.7/site-packages/numpy/lib/format.py", line 696, in read_array
    raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False

問題の詳細

この問題は NumPy の脆弱性に対する対応と、imdb が Pickle 形式の npz フォーマットで配布されていることに起因している。

まず、発端は以下の脆弱性 CVE-2019-6446 に始まる。 この脆弱性は、誤って信頼できない (細工された) Pickle を NumPy で読み込んでしまうと任意のコード実行が生じるというもの。

nvd.nist.gov

上記の脆弱性に対する対応として、NumPy はバージョン 1.16.3 以降で以下のようにコードを修正した。 具体的には、意図的にフラグ (allow_pickle=True) を有効にしない限り Pickle フォーマットのデータを読めないようにしている。

github.com

その煽りを受けたのが Keras の imdb データセットだった。 Pickle 形式のデータセットを NumPy デフォルトのオプションで読み込んでいた。 そのため、前述したように NumPy 1.16.3 以降を使うと例外になってしまう。

上記のような事情があるため、前述した通り Web を探すと以下のような回避策が提案されている。

  • NumPy のバージョンを 1.16.2 以下にダウングレードする
  • インストール済みの Keras のソースコードを書き換える

とはいえ、どちらもあまりやりたくないのは前述した通り。

第三の選択肢 (モンキーパッチ)

そこで提案するのが、モンキーパッチを使うやり方。 これは、データセットを読み込むタイミングだけ、一時的にピンポイントでコードを動的に書き換えてしまうというもの。 問題は NumPy の load() 関数がデフォルトのオプションのまま呼び出される点にある。 だとすると、関数が呼ばれるタイミングだけオプションを一時的に上書きしてしまえば良い。

具体的には、次のように関数のパラメータを部分適用して上書きする。

>>> from functools import partial
>>> import numpy as np
>>> np.load = partial(np.load, allow_pickle=True)  # monkey patch

この状態なら、エラーにならずにデータセットを読み込むことができる。

>>> from keras.datasets import imdb
Using TensorFlow backend.
>>> imdb.load_data()  # エラーにならずデータが得られる

もし、そのままになっているのが気持ち悪いのであれば、読み込みが終わった後でまた元のパラメータに戻してやれば良い。

>>> np.load = partial(np.load, allow_pickle=False)

まあ次の Keras のリリース版が出るまでの短い間だけ必要な回避策だけど、スクリプト言語ならこんなやり方もありますよということで。

PythonとKerasによるディープラーニング

PythonとKerasによるディープラーニング