CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PySpark で DataFrame にカラムを追加する

Apache Spark の Python 版インターフェースである PySpark で DataFrame オブジェクトにカラムを追加する方法について。 いくつかやり方があるので見ていく。 ちなみに DataFrame や、それを支える内部的な RDD はイミュータブル (不変) なオブジェクトになっている。 そのため、カラムを追加するときは既存のオブジェクトを変更するのではなく、新たなオブジェクトを作ることになる。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G87
$ python -V
Python 3.7.4
$ pip list | grep -i pyspark
pyspark            2.4.3

下準備

まずは PySpark のインタプリタを起動しておく。 今回に関しては分散処理はしないのでローカルモードで構わない。

$ pyspark

サンプルとしてユーザ情報を模したデータを作ってみる。 まずは、次のように RDD (Resilient Distributed Dataset) を作る。

>>> users_rdd = sc.parallelize([
...   ('Alice', 20),
...   ('Bob', 25),
...   ('Carol', 30),
...   ('Daniel', 30),
... ])

上記にスキーマを定義して DataFrame に変換する。

>>> from pyspark.sql.types import StructType
>>> from pyspark.sql.types import StructField
>>> from pyspark.sql.types import StringType
>>> from pyspark.sql.types import IntegerType
>>> df_schema = StructType([
...   StructField('name', StringType(), False),
...   StructField('age', IntegerType(), False),
... ])
>>> users_df = spark.createDataFrame(users_rdd, df_schema)

上手くいけば、次のようになる。

>>> users_df.show(truncate=False)
+------+---+
|name  |age|
+------+---+
|Alice |20 |
|Bob   |25 |
|Carol |30 |
|Daniel|30 |
+------+---+

今回はここに、年齢を倍にした double_age というカラムを追加してみる。

SparkSQL を使ってカラムを追加する

まずは SparkSQL を使ってカラムを追加してみる。

先ほどの DataFrame を SparkSQL から操作できるように登録しておく。

>>> users_df.registerTempTable('users')

あるいは、以下のようにしても良い。

>>> users_df.createOrReplaceTempView('users')

既存のカラムに加えて年齢を倍にしたカラムを追加するように SQL を用意する。

>>> query = """
... SELECT
...   name,
...   age,
...   age * 2 AS double_age
... FROM users
... """

そして SparkSession#sql() で実行する。

>>> new_users_df = spark.sql(query)

得られた DataFrame を見ると、ちゃんとカラムが新たに追加されている。

>>> new_users_df.show(truncate=False)
+------+---+----------+
|name  |age|double_age|
+------+---+----------+
|Alice |20 |40        |
|Bob   |25 |50        |
|Carol |30 |60        |
|Daniel|30 |60        |
+------+---+----------+

DataFrame API を使ってカラムを追加する

DataFrame に生えたメソッドを使ってカラムを追加する方法もある。 見栄えはだいぶ変わるけど、先ほどとやっていることは基本的に変わらない。

>>> new_users_df = users_df.withColumn('double_age', users_df.age * 2)

DataFrame API は、使っていくと「これ SQL 書いてるのと変わらなくね?」ってなってくる。 なので、個人的にはあまり出番がない。

>>> new_users_df.show(truncate=False)
+------+---+----------+
|name  |age|double_age|
+------+---+----------+
|Alice |20 |40        |
|Bob   |25 |50        |
|Carol |30 |60        |
|Daniel|30 |60        |
+------+---+----------+

RDD API を使ってカラムを追加する

最後に、Apache Spark の最もプリミティブなデータ表現である RDD の API を使って追加する方法について。 ただし、このやり方は UDF (User Defined Function) を使うので遅いはず。

まずは、次のように RDD を行単位で処理してカラムを追加する関数を用意する。

>>> def double_age(row):
...     """年齢を倍にしたカラムを追加する関数"""
...     return list(row) + [row['age'] * 2]
...

DataFrame の RDD に適用すると、次のようになる。

