CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: 未処理の例外が上がったときの処理をオーバーライドする

今回はだいぶダーティーな手法に関する話。 未処理の例外が上がったときに走るデフォルトの処理をオーバーライドしてしまう方法について。 あらかじめ断っておくと、どうしても必要でない限り、こんなことはやらない方が望ましい。 とはいえ、これによって助けられることもあるかも。

使った環境は次の通り。

$ sw_vers                               
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G1012
$ python -V                    
Python 3.7.5

もくじ

下準備

下準備として、Python のインタプリタを起動しておく。

$ python

デフォルトの挙動をオーバーライドする

try ~ except で捕捉されない例外があると、次のように例外の詳細とトレースバックが出力される。

>>> raise Exception('Oops!')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
Exception: Oops!

このときの挙動は sys.excepthook でフックされているので、このオブジェクトを上書きすることでオーバーライドできる。 例えば、実用性は皆無だけどただメッセージを出力するだけの処理に置き換えてみよう。

>>> import sys
>>> def myhook(type, value, traceback):
...     print('Hello, World!', file=sys.stderr)
... 
>>> sys.excepthook = myhook

例外を上げてみると、次のようにメッセージが表示されるだけになる。

>>> raise Exception('Oops!')
Hello, World!

関数のシグネチャについて

フックの関数のシグネチャについて、もうちょっと詳しく見てみよう。 以下のようにデバッグ用の関数をフックに指定する。

>>> def debughook(type, value, traceback):
...     print(type, value, traceback, file=sys.stderr)
... 
>>> sys.excepthook = debughook

試しに例外を上げてみると、次のようになった。 例外クラスの型、引数、トレースバックのオブジェクトが渡されるようだ。

>>> raise Exception('Oops!')
<class 'Exception'> Oops! <traceback object at 0x1024a6910>

スレッドを使うときの問題点について

なお、このフックはスレッドを使っているときに有効にならないという問題がある。

実際に試してみよう。 先ほどのデバッグ用のフックが有効な状態で、別のスレッドを起動する。 そして、スレッドの中で例外を上げるように細工してやろう。 すると、次のように普通のトレースバックが表示されてしまう。

>>> import threading
>>> def f():
...     raise Exception('Oops!')
... 
>>> threading.Thread(target=f).start()
>>> Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "<stdin>", line 2, in f
Exception: Oops!

上記のように、スレッドを使った場合にはフックが有効にならない。 この問題は Python 3.8 で追加された API によって解決できる。

$ python -V
Python 3.8.0

Python 3.8 では threading モジュールに excepthook というオブジェクトが追加されている。 このオブジェクトを上書きすることで処理をオーバーライドできるようになった。

>>> def threading_hook(args):
...     print('Hello, World!', args)
... 
>>> threading.excepthook = threading_hook
>>> 
>>> threading.Thread(target=f).start()
Hello, World! _thread.ExceptHookArgs(exc_type=<class 'Exception'>, exc_value=Exception('Oops!'), exc_traceback=<traceback object at 0x1033f4900>, thread=<Thread(Thread-2, started 123145518649344)>)

デフォルトの挙動に戻す

デフォルトのフックへの参照は sys.__excepthook__ にあるため、これを使えば挙動を元に戻せる。 なお、sys.__excepthook__ の方は絶対に変更しないこと。

>>> sys.excepthook = sys.__excepthook__
>>> raise Exception('Oops!')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
Exception: Oops!

試してないけど Jupyter とかでエラーになったときチャットに通知を送る、なんて用途に使えるかもね。

Python: 関数合成できる API を作ってみる

今回は普通の Python では満足できなくなってしまった人向けの話題。 dfplypipe といった一部のパッケージで採用されているパイプ処理や関数合成できる API を作る一つのやり方について。

使った環境は次の通り。

$ sw_vers                  
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G1012
$ python -V
Python 3.7.5

もくじ

カッコ以外で評価されるオブジェクトを作る

通常の Python では、関数やメソッドはカッコを使ってオブジェクトを評価する。 しかし、クラスを定義するとき特殊メソッドを使って演算子オーバーライドすることで、その枠に収まらないオブジェクトが作れる。

例えば、以下のように関数をラップするクラスを定義する。 特殊メソッドの __rrshift__() は、自身の「左辺」にある右ビットシフト演算子が評価されるときに呼び出される。 なお、別に右ビットシフト演算子を使う必然性はないので、別の演算子をオーバーライドしても構わない。

>>> class Pipe:
...     """関数をラップするクラス"""
...     def __init__(self, f):
...         # インスタンス化するとき関数を受け取る
...         self.f = f
...     def __rrshift__(self, other):
...         # 自身の左辺にある右ビットシフト演算子を評価するとき関数を実行する
...         return self.f(other)
... 

これを使って、例えば値を二乗するオブジェクトを作ってみよう。

>>> pow = Pipe(lambda x: x ** 2)

このオブジェクトに右ビットシフト演算子を使って値を渡すと、その内容が二乗される。

>>> 10 >> pow
100

他の関数も定義してつなげるとメソッドチェーンっぽいことができる。

>>> double = Pipe(lambda x: x * 2)
>>> 10 >> pow >> double
200

ただ、このままだと関数と関数だけをつないだときに例外になってしまう。 この場合は、左辺にあるオブジェクトの右辺にある右ビットシフト演算子が先に評価されているため。

>>> pow >> double
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unsupported operand type(s) for >>: 'Pipe' and 'Pipe'

関数合成できるオブジェクトを作る

そこで、先ほどのクラスに手を加える。 以下のように、自身の右辺に渡された処理を自身のリストにキューイングしておけるようにする。

>>> import copy
>>> class Pipe:
...     """関数合成に対応したクラス"""
...     def __init__(self, f):
...         self.f = f
...         # 適用したい一連の関数を記録しておくリスト
...         self.pipes = []
...     def __rshift__(self, other):
...         """自身の右辺にある右ビットシフト演算子を評価したときに呼ばれる特殊メソッド"""
...         # 自身をコピーしておく
...         copied_self = copy.deepcopy(self)
...         # コピーした内容のリストに適用したい処理をキューイングする
...         copied_self.pipes.append(other)
...         # コピーした自身を返す
...         return copied_self
...     def __rrshift__(self, other):
...         """自身の左辺にある右ビットシフト演算子を評価したときに呼ばれる特殊メソッド"""
...         # まずは自身の関数を適用する
...         result = self.f(other)
...         # キューイングされていた関数を順番に適用していく
...         for pipe in self.pipes:
...             result = pipe.__rrshift__(result)
...         # 最終的な結果を返す
...         return result
... 

こうすると、関数同士をつないだ場合にもオブジェクトが返るようになる。

>>> pow = Pipe(lambda x: x ** 2)
>>> double = Pipe(lambda x: x * 2)
>>> pow >> double
<__main__.Pipe object at 0x102090950>

上記を変数に保存しておいて値を適用すると、ちゃんと本来の意図通りにチェーンされた結果が返ってくる。

>>> pow_double = pow >> double
>>> 10 >> pow_double
200

自身をディープコピーする理由について

ちなみに、先ほど右辺にある右ビットシフト演算子が評価されるときに自身のオブジェクトをディープコピーしていた。 もし、ディープコピーしないとどうなるだろうか。 実際にやってみよう。

>>> class Pipe:
...     def __init__(self, f):
...         self.f = f
...         self.pipes = []
...     def __rshift__(self, other):
...         # 自身をコピーせずに関数をキューイングする場合
...         self.pipes.append(other)
...         return self
...     def __rrshift__(self, other):
...         result = self.f(other)
...         for pipe in self.pipes:
...             result = pipe.__rrshift__(result)
...         return result
... 

コピーしない場合でも、ちゃんと Pipe オブジェクトは返ってくる。

>>> pow = Pipe(lambda x: x ** 2)
>>> double = Pipe(lambda x: x * 2)
>>> pow >> double
<__main__.Pipe object at 0x102090a50>

