CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: dfply を使ってみる

R には、データフレームを関数型プログラミングっぽく操作できるようになる dplyr というパッケージがある。 今回紹介する dfply は、その API を Python に移植したもの。 実用性云々は別としても、なかなか面白い作りで参考になった。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G1012
$ python -V        
Python 3.7.5

もくじ

下準備

まずは下準備として dfply をインストールしておく。

$ pip install dfply

Python のインタプリタを起動する。

$ python

ちょっとお行儀が悪いけど dfply 以下をワイルドカードインポートしておく。

>>> from dfply import *

基本的な使い方

例えば dfply には diamonds データセットがサンプルとして組み込まれている。 これは、ダイヤモンドの大きさや色などの情報と付けられた値段が含まれる。

>>> diamonds.head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

上記では DataFrame#head() を使って先頭を取り出した。 dfply では、同じことを右ビットシフト用の演算子 (>>) と head() 関数を使って次のように表現する。

>>> diamonds >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

これだけでピンとくる人もいるだろうけど、上記はようするにメソッドチェーンと同じこと。 例えば head()tail() を組み合わせれば、途中の要素を取り出すことができる。

>>> diamonds >> head(4) >> tail(2)
   carat      cut color clarity  depth  table  price     x     y     z
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63

同じことを DataFrame 標準の API でやるとしたら、こうかな?

>>> diamonds.iloc[:4].iloc[2:]
   carat      cut color clarity  depth  table  price     x     y     z
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63

ちなみに head()tail() を組み合わせなくても row_slice() を使えば一発でいける。

>>> diamonds >> row_slice([2, 4])
   carat   cut color clarity  depth  table  price     x     y     z
2   0.23  Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
4   0.31  Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

列を選択する (select / drop)

ここまでは行を取り出していたけど、select() を使えば列を取り出せる。

>>> diamonds >> select(['carat', 'cut', 'price']) >> head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

同じことを DataFrame 標準の API でやろうとしたら、こうかな。

>>> diamonds[['carat', 'cut', 'price']].head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

select() とは反対に、それ以外を取り出したいときは drop() を使う。

>>> diamonds >> drop(['carat', 'cut', 'price']) >> head()
  color clarity  depth  table     x     y     z
0     E     SI2   61.5   55.0  3.95  3.98  2.43
1     E     SI1   59.8   61.0  3.89  3.84  2.31
2     E     VS1   56.9   65.0  4.05  4.07  2.31
3     I     VS2   62.4   58.0  4.20  4.23  2.63
4     J     SI2   63.3   58.0  4.34  4.35  2.75

また、dfply の特徴的な点として Intention というオブジェクトがある。 一般的には、最初から用意されている X というオブジェクトを使えば良い。

>>> X
<dfply.base.Intention object at 0x10cf4c6d0>

例えば、さっきの select() と同じことを Intention を使って次のように書ける。

>>> diamonds >> select(X.carat, X.cut, X.price) >> head()
   carat      cut  price
0   0.23    Ideal    326
1   0.21  Premium    326
2   0.23     Good    327
3   0.29  Premium    334
4   0.31     Good    335

これだけだと何が嬉しいのって感じだけど、Intention を使えば否定条件が書けたりもする。

>>> diamonds >> select(~X.carat, ~X.cut, ~X.price) >> head()
  color clarity  depth  table     x     y     z
0     E     SI2   61.5   55.0  3.95  3.98  2.43
1     E     SI1   59.8   61.0  3.89  3.84  2.31
2     E     VS1   56.9   65.0  4.05  4.07  2.31
3     I     VS2   62.4   58.0  4.20  4.23  2.63
4     J     SI2   63.3   58.0  4.34  4.35  2.75

また、select()drop() には、カラムの名前を使った絞り込みをする関数を渡せる。 例えば c から始まるカラムがほしければ starts_with() を使って次のように書ける。

>>> diamonds >> select(~starts_with('c')) >> head()
   depth  table  price     x     y     z
0   61.5   55.0    326  3.95  3.98  2.43
1   59.8   61.0    326  3.89  3.84  2.31
2   56.9   65.0    327  4.05  4.07  2.31
3   62.4   58.0    334  4.20  4.23  2.63
4   63.3   58.0    335  4.34  4.35  2.75

もし、DataFrame 標準の API で書くとしたら、こんな感じかな?

>>> diamonds[[col for col in diamonds.columns if col.startswith('c')]].head()
   carat      cut color clarity