>>> new_users_rdd = users_df.rdd.map(double_age)
>>> new_users_rdd.collect()
[['Alice', 20, 40], ['Bob', 25, 50], ['Carol', 30, 60], ['Daniel', 30, 60]]

元の DataFrame に戻したいけど、そのままだとカラム名や型の情報がない。

>>> new_users_rdd.toDF().show(truncate=False)
+------+---+---+
|_1    |_2 |_3 |
+------+---+---+
|Alice |20 |40 |
|Bob   |25 |50 |
|Carol |30 |60 |
|Daniel|30 |60 |
+------+---+---+

そこで、元あった DataFrame のスキーマを改変する形で新たな DataFrame のスキーマを定義する。

>>> new_schema_fields = users_df.schema.fields + [StructField('double_age', IntegerType(), False)]
>>> new_schema = StructType(new_schema_fields)

用意したスキーマを使って DataFrame に変換する。

>>> new_user_df = new_users_rdd.toDF(new_schema)

これでカラム名や型の情報がちゃんとした DataFrame になった。

>>> new_user_df.show(truncate=False)
+------+---+----------+
|name  |age|double_age|
+------+---+----------+
|Alice |20 |40        |
|Bob   |25 |50        |
|Carol |30 |60        |
|Daniel|30 |60        |
+------+---+----------+

補足

ちなみにカラムを削除したいときは、次のように DataFrame API で DataFrame#drop() を呼び出せば良い。

>>> new_user_df.drop('age').show(truncate=False)
+------+----------+
|name  |double_age|
+------+----------+
|Alice |40        |
|Bob   |50        |
|Carol |60        |
|Daniel|60        |
+------+----------+

いじょう。

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

Python: Apache Spark のパーティションは要素が空になるときがある

PySpark とたわむれていて、なんかたまにエラーになるなーと思って原因を調べて分かった話。 最初、パーティションの中身は空になる場合があるとは思っていなかったので、結構おどろいた。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G87
$ pyspark --version
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.4.3
      /_/

Using Scala version 2.11.12, Java HotSpot(TM) 64-Bit Server VM, 1.8.0_121
Branch
Compiled by user  on 2019-05-01T05:08:38Z
Revision
Url
Type --help for more information.
$ python -V
Python 3.7.4
$ java -version
openjdk version "12.0.1" 2019-04-16
OpenJDK Runtime Environment (build 12.0.1+12)
OpenJDK 64-Bit Server VM (build 12.0.1+12, mixed mode, sharing)

下準備

下準備として PySpark をインストールしたら REPL を起動しておく。 今回の検証に関しては分散処理をしないローカルモードでも再現できる。

$ pip install pyspark
$ pyspark

サンプルデータを用意する

例えば SparkSession#range() を使ってサンプルの DataFrame オブジェクトを作る。

>>> df = spark.range(10)

中身は bigint 型の連番が格納されている。

>>> df
DataFrame[id: bigint]
>>> df.show(truncate=False)
+---+
|id |
+---+
|0  |
|1  |
|2  |
|3  |
|4  |
|5  |
|6  |
|7  |
|8  |
|9  |
+---+

今回使う環境ではこの DataFrame は 4 つのパーティションに分けて処理されることが分かる。 パーティションというのは Apache Spark が内部的に RDD (Resilient Distributed Dataset) を処理する際の分割数を指している。 RDD は Apache Spark の最も低レイヤなデータ表現で、DataFrame も最終的には RDD に変換されて処理される。

>>> df.rdd.getNumPartitions()
4

試しにパーティションに入っている要素の数をカウントしてみることにしよう。 次のような関数を用意する。

>>> def size_of_partition(map_of_rows):
...     """パーティションの要素の数を計算する関数"""
...     list_of_rows = list(map_of_rows)
...     size_of_list = len(list_of_rows)
...     return [size_of_list]
...

これを RDD#mapPartitions() 経由で呼び出す。 これでパーティションの中の要素の数をカウントできる。