しかし、右辺の右ビットシフト演算子が評価された時点で適用する処理のリストにキューイングされてしまう。

>>> pow.pipes
[<__main__.Pipe object at 0x102090b50>]

つまり、元のオブジェクトを変更してしまう。 本来なら二乗してほしいだけのオブジェクトで二倍も同時にされてしまうことになる。

>>> 10 >> pow
200

デコレータとして使う

ちなみに、ここまで作ってきたクラスはクラスデコレータとして使うこともできる。

ようするに、次のように関数をクラスでデコレートできる。

>>> @Pipe
... def triple(x):
...     return x * 3
... 
>>> @Pipe
... def half(x):
...     return x // 2
... 
>>> 10 >> triple >> half
15

デコレータの詳細については以下を参照のこと。

blog.amedama.jp

引数を受け取れるようにする

次に、適用される関数に引数を渡したくなる。 この場合、__call__() メソッドを実装して関数に引数を渡す形でオブジェクトを作り直すようにすると良い。

>>> class Pipe:
...     def __init__(self, f):
...         self.f = f
...         self.pipes = []
...     def __rshift__(self, other):
...         copied_self = copy.deepcopy(self)
...         copied_self.pipes.append(other)
...         return copied_self
...     def __rrshift__(self, other):
...         result = self.f(other)
...         for pipe in self.pipes:
...             result = pipe.__rrshift__(result)
...         return result
...     def __call__(self, *args, **kwargs):
...         """オブジェクトが実行されたときに呼ばれる特殊メソッド"""
...         # 実行されたときの引数を関数に渡すようにしたオブジェクトを返す
...         return Pipe(lambda x: self.f(x, *args, **kwargs))
... 

例えば、掛ける数を引数にした掛け算を実装してみよう。

>>> @Pipe
... def multiply(x, n):
...     return x * n
... 

これは、次のように使うことができる。

>>> 10 >> multiply(2) >> multiply(3)
60

おもしろいね。

参考プロジェクト

github.com

github.com

Python: dfply を使ってみる

R には、データフレームを関数型プログラミングっぽく操作できるようになる dplyr というパッケージがある。 今回紹介する dfply は、その API を Python に移植したもの。 実用性云々は別としても、なかなか面白い作りで参考になった。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G1012
$ python -V        
Python 3.7.5

もくじ

下準備

まずは下準備として dfply をインストールしておく。

$ pip install dfply

Python のインタプリタを起動する。

$ python

ちょっとお行儀が悪いけど dfply 以下をワイルドカードインポートしておく。

>>> from dfply import *

基本的な使い方

例えば dfply には diamonds データセットがサンプルとして組み込まれている。 これは、ダイヤモンドの大きさや色などの情報と付けられた値段が含まれる。

>>> diamonds.head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

上記では DataFrame#head() を使って先頭を取り出した。 dfply では、同じことを右ビットシフト用の演算子 (>>) と head() 関数を使って次のように表現する。

>>> diamonds >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

これだけでピンとくる人もいるだろうけど、上記はようするにメソッドチェーンと同じこと。 例えば head()tail() を組み合わせれば、途中の要素を取り出すことができる。

>>> diamonds >> head(4) >> tail(2)
   carat      cut color clarity  depth  table  price     x     y     z
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63

同じことを DataFrame 標準の API でやるとしたら、こうかな?

>>> diamonds.iloc[:4].iloc[2:]
   carat      cut color clarity  depth  table  price     x     y     z
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63

ちなみに head()tail() を組み合わせなくても row_slice() を使えば一発でいける。

>>> diamonds >> row_slice([2, 4])
   carat   cut color clarity  depth  table  price     x     y     z
2   0.23  Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
4   0.31  Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

列を選択する (select / drop)

ここまでは行を取り出していたけど、select() を使えば列を取り出せる。

>>> diamonds >> select(['carat', 'cut', 'price']) >> head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

同じことを DataFrame 標準の API でやろうとしたら、こうかな。

>>> diamonds[['carat', 'cut', 'price']].head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

select() とは反対に、それ以外を取り出したいときは drop() を使う。

>>> diamonds >> drop(['carat', 'cut', 'price']) >> head()
  color clarity  depth  table     x     y     z
0     E     SI2   61.5   55.0  3.95  3.98  2.43
1     E     SI1   59.8   61.0  3.89  3.84  2.31
2     E     VS1   56.9   65.0  4.05  4.07  2.31
3     I     VS2   62.4   58.0  4.20  4.23  2.63
4     J     SI2   63.3   58.0  4.34  4.35  2.75

また、dfply の特徴的な点として Intention というオブジェクトがある。 一般的には、最初から用意されている X というオブジェクトを使えば良い。

>>> X
<dfply.base.Intention object at 0x10cf4c6d0>

例えば、さっきの select() と同じことを Intention を使って次のように書ける。

>>> diamonds >> select(X.carat, X.cut, X.price) >> head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

これだけだと何が嬉しいのって感じだけど、Intention を使えば否定条件が書けたりもする。

>>> diamonds >> select(~X.carat, ~X.cut, ~X.price) >> head()
  color clarity  depth  table     x     y     z
0     E     SI2   61.5   55.0  3.95  3.98  2.43
1     E     SI1   59.8   61.0  3.89  3.84  2.31
2     E     VS1   56.9   65.0  4.05  4.07  2.31
3     I     VS2   62.4   58.0  4.20  4.23  2.63
4     J     SI2   63.3   58.0  4.34  4.35  2.75

また、select()drop() には、カラムの名前を使った絞り込みをする関数を渡せる。 例えば c から始まるカラムがほしければ starts_with() を使って次のように書ける。

>>> diamonds >> select(~starts_with('c')) >> head()
   depth  table  price     x     y     z
0   61.5   55.0    326  3.95  3.98  2.43
1   59.8   61.0    326  3.89  3.84  2.31
2   56.9   65.0    327  4.05  4.07  2.31
3   62.4   58.0    334  4.20  4.23  2.63
4   63.3   58.0    335  4.34  4.35  2.75

もし、DataFrame 標準の API で書くとしたら、こんな感じかな?

>>> diamonds[[col for col in diamonds.columns if col.startswith('c')]].head()
   carat      cut color clarity
0   0.23    Ideal     E     SI2
1   0.21  Premium     E     SI1
2   0.23     Good     E     VS1
3   0.29  Premium     I     VS2
4   0.31     Good     J     SI2

この他にも、色々とある。

>>> diamonds >> select(ends_with('e')) >> head()
   table  price
0   55.0    326
1   61.0    326
2   65.0    327
3   58.0    334
4   58.0    335
>>> diamonds >> select(contains('a')) >> head()
   carat clarity  table
0   0.23     SI2   55.0
1   0.21     SI1   61.0
2   0.23     VS1   65.0
3   0.29     VS2   58.0
4   0.31     SI2   58.0
>>> diamonds >> select(columns_between('color', 'depth')) >> head()
  color clarity  depth
0     E     SI2   61.5
1     E     SI1   59.8
2     E     VS1   56.9
3     I     VS2   62.4
4     J     SI2   63.3

ちなみに、これらを混ぜて select() に放り込むこともできる。

>>> diamonds >> select('cut', [X.depth, X.table], columns_from('y')) >> head()
       cut  depth  table     y     z
0    Ideal   61.5   55.0  3.98  2.43
1  Premium   59.8   61.0  3.84  2.31
2     Good   56.9   65.0  4.07  2.31
3  Premium   62.4   58.0  4.23  2.63
4     Good   63.3   58.0  4.35  2.75

順序を並び替える (arrange)

特定のカラムを基準にして順序を並び替えるときは arrange() 関数を使う。

>>> diamonds >> arrange(X.carat) >> head()
       carat      cut color clarity  depth  table  price     x     y     z
