CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: pandas のデータ型をキャストしてメモリを節約してみる

pandas の DataFrame は明示的にデータ型を指定しないと整数型や浮動小数点型のカラムを 64 ビットで表現する。 pandas の DataFrame は、表現に使うビット数が大きいと、メモリ上のオブジェクトのサイズも当然ながら大きくなる。 そこで、今回は DataFrame の各カラムに含まれる値を調べながら、より小さなビット数の表現にキャストすることでメモリの使用量を節約してみる。 なお、ネットを調べると既に同じような実装が見つかったけど、自分でスクラッチしてみた。

今回使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G95
$ python -V          
Python 3.7.4
$ pip list | grep -i pandas               
pandas          0.25.1 

下準備

まずは必要となるパッケージをインストールしておく。

$ pip install pandas tqdm seaborn

pandas のデータ型をキャストしてメモリを節約する

以下が DataFrame の各カラムのデータ型をキャストしてメモリを節約するサンプルコード。 例として 64 ビットで表現されているものの、実はもっと小さなビット数で表現できるデータを適用している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import logging

import numpy as np
import pandas as pd
from tqdm import tqdm


LOGGER = logging.getLogger(__name__)


def _fall_within_range(dtype_min_value, dtype_max_value, min_value, max_value):
    """データ型の表現できる範囲に収まっているか調べる関数"""
    if min_value < dtype_min_value:
        # 下限が越えている
        return False

    if max_value > dtype_max_value:
        # 上限が越えている
        return False

    # 範囲内に収まっている
    return True


def _cast(df, col_name, cast_candidates):
    # カラムに含まれる最小値と最大値を取り出す
    min_value, max_value = df[col_name].min(), df[col_name].max()

    for cast_type, (dtype_min_value, dtype_max_value) in cast_candidates.items():
        if df[col_name].dtype == cast_type:
            # 同じ型まで到達した時点で、キャストする意味はなくなる
            return

        if _fall_within_range(dtype_min_value, dtype_max_value, min_value, max_value):
            # キャストしたことをログに残す
            LOGGER.info(f'column {col_name} casted: {df[col_name].dtype.type} to {cast_type}')
            # 最も小さなビット数で表現できる型にキャストできたので終了
            df[col_name] = df[col_name].astype(cast_type)
            return


def _cast_func(df, col_name):
    col_type = df[col_name].dtype.type

    if issubclass(col_type, np.integer):
        # 整数型
        cast_candidates = {
            cast_type: (np.iinfo(cast_type).min, np.iinfo(cast_type).max)
            for cast_type in [np.int8, np.int16, np.int32]
        }
        return partial(_cast, cast_candidates=cast_candidates)

    if issubclass(col_type, np.floating):
        # 浮動小数点型
        cast_candidates = {
            cast_type: (np.finfo(cast_type).min, np.finfo(cast_type).max)
            for cast_type in [np.float16, np.float32]
        }
        return partial(_cast, cast_candidates=cast_candidates)

    # その他は未対応
    return None


def _memory_usage(df):
    """データフレームのサイズと接頭辞を返す"""
    units = ['B', 'kB', 'MB', 'GB']
    usage = float(df.memory_usage().sum())

    for unit in units:
        if usage < 1024:
            return usage, unit
        usage /= 1024

    return usage, unit


def shrink(df):
    # 元のサイズをログに記録しておく
    usage, unit = _memory_usage(df)
    LOGGER.info(f'original dataframe size: {usage:.0f}{unit}')

    for col_name in tqdm(df.columns):
        # 各カラムごとにより小さなビット数で表現できるか調べていく
        func = _cast_func(df, col_name)
        if func is None:
            continue
        func(df, col_name)

    # 事後のサイズをログに記録する
    usage, unit = _memory_usage(df)
    LOGGER.info(f'shrinked dataframe size: {usage:.0f}{unit}')


def main():
    log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(format=log_fmt,
                        level=logging.DEBUG)

    data = [
        (2147483648, 32768, 129, 0, 2.0e+308, 65510.0, 0.0, 'foo'),
        (2147483649, 32769, 130, 1, 2.1e+308, 65520.0, 0.1, 'bar'),
    ]

    df = pd.DataFrame(data)
    print(df.dtypes)

    shrink(df)
    print(df.dtypes)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python shrink.py
0      int64
1      int64
2      int64
3      int64
4    float64
5    float64
6    float64
7     object
dtype: object
2019-09-05 22:32:45,726 - __main__ - INFO - original dataframe size: 256B
  0%|                                                     | 0/8 [00:00<?, ?it/s]2019-09-05 22:32:45,742 - __main__ - INFO - column 1 casted: <class 'numpy.int64'> to <class 'numpy.int32'>
2019-09-05 22:32:45,743 - __main__ - INFO - column 2 casted: <class 'numpy.int64'> to <class 'numpy.int16'>
2019-09-05 22:32:45,744 - __main__ - INFO - column 3 casted: <class 'numpy.int64'> to <class 'numpy.int8'>
2019-09-05 22:32:45,746 - __main__ - INFO - column 5 casted: <class 'numpy.float64'> to <class 'numpy.float32'>
2019-09-05 22:32:45,748 - __main__ - INFO - column 6 casted: <class 'numpy.float64'> to <class 'numpy.float16'>
100%|███████████████████████████████████████████| 8/8 [00:00<00:00, 1066.51it/s]
2019-09-05 22:32:45,750 - __main__ - INFO - shrinked dataframe size: 202B
0      int64
1      int32
2      int16
3       int8
4    float64
5    float32
6    float16
7     object
dtype: object