0   0.23    Ideal     E     SI2
1   0.21  Premium     E     SI1
2   0.23     Good     E     VS1
3   0.29  Premium     I     VS2
4   0.31     Good     J     SI2

この他にも、色々とある。

>>> diamonds >> select(ends_with('e')) >> head()
   table  price
0   55.0    326
1   61.0    326
2   65.0    327
3   58.0    334
4   58.0    335
>>> diamonds >> select(contains('a')) >> head()
   carat clarity  table
0   0.23     SI2   55.0
1   0.21     SI1   61.0
2   0.23     VS1   65.0
3   0.29     VS2   58.0
4   0.31     SI2   58.0
>>> diamonds >> select(columns_between('color', 'depth')) >> head()
  color clarity  depth
0     E     SI2   61.5
1     E     SI1   59.8
2     E     VS1   56.9
3     I     VS2   62.4
4     J     SI2   63.3

ちなみに、これらを混ぜて select() に放り込むこともできる。

>>> diamonds >> select('cut', [X.depth, X.table], columns_from('y')) >> head()
       cut  depth  table     y     z
0    Ideal   61.5   55.0  3.98  2.43
1  Premium   59.8   61.0  3.84  2.31
2     Good   56.9   65.0  4.07  2.31
3  Premium   62.4   58.0  4.23  2.63
4     Good   63.3   58.0  4.35  2.75

順序を並び替える (arrange)

特定のカラムを基準にして順序を並び替えるときは arrange() 関数を使う。

>>> diamonds >> arrange(X.carat) >> head()
       carat      cut color clarity  depth  table  price     x     y     z
31593    0.2  Premium     E     VS2   61.1   59.0    367  3.81  3.78  2.32
31597    0.2    Ideal     D     VS2   61.5   57.0    367  3.81  3.77  2.33
31596    0.2  Premium     F     VS2   62.6   59.0    367  3.73  3.71  2.33
31595    0.2    Ideal     E     VS2   59.7   55.0    367  3.86  3.84  2.30
31594    0.2  Premium     E     VS2   59.7   62.0    367  3.84  3.80  2.28

デフォルトは昇順なので、降順にしたいときは ascending オプションに False を指定する。

>>> diamonds >> arrange(X.carat, ascending=False) >> head()
       carat      cut color clarity  depth  table  price      x      y     z
27415   5.01     Fair     J      I1   65.5   59.0  18018  10.74  10.54  6.98
27630   4.50     Fair     J      I1   65.8   58.0  18531  10.23  10.16  6.72
27130   4.13     Fair     H      I1   64.8   61.0  17329  10.00   9.85  6.43
25999   4.01  Premium     J      I1   62.5   62.0  15223  10.02   9.94  6.24
25998   4.01  Premium     I      I1   61.0   61.0  15223  10.14  10.10  6.17

行でサンプリングする (sampling)

行をサンプリングするときは sampling() 関数を使う。 割合で指定したいときは frac オプションを指定する。

>>> diamonds >> sample(frac=0.01)
       carat        cut color clarity  depth  table  price     x     y     z
51269   0.72  Very Good     I     VS2   61.6   59.0   2359  5.71  5.75  3.53
49745   0.70       Good     G     SI1   61.8   62.0   2155  5.68  5.72  3.52
23252   1.40  Very Good     G     VS1   62.6   58.0  11262  7.03  7.07  4.41
36940   0.23  Very Good     D    VVS1   63.3   57.0    478  3.90  3.93  2.48
24644   1.79    Premium     I     VS1   62.6   59.0  12985  7.65  7.72  4.81
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
53913   0.80       Good     G     VS2   64.2   58.0   2753  5.84  5.81  3.74
20653   1.01       Good     D    VVS2   63.5   57.0   8943  6.32  6.35  4.02
17544   1.01       Good     F    VVS2   63.6   60.0   7059  6.36  6.31  4.03
45636   0.25  Very Good     G    VVS1   60.6   55.0    525  4.12  4.14  2.50
30774   0.35      Ideal     G     VS1   61.3   54.0    741  4.58  4.63  2.83

[539 rows x 10 columns]

具体的な行数は n オプションを指定すれば良い。

>>> diamonds >> sample(n=100)
       carat        cut color clarity  depth  table  price     x     y     