31593    0.2  Premium     E     VS2   61.1   59.0    367  3.81  3.78  2.32
31597    0.2    Ideal     D     VS2   61.5   57.0    367  3.81  3.77  2.33
31596    0.2  Premium     F     VS2   62.6   59.0    367  3.73  3.71  2.33
31595    0.2    Ideal     E     VS2   59.7   55.0    367  3.86  3.84  2.30
31594    0.2  Premium     E     VS2   59.7   62.0    367  3.84  3.80  2.28

デフォルトは昇順なので、降順にしたいときは ascending オプションに False を指定する。

>>> diamonds >> arrange(X.carat, ascending=False) >> head()
       carat      cut color clarity  depth  table  price      x      y     z
27415   5.01     Fair     J      I1   65.5   59.0  18018  10.74  10.54  6.98
27630   4.50     Fair     J      I1   65.8   58.0  18531  10.23  10.16  6.72
27130   4.13     Fair     H      I1   64.8   61.0  17329  10.00   9.85  6.43
25999   4.01  Premium     J      I1   62.5   62.0  15223  10.02   9.94  6.24
25998   4.01  Premium     I      I1   61.0   61.0  15223  10.14  10.10  6.17

行でサンプリングする (sampling)

行をサンプリングするときは sampling() 関数を使う。 割合で指定したいときは frac オプションを指定する。

>>> diamonds >> sample(frac=0.01)
       carat        cut color clarity  depth  table  price     x     y     z
51269   0.72  Very Good     I     VS2   61.6   59.0   2359  5.71  5.75  3.53
49745   0.70       Good     G     SI1   61.8   62.0   2155  5.68  5.72  3.52
23252   1.40  Very Good     G     VS1   62.6   58.0  11262  7.03  7.07  4.41
36940   0.23  Very Good     D    VVS1   63.3   57.0    478  3.90  3.93  2.48
24644   1.79    Premium     I     VS1   62.6   59.0  12985  7.65  7.72  4.81
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
53913   0.80       Good     G     VS2   64.2   58.0   2753  5.84  5.81  3.74
20653   1.01       Good     D    VVS2   63.5   57.0   8943  6.32  6.35  4.02
17544   1.01       Good     F    VVS2   63.6   60.0   7059  6.36  6.31  4.03
45636   0.25  Very Good     G    VVS1   60.6   55.0    525  4.12  4.14  2.50
30774   0.35      Ideal     G     VS1   61.3   54.0    741  4.58  4.63  2.83

[539 rows x 10 columns]

具体的な行数は n オプションを指定すれば良い。

>>> diamonds >> sample(n=100)
       carat        cut color clarity  depth  table  price     x     y     
46135   0.41      Ideal     E    VVS1   61.1   56.0   1745  4.80  4.82  2.94
35405   0.32      Ideal     E     VS2   61.9   56.0    900  4.40  4.36  2.71
30041   0.33      Ideal     I      IF   61.5   56.0    719  4.43  4.47  2.74
313     0.61      Ideal     G      IF   62.3   56.0   2800  5.43  5.45  3.39
24374   0.34      Ideal     E     SI1   61.0   55.0    637  4.54  4.56  2.77
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
27244   2.20    Premium     H     SI2   62.7   58.0  17634  8.33  8.27  5.20
17487   1.05    Premium     F     VS2   62.6   58.0   7025  6.47  6.50  4.06
52615   0.77    Premium     H     VS2   59.4   60.0   2546  6.00  5.96  3.55
12670   1.07  Very Good     E     SI2   61.7   58.0   5304  6.54  6.56  4.04
16466   1.25      Ideal     G     SI1   62.5   54.0   6580  6.88  6.85  4.29

[100 rows x 10 columns]

内容が重複した行を取り除く (distinct)

重複した要素を取り除くときは dictinct() 関数を使う。

>>> diamonds >> distinct('color')
    carat        cut color clarity  depth  table  price     x     y     z
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
7    0.26  Very Good     H     SI1   61.9   55.0    337  4.07  4.11  2.53
12   0.22    Premium     F     SI1   60.4   61.0    342  3.88  3.84  2.33
25   0.23  Very Good     G    VVS2   60.4   58.0    354  3.97  4.01  2.41
28   0.23  Very Good     D     VS2   60.5   61.0    357  3.96  3.97  2.40

特定の条件に一致した行を取り出す (mask)

特定の条件に一致した行を取り出したいときは mask() 関数を使う。 Intention と組み合わせると、なかなか直感的に書ける。 例えば cut'Ideal' なものだけ取り出したいなら、こう。

>>> diamonds >> mask(X.cut == 'Ideal') >> head()
    carat    cut color clarity  depth  table  price     x     y     z
0    0.23  Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
11   0.23  Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46
13   0.31  Ideal     J     SI2   62.2   54.0    344  4.35  4.37  2.71
16   0.30  Ideal     I     SI2   62.0   54.0    348  4.31  4.34  2.68
39   0.33  Ideal     I     SI2   61.8   55.0    403  4.49  4.51  2.78

引数を増やすことでアンド条件にできる。 これは cut'Ideal' で、かつ carat1.0 以上のものを取り出す場合。

>>> diamonds >> mask(X.cut == 'Ideal', X.carat > 1.0) >> head()
     carat    cut color clarity  depth  table  price     x     y     z
653   1.01  Ideal     I      I1   61.5   57.0   2844  6.45  6.46  3.97
715   1.02  Ideal     H     SI2   61.6   55.0   2856  6.49  6.43  3.98
865   1.02  Ideal     I      I1   61.7   56.0   2872  6.44  6.49  3.99
918   1.02  Ideal     J     SI2   60.3   54.0   2879  6.53  6.50  3.93
992   1.01  Ideal     I      I1   61.5   57.0   2896  6.46  6.45  3.97

mask() 関数には filter_by() という名前のエイリアスもある。

>>> diamonds >> filter_by(X.cut == 'Ideal', X.carat > 1.0) >> head()
     carat    cut color clarity  depth  table  price     x     y     z
653   1.01  Ideal     I      I1   61.5   57.0   2844  6.45  6.46  3.97
715   1.02  Ideal     H     SI2   61.6   55.0   2856  6.49  6.43  3.98
865   1.02  Ideal     I      I1   61.7   56.0   2872  6.44  6.49  3.99
918   1.02  Ideal     J     SI2   60.3   54.0   2879  6.53  6.50  3.93
992   1.01  Ideal     I      I1   61.5   57.0   2896  6.46  6.45  3.97

複数のカラムを組み合わせたカラムを作る (mutate)

複数のカラムを組み合わせて新しい特徴量などのカラムを作るときは mutate() 関数が使える。

例えば xy のカラムを足した新たなカラムをデータフレームに追加したいときは、次のようにする。 引数の名前は追加するカラムの名前に使われる。

>>> diamonds >> mutate(x_plus_y=X.x+X.y) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  x_plus_y
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      7.93
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      7.73
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      8.12
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      8.43
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      8.69

もちろん、3 つ以上のカラムの組み合わせでも構わない。

>>> diamonds >> mutate(plus_xyz=X.x+X.y+X.z) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  plus_xyz
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43     10.36
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31     10.04
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31     10.43
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63     11.06
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75     11.44

また、一度に複数のカラムを作ることもできる。

>>> diamonds >> mutate(x_plus_y=X.x+X.y, x_minus_y=X.x-X.y) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  x_plus_y  x_minus_y
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      7.93      -0.03
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      7.73       0.05
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      8.12      -0.02
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      8.43      -0.03
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      8.69      -0.01

もし、作ったカラムだけがほしいときは transmute() 関数を使えば良い。

>>> diamonds >> transmute(x_plus_y=X.x+X.y, x_minus_y=X.x-X.y) >> head()
   x_plus_y  x_minus_y
0      7.93      -0.03
1      7.73       0.05
2      8.12      -0.02
3      8.43      -0.03
4      8.69      -0.01

カラムの名前を変更する (rename)