上記の実行結果を確認すると、元は 256B だったデータサイズが 202B まで小さくなっている。 また、各カラムのデータ型も表現できる最小のビット数まで小さくなっていることもわかる。

もうちょっとちゃんとしたデータに対しても適用してみることにしよう。 seaborn からロードできる diamonds データセットを使ってみることにした。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import logging

import seaborn
import numpy as np
import pandas as pd
from tqdm import tqdm


LOGGER = logging.getLogger(__name__)


def _fall_within_range(dtype_min_value, dtype_max_value, min_value, max_value):
    """データ型の表現できる範囲に収まっているか調べる関数"""
    if min_value < dtype_min_value:
        # 下限が越えている
        return False

    if max_value > dtype_max_value:
        # 上限が越えている
        return False

    # 範囲内に収まっている
    return True


def _cast(df, col_name, cast_candidates):
    # カラムに含まれる最小値と最大値を取り出す
    min_value, max_value = df[col_name].min(), df[col_name].max()

    for cast_type, (dtype_min_value, dtype_max_value) in cast_candidates.items():
        if df[col_name].dtype == cast_type:
            # 同じ型まで到達した時点で、キャストする意味はなくなる
            return

        if _fall_within_range(dtype_min_value, dtype_max_value, min_value, max_value):
            # キャストしたことをログに残す
            LOGGER.info(f'column {col_name} casted: {df[col_name].dtype.type} to {cast_type}')
            # 最も小さなビット数で表現できる型にキャストできたので終了
            df[col_name] = df[col_name].astype(cast_type)
            return


def _cast_func(df, col_name):
    col_type = df[col_name].dtype.type

    if issubclass(col_type, np.integer):
        # 整数型
        cast_candidates = {
            cast_type: (np.iinfo(cast_type).min, np.iinfo(cast_type).max)
            for cast_type in [np.int8, np.int16, np.int32]
        }
        return partial(_cast, cast_candidates=cast_candidates)

    if issubclass(col_type, np.floating):
        # 浮動小数点型
        cast_candidates = {
            cast_type: (np.finfo(cast_type).min, np.finfo(cast_type).max)
            for cast_type in [np.float16, np.float32]
        }
        return partial(_cast, cast_candidates=cast_candidates)

    # その他は未対応
    return None


def _memory_usage(df):
    """データフレームのサイズと接頭辞を返す"""
    units = ['B', 'kB', 'MB', 'GB']
    usage = float(df.memory_usage().sum())

    for unit in units:
        if usage < 1024:
            return usage, unit
        usage /= 1024

    return usage, unit


def shrink(df):
    # 元のサイズをログに記録しておく
    usage, unit = _memory_usage(df)
    LOGGER.info(f'original dataframe size: {usage:.0f}{unit}')

    for col_name in tqdm(df.columns):
        # 各カラムごとにより小さなビット数で表現できるか調べていく
        func = _cast_func(df, col_name)
        if func is None:
            continue
        func(df, col_name)

    # 事後のサイズをログに記録する
    usage, unit = _memory_usage(df)
    LOGGER.info(f'shrinked dataframe size: {usage:.0f}{unit}')


def main():
    log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(format=log_fmt,
                        level=logging.DEBUG)

    df = seaborn.load_dataset('diamonds')

    print(df.dtypes)
    shrink(df)
    print(df.dtypes)


if __name__ == '__main__':
    main()

上記を実行した結果が次の通り。 元の DataFrame が 4MB だったのに対し、キャストした後は 2MB と半減していることがわかる。

$ python shrink.py
carat      float64
cut         object
color       object
clarity     object
depth      float64
table      float64
price        int64
x          float64
y          float64
z          float64
dtype: object
2019-09-05 22:38:13,848 - __main__ - INFO - original dataframe size: 4MB
  0%|          | 0/10 [00:00<?, ?it/s]2019-09-05 22:38:13,854 - __main__ - INFO - column carat casted: <class 'numpy.float64'> to <class 'numpy.float16'>
2019-09-05 22:38:13,858 - __main__ - INFO - column depth casted: <class 'numpy.float64'> to <class 'numpy.float16'>
2019-09-05 22:38:13,860 - __main__ - INFO - column table casted: <class 'numpy.float64'> to <class 'numpy.float16'>
2019-09-05 22:38:13,863 - __main__ - INFO - column price casted: <class 'numpy.int64'> to <class 'numpy.int16'>
2019-09-05 22:38:13,865 - __main__ - INFO - column x casted: <class 'numpy.float64'> to <class 'numpy.float16'>
2019-09-05 22:38:13,868 - __main__ - INFO - column y casted: <class 'numpy.float64'> to <class 'numpy.float16'>
2019-09-05 22:38:13,870 - __main__ - INFO - column z casted: <class 'numpy.float64'> to <class 'numpy.float16'>
100%|██████████| 10/10 [00:00<00:00, 538.64it/s]
2019-09-05 22:38:13,873 - __main__ - INFO - shrinked dataframe size: 2MB
carat      float16
cut         object
color       object
clarity     object
depth      float16
table      float16
price        int16
x          float16
y          float16
z          float16
dtype: object

大きなデータをオンメモリで扱うときは使える場面があるかもしれない。