46135   0.41      Ideal     E    VVS1   61.1   56.0   1745  4.80  4.82  2.94
35405   0.32      Ideal     E     VS2   61.9   56.0    900  4.40  4.36  2.71
30041   0.33      Ideal     I      IF   61.5   56.0    719  4.43  4.47  2.74
313     0.61      Ideal     G      IF   62.3   56.0   2800  5.43  5.45  3.39
24374   0.34      Ideal     E     SI1   61.0   55.0    637  4.54  4.56  2.77
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
27244   2.20    Premium     H     SI2   62.7   58.0  17634  8.33  8.27  5.20
17487   1.05    Premium     F     VS2   62.6   58.0   7025  6.47  6.50  4.06
52615   0.77    Premium     H     VS2   59.4   60.0   2546  6.00  5.96  3.55
12670   1.07  Very Good     E     SI2   61.7   58.0   5304  6.54  6.56  4.04
16466   1.25      Ideal     G     SI1   62.5   54.0   6580  6.88  6.85  4.29

[100 rows x 10 columns]

内容が重複した行を取り除く (distinct)

重複した要素を取り除くときは dictinct() 関数を使う。

>>> diamonds >> distinct('color')
    carat        cut color clarity  depth  table  price     x     y     z
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
7    0.26  Very Good     H     SI1   61.9   55.0    337  4.07  4.11  2.53
12   0.22    Premium     F     SI1   60.4   61.0    342  3.88  3.84  2.33
25   0.23  Very Good     G    VVS2   60.4   58.0    354  3.97  4.01  2.41
28   0.23  Very Good     D     VS2   60.5   61.0    357  3.96  3.97  2.40

特定の条件に一致した行を取り出す (mask)

特定の条件に一致した行を取り出したいときは mask() 関数を使う。 Intention と組み合わせると、なかなか直感的に書ける。 例えば cut'Ideal' なものだけ取り出したいなら、こう。

>>> diamonds >> mask(X.cut == 'Ideal') >> head()
    carat    cut color clarity  depth  table  price     x     y     z
0    0.23  Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
11   0.23  Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46
13   0.31  Ideal     J     SI2   62.2   54.0    344  4.35  4.37  2.71
16   0.30  Ideal     I     SI2   62.0   54.0    348  4.31  4.34  2.68
39   0.33  Ideal     I     SI2   61.8   55.0    403  4.49  4.51  2.78

引数を増やすことでアンド条件にできる。 これは cut'Ideal' で、かつ carat1.0 以上のものを取り出す場合。

>>> diamonds >> mask(X.cut == 'Ideal', X.carat > 1.0) >> head()
     carat    cut color clarity  depth  table  price     x     y     z
653   1.01  Ideal     I      I1   61.5   57.0   2844  6.45  6.46  3.97
715   1.02  Ideal     H     SI2   61.6   55.0   2856  6.49  6.43  3.98
865   1.02  Ideal     I      I1   61.7   56.0   2872  6.44  6.49  3.99
918   1.02  Ideal     J     SI2   60.3   54.0   2879  6.53  6.50  3.93
992   1.01  Ideal     I      I1   61.5   57.0   2896  6.46  6.45  3.97

mask() 関数には filter_by() という名前のエイリアスもある。

>>> diamonds >> filter_by(X.cut == 'Ideal', X.carat > 1.0) >> head()
     carat    cut color clarity  depth  table  price     x     y     z
653   1.01  Ideal     I      I1   61.5   57.0   2844  6.45  6.46  3.97
715   1.02  Ideal     H     SI2   61.6   55.0   2856  6.49  6.43  3.98
865   1.02  Ideal     I      I1   61.7   56.0   2872  6.44  6.49  3.99
918   1.02  Ideal     J     SI2   60.3   54.0   2879  6.53  6.50  3.93
992   1.01  Ideal     I      I1   61.5   57.0   2896  6.46  6.45  3.97

複数のカラムを組み合わせたカラムを作る (mutate)

複数のカラムを組み合わせて新しい特徴量などのカラムを作るときは mutate() 関数が使える。

例えば xy のカラムを足した新たなカラムをデータフレームに追加したいときは、次のようにする。 引数の名前は追加するカラムの名前に使われる。

>>> diamonds >> mutate(x_plus_y=X.x+X.y) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  x_plus_y
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      7.93
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      7.73
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      8.12
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      8.43
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      8.69

もちろん、3 つ以上のカラムの組み合わせでも構わない。

>>> diamonds >> mutate(plus_xyz=X.x+X.y+X.z) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  plus_xyz
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43     10.36
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31     10.04
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31     10.43
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63     11.06
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75     11.44