もし、カラムの名前を変えたくなったときは rename() 関数を使えば良い。 カラムの順番も入れ替わることがない。

>>> diamonds >> rename(new_x=X.x, new_y=X.y) >> head()
   carat      cut color clarity  depth  table  price  new_x  new_y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326   3.95   3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326   3.89   3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327   4.05   4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334   4.20   4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335   4.34   4.35  2.75

特定のグループ毎に集計する (group_by)

特定のグループ毎に何らかの集計をしたいときは group_by() 関数を使う。 ただし、一般的にイメージする SQL などのそれとは少し異なる。

例えば、ただ group_by() するだけではデータフレームに何も起きない。

>>> diamonds >> group_by(X.cut)
       carat        cut color clarity  depth  table  price     x     y     z
0       0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1       0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2       0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3       0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4       0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
53935   0.72      Ideal     D     SI1   60.8   57.0   2757  5.75  5.76  3.50
53936   0.72       Good     D     SI1   63.1   55.0   2757  5.69  5.75  3.61
53937   0.70  Very Good     D     SI1   62.8   60.0   2757  5.66  5.68  3.56
53938   0.86    Premium     H     SI2   61.0   58.0   2757  6.15  6.12  3.74
53939   0.75      Ideal     D     SI2   62.2   55.0   2757  5.83  5.87  3.64

[53940 rows x 10 columns]

では、どのように使うかというと、別の何らかの処理と組み合わせて使うことで真価を発揮する。 例えば、cut カラムごとに price の平均値を計算したい、という場合には次のようにする。

>>> diamonds >> group_by(X.cut) >> mutate(price_mean=mean(X.price)) >> head(3)
    carat        cut color clarity  depth  table  price     x     y     z   price_mean
8    0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49  4358.757764
91   0.86       Fair     E     SI2   55.1   69.0   2757  6.45  6.33  3.52  4358.757764
97   0.96       Fair     F     SI2   66.3   62.0   2759  6.27  5.95  4.07  4358.757764
2    0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31  3928.864452
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75  3928.864452
10   0.30       Good     J     SI1   64.0   55.0    339  4.25  4.28  2.73  3928.864452
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43  3457.541970
11   0.23      Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46  3457.541970
13   0.31      Ideal     J     SI2   62.2   54.0    344  4.35  4.37  2.71  3457.541970
1    0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31  4584.257704
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63  4584.257704
12   0.22    Premium     F     SI1   60.4   61.0    342  3.88  3.84  2.33  4584.257704
5    0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48  3981.759891
6    0.24  Very Good     I    VVS1   62.3   57.0    336  3.95  3.98  2.47  3981.759891
7    0.26  Very Good     H     SI1   61.9   55.0    337  4.07  4.11  2.53  3981.759891

上記を見てわかる通り、集計した処理が全ての行に反映されている。 いうなれば、これは SQL の WINDOW 関数に PartitionBy を指定した処理に相当している。 その証左として、例えば lead() 関数や lag() 関数が使える。

>>> diamonds >> group_by(X.cut) >> transmute(X.price, next=lead(X.price), prev=lag(X.price)) >> head(3)
          cut    next    prev  price
8        Fair  2757.0     NaN    337
91       Fair  2759.0   337.0   2757
97       Fair  2762.0  2757.0   2759
2        Good   335.0     NaN    327
4        Good   339.0   327.0    335
10       Good   351.0   335.0    339
0       Ideal   340.0     NaN    326
11      Ideal   344.0   326.0    340
13      Ideal   348.0   340.0    344
1     Premium   334.0     NaN    326
3     Premium   342.0   326.0    334
12    Premium   345.0   334.0    342
5   Very Good   336.0     NaN    336
6   Very Good   337.0   336.0    336
7   Very Good   338.0   336.0    337

ただし、ここで一つ気になることがある。 もし、途中からグループ化しない集計をしたいときは、どうしたら良いのだろうか。

例えば、次のように cut ごとに先頭 2 つの要素を取り出すとする。

>>> diamonds >> group_by(X.cut) >> head(2)
    carat        cut color clarity  depth  table  price     x     y     z
8    0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49
91   0.86       Fair     E     SI2   55.1   69.0   2757  6.45  6.33  3.52
2    0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
11   0.23      Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46
1    0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
5    0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48
6    0.24  Very Good     I    VVS1   62.3   57.0    336  3.95  3.98  2.47

もし、ここからさらに全体における先頭 1 つの要素を取り出したいときは、どうしたら良いだろう。あ ただ head() するだけだと、グループごとに先頭 1 要素が取り出されてしまう。

>>> diamonds >> group_by(X.cut) >> head(2) >> head(1)
   carat        cut color clarity  depth  table  price     x     y     z
8   0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49
2   0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
0   0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
5   0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48

この問題を解決するには ungroup() 関数を用いる。

>>> diamonds >> group_by(X.cut) >> head(2) >> ungroup() >> head(1)
   carat   cut color clarity  depth  table  price     x     y     z
8   0.22  Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49

色々な WINDOW 関数

いくつか dfply で使える WINDOW 関数を紹介しておく。

カラムの値が特定の範囲に収まるか真偽値を返すのが between() 関数。

>>> diamonds >> transmute(X.price, price_between=between(X.price, 330, 340)) >> head()
   price_between  price
0          False    326
1          False    326
2          False    327
3           True    334
4           True    335

同じ値は同じランクとして、間を空けずにランク付けするのが dense_rank() 関数。

>>> diamonds >> transmute(X.price, drank=dense_rank(X.price)) >> head()
   drank  price
0    1.0    326
1    1.0    326
2    2.0    327
3    3.0    334
4    4.0    335

同じ値は同じランクとして、間を空けてランク付けするのが min_rank() 関数。

>>> diamonds >> transmute(X.price, mrank=min_rank(X.price)) >> head()
   mrank  price
0    1.0    326
1    1.0    326
2    3.0    327
3    4.0    334
4    5.0    335

単純な行番号が row_number() 関数。

>>> diamonds >> transmute(X.price, rownum=row_number(X.price)) >> head()
   rownum  price
0     1.0    326
1     2.0    326
2     3.0    327
3     4.0    334
4     5.0    335

標準化したランク付けをするのが percent_rank() 関数。

>>> diamonds >> transmute(X.price, prank=percent_rank(X.price)) >> head()
      prank  price
0  0.000000    326
1  0.000000    326
2  0.000037    327
3  0.000056    334
4  0.000074    335

積算値を計算するのが cunsum() 関数。

>>> diamonds >> transmute(X.price, cumprice=cumsum(X.price)) >> head()
   cumprice  price
0       326    326
1       652    326
2       979    327
3      1313    334
4      1648    335

積算の平均値を計算するのが cummean() 関数。

>>> diamonds >> transmute(X.price, cummean=cummean(X.price)) >> head()
      cummean  price
0  326.000000    326
1  326.000000    326
2  326.333333    327
3  328.250000    334
4  329.600000    335

集計値を計算する (summarize)

一般的な group by と聞いて思い浮かべる処理は、むしろこちらの summarize() 関数の方だろう。

例えば、表全体の要約統計量として平均と標準偏差を計算してみよう。

>>> diamonds >> summarize(price_mean=X.price.mean(), price_std=X.price.std())
    price_mean    price_std
0  3932.799722  3989.439738

上記は Intention に生えているメソッドを使って計算したけど、以下のように関数を使うこともできる。

>>> diamonds >> summarize(price_mean=mean(X.price), price_std=sd(X.price))
    price_mean    price_std
0  3932.799722  3989.439738

また、group_by() と組み合わせて使うこともできる。 例えば cut ごとに統計量を計算してみよう。

>>> diamonds >> group_by(X.cut) >> summarize(price_mean=mean(X.price), price_std=sd(X.price))
         cut   price_mean    price_std
0       Fair  4358.757764  3560.386612
1       Good  3928.864452  3681.589584
2      Ideal  3457.541970  3808.401172
3    Premium  4584.257704  4349.204961
4  Very Good  3981.759891  3935.862161