>>> df.rdd.mapPartitions(size_of_partition).collect()
[2, 3, 2, 3]

各パーティションには 2 ないし 3 の要素が入っているようだ。

意図的にパーティションを空にしてみる

続いては、意図的にパーティションの中身をスカスカにするためにパーティションの分割数を増やしてみよう。 先ほど 4 だった分割数を 20 まで増やしてみる。 パーティションの分割数を増やすには RDD#repartition() が使える。

>>> reparted_rdd = df.rdd.repartition(20)
>>> reparted_rdd.getNumPartitions()
20

この状態でパーティションの要素の数をカウントすると、次のようになった。 要素の数として 0 が登場していることから、パーティションによっては中身が空なことが分かる。

>>> reparted_rdd.mapPartitions(size_of_partition).collect()
[0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0]

意外だったのは、要素数を均すリバランスがされていないこと。

ちなみに先ほど使った RDD#repartition()RDD#coalesce() をオプション shuffle=True で呼び出した場合と等価なようだ。 オプション shuffle=True は要素の順序を保持しないことを表している。

>>> df.rdd.coalesce(20, shuffle=True)  # df.rdd.repartition(20) と等価

ちなみに、要素の順序を保持したままパーティションを拡張することはできない。

>>> unshuffled_reparted_rdd = df.rdd.coalesce(20, shuffle=False)
>>> unshuffled_reparted_rdd.getNumPartitions()
4

オプションの shuffleFalse だとパーティションの分割数が増えていないことが分かる。

ようするにパーティションの分割数を増やしたいときは、要素の順序が必ず入れ替わると考えた方が良い。 先ほどパーティションを増やした RDD も、確認すると順番が入れ替わっている。

>>> reparted_rdd.map(lambda x: x).collect()
[Row(id=0), Row(id=1), Row(id=7), Row(id=8), Row(id=9), Row(id=5), Row(id=6), Row(id=2), Row(id=3), Row(id=4)]

RDD のままでもいいけど、ちょっと分かりにくいかもしれないので DataFrame に直すとこんな感じ。

>>> reparted_rdd.map(lambda x: x).toDF(df.schema).show(truncate=False)
+---+
|id |
+---+
|0  |
|1  |
|7  |
|8  |
|9  |
|5  |
|6  |
|2  |
|3  |
|4  |
+---+

いじょう。

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

Python: LightGBM の cv() 関数から取得した学習済みモデルを SerDe する

今回は、前回のエントリを書くきっかけになったネタについて。

blog.amedama.jp

上記は今回扱う LightGBM の cv() 関数から取得した _CVBooster のインスタンスで起きた問題だった。 このインスタンスは、そのままでは pickle で直列化・非直列化 (SerDe) できずエラーになってしまう。

ちなみに LightGBM の cv() 関数から学習済みモデルを取得する件については以下のエントリに書いてある。

blog.amedama.jp

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V            
Python 3.7.4

下準備

準備として LightGBM と Scikit-learn をインストールしておく。

$ pip install lightgbm scikit-learn

問題が生じるコード

まずは件の問題が生じるコードから。 以下のサンプルコードでは、取得した _CVBooster のインスタンスを pickle で直列化しようとしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle

import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split


class ModelExtractionCallback(object):
    """lightgbm.cv() 関数からモデルを取り出すコールバック"""

    def __init__(self):
        self._model = None

    def __call__(self, env):
        self._model = env.model

    def _assert_called_cb(self):
        if self._model is None:
            raise RuntimeError('callback has not called yet')

    @property
    def boosters_proxy(self):
        self._assert_called_cb()
        return self._model

    @property
    def raw_boosters(self):
        self._assert_called_cb()
        return self._model.boosters

    @property
    def best_iteration(self):
        self._assert_called_cb()
        return self._model.best_iteration