また、一度に複数のカラムを作ることもできる。

>>> diamonds >> mutate(x_plus_y=X.x+X.y, x_minus_y=X.x-X.y) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  x_plus_y  x_minus_y
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      7.93      -0.03
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      7.73       0.05
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      8.12      -0.02
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      8.43      -0.03
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      8.69      -0.01

もし、作ったカラムだけがほしいときは transmute() 関数を使えば良い。

>>> diamonds >> transmute(x_plus_y=X.x+X.y, x_minus_y=X.x-X.y) >> head()
   x_plus_y  x_minus_y
0      7.93      -0.03
1      7.73       0.05
2      8.12      -0.02
3      8.43      -0.03
4      8.69      -0.01

カラムの名前を変更する (rename)

もし、カラムの名前を変えたくなったときは rename() 関数を使えば良い。 カラムの順番も入れ替わることがない。

>>> diamonds >> rename(new_x=X.x, new_y=X.y) >> head()
   carat      cut color clarity  depth  table  price  new_x  new_y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326   3.95   3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326   3.89   3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327   4.05   4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334   4.20   4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335   4.34   4.35  2.75

特定のグループ毎に集計する (group_by)

特定のグループ毎に何らかの集計をしたいときは group_by() 関数を使う。 ただし、一般的にイメージする SQL などのそれとは少し異なる。

例えば、ただ group_by() するだけではデータフレームに何も起きない。

>>> diamonds >> group_by(X.cut)
       carat        cut color clarity  depth  table  price     x     y     z
0       0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1       0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2       0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3       0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4       0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
...      ...        ...   ...     ...    ...    ...    ...   ...   ...   ...
53935   0.72      Ideal     D     SI1   60.8   57.0   2757  5.75  5.76  3.50
53936   0.72       Good     D     SI1   63.1   55.0   2757  5.69  5.75  3.61
53937   0.70  Very Good     D     SI1   62.8   60.0   2757  5.66  5.68  3.56
53938   0.86    Premium     H     SI2   61.0   58.0   2757  6.15  6.12  3.74
53939   0.75      Ideal     D     SI2   62.2   55.0   2757  5.83  5.87  3.64

[53940 rows x 10 columns]

では、どのように使うかというと、別の何らかの処理と組み合わせて使うことで真価を発揮する。 例えば、cut カラムごとに price の平均値を計算したい、という場合には次のようにする。

>>> diamonds >> group_by(X.cut) >> mutate(price_mean=mean(X.price)) >> head(3)
    carat        cut color clarity  depth  table  price     x     y     z   price_mean
8    0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49  4358.757764
91   0.86       Fair     E     SI2   55.1   69.0   2757  6.45  6.33  3.52  4358.757764
97   0.96       Fair     F     SI2   66.3   62.0   2759  6.27  5.95  4.07  4358.757764
2    0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31  3928.864452
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75  3928.864452
10   0.30       Good     J     SI1   64.0   55.0    339  4.25  4.28  2.73  3928.864452
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43  3457.541970
11   0.23      Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46  3457.541970
13   0.31      Ideal     J     SI2   62.2   54.0    344  4.35  4.37  2.71  3457.541970
1    0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31  4584.257704
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63  4584.257704
12   0.22    Premium     F     SI1   60.4   61.0    342  3.88  3.84  2.33  4584.257704
5    0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48  3981.759891
6    0.24  Very Good     I    VVS1   62.3   57.0    336  3.95  3.98  2.47  3981.759891
7    0.26  Very Good     H     SI1   61.9   55.0    337  4.07  4.11  2.53  3981.759891

上記を見てわかる通り、集計した処理が全ての行に反映されている。 いうなれば、これは SQL の WINDOW 関数に PartitionBy を指定した処理に相当している。 その証左として、例えば lead() 関数や lag() 関数が使える。

>>> diamonds >> group_by(X.cut) >> transmute(X.price, next=lead(X.price), prev=lag(X.price)) >> head(3)
          cut    next    prev  price
8        Fair  2757.0     NaN    337
91       Fair  2759.0   337.0   2757
97       Fair  2762.0  2757.0   2759
2        Good   335.0     NaN    327
4        Good   339.0   327.0    335
10       Good   351.0   335.0    339
0       Ideal   340.0     NaN    326
11      Ideal   344.0   326.0    340
13      Ideal   348.0   340.0    344
1     Premium   334.0     NaN    326
3     Premium   342.0   326.0    334
12    Premium   345.0   334.0    342
5   Very Good   336.0     NaN    336
6   Very Good   337.0   336.0    336
7   Very Good   338.0   336.0    337

