CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: pandas でグループごとにデータをサンプリングする

取り扱うデータをサンプリングする機会は意外と多い。 ユースケースとしては、例えばデータが多すぎて扱いにくい場合や、グループごとに件数の偏りのある場合が挙げられる。 今回は pandas を使ってグループごとに特定の件数をサンプリングする方法について書いてみる。

使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.3
BuildVersion:   18D109
$ python -V                 
Python 3.7.2
$ pip list | grep pandas    
pandas          0.24.2 

下準備

まずは pandas をインストールしておく。

$ pip install pandas

続いて Python のインタプリタを起動する。

$ python

そして、サンプルとなる DataFrame を用意する。 今回は野菜とくだものが入ったデータをサンプルとして使うことにした。

>>> import pandas as pd
>>> data = [
...   ('Apple', 'Fruit'),
...   ('Beetroot', 'Vegetable'),
...   ('Carrot', 'Vegetable'),
...   ('Date', 'Fruit'),
...   ('Eggplant', 'Vegetable'),
...   ('Fig', 'Fruit'),
... ]
>>> df = pd.DataFrame(data, columns=['name', 'category'])

グループを加味しないでサンプリングする

今回用意したデータは野菜とくだものが 3 つずつ入っている。

>>> df
       name   category
0     Apple      Fruit
1  Beetroot  Vegetable
2    Carrot  Vegetable
3      Date      Fruit
4  Eggplant  Vegetable
5       Fig      Fruit

これをそのままランダムサンプリングすると、野菜とくだものが均等に取り出される確率が最も高い。 とはいえ、それはあくまで期待値としての話なので実際にはそうならない場合も多い。 次のように、グループを加味しない状態で DataFrame からサンプリングすると偏ることもある。

>>> df.sample(n=4)
       name   category
0     Apple      Fruit
3      Date      Fruit
5       Fig      Fruit
1  Beetroot  Vegetable

こうなると都合が悪い。

グループを加味してサンプリングする

続いてはグループについて加味した上で特定の件数をサンプリングしてみる。

まずは DataFrame#groupby()category ごとにグルーピングする。

>>> gdf = df.groupby('category')

これで得られるのは DataFrameGroupBy というクラスのインスタンス。

>>> gdf
<pandas.core.groupby.generic.DataFrameGroupBy object at 0x117e6e1d0>
>>> gdf.size()
category
Fruit        3
Vegetable    3
dtype: int64

DataFrameGroupByapply() メソッドを使うと、グループ単位の DataFrame に対して処理を実行できる。 結果は、全ての DataFrame が一枚に結合された状態で得られる。 つまり、DataFrameGroupBy#apply() の中で DataFrame に対して sample() メソッドを呼び出せば良い。 例えば野菜とくだものを 2 件ずつ取り出したいときは次のようにする。

>>> gdf.apply(lambda x: x.sample(n=2))
                 name   category
category                        
Fruit     0     Apple      Fruit
          5       Fig      Fruit
Vegetable 1  Beetroot  Vegetable
          4  Eggplant  Vegetable

これで、野菜とくだものが必ず毎回 2 件ずつ取り出せるようになった。

なお、ランダムサンプリングの結果を安定させたいときは random_state オプションを指定すれば良い。

>>> gdf.apply(lambda x: x.sample(n=2, random_state=42))
                 name   category
category                        
Fruit     0     Apple      Fruit
          3      Date      Fruit
Vegetable 1  Beetroot  Vegetable
          2    Carrot  Vegetable

ちなみに、元のグループの比率を維持してサンプリングしたいときは、次のように DataFrame の長さを用いると良い。

>>> gdf.apply(lambda x: x.sample(n=round(len(x) * 0.5)))
                 name   category
category                        
Fruit     0     Apple      Fruit
          3      Date      Fruit
Vegetable 4  Eggplant  Vegetable
          2    Carrot  Vegetable

MultiIndex を解除する

ちなみに、このやり方で作った DataFrame はインデックスが MultiIndex になっている。

>>> sampled_df = gdf.apply(lambda x: x.sample(n=2))
>>> sampled_df.index
MultiIndex(levels=[['Fruit', 'Vegetable'], [0, 1, 4, 5]],
           codes=[[0, 0, 1, 1], [3, 0, 1, 2]],
           names=['category', None])

これはこれでグループごとに要素を取り出しやすくて良い。

>>> sampled_df.loc['Fruit']
    name category
5    Fig    Fruit
0  Apple    Fruit

もし、元の DataFrame のフォーマットと合わせたいときには DataFrame#reset_index() を使う。

>>> sampled_df.reset_index(level='category', drop=True)
       name   category
5       Fig      Fruit
0     Apple      Fruit
1  Beetroot  Vegetable
4  Eggplant  Vegetable

サンプリングしたい件数が実在する件数よりも多いとき

先ほどのデータはグループごとに件数が均衡だった。 しかし、現実世界のデータではこのようにキレイな分布をするものはあまりない。 次は分布に偏りがあるときに生じる悩みについて考えてみる。

次のデータは、先ほどのデータからくだものを 2 つほど取り除いたものを扱う。 これで、データに含まれるくだものは Fig (いちじく) だけになった。

>>> lack_df = df[~df.name.isin(['Apple', 'Date'])]
>>> lack_df
       name   category
1  Beetroot  Vegetable
2    Carrot  Vegetable
4  Eggplant  Vegetable
5       Fig      Fruit

先ほどと同じように category の列でグルーピングする。

>>> lack_gdf = lack_df.groupby('category')
>>> lack_gdf.size()
category
Fruit        1
Vegetable    3
dtype: int64

この状態で、野菜とくだものを 2 件ずつ取り出したいという場合について考えてみる。 しかし、データの中に存在するくだものは 1 件しかないので 2 件というサンプル数に満たない。 実際に実行してみると、次のようなエラーになってしまう。

>>> lack_gdf.apply(lambda x: x.sample(n=2))
Traceback (most recent call last):
...
ValueError: Cannot take a larger sample than population when 'replace=False'

こうしたときに、データがサンプル数に満たないものは全てそのまま取り出したいのであれば、次のようにする。 具体的には、組み込み関数 min() を使って、DataFrame の長さとサンプル数の少ない方に合わせてしまう。

>>> lack_gdf.apply(lambda x: x.sample(n=min(2, len(x))))
                 name
category             
Fruit     5       Fig
Vegetable 1  Beetroot
          2    Carrot

もし、重複 (繰り返し) を許してでもサンプル数にあわせてほしいときは、次のように replace オプションに True を指定する。 こういった、本来ある件数よりも多くサンプリングする処理はオーバーサンプリングと呼ばれる。

>>> lack_gdf.apply(lambda x: x.sample(n=2, replace=True))
                 name
category             
Fruit     5       Fig
          5       Fig
Vegetable 1  Beetroot
          4  Eggplant

いじょう。