def main():
    # データセットを読み込む
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    # デモ用にデータセットを分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        test_size=0.2,
                                                        random_state=42)


    # LightGBM のデータセット表現にラップする
    lgb_train = lgb.Dataset(X_train, y_train)

    # モデルを学習する
    extraction_cb = ModelExtractionCallback()
    callbacks = [
        extraction_cb,
    ]
    lgb_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    result = lgb.cv(lgb_params,
                    lgb_train,
                    num_boost_round=1000,
                    early_stopping_rounds=10,
                    folds=skf,
                    seed=42,
                    callbacks=callbacks,
                    verbose_eval=10)

    print('cv logloss:', result['binary_logloss-mean'][-1])

    # モデルを取り出す
    proxy = extraction_cb.boosters_proxy

    # モデルを SerDe する
    serialized_model = pickle.dumps(proxy)
    restored_model = pickle.loads(serialized_model)

    # Deserialize したオブジェクト
    print(restored_model)

    # Hold-out しておいたデータを予測させる
    y_pred_probas = restored_model.predict(X_test)
    y_pred_proba = np.array(y_pred_probas).mean(axis=0)
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    # Accuracy について評価する
    acc = accuracy_score(y_test, y_pred)
    print('test accuracy:', acc)


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、次のように例外になる。

$ python lgbcvbserde.py
...
cv logloss: 0.12616399920831986
Traceback (most recent call last):
  File "lgbcvbserde.py", line 99, in <module>
    main()
  File "lgbcvbserde.py", line 84, in main
    restored_model = pickle.loads(serialized_model)
  File "/Users/amedama/.virtualenvs/py37/lib/python3.7/site-packages/lightgbm/engine.py", line 262, in handler_function
    for booster in self.boosters:
TypeError: 'function' object is not iterable

これは、先のエントリに記述した通り以下の条件が重なることで生じている。

  • ラッパーとなる _CVBooster__getattr__() が実装されており __getstate__()__setstate() をトラップする
  • ラップされるオブジェクトに __getstate__()__setstate__() が実装されておりラッパー経由で呼ばれている

問題を修正するコード

問題の修正方法は先のエントリに記述した通り。 ラッパーとして動作するオブジェクト、今回であれば _CVBooster のインスタンスに __getstate__()__setstate__() が必要になる。 ただし、_CVBooster は LightGBM のパッケージなので直接ソースコードを修正することは望ましくない。 そのためモンキーパッチを駆使して解決する。

以下のサンプルコードではクラスに動的にメソッドを追加することで問題を修正している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle

import numpy as np
import lightgbm as lgb
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split


class ModelExtractionCallback(object):
    """lightgbm.cv() 関数からモデルを取り出すコールバック"""

    def __init__(self):
        self._model = None

    def __call__(self, env):
        self._model = env.model

    def _assert_called_cb(self):
        if self._model is None:
            raise RuntimeError('callback has not called yet')

    @property
    def boosters_proxy(self):
        self._assert_called_cb()
        return self._model

    @property
    def raw_boosters(self):
        self._assert_called_cb()
        return self._model.boosters

    @property
    def best_iteration(self):
        self._assert_called_cb()
        return self._model.best_iteration


def main():
    dataset = datasets.load_breast_cancer()
    X, y = dataset.data, dataset.target

    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        test_size=0.2,
                                                        random_state=42)


    lgb_train = lgb.Dataset(X_train, y_train)

    extraction_cb = ModelExtractionCallback()
    callbacks = [
        extraction_cb,
    ]
    lgb_params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    result = lgb.cv(lgb_params,
                    lgb_train,
                    num_boost_round=1000,
                    early_stopping_rounds=10,
                    folds=skf,
                    seed=42,
                    callbacks=callbacks,
                    verbose_eval=10)

    print('cv logloss:', result['binary_logloss-mean'][-1])

    proxy = extraction_cb.boosters_proxy

    # lightgbm.engine._CVBooster のクラスに
    # __getstate__() と __setstate__() を動的に追加する
    def __getstate__(self):
        return self.__dict__.copy()
    setattr(lgb.engine._CVBooster, '__getstate__', __getstate__)

    def __setstate__(self, state):
        self.__dict__.update(state)
    setattr(lgb.engine._CVBooster, '__setstate__', __setstate__)

    serialized_model = pickle.dumps(proxy)
    restored_model = pickle.loads(serialized_model)

    print(restored_model)

    y_pred_probas = restored_model.predict(X_test)
    y_pred_proba = np.array(y_pred_probas).mean(axis=0)
    y_pred = np.where(y_pred_proba > 0.5, 1, 0)
    acc = accuracy_score(y_test, y_pred)
    print('test accuracy:', acc)