ただし、ここで一つ気になることがある。 もし、途中からグループ化しない集計をしたいときは、どうしたら良いのだろうか。

例えば、次のように cut ごとに先頭 2 つの要素を取り出すとする。

>>> diamonds >> group_by(X.cut) >> head(2)
    carat        cut color clarity  depth  table  price     x     y     z
8    0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49
91   0.86       Fair     E     SI2   55.1   69.0   2757  6.45  6.33  3.52
2    0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
4    0.31       Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
0    0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
11   0.23      Ideal     J     VS1   62.8   56.0    340  3.93  3.90  2.46
1    0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
3    0.29    Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
5    0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48
6    0.24  Very Good     I    VVS1   62.3   57.0    336  3.95  3.98  2.47

もし、ここからさらに全体における先頭 1 つの要素を取り出したいときは、どうしたら良いだろう。あ ただ head() するだけだと、グループごとに先頭 1 要素が取り出されてしまう。

>>> diamonds >> group_by(X.cut) >> head(2) >> head(1)
   carat        cut color clarity  depth  table  price     x     y     z
8   0.22       Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49
2   0.23       Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
0   0.23      Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21    Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
5   0.24  Very Good     J    VVS2   62.8   57.0    336  3.94  3.96  2.48

この問題を解決するには ungroup() 関数を用いる。

>>> diamonds >> group_by(X.cut) >> head(2) >> ungroup() >> head(1)
   carat   cut color clarity  depth  table  price     x     y     z
8   0.22  Fair     E     VS2   65.1   61.0    337  3.87  3.78  2.49

色々な WINDOW 関数

いくつか dfply で使える WINDOW 関数を紹介しておく。

カラムの値が特定の範囲に収まるか真偽値を返すのが between() 関数。

>>> diamonds >> transmute(X.price, price_between=between(X.price, 330, 340)) >> head()
   price_between  price
0          False    326
1          False    326
2          False    327
3           True    334
4           True    335

同じ値は同じランクとして、間を空けずにランク付けするのが dense_rank() 関数。

>>> diamonds >> transmute(X.price, drank=dense_rank(X.price)) >> head()
   drank  price
0    1.0    326
1    1.0    326
2    2.0    327
3    3.0    334
4    4.0    335

同じ値は同じランクとして、間を空けてランク付けするのが min_rank() 関数。

>>> diamonds >> transmute(X.price, mrank=min_rank(X.price)) >> head()
   mrank  price
0    1.0    326
1    1.0    326
2    3.0    327
3    4.0    334
4    5.0    335

単純な行番号が row_number() 関数。

>>> diamonds >> transmute(X.price, rownum=row_number(X.price)) >> head()
   rownum  price
0     1.0    326
1     2.0    326
2     3.0    327
3     4.0    334
4     5.0    335

標準化したランク付けをするのが percent_rank() 関数。

>>> diamonds >> transmute(X.price, prank=percent_rank(X.price)) >> head()
      prank  price
0  0.000000    326
1  0.000000    326
2  0.000037    327
3  0.000056    334
4  0.000074    335

積算値を計算するのが cunsum() 関数。

>>> diamonds >> transmute(X.price, cumprice=cumsum(X.price)) >> head()
   cumprice  price
0       326    326
1       652    326
2       979    327
3      1313    334
4      1648    335

積算の平均値を計算するのが cummean() 関数。

>>> diamonds >> transmute(X.price, cummean=cummean(X.price)) >> head()
      cummean  price
0  326.000000    326
1  326.000000    326
2  326.333333    327
3  328.250000    334
4  329.600000    335

集計値を計算する (summarize)

一般的な group by と聞いて思い浮かべる処理は、むしろこちらの summarize() 関数の方だろう。

例えば、表全体の要約統計量として平均と標準偏差を計算してみよう。

>>> diamonds >> summarize(price_mean=X.price.mean(), price_std=X.price.std())
    price_mean    price_std
0  3932.799722  3989.439738

上記は Intention に生えているメソッドを使って計算したけど、以下のように関数を使うこともできる。

>>> diamonds >> summarize(price_mean=mean(X.price), price_std=sd(X.price))
    price_mean    price_std
0  3932.799722  3989.439738

また、group_by() と組み合わせて使うこともできる。 例えば cut ごとに統計量を計算してみよう。