集計に使う関数は、組み込み以外のものを使うこともできる。 例えば numpy の関数を使ってみることに使用。

>>> import numpy as np
>>> diamonds >> group_by(X.cut) >> summarize(price_mean=np.mean(X.price), price_std=np.std(X.price))
         cut   price_mean    price_std
0       Fair  4358.757764  3559.280730
1       Good  3928.864452  3681.214352
2      Ideal  3457.541970  3808.312813
3    Premium  4584.257704  4349.047276
4  Very Good  3981.759891  3935.699276

平均や標準偏差の他にも、サイズや重複を除いたサイズを計算する関数なんかもある。

>>> diamonds >> group_by(X.cut) >> summarize(size=n(X.price), distinct_size=n_distinct(X.price))
         cut   size  distinct_size
0       Fair   1610           1267
1       Good   4906           3086
2      Ideal  21551           7281
3    Premium  13791           6014
4  Very Good  12082           5840

一度に計算したいときは、こんな感じでやればいいかな?

>>> stats = {
...     'iqr': IQR(X.price),
...     'max': colmax(X.price),
...     'q75': X.price.quantile(0.75),
...     'mean': mean(X.price),
...     'median': median(X.price),
...     'q25': X.price.quantile(0.25),
...     'min': colmin(X.price),
... }
>>> diamonds >> group_by(X.cut) >> summarize(**stats)
         cut      iqr    max      q75         mean  median      q25  min
0       Fair  3155.25  18574  5205.50  4358.757764  3282.0  2050.25  337
1       Good  3883.00  18788  5028.00  3928.864452  3050.5  1145.00  327
2      Ideal  3800.50  18806  4678.50  3457.541970  1810.0   878.00  326
3    Premium  5250.00  18823  6296.00  4584.257704  3185.0  1046.00  326
4  Very Good  4460.75  18818  5372.75  3981.759891  2648.0   912.00  336

各カラムに複数の集計する (summarize_each)

カラムと集計内容が複数あるときは summarize_each() 関数を使うと良い。

以下では、例として pricecarat に対して平均と標準偏差を計算している。

>>> diamonds >> summarize_each([np.mean, np.std], X.price, X.carat)
    price_mean    price_std  carat_mean  carat_std
0  3932.799722  3989.402758     0.79794   0.474007

もちろん、この処理も group_by と組み合わせることができる。

>>> diamonds >> group_by(X.cut) >> summarize_each([np.mean, np.std], X.price, X.carat)
         cut   price_mean    price_std  carat_mean  carat_std
0       Fair  4358.757764  3559.280730    1.046137   0.516244
1       Good  3928.864452  3681.214352    0.849185   0.454008
2      Ideal  3457.541970  3808.312813    0.702837   0.432866
3    Premium  4584.257704  4349.047276    0.891955   0.515243
4  Very Good  3981.759891  3935.699276    0.806381   0.459416

複数のデータフレームをカラム方向に結合する (join)

続いては複数のデータフレームを結合する処理について。

例に使うデータフレームを用意する。 微妙に行や列の内容がかぶっている。

>>> data = {
...     'name': ['alice', 'bob', 'carrol'],
...     'age': [20, 30, 40],
... }
>>> a = pd.DataFrame(data)
>>> 
>>> data = {
...     'name': ['alice', 'bob', 'daniel'],
...     'is_male': [False, True, True],
... }
>>> b = pd.DataFrame(data)

内部結合には inner_join() 関数を使う。

>>> a >> inner_join(b, by='name')
    name  age  is_male
0  alice   20    False
1    bob   30     True

外部結合には outer_join() を使う。

>>> a >> outer_join(b, by='name')
     name   age is_male
0   alice  20.0   False
1     bob  30.0    True
2  carrol  40.0     NaN
3  daniel   NaN    True

左外部結合には left_join() を使う。

>>> a >> left_join(b, by='name')
     name  age is_male
0   alice   20   False
1     bob   30    True
2  carrol   40     NaN

右外部結合には right_join() を使う。

>>> a >> right_join(b, by='name')
     name   age  is_male
0   alice  20.0    False
1     bob  30.0     True
2  daniel   NaN     True

複数のデータフレームを行方向に結合する (union / intersect / set_diff / bind_rows)

ここからは縦 (行) 方向の結合を扱う。 データフレームを追加しておく。

>>> data = {
...     'name': ['carrol', 'daniel'],
...     'age': [40, 50],
... }
>>> c = pd.DataFrame(data)

重複したものは除外して行方向にくっつけたいときは union() を使う。

>>> a >> union(c)
     name  age
0   alice   20
1     bob   30
2  carrol   40
1  daniel   50

両方のデータフレームにあるものだけくっつけたいなら intersect() を使う。

>>> a >> intersect(c)
     name  age
0  carrol   40

両方に存在しないものだけほしいときは set_diff() を使う。

>>> a >> set_diff(c)
    name  age
0  alice   20
1    bob   30

行と列を含む結合 (bind_rows)

行と列の両方を使って結合したいときは bind_rows() 関数を使う。 joininner を指定すると、両方にあるカラムだけを使って結合される。

>>> a >> bind_rows(b, join='inner')
     name
0   alice
1     bob
2  carrol
0   alice
1     bob
2  daniel

joinouter を指定したときは、存在しない行が NaN で埋められる。

>>> a >> bind_rows(b, join='outer')
    age is_male    name
0  20.0     NaN   alice
1  30.0     NaN     bob
2  40.0     NaN  carrol
0   NaN   False   alice
1   NaN    True     bob
2   NaN    True  daniel

dfply に対応した API を実装する

ここからは dfply に対応した API を実装する方法について書いていく。

pipe

最も基本となるのは @pipe デコレータで、これはデータフレームを受け取ってデータフレームを返す関数を定義する。 例えば、最も単純な処理として受け取ったデータフレームをそのまま返す関数を作ってみよう。

>>> @pipe
... def nop(df):
...     return df
... 

この関数も、ちゃんと dfply の API として機能する。

>>> diamonds >> nop() >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

次に、もう少し複雑な関数として、特定のカラムの値を 2 倍する関数を定義してみよう。 この中ではデータフレームのカラムの内容を上書きしている。

>>> @pipe
... def double(df, cols):
...     df[cols] = df[cols] * 2
...     return df
... 

使ってみると、ちゃんとカラムの値が 2 倍になっている。

>>> diamonds >> double(['carat', 'price']) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

カラムの内容を上書きしているということは、元のデータフレームの内容も書き換わっているのでは?と思うだろう。 しかし、確認すると元の値のままとなっている。

>>> diamonds >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

実は dfply では、右ビットシフト演算子が評価される度にデータフレームをディープコピーしている。 そのため、元のデータフレームが壊れることはない。

github.com

ただし、上記は大きなサイズのデータフレームを扱う上でパフォーマンス上の問題ともなる。 なぜなら、何らかの処理を評価するたびにメモリ上で大量のコピーが発生するため。 メモリのコピーは、大量のデータを処理する場合にスループットを高める上でボトルネックとなる。

Intention

ところで、先ほど定義した double() 関数は Intention を受け取ることができない。 試しに渡してみると、次のようなエラーになってしまう。

>>> diamonds >> double(X.carat, X.price) >> head()
Traceback (most recent call last):
...(snip)...
    return pipe(lambda x: self.function(x, *args, **kwargs))
TypeError: double() takes 2 positional arguments but 3 were given

配列として指定してもダメ。

>>> diamonds >> double(X.carat, X.price) >> head()
Traceback (most recent call last):
...(snip)...
    if len(arrays[i]) != len(arrays[i - 1]):
TypeError: __index__ returned non-int (type Intention)

上記がエラーになるのは、Intention を解決するのにデコレータの追加が必要なため。 具体的には symbolic_evaluation() を追加する。 こうすると、Intention が pandas.Series に解決した上で渡される。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_double(df, serieses):
...     for series in serieses:
...         df[series.name] = series * 2
...     return df
... 