if __name__ == '__main__':
    main()

上記を実行してみよう。 SerDe の部分は全く修正していないけど、今度は例外にならず実行できている。

$ python lgbcvbserde.py
...
cv logloss: 0.12616399920831986
<lightgbm.engine._CVBooster object at 0x114704090>
test accuracy: 0.9736842105263158

ちなみに、上記のように _CVBooster ごと直列化しようとするから今回のような問題になるのであって、中身の Booster を格納したリストを直列化するという選択肢もある。

Python: __getattr__() のあるオブジェクトを直列化しようとしてハマった話

今回は特殊メソッドの __getattr__() があるオブジェクトを pickle で直列化・非直列化 (SerDe) しようとしたらハマった話について。

まず、特殊メソッドの __getattr__() をクラスに実装してあると、そのインスタンスは未定義のアトリビュートにアクセスが生じたとき呼び出しがトラップされる。 そして、この __getattr__() を実装したクラスのインスタンスを pickle で SerDe しようとしたとき思わぬ挙動となった。 結論から先に述べると __getattr__() を実装してあると __getstate__()__setstate__() の呼び出しまでトラップされてしまう。 これらのメソッドは SerDe の振る舞いをオーバーライドするための特殊メソッドとなっている。 この問題の対策としては __getattr__() のある SerDe が必要なクラスには __getstate__()__setstate__() を実装しておくことが考えられる。

なお、pickle を使ったオブジェクトの SerDe の概要については、以下のエントリを参照のこと。

blog.amedama.jp

使った環境は次の通り。

$ sw_vers           
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V
Python 3.7.4

特殊メソッド __getattr__() がないときの振る舞いについて