>>> diamonds >> group_by(X.cut) >> summarize(price_mean=mean(X.price), price_std=sd(X.price))
         cut   price_mean    price_std
0       Fair  4358.757764  3560.386612
1       Good  3928.864452  3681.589584
2      Ideal  3457.541970  3808.401172
3    Premium  4584.257704  4349.204961
4  Very Good  3981.759891  3935.862161

集計に使う関数は、組み込み以外のものを使うこともできる。 例えば numpy の関数を使ってみることに使用。

>>> import numpy as np
>>> diamonds >> group_by(X.cut) >> summarize(price_mean=np.mean(X.price), price_std=np.std(X.price))
         cut   price_mean    price_std
0       Fair  4358.757764  3559.280730
1       Good  3928.864452  3681.214352
2      Ideal  3457.541970  3808.312813
3    Premium  4584.257704  4349.047276
4  Very Good  3981.759891  3935.699276

平均や標準偏差の他にも、サイズや重複を除いたサイズを計算する関数なんかもある。

>>> diamonds >> group_by(X.cut) >> summarize(size=n(X.price), distinct_size=n_distinct(X.price))
         cut   size  distinct_size
0       Fair   1610           1267
1       Good   4906           3086
2      Ideal  21551           7281
3    Premium  13791           6014
4  Very Good  12082           5840

一度に計算したいときは、こんな感じでやればいいかな?

>>> stats = {
...     'iqr': IQR(X.price),
...     'max': colmax(X.price),
...     'q75': X.price.quantile(0.75),
...     'mean': mean(X.price),
...     'median': median(X.price),
...     'q25': X.price.quantile(0.25),
...     'min': colmin(X.price),
... }
>>> diamonds >> group_by(X.cut) >> summarize(**stats)
         cut      iqr    max      q75         mean  median      q25  min
0       Fair  3155.25  18574  5205.50  4358.757764  3282.0  2050.25  337
1       Good  3883.00  18788  5028.00  3928.864452  3050.5  1145.00  327
2      Ideal  3800.50  18806  4678.50  3457.541970  1810.0   878.00  326
3    Premium  5250.00  18823  6296.00  4584.257704  3185.0  1046.00  326
4  Very Good  4460.75  18818  5372.75  3981.759891  2648.0   912.00  336

各カラムに複数の集計する (summarize_each)

カラムと集計内容が複数あるときは summarize_each() 関数を使うと良い。

以下では、例として pricecarat に対して平均と標準偏差を計算している。

>>> diamonds >> summarize_each([np.mean, np.std], X.price, X.carat)
    price_mean    price_std  carat_mean  carat_std
0  3932.799722  3989.402758     0.79794   0.474007

もちろん、この処理も group_by と組み合わせることができる。

>>> diamonds >> group_by(X.cut) >> summarize_each([np.mean, np.std], X.price, X.carat)
         cut   price_mean    price_std  carat_mean  carat_std
0       Fair  4358.757764  3559.280730    1.046137   0.516244
1       Good  3928.864452  3681.214352    0.849185   0.454008
2      Ideal  3457.541970  3808.312813    0.702837   0.432866
3    Premium  4584.257704  4349.047276    0.891955   0.515243
4  Very Good  3981.759891  3935.699276    0.806381   0.459416

複数のデータフレームをカラム方向に結合する (join)

続いては複数のデータフレームを結合する処理について。

例に使うデータフレームを用意する。 微妙に行や列の内容がかぶっている。

>>> data = {
...     'name': ['alice', 'bob', 'carrol'],
...     'age': [20, 30, 40],
... }
>>> a = pd.DataFrame(data)
>>> 
>>> data = {
...     'name': ['alice', 'bob', 'daniel'],
...     'is_male': [False, True, True],
... }
>>> b = pd.DataFrame(data)

内部結合には inner_join() 関数を使う。

>>> a >> inner_join(b, by='name')
    name  age  is_male
0  alice   20    False
1    bob   30     True

外部結合には outer_join() を使う。

>>> a >> outer_join(b, by='name')
     name   age is_male
0   alice  20.0   False
1     bob  30.0    True
2  carrol  40.0     NaN
3  daniel   NaN    True

左外部結合には left_join() を使う。

>>> a >> left_join(b, by='name')
     name  age is_male
0   alice   20   False
1     bob   30    True
2  carrol   40     NaN

右外部結合には right_join() を使う。

>>> a >> right_join(b, by='name')
     name   age  is_male