上記を使ってみると、ちゃんと動作することがわかる。

>>> diamonds >> symbolic_double([X.carat, X.price]) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

この処理は、Intention を解決した上で Series として渡すだけなので、次のように任意の長さの引数として受け取ることもできる。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_double(df, *serieses):
...     for series in serieses:
...         df[series.name] = series * 2
...     return df
... 
>>> diamonds >> symbolic_double(X.carat, X.price) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

Intention 以外のオブジェクトを引数に受け取りたいときは、こんな感じ。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_multiply(df, n, serieses):
...     for series in serieses:
...         df[series.name] = series * n
...     return df
... 
>>> diamonds >> symbolic_multiply(3, [X.carat, X.price]) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.69    Ideal     E     SI2   61.5   55.0    978  3.95  3.98  2.43
1   0.63  Premium     E     SI1   59.8   61.0    978  3.89  3.84  2.31
2   0.69     Good     E     VS1   56.9   65.0    981  4.05  4.07  2.31
3   0.87  Premium     I     VS2   62.4   58.0   1002  4.20  4.23  2.63
4   0.93     Good     J     SI2   63.3   58.0   1005  4.34  4.35  2.75

ちなみに引数の eval_as_selectorTrue を指定すると、渡されるのが numpy 配列になる。 この配列はカラム名と同じ長さで、どのカラムが Intention によって指定されたかがビットマスクとして得られる。

>>> @pipe
... @symbolic_evaluation(eval_as_selector=True)
... def symbolic_double(df, *selected_masks):
...     # もし列の指定が入れ子になってるとしたらフラットに直す
...     selectors = np.array(list(flatten(selected_masks)))
...     selected_cols = [col for col, selected
...                      in zip(df.columns, np.any(selectors, axis=0))
...                      if selected]
...     df[selected_cols] = df[selected_cols] * 2
...     return df
... 
>>> diamonds >> symbolic_double(X.carat, X.price) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

WINDOW 関数を定義する

ただ、あんまり複雑な処理を単発の @pipe 処理で作るよりは、もっと小さな処理を組み合わせていく方が関数型プログラミングっぽくてキレイだと思う。 そこで、次は WINDOW 関数の作り方を扱う。

WINDOW 関数を定義したいときは、@make_symbolic をつけて Series を受け取る関数を作る。 例えばカラムの内容を 2 倍にする関数を作ってみよう。

>>> @make_symbolic
... def double(series):
...     return series * 2
... 

使ってみると、たしかに 2 倍になる。

>>> diamonds >> mutate(double_price=double(X.price)) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  double_price
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43           652
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31           652
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31           654
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63           668
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75           670

こちらの @make_symbolic も、Intention を解決して Series をインジェクトする以上の意味はない。 なので、次のように任意の長さのリストとして受け取ることもできる。

>>> @make_symbolic
... def add(*serieses):
...     return sum(serieses)
... 

上記は複数のカラムの内容を足し合わせる処理になっている。

>>> diamonds >> mutate(add_column=add(X.carat, X.price)) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  add_column
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      326.23
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      326.21
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      327.23
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      334.29
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      335.31

summarize 相当の処理を定義する

summarize 相当の関数は @group_delegation デコレータを使って作れる。

例えば要素数をカウントする関数を定義してみよう。

>>> @pipe
... @group_delegation
... def mycount(df):
...     return len(df)
... 

そのまま適用すれば、全体の要素数が得られる。

>>> diamonds >> mycount()
53940

group_by() とチェインすれば、グループ化した中での要素数が計算できる。

>>> diamonds >> group_by(X.cut) >> mycount()
cut
Fair          1610
Good          4906
Ideal        21551
Premium      13791
Very Good    12082
dtype: int64

一通り適用した関数を作るとき

ちなみに、ショートカット的な記述方法もあって、次のように @dfpipe デコレータを使うと...

>>> @dfpipe
... def myfunc(df):
...     return len(df)
... 

以下の 3 つのデコレータを組み合わせたのと同義になる。 WINDOW 関数は別として、いつもはこれを使っておけばとりあえず良いかもしれない。

>>> @pipe
... @group_delegation
... @symbolic_evaluation
... def myfunc(df):
...     return len(df)
... 

パフォーマンスに問題は抱えているけど、API はすごく面白いね。

子供が生まれました

このブログには、まれに技術系でないことも書くことがあり、今回もそれにあたります。 私事で恐縮ですが、先日子供が生まれました。 今のところ、健康に生まれて、順調に育っているようです。 この点は、本当に良かったと思います。

一方で自分自身に目を向けると、今後は一人の時間をコントロールすることが、さらに難しくなっていくと考えられます。 この点は、環境の変化に順応しつつ、なんとか工夫できるところを見つけていきたいです。 今後も、自分自身の成長と技術系コミュニティへの貢献に向けては、できるだけ精進していけたらと思います。


もみじあめの欲しいものリスト

もし、万が一にも気が向いたときにはよろしくお願いします。 頂けると、育児がはかどるものリストです。

Python: Optuna の LightGBMTuner で Stepwise Tuning を試す

先日の PyData.tokyo で発表されていた Optuna の LightGBMTuner だけど v0.18.0 でリリースされたらしい。 まだ Experimental (実験的) リリースでドキュメントも整備されていないけど、動くみたいなのでコードを眺めながら試してみた。

github.com

LightGBMTuner を使うことで、ユーザは LightGBM のハイパーパラメータを意識することなくチューニングできる。 チューニングには Stepwise Tuning という、特定のハイパーパラメータを一つずつ最適化していく手法が使われている。 これは、過去のコンペで実績のある手法らしい。 詳細については以下を参照のこと。

www.slideshare.net

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G1012
$ python -V
Python 3.7.5

もくじ

下準備

使うパッケージをインストールしておく。 なお、LightGBMTuner を動かす上で最低限必要なのは先頭から二つの Optuna と LightGBM だけ。

$ pip install optuna lightgbm seaborn scikit-learn sklearn-pandas category_encoders

使ってみる

今回は seaborn から読み込める Diamonds データセットを使って回帰のタスクを使う。 これは、それなりに行数のあるデータを使いたかったため。

以下が LightGBMTuner を使ってハイパーパラメータを最適化するサンプルコード。 基本的な使い方としては optuna.integration.lightgbm_tuner.train()lightgbm.train() の代わりに用いる。 これだけで透過的に、ハイパーパラメータが最適化された上で学習済みの Booster オブジェクトが返ってくる。 なお、今のところ lightgbm.cv() 相当の機能は実装されていないので、自分でデータを Holdout するなり CV する必要がある。 サンプルコードでは、比較用のためにデフォルトのパラメータで学習されたモデルのメトリック (MSE) も出力している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import category_encoders as ce
import seaborn as sns
import lightgbm as lgb
from optuna.integration import lightgbm_tuner
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn_pandas import DataFrameMapper