まずは __getattr__() を実装していないクラスを直列化・非直列化 (SerDe) してみる。 以下のサンプルコードでは Example というクラスのインスタンスをバイト列にしてから元のオブジェクトに戻している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Example(object):
    """SerDe されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    # Example クラスをインスタンス化する
    o = Example('World')
    # メソッドを呼び出す
    o.greet()

    # オブジェクトをバイト列にシリアライズする
    # このときオブジェクトに __getstate__() があれば呼ばれる
    # このサンプルコードにはないためデフォルトの振る舞いになる
    s = pickle.dumps(o)

    # バイト列からオブジェクトをデシリアライズする
    # このときオブジェクトに __setstate__() があれば呼ばれる
    # このサンプルコードにはないためデフォルトの振る舞いになる
    restored_o = pickle.loads(s)

    # 復元したオブジェクトのメソッドを呼び出す
    restored_o.greet()


if __name__ == '__main__':
    main()

上記を実行した結果が次の通り。 ちゃんとオブジェクトをバイト列にして、また元のオブジェクトに戻せていることがわかる。

$ python serde1.py     
Hello, World!
Hello, World!

特殊メソッド __getattr__() があるときの振る舞いについて

続いては __getattr__() のあるオブジェクトを SerDe してみる。 ただ、先ほどの Example クラスに直接 __getattr__() を追加するのはユースケースとして考えにくいので、ちょっとアレンジを加えてある。 Example クラスはそのままに、そのラッパーとして動作する Wrapper クラスを用意して、そこに __getattr__() メソッドを実装した。 こういったプロキシのようなクラスは、プロキシする先のオブジェクトの呼び出しを中継するために __getattr__() を使うことが多い。 このような状況で Wrapper クラスのインスタンスを SerDe すると上手くいかない、というのが今回の本題となる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            # ラップするオブジェクトのアトリビュートを取得して呼び出す
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    # Wrapper でオブジェクトをラップする
    w = Wrapper(o)
    # ラッパー経由でメソッドを呼び出す
    w.greet()

    # XXX: __getstate__() が __getattr__() 経由で呼ばれようとする
    s = pickle.dumps(w)

    # XXX: __setstate__() が __getattr__() 経由で呼ばれようとする
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、次のように直列化するタイミングでエラーになる。 見ると __getstate__()Example のオブジェクトにない、という内容のようだ。

$ python serde2.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: __getstate__
Traceback (most recent call last):
  File "serde2.py", line 50, in <module>
    main()
  File "serde2.py", line 41, in main
    s = pickle.dumps(w)
  File "serde2.py", line 17, in _wrapper
    func = getattr(self.wrap_target, item)
AttributeError: 'Example' object has no attribute '__getstate__'

Example クラスに __*state__() を実装すれば解決...しない

では、エラーメッセージに習って Example クラスに __getstate__()__setstate__() を実装すれば解決するだろうか? 試してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))

    def __getstate__(self):
        """__getstate__() を明示的に定義する"""
        return self.__dict__.copy()

    def __setstate__(self, state):
        """__setstate__() を明示的に定義する"""
        self.__dict__.update(state)


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # XXX: __getstate__() が __getattr__() 経由で呼ばれようとする
    s = pickle.dumps(w)

    # XXX: __setstate__() が __getattr__() 経由で呼ばれようとする
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

残念ながら、今度は以下のようなエラーになる。 そもそも SerDe したいのは Wrapper クラスのインスタンスなので Example クラスに実装しても解決できない。

$ python serde3.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: __getstate__
trapped undefined access: __setstate__
Traceback (most recent call last):
  File "serde3.py", line 56, in <module>
    main()
  File "serde3.py", line 50, in main
    restored_w = pickle.loads(s)
  File "serde3.py", line 17, in _wrapper
    func = getattr(self.wrap_target, item)
AttributeError: 'function' object has no attribute '__setstate__'

このときのエラーメッセージがまた分かりにくくて、どうして function オブジェクトでエラーになるんだ、となる。

Wrapper クラスに __*state__() を実装してみる

ということで、今度は Wrapper クラスの方に __getstate__()__setstate__() を実装してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper

    def __getstate__(self):
        """__getstate__() を明示的に定義する"""
        return self.__dict__.copy()

    def __setstate__(self, state):
        """__setstate__() を明示的に定義する"""
        self.__dict__.update(state)


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # __getstate__() が明示的に定義されているため __getattr__() は呼ばれない
    s = pickle.dumps(w)

    # __setstate__() が明示的に定義されているため __getattr__() は呼ばれない
    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

今度は次のようにエラーにならず SerDe できた。 Wrapper クラスに __getstate__()__setstate__() が定義されているため、呼び出しが __getattr__() にトラップされることがない。

$ python serde4.py
trapped undefined access: greet
Hello, World!
trapped undefined access: greet
Hello, World!

サードパーティーのライブラリで問題が発生しているとき

先ほどのようにクラスにメソッドを定義して救えるのは、自分で定義したクラスで問題が発生している場合に限られる。 もし、サードパーティ製のライブラリで同様の問題が生じた場合には、どのような解決策があるだろうか。 幸いなことに Python は既存のクラスにも動的にメソッドを追加できる。

以下のサンプルコードでは SerDe する直前で対象のクラスに __getstate__()__setstate__() を動的に追加している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pickle


class Wrapper(object):
    """別のオブジェクトへのラッパーとして動作するクラス (SerDe される)"""

    def __init__(self, wrap_target):
        self.wrap_target = wrap_target

    def __getattr__(self, item):
        """未定義のアトリビュートへのアクセスをトラップする"""
        def _wrapper(*args, **kwargs):
            print('trapped undefined access:', item)
            func = getattr(self.wrap_target, item)
            return func(*args, **kwargs)
        return _wrapper


class Example(object):
    """Wrapper 経由で呼び出されるクラス"""

    def __init__(self, message):
        self.message = message

    def greet(self):
        print('Hello, {msg}!'.format(msg=self.message))


def main():
    o = Example('World')

    w = Wrapper(o)
    w.greet()

    # オブジェクトに __getstate__() を動的に追加する
    def __getstate__(self):
        return self.__dict__.copy()
    setattr(Wrapper, '__getstate__', __getstate__)

    # オブジェクトに __setstate__() を動的に追加する
    def __setstate__(self, state):
        self.__dict__.update(state)
    setattr(Wrapper, '__setstate__', __setstate__)

    s = pickle.dumps(w)

    restored_w = pickle.loads(s)

    restored_w.greet()


if __name__ == '__main__':
    main()

上記を実行してみよう。 ちゃんと SerDe できていることがわかる。

$ python serde5.py 
trapped undefined access: greet
Hello, World!
trapped undefined access: greet
Hello, World!

macOS で CH34x のシリアルコンソールを使う

Arduino などで使われていることがある CH34x のチップを macOS から使う方法について。

基本的には以下のリポジトリに詳細が載っている。

github.com

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84

インストール

もし過去に古いドライバを手動でインストールしたことがあるときは下部に記載したアンインストールを先に実行する。

Homebrew Cask を使ってドライバをインストールする。

$ brew cask reinstall wch-ch34x-usb-serial-driver 

マシンを再起動するか、あるいは以下のコマンドを実行してカーネルモジュールを読み込む。

$ sudo kextload /Library/Extensions/usbserial.kext

これで tty.wchusbserial から始まるデバイスファイルが見えるようになるはず。

$ ls /dev/tty.wchusbserial*
tty.wchusbserial141120

あとは一般的なシリアルデバイスとして screen なり pyserial などから使えば良い。

$ screen /dev/tty.wchusbserial141120 9600

手動で古いドライバを削除する

過去に古いドライバを手動でインストールしたことがあるときは、以下の手順にもとづいてアンインストールする。

まず、カーネルモジュールをアンロードする。

$ sudo kextunload /Library/Extensions/usbserial.kext
$ sudo kextunload /System/Library/Extensions/usb.kext

そして、カーネルモジュールのファイルを削除する。

$ sudo rm -rf /System/Library/Extensions/usb.kext
$ sudo rm -rf /Library/Extensions/usbserial.kext

いじょう。

Python: インポートするだけで Kivy が日本語を表示できるようになる japanize-kivy を作った

Python の GUI フレームワークである Kivy は、そのままだと日本語が表示できない。 そこで、インポートするだけで日本語を表示できるようにするパッケージ japanize-kivy を作った。

github.com

知っている人はピンと来るはずだけど名前や思想は以下のパッケージをインスパイアしている。

github.com

使った環境は次の通り。 パッケージがサポートする Python は 3.6 以上を想定している。

$ sw_vers  
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G84
$ python -V
Python 3.7.4

インストール

pip からインストールできる。

$ pip install japanize-kivy

試す

Python のインタプリタを起動する。

$ python

japanize_kivy パッケージをインポートする。

>>> import japanize_kivy

あとは日本語を含む Kivy のアプリケーションを用意する。

>>> from kivy.app import App
>>> from kivy.uix.boxlayout import BoxLayout
>>> from kivy.uix.label import Label
>>> class GreetingApp(App):
...     def build(self):
...         main_screen = BoxLayout()
...         label = Label(text='こんにちは、世界')
...         main_screen.add_widget(label)
...         return main_screen
... 
>>> GreetingApp().run()

以下のように日本語が表示できるようになる。

f:id:momijiame:20190730183416p:plain

インポートしないと、次のように日本語が豆腐になる。

f:id:momijiame:20190730183451p:plain

フォントのライセンスに関して

日本語を表示するためのフォントは IPAex ゴシックフォントを使わせてもらっている。 そのため、本パッケージを利用する上ではライセンスへの同意が必要となる。

次のようにするとライセンスが表示されるので、IPA への感謝と共に同意してほしい。

>>> japanize_kivy.show_license()

いじょう。

Python: Kivy と Matplotlib でデータセットの確認ツールを書いてみる

以前、このブログで Kivy で作った GUI に Matplotlib のグラフを埋め込む方法について書いた。

blog.amedama.jp

今回は、これを応用したツール作りをしてみる。 といっても、やっていることは単純で先の例にボタンを付けてインタラクティブにした程度にすぎない。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.14.5
BuildVersion:   18F132
$ python -V          
Python 3.7.4

下準備

下準備として必要なパッケージをインストールしておく。

$ pip install kivy matplotlib scikit-learn

Digit データセットの内容を表示してみる

今回書いてみたサンプルコードが次の通り。 内容としては scikit-learn に同梱されている Digit データセットの内容を表示させてみることにした。 ボタンを使って表示するデータを前後に進めたり戻したりできる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn import datasets
from kivy.app import App
from kivy.uix.boxlayout import BoxLayout
from kivy.lang import Builder
from kivy.garden.matplotlib.backend_kivyagg import FigureCanvasKivyAgg


kv_def = '''
<RootWidget>:
    orientation: 'vertical'

    GraphView:
        id: graph_view
        size_hint_y: 0.8

    BoxLayout:
        size_hint_y: 0.2

        Button:
            id: prev_button
            text: '< Prev'
            on_press: root.ids.graph_view.prev()

        Button:
            id: next_button
            text: 'Next >'
            on_press: root.ids.graph_view.next()

<GraphView>:
'''
Builder.load_string(kv_def)


class GraphView(BoxLayout):
    """Matplotlib のグラフを表示するウィジェット"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # データセットを読み込んでおく
        self.dataset = datasets.load_digits()
        # 表示するデータのインデックス
        self.cursor = 0

        # 描画領域を用意する
        self.fig, self.ax = plt.subplots()

        # 描画を初期化する
        self._update_view()

        # グラフをウィジェットとして追加する
        widget = FigureCanvasKivyAgg(self.fig)
        self.add_widget(widget)

    def _update_view(self):
        """描画を更新するメソッド"""
        # 以前の内容を消去する
        self.ax.clear()
        self.ax.axis('off')

        # データを取得する
        img_data = self.dataset.data[self.cursor]
        label = self.dataset.target[self.cursor]

        # データを描画する
        self.ax.imshow(img_data.reshape(8, 8),
                       cmap=cm.gray_r,
                       interpolation='nearest')
        title_msg = 'index={idx}, label={label}'.format(idx=self.cursor,
                                                        label=label)
        self.ax.set_title(title_msg, color='red')

        # 再描画する
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

    def next(self):
        """次へボタンを押したときのコールバック"""
        if self.cursor < len(self.dataset.data) - 1:
            self.cursor += 1
        self._update_view()

    def prev(self):
        """戻るボタンを押したときのコールバック"""
        if self.cursor > 0:
            self.cursor -= 1
        self._update_view()


class RootWidget(BoxLayout):
    pass


class ViewerApp(App):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title = 'Digit dataset viewer'

    def build(self):
        root_widget = RootWidget()
        return root_widget


def main():
    # アプリケーションを開始する
    app = ViewerApp()
    # ここでスレッドがブロックする
    app.run()


if __name__ == '__main__':
    main()

上記を実行してみる。

$ python digitviewer.py

すると、次のような GUI が表示される。

f:id:momijiame:20190725060813g:plain

応用すればアノテーションに使うツールなんかも作れるだろうね。 いじょう。