0   alice  20.0    False
1     bob  30.0     True
2  daniel   NaN     True

複数のデータフレームを行方向に結合する (union / intersect / set_diff / bind_rows)

ここからは縦 (行) 方向の結合を扱う。 データフレームを追加しておく。

>>> data = {
...     'name': ['carrol', 'daniel'],
...     'age': [40, 50],
... }
>>> c = pd.DataFrame(data)

重複したものは除外して行方向にくっつけたいときは union() を使う。

>>> a >> union(c)
     name  age
0   alice   20
1     bob   30
2  carrol   40
1  daniel   50

両方のデータフレームにあるものだけくっつけたいなら intersect() を使う。

>>> a >> intersect(c)
     name  age
0  carrol   40

両方に存在しないものだけほしいときは set_diff() を使う。

>>> a >> set_diff(c)
    name  age
0  alice   20
1    bob   30

行と列を含む結合 (bind_rows)

行と列の両方を使って結合したいときは bind_rows() 関数を使う。 joininner を指定すると、両方にあるカラムだけを使って結合される。

>>> a >> bind_rows(b, join='inner')
     name
0   alice
1     bob
2  carrol
0   alice
1     bob
2  daniel

joinouter を指定したときは、存在しない行が NaN で埋められる。

>>> a >> bind_rows(b, join='outer')
    age is_male    name
0  20.0     NaN   alice
1  30.0     NaN     bob
2  40.0     NaN  carrol
0   NaN   False   alice
1   NaN    True     bob
2   NaN    True  daniel

dfply に対応した API を実装する

ここからは dfply に対応した API を実装する方法について書いていく。

pipe

最も基本となるのは @pipe デコレータで、これはデータフレームを受け取ってデータフレームを返す関数を定義する。 例えば、最も単純な処理として受け取ったデータフレームをそのまま返す関数を作ってみよう。

>>> @pipe
... def nop(df):
...     return df
... 

この関数も、ちゃんと dfply の API として機能する。

>>> diamonds >> nop() >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

次に、もう少し複雑な関数として、特定のカラムの値を 2 倍する関数を定義してみよう。 この中ではデータフレームのカラムの内容を上書きしている。

>>> @pipe
... def double(df, cols):
...     df[cols] = df[cols] * 2
...     return df
... 

使ってみると、ちゃんとカラムの値が 2 倍になっている。

>>> diamonds >> double(['carat', 'price']) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

カラムの内容を上書きしているということは、元のデータフレームの内容も書き換わっているのでは?と思うだろう。 しかし、確認すると元の値のままとなっている。

>>> diamonds >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

実は dfply では、右ビットシフト演算子が評価される度にデータフレームをディープコピーしている。 そのため、元のデータフレームが壊れることはない。

github.com

ただし、上記は大きなサイズのデータフレームを扱う上でパフォーマンス上の問題ともなる。 なぜなら、何らかの処理を評価するたびにメモリ上で大量のコピーが発生するため。 メモリのコピーは、大量のデータを処理する場合にスループットを高める上でボトルネックとなる。

Intention

ところで、先ほど定義した double() 関数は Intention を受け取ることができない。 試しに渡してみると、次のようなエラーになってしまう。

>>> diamonds >> double(X.carat, X.price) >> head()
Traceback (most recent call last):
...(snip)...
    return pipe(lambda x: self.function(x, *args, **kwargs))
TypeError: double() takes 2 positional arguments but 3 were given

配列として指定してもダメ。

>>> diamonds >> double(X.carat, X.price) >> head()
Traceback (most recent call last):
...(snip)...
    if len(arrays[i]) != len(arrays[i - 1]):
TypeError: __index__ returned non-int (type Intention)

上記がエラーになるのは、Intention を解決するのにデコレータの追加が必要なため。 具体的には symbolic_evaluation() を追加する。 こうすると、Intention が pandas.Series に解決した上で渡される。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_double(df, serieses):
...     for series in serieses:
...         df[series.name] = series * 2
...     return df
... 

上記を使ってみると、ちゃんと動作することがわかる。

>>> diamonds >> symbolic_double([X.carat, X.price]) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

この処理は、Intention を解決した上で Series として渡すだけなので、次のように任意の長さの引数として受け取ることもできる。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_double(df, *serieses):
...     for series in serieses:
...         df[series.name] = series * 2
...     return df
... 
>>> diamonds >> symbolic_double(X.carat, X.price) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