def main():
    # データセットを読み込む
    df = sns.load_dataset('diamonds')

    # ラベルエンコードする
    mapper = DataFrameMapper([
        ('cut', ce.OrdinalEncoder()),
        ('color', ce.OrdinalEncoder()),
        ('clarity', ce.OrdinalEncoder()),
    ], default=None, df_out=True)
    df = mapper.fit_transform(df)

    # 説明変数と目的変数に分ける
    X, y = df.drop('price', axis=1), df.price

    # Holt-out 検証用にデータを分割する
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        shuffle=True,
                                                        random_state=42)

    # 学習用データと検証用データに分割する
    X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train,
                                                shuffle=True,
                                                random_state=42)

    # LightGBM のデータセット表現にする
    lgb_train = lgb.Dataset(X_tr, y_tr)
    lgb_valid = lgb.Dataset(X_val, y_val, reference=lgb_train)

    # 学習用基本パラメータ
    lgb_params = {
        'objective': 'regression',
        'metric': 'rmse',
    }

    # Optuna でハイパーパラメータを Stepwise Optimization する
    tuned_booster = lightgbm_tuner.train(lgb_params, lgb_train,
                                         valid_sets=lgb_valid,
                                         num_boost_round=1000,
                                         early_stopping_rounds=100,
                                         verbose_eval=10,
                                         )

    # 比較用にデフォルトのパラメータを使ったモデルも用意する
    default_booster = lgb.train(lgb_params, lgb_train,
                                valid_sets=lgb_valid,
                                num_boost_round=1000,
                                early_stopping_rounds=100,
                                verbose_eval=10,
                                )

    # Optuna で最適化したモデルの Holt-out データに対するスコア
    y_pred_tuned = tuned_booster.predict(X_test)
    tuned_metric = mean_squared_error(y_test, y_pred_tuned)
    print('tuned model metric: ', tuned_metric)

    # デフォルトの Holt-out データに対するスコア
    y_pred_default = default_booster.predict(X_test)
    default_metric = mean_squared_error(y_test, y_pred_default)
    print('default model metric: ', default_metric)


if __name__ == '__main__':
    main()

上記を保存して実行してみよう。 time コマンドで実行時間も計測してみる。

$ time python lgbtune.py
...(snip)...
tuned model metric:  309501.36031006125
default model metric:  314903.9460911957
python lgbtune.py  324.61s user 6.12s system 298% cpu 1:50.82 total

ちゃんとチューニングしたモデルの方がデフォルトのパラメータより結果が良くなっている。 かつ、全体の実行時間も約 5 分で完了している。

これまでの経験から、ハイパーパラメータのチューニングはデフォルトのパラメータに勝つだけでも探索空間が広いとそれなりの時間を要する印象があった。 それを考えると LightGBMTuner (LightGBM + Stepwise Tuning) は短時間でベターな解を出してきているように感じる。

Kaggleで勝つデータ分析の技術

Kaggleで勝つデータ分析の技術

  • 作者: 門脇大輔,阪田隆司,保坂桂佑,平松雄司
  • 出版社/メーカー: 技術評論社
  • 発売日: 2019/10/09
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

Python: 広義の Target Encoding と Stacking は同じもの (と解釈できる)

おそらく、既に分かっている人には「知らなかったの?」とびっくりされる系の話なんだろうけど、今さら理解したので備忘録として残しておく。 結論から書くと、目的変数を用いた特徴量生成を広義の Target Encoding と定義した場合、Target Encoding と Stacking は同じものと解釈できる。 例えば、Target Mean Encoding は多項分布を仮定したナイーブベイズ分類器を用いた Stacking とやっていることは同じになる。 また、Target Encoding と Stacking が同じものであると解釈することで、周辺の知識についても理解しやすくなる。

Target Encoding について

Target Encoding は、データ分析コンペで用いられることがある特徴量生成 (Feature Extraction) の手法のこと。 一般的にはカテゴリ変数と目的変数について統計量を計算して、それを新たな特徴量として用いる。 統計量には平均値が使われることが多く、この点から平均値を使うものを Target Mean Encoding と限定して呼ぶこともある。

このエントリでは、上記のようにカテゴリ変数と目的変数、および関連する特徴量について統計量を扱うものを狭義の Target Encoding と定義する。 それに対し、目的変数を使った何らか (任意) の特徴量生成の手法を広義の Target Encoding と定義する。

きっかけについて

久しぶりにオライリーの「機械学習のための特徴量エンジニアリング」を読み返していたところ、以下のような記述があった。

5.2.2 ビンカウンティング

ビンカウンティングの考え方はとても簡単です。カテゴリ値をエンコードして特徴量として使用する代わりに、カテゴリごとに何らかの値を集計した統計量を利用します。カテゴリごとにター ゲットの値を集計して算出した条件付き確率は、そのような統計量の一例です。ナイーブベイズ分類器に精通している人はピンとくるはずです。なぜなら、ナイーブベイズ分類器では特徴量が互いに独立と考えてクラスの条件付き確率を求めたからです。

機械学習のための特徴量エンジニアリング ―その原理とPythonによる実践 (オライリー・ジャパン)

機械学習のための特徴量エンジニアリング ―その原理とPythonによる実践 (オライリー・ジャパン)

  • 作者: Alice Zheng,Amanda Casari,株式会社ホクソエム
  • 出版社/メーカー: オライリージャパン
  • 発売日: 2019/02/23
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

上記でビンカウンティングの一例として挙げられている処理は Target (Mean) Encoding を指している。 そして、やっていることはナイーブベイズ分類器を使って計算した条件付き確率と同じ、とある。 これは、Target Mean Encoding がナイーブベイズ分類器を使った Stacking である、とも解釈できる。

確かめてみよう

念のため、実際にコードで確認してみよう。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G103
$ python -V                      
Python 3.7.5

下準備

下準備として必要なパッケージをインストールしておく。

$ pip install scikit-learn pandas

Python のインタプリタを起動する。

$ python

次のようなサンプルデータを用意する。 とあるフルーツの名前をカテゴリ変数の特徴量として、それが美味しいかについて二値のラベルがついていると解釈してもらえれば。

>>> import pandas as pd
>>> 
>>> data = {
...     'category': ['apple', 'apple',
...                  'banana', 'banana', 'banana',
...                  'cherry', 'cherry', 'cherry', 'cherry',
...                  'durian'],
...     'label': [0, 1,
...               0, 0, 1,
...               0, 1, 1, 1,
...               1],
... }
>>> 
>>> df = pd.DataFrame(data=data)
>>> df
  category  label
0    apple      0
1    apple      1
2   banana      0
3   banana      0
4   banana      1
5   cherry      0
6   cherry      1
7   cherry      1
8   cherry      1
9   durian      1

Target Mean Encoding の計算

単純な Target Mean Encoding では、カテゴリ変数ごとの目的変数の平均値を計算する。 つまり、以下のようになる。

>>> greedy_ts = df.groupby('category').agg({'label': 'mean'})
>>> pd.merge(df, greedy_ts, on='category', right_index=True)
  category  label_x   label_y
0    apple        0  0.500000
1    apple        1  0.500000
2   banana        0  0.333333
3   banana        0  0.333333
4   banana        1  0.333333
5   cherry        0  0.750000
6   cherry        1  0.750000
7   cherry        1  0.750000
8   cherry        1  0.750000
9   durian        1  1.000000

なお、上記のように学習データ全体を使った計算方法を Greedy TS と呼ぶ。 Greedy TS はリークが生じるため、本来は Target Encoding するときには避けた方が良い。 ただし、今回はリークの説明がしたいわけではないので気にしない。 気になる人は末尾の参考文献のブログエントリを読んでもらえれば。

多項分布を仮定したナイーブベイズ分類器を用いた Stacking

続いては多項分布を仮定したナイーブベイズ分類器を使って Stacking してみる。

まずは scikit-learn のモデルから使いやすいように、特徴量を One-Hot エンコードしておく。

>>> from sklearn.preprocessing import OneHotEncoder
>>> 
>>> encoder = OneHotEncoder(sparse=False)
>>> X = encoder.fit_transform(df[['category']])
>>> y = df.label.values

それぞれのフルーツごとに対応した次元ができる。

>>> X
array([[1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]])
>>> y
array([0, 1, 0, 0, 1, 0, 1, 1, 1, 1])

多項分布を仮定したナイーブベイズ分類器を用意する。 Smoothing しないので alpha オプションには 0 を指定する。

>>> from sklearn.naive_bayes import MultinomialNB
>>> clf = MultinomialNB(alpha=0)

データ全体を学習させたら predict_proba() メソッドで推論する。

>>> clf.fit(X, y)
>>> y_pred_proba = clf.predict_proba(X)

得られた結果を、元のデータと連結してみよう。

>>> df.join(pd.Series(y_pred_proba[:, 1], name='y_pred_proba'))
  category  label  y_pred_proba
0    apple      0      0.500000
1    apple      1      0.500000
2   banana      0      0.333333
3   banana      0      0.333333
4   banana      1      0.333333
5   cherry      0      0.750000
6   cherry      1      0.750000
7   cherry      1      0.750000
8   cherry      1      0.750000
9   durian      1      1.000000

多項分布ナイーブベイズ分類器から得れた特徴量は、先ほど手作業で作った Target Mean Encoding の特徴量と一致している。

上記から、Target (Mean) Encoding と Stacking のつながりが見えてくる。 GBDT や NN などを用いた Stacking も、既存の特徴量と目的変数から新たな (メタ) 特徴量を作るという点で、広義の Target Encoding とやっていることは変わらない。 この点を理解することで、次のようなことを考えた。

Stacking で OOF Prediction する理由を説明しやすい

学習データ全体を使って Stacking するとリークが生じることが知られている。 この原理は、Target Encoding がリークを起こす仕組みと変わらない。 特徴量を付与する対象の行をモデルの学習データに含めることは、Target Mean Encoding で Greedy TS を計算するのと同じことになる。 もし Stacking でリークする理由がイメージしにくかったとしても、より単純な Target Mean Encoding を例に挙げれば理解しやすい。 それを防ぐ方法として Holdout TS (OOF Prediction) がある理由も分かりやすいはず。 これは Stacking が何段になっても、Target Encoding を複数回やっているのと同じことなので分割方法を使い回さなければいけない理由も直感的に理解できる。

コードを共通化できる可能性がある

これは Target Encoding で、こっちは Stacking というように、別々の概念としてコードを書く必要がなくなるかもしれない。 例えば目的変数を使う特徴量生成と、使わない特徴量生成くらいのざっくりした概念として扱えるとうれしい。 もしコードが共通化できるのであれば、パイプラインを作る観点で有用と考えられる。

それぞれで用いられている手法のお互いへの応用も可能では

両者が同じものだとすると、それぞれで用いられている手法を互いに応用できる可能性が出てくる。

例えば Target Encoding のリークを防ぐ手法として Ordered TS という計算方法が提案されている。 Target Encoding と Stacking が同一だとすれば、Ordered TS の計算方法を Stacking にも応用できるのではないか。 Ordered TS を用いると、Holdout TS よりもリークしにくいのに加えて、計算量の削減にもなると考えられる。

Holdout TS では、分割数を  k とした場合、計算量は  O((k - 1) N) になる。 それに対し、Ordered TS では  O(N) になるはずなので。

ただ Ordered TS は履歴が十分に貯まるまでポンコツな結果が出てしまう問題があるので、実用的かどうかは分からない。

Stacking と Target Mean Encoding で上位にくるモデルの違いについて

Stacking では、一般的に下位の層には GBDT や NN といった表現力の高いモデルを用いる。 そして、上位の層では過学習を防ぐために線形モデルなど単純なモデルを使われることが多い。

それに対し、Target Mean Encoding を Stacking と解釈した場合、下位の層がナイーブベイズ分類器という単純なモデルになっている。 そのため上位には表現力の高い GBDT や NN が使われることになる。

このように、両者を同一視した場合、表現力によって上位と下位のモデルが組み合わせになっていることも納得できる。

参考文献

機械学習のための特徴量エンジニアリング ―その原理とPythonによる実践 (オライリー・ジャパン)

機械学習のための特徴量エンジニアリング ―その原理とPythonによる実践 (オライリー・ジャパン)

  • 作者: Alice Zheng,Amanda Casari,株式会社ホクソエム
  • 出版社/メーカー: オライリージャパン
  • 発売日: 2019/02/23
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

Kaggleで勝つデータ分析の技術

Kaggleで勝つデータ分析の技術

  • 作者: 門脇大輔,阪田隆司,保坂桂佑,平松雄司
  • 出版社/メーカー: 技術評論社
  • 発売日: 2019/10/09
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

blog.amedama.jp

CatBoost: unbiased boosting with categorical features (PDF)

trap コマンドを使ったシェルスクリプトのエラーハンドリング

今回は、シェルの組み込みコマンドの trap を使ったシェルスクリプトのエラーハンドリングについて。 シェルの組み込みコマンド trap は、特定のシグナルやコマンドの返り値が非ゼロとなったときに実行する処理を指定できる。

trap コマンドは、次のようにして使う。 以下の <arg> が実行する処理で、<sigspec> が反応させたいシグナルや状況となる。

$ trap <arg> <sigspec>

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G103
$ bash -version
GNU bash, version 3.2.57(1)-release (x86_64-apple-darwin18)
Copyright (C) 2007 Free Software Foundation, Inc.

コマンドが非ゼロの返り値を返したときのハンドリング

以下のサンプルコードでは、コマンドの返り値が非ゼロになったときに標準エラー出力に "ERROR" という文字列を表示する。 このサンプルコードでは、最後に実行結果として非ゼロを返す false コマンドを実行しているため、必ずエラーハンドラが実行される。 trap コマンドの <arg> には関数 error_handler() を指定してあって、<sigspec> には非ゼロが返ったときに反応する ERR を指定している。

#!/usr/bin/env bash

# 実行した処理を標準出力に記録する
set -x

# エラーになったときに実行したい関数
function error_handler() {
  # 何か起きたことを標準エラー出力に書く
  echo "ERROR" >&2
  # スクリプトを終了する
  exit 1
}

# コマンドの返り値が非ゼロのときハンドラを実行するように指定する
trap error_handler ERR

# 例として非ゼロを返すコマンドを実行する
false

上記を適当な名前で保存して実行してみよう。

$ bash errhandle.sh 
+ trap error_handler ERR
+ false
++ error_handler
++ echo ERROR
ERROR
++ exit 1

ちゃんとエラーハンドラが発火して "ERROR" という文字列が表示されていることがわかる。

ちなみに、trap コマンドで指定されたハンドラは、スクリプトの中で set -E されていたとしても発火する。 set -E は、コマンドの返り値が非ゼロになった時点でスクリプトの実行を止めるという指定になる。

#!/usr/bin/env bash

# コマンドの返り値が非ゼロになった時点で止める
set -E

# 実行した処理を標準出力に記録する
set -x

# エラーになったときに実行したい関数
function error_handler() {
  # 何か起きたことを標準エラー出力に書く
  echo "ERROR" >&2
  # スクリプトを終了する
  exit 1
}

# コマンドの返り値が非ゼロのときハンドラを実行するように指定する
trap error_handler ERR

# 例として非ゼロを返すコマンドを実行する
false

実行結果は先ほどと変わらない。

$ bash errhandle.sh
+ trap error_handler ERR
+ false
++ error_handler
++ echo ERROR
ERROR
++ exit 1

プロセスが特定のシグナルを受信したときのハンドリング

同様に、プロセスが特定のシグナルを受信したときのハンドリングについても確認しておく。

以下のサンプルコードでは SIGINT シグナルを受信したときにハンドラが発火するように trap コマンドで指定している。 ハンドラでは "SIGINT" という文字列を標準エラー出力に表示する。 スクリプトは 2 秒のスリープをはさみながら、無限ループで SIGINT を待ち受ける。

#!/usr/bin/env bash

# SIGINT を受け取ったら実行するハンドラ
function sigint_handler() {
  echo "SIGINT" >&2
  exit 0
}

# SIGINT を受け取ったときにハンドラを実行する
trap sigint_handler SIGINT

# 無限ループで SIGINT を待つ
while true;
do
  echo "press Ctrl+C to stop"
  sleep 2
done

上記も名前をつけて保存したら実行してみよう。 SIGINT はキーボードの Ctrl + C キーを使って送れる。

$ bash trapsigint.sh
press Ctrl+C to stop
press Ctrl+C to stop
press Ctrl+C to stop
^CSIGINT

どうやら、ちゃんとハンドラが発火しているようだ。

いじょう。