Intention 以外のオブジェクトを引数に受け取りたいときは、こんな感じ。

>>> @pipe
... @symbolic_evaluation()
... def symbolic_multiply(df, n, serieses):
...     for series in serieses:
...         df[series.name] = series * n
...     return df
... 
>>> diamonds >> symbolic_multiply(3, [X.carat, X.price]) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.69    Ideal     E     SI2   61.5   55.0    978  3.95  3.98  2.43
1   0.63  Premium     E     SI1   59.8   61.0    978  3.89  3.84  2.31
2   0.69     Good     E     VS1   56.9   65.0    981  4.05  4.07  2.31
3   0.87  Premium     I     VS2   62.4   58.0   1002  4.20  4.23  2.63
4   0.93     Good     J     SI2   63.3   58.0   1005  4.34  4.35  2.75

ちなみに引数の eval_as_selectorTrue を指定すると、渡されるのが numpy 配列になる。 この配列はカラム名と同じ長さで、どのカラムが Intention によって指定されたかがビットマスクとして得られる。

>>> @pipe
... @symbolic_evaluation(eval_as_selector=True)
... def symbolic_double(df, *selected_masks):
...     # もし列の指定が入れ子になってるとしたらフラットに直す
...     selectors = np.array(list(flatten(selected_masks)))
...     selected_cols = [col for col, selected
...                      in zip(df.columns, np.any(selectors, axis=0))
...                      if selected]
...     df[selected_cols] = df[selected_cols] * 2
...     return df
... 
>>> diamonds >> symbolic_double(X.carat, X.price) >> head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.46    Ideal     E     SI2   61.5   55.0    652  3.95  3.98  2.43
1   0.42  Premium     E     SI1   59.8   61.0    652  3.89  3.84  2.31
2   0.46     Good     E     VS1   56.9   65.0    654  4.05  4.07  2.31
3   0.58  Premium     I     VS2   62.4   58.0    668  4.20  4.23  2.63
4   0.62     Good     J     SI2   63.3   58.0    670  4.34  4.35  2.75

WINDOW 関数を定義する

ただ、あんまり複雑な処理を単発の @pipe 処理で作るよりは、もっと小さな処理を組み合わせていく方が関数型プログラミングっぽくてキレイだと思う。 そこで、次は WINDOW 関数の作り方を扱う。

WINDOW 関数を定義したいときは、@make_symbolic をつけて Series を受け取る関数を作る。 例えばカラムの内容を 2 倍にする関数を作ってみよう。

>>> @make_symbolic
... def double(series):
...     return series * 2
... 

使ってみると、たしかに 2 倍になる。

>>> diamonds >> mutate(double_price=double(X.price)) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  double_price
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43           652
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31           652
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31           654
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63           668
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75           670

こちらの @make_symbolic も、Intention を解決して Series をインジェクトする以上の意味はない。 なので、次のように任意の長さのリストとして受け取ることもできる。

>>> @make_symbolic
... def add(*serieses):
...     return sum(serieses)
... 

上記は複数のカラムの内容を足し合わせる処理になっている。

>>> diamonds >> mutate(add_column=add(X.carat, X.price)) >> head()
   carat      cut color clarity  depth  table  price     x     y     z  add_column
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43      326.23
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31      326.21
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31      327.23
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63      334.29
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75      335.31

summarize 相当の処理を定義する

summarize 相当の関数は @group_delegation デコレータを使って作れる。

例えば要素数をカウントする関数を定義してみよう。

>>> @pipe
... @group_delegation
... def mycount(df):
...     return len(df)
... 

そのまま適用すれば、全体の要素数が得られる。

>>> diamonds >> mycount()
53940

group_by() とチェインすれば、グループ化した中での要素数が計算できる。

>>> diamonds >> group_by(X.cut) >> mycount()
cut
Fair          1610
Good          4906
Ideal        21551
Premium      13791
Very Good    12082
dtype: int64

一通り適用した関数を作るとき

ちなみに、ショートカット的な記述方法もあって、次のように @dfpipe デコレータを使うと...

>>> @dfpipe
... def myfunc(df):
...     return len(df)
... 

以下の 3 つのデコレータを組み合わせたのと同義になる。 WINDOW 関数は別として、いつもはこれを使っておけばとりあえず良いかもしれない。

>>> @pipe
... @group_delegation
... @symbolic_evaluation
... def myfunc(df):
...     return len(df)
... 

パフォーマンスに問題は抱えているけど、API はすごく面白いね。