CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LightGBM の cv() 関数の実装について

今回は LightGBM の cv() 関数について書いてみる。 LightGBM の cv() 関数は、一般的にはモデルの性能を評価する交差検証に使われる。 一方で、この関数から取り出した学習済みモデルを推論にまで使うユーザもいる。 今回は、その理由やメリットとデメリットについて書いてみる。

cv() 関数から取り出した学習済みモデルを使う理由とメリット・デメリットについて

一部のユーザの間では有名だけど、LightGBM の cv() 関数は各 Fold の決定木の増やし方に特色がある。 まず、LightGBM では決定木の集まりを Booster というオブジェクトで管理している。 Booster が内包する決定木の本数は、ラウンド (イテレーション) 数として認識できる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1930

ちなみに、train() 関数を使って得られるのは、この Booster というオブジェクト。 一般的に、train() 関数を使って自前で交差検証をするときは、この Booster を Fold 毎にひとつずつ学習させることになる。

一方で、cv() 関数では全 Fold を並列で、複数の Booster を一度に学習させる。 具体的には、すべての Fold で歩調を合わせながら、それぞれの Booster のラウンド数をひとつずつ増やしている。 このとき、検証用のデータに対するメトリックも、ラウンド (イテレーション) 毎に「全 Fold の平均」で計算される。 つまり、全 Fold の平均的なメトリックが悪化するタイミングで Early Stopping がかかる。

言いかえると、cv() 関数から得られる学習済みモデルは Booster が内包する決定木の本数がすべて同じに揃う。 それに対して、train() 関数を使ってひとつずつ Booster を学習する方法では、ラウンド数が Fold によってバラつくことになる。 バラつきが小さいときは良いけど、ときには大きく偏ることもあって、その際は性能の見積もりや推論に悪影響があると考えられる。 この点から、cv() 関数では Fold 毎の偏りを考慮した、ようするに無難なモデルを得られることが期待できる。 なお、各 Fold から複数の Booster が得られるので、推論するときは Averaging などで対応する。

また、ターゲットの情報を使った特徴量抽出やスタッキングをするときも、この点は都合が良い。 これらのユースケースでは、一般的にはリークを防ぐために Out-of-Fold で処理することになる。 となると、データの全体を使って学習することが難しいので、各 Fold ごとに学習したモデルを使えると使い勝手が良い。

と、ここまでメリットばかり説明してきたけど、もちろんデメリットもある。 前述したとおり、cv() 関数では各 Fold の Booster を同時に並列で学習させていく。 そのため、学習に使うデータやモデルを一度にメモリに載せることになる。 つまり、train() 関数を使って Booster をひとつずつ学習するときよりも、相対的にメモリの制約は厳しくなると考えられる。 また、他の Fold を使って補える部分もあるとはいえ Out-of-Fold したデータは学習に使えない点もデメリットとして挙げられる。

cv() 関数の実装について

ここからは LightGBM のコードを軽く追いかけてみよう。

はじめに、LightGBM のコアといえる部分は C++ で書かれている。 Python では、それを ctypes モジュールを使った Binding として呼び出している。

自身の Python 実行環境で LightGBM のインストール先パスがわかっているときは LightGBM の共有ライブラリを探してみると良い。 上記でいうコアは Python 実行環境の中に「lib_lightgbm.so」として存在している。

$ python -c "import site; print (site.getsitepackages())"
['/Users/amedama/.virtualenvs/py38/lib/python3.8/site-packages']
$ file  ~/.virtualenvs/py38/lib/python3.8/site-packages/lightgbm/lib_lightgbm.so
/Users/amedama/.virtualenvs/py38/lib/python3.8/site-packages/lightgbm/lib_lightgbm.so: Mach-O 64-bit dynamically linked shared library x86_64

上記の共有ライブラリは Python Binding の Booster クラスから呼ばれている。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1930

ctypes モジュールで読み込んだライブラリを _LIB として呼び出している部分がそれ。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/basic.py#L1988,L1991

そして、train() 関数や cv() 関数は、上記の Booster を学習させるためのラッパーになっている。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L18

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L394

cv() 関数に着目して読んでいくと、以下で全 Fold の Booster を同時に更新していることがわかる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L592

また、Early Stopping は全 Fold の平均的なメトリックを元に発火することが確認できる。

https://github.com/microsoft/LightGBM/blob/v3.0.0rc1/python-package/lightgbm/engine.py#L593,L609

いじょう。

Python: Null Importance を使った特徴量選択について

今回は特徴量選択 (Feature Selection) の手法のひとつとして使われることのある Null Importance を試してみる。 Null Importance というのは、目的変数をシャッフルして意味がなくなった状態で学習させたモデルから得られる特徴量の重要度を指す。 では、それを使ってどのように特徴量選択をするかというと、シャッフルしなかったときの重要度との比率をスコアとして計算する。 もし、シャッフルしたときの重要度が元となった重要度よりも小さくなっていれば、スコアは大きくなって特徴量に意味があるとみなせる。 一方で、シャッフルしたときの重要度が元とさほど変わらなければ、スコアは小さくなってその特徴量は単なるノイズに近い存在と判断できる。 あとはスコアに一定の閾値を設けたり、上位 N 件を取り出すことで特徴量選択ができるようだ。

今回使った環境は次のとおり。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.15.6
BuildVersion:   19G73
$ python -V                             
Python 3.8.3

下準備

下準備として、あらかじめ今回使うパッケージをインストールしておく。

$ pip install scikit-learn lightgbm matplotlib pandas tqdm

題材とするデータについて

題材としては scikit-learn が生成するダミーデータを用いる。 sklearn.datasets.make_classification() 関数を使うと、データの次元数や推論の難易度などを調整してダミーデータが作れる。

以下のサンプルコードでは二値分類のダミーデータを生成している。 このダミーデータは全体で 100 次元あるものの、先頭の 5 次元だけが推論する上で意味のあるデータとなっている。 試しにダミーデータを RandomForest で分類して、モデルに組み込みの特徴量の可視化してみよう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
from matplotlib import pyplot as plt


def main():
    # 疑似的な教師信号を作るためのパラメータ
    args = {
        # データ点数
        'n_samples': 1_000,
        # 次元数
        'n_features': 100,
        # その中で意味のあるもの
        'n_informative': 5,
        # 重複や繰り返しはなし
        'n_redundant': 0,
        'n_repeated': 0,
        # タスクの難易度
        'class_sep': 0.65,
        # 二値分類問題
        'n_classes': 2,
        # 生成に用いる乱数
        'random_state': 42,
        # 特徴の順序をシャッフルしない (先頭の次元が informative になる)
        'shuffle': False,
    }
    # ノイズを含んだ教師データを作る
    X, y = make_classification(**args)

    # 分類器にランダムフォレストを使う
    clf = RandomForestClassifier(n_estimators=100,
                                 random_state=42)

    # Stratified 5-Fold CV + ROC AUC でモデルを検証する
    folds = StratifiedKFold(n_splits=5,
                            shuffle=True,
                            random_state=42,
                            )
    # すべての次元を使った場合
    cv_all_result = cross_validate(clf, X, y,
                                  cv=folds,
                                  return_estimator=True,
                                  scoring='roc_auc',
                                  )
    # 意味のある特徴量だけを使った場合
    cv_ideal_result = cross_validate(clf,
                                     # 使う次元を先頭だけに絞り込む
                                     X[:, :args['n_informative']],
                                     y,
                                     return_estimator=False,
                                     scoring='roc_auc',
                                     )

    # それぞれの状況でのメトリックのスコア
    print('All used AUC:', np.mean(cv_all_result['test_score']))
    print('Ideal AUC:', np.mean(cv_ideal_result['test_score']))

    # 学習済みモデルを取り出す
    clfs = cv_all_result['estimator']

    # 特徴量の重要度を取り出す
    importances = [clf.feature_importances_ for clf in clfs]
    mean_importances = np.mean(importances, axis=0)
    std_importances = np.std(importances, axis=0)
    sorted_indices = np.argsort(mean_importances)[::-1]

    # 重要度の高い特徴量を表示する
    MAX_TOP_N = 10
    rank_n = min(X.shape[1], MAX_TOP_N)
    print('Feature importance ranking (TOP {rank_n})'.format(rank_n=rank_n))
    for rank, idx in enumerate(sorted_indices[:rank_n], start=1):
        params = {
            'rank': rank,
            'idx': idx,
            'importance': mean_importances[idx]
        }
        print('{rank}. feature {idx:02d}: {importance}'.format(**params))

    # 特徴量の重要度を可視化する
    plt.figure(figsize=(6, 8))
    plt.barh(range(rank_n),
             mean_importances[sorted_indices][:rank_n][::-1],
             color='g',
             ecolor='r',
             yerr=std_importances[sorted_indices][:rank_n][::-1],
             align='center')
    plt.yticks(range(rank_n), sorted_indices[:rank_n][::-1])
    plt.ylabel('Features')
    plt.xlabel('Importance')
    plt.grid()
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。 はじめに、すべての次元を使って学習した場合と、実際に意味のある次元 (先頭 5 次元) だけを使った場合の ROC AUC を出力している。 そして、それに続いて値の大きさで降順ソートした特徴量の重要度が表示される。

$ python rfimp.py
All used AUC: 0.8748692786445511
Ideal AUC: 0.947211653938503
Feature importance ranking (TOP 10)
1. feature 00: 0.0843898282531266
2. feature 03: 0.05920082172020726
3. feature 01: 0.05713024480668112
4. feature 04: 0.05027373664878948
5. feature 02: 0.017385737849956347
6. feature 55: 0.010379957110544775
7. feature 52: 0.009663802073627043
8. feature 87: 0.00907575944742235
9. feature 34: 0.008999985108968961
10. feature 82: 0.008999816679827738

先頭の 5 次元だけを使った場合の方が ROC AUC のスコアが高くなっていることがわかる。 また、RandomForest のモデルに組み込まれている特徴量の重要度を見ても先頭の 5 次元が上位にきていることが確認できる。 上記の重要度を棒グラフにプロットしたものは以下のようになる。

f:id:momijiame:20200805181014p:plain
RandomForest から得られる特徴量の重要度

この特徴量の重要度をそのまま使って特徴量選択をすることもできる。 たとえば上位 N 件を取り出した場合で交差検証のスコアを比較することが考えられる。 しかし、スコアの増加が小さな特徴量は単なるノイズかどうかの判断が難しい。 たとえば上記であれば「2」の特徴量は事前知識がなければノイズかどうか怪しいところ。

特徴量の重要度を Null Importance と比較する

そこで、特徴量の重要度をより細かく検証するために Null Importance と比較する。 前述したとおり Null Importance は目的変数をシャッフルして学習させたモデルから計算した特徴量の重要度となる。 Null Importance を基準として、元の特徴量の重要度がどれくらい大きいか比べることでノイズかどうかの判断がしやすくなる。

以下のサンプルコードでは Null Importance と元の特徴量の重要度を上位 10 件についてヒストグラムとしてプロットしている。 Null Importance は何回も計算するのが比較的容易なので、このようにヒストグラムで比較できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import sys
from itertools import chain

from tqdm import tqdm
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate


LOGGER = logging.getLogger(__name__)


def feature_importance(X, y):
    """特徴量の重要度を交差検証で計算する"""
    clf = RandomForestClassifier(n_estimators=100,
                                 random_state=42)

    folds = StratifiedKFold(n_splits=5,
                            shuffle=True,
                            random_state=42,
                            )

    cv_result = cross_validate(clf, X, y,
                               cv=folds,
                               return_estimator=True,
                               scoring='roc_auc',
                               n_jobs=-1,
                               )
    clfs = cv_result['estimator']

    importances = [clf.feature_importances_ for clf in clfs]
    mean_importances = np.mean(importances, axis=0)

    return mean_importances


def main():
    # 結構時間がかかるのでログを出す
    logging.basicConfig(level=logging.INFO, stream=sys.stderr)

    args = {
        'n_samples': 1_000,
        'n_features': 100,
        'n_informative': 5,
        'n_redundant': 0,
        'n_repeated': 0,
        'class_sep': 0.65,
        'n_classes': 2,
        'random_state': 42,
        'shuffle': False,
    }
    X, y = make_classification(**args)

    # ベースとなる特徴量の重要度を計算する
    LOGGER.info('Starting base importance calculation')
    base_importance = feature_importance(X, y)

    # データのシャッフルに再現性をもたせる
    np.random.seed(42)

    # Null Importance を何度か計算する
    LOGGER.info('Starting null importance calculation')
    TRIALS_N = 20
    null_importances = []
    for _ in tqdm(range(TRIALS_N)):
        # 目的変数をシャッフルする
        y_permuted = np.random.permutation(y)
        # シャッフルした状態で特徴量の重要度を計算する
        null_importance = feature_importance(X, y_permuted)
        null_importances.append(null_importance)
    null_importances = np.array(null_importances)
    # 列と行を入れ替える
    transposed_null_importances = null_importances.transpose(1, 0)

    sorted_indices = np.argsort(base_importance)[::-1]
    sorted_base_importance = base_importance[sorted_indices]
    sorted_null_importance = transposed_null_importances[sorted_indices]

    # ベースとなる特徴量の重要度の上位と Null Importance を可視化する
    LOGGER.info('Starting visualization')
    HIST_ROWS = 3
    HIST_COLS = 3
    fig, axes = plt.subplots(HIST_ROWS, HIST_COLS,
                             figsize=(8, 18))
    for index, ax in enumerate(chain.from_iterable(axes)):
        # Null Importance をヒストグラムにする
        col_null_importance = sorted_null_importance[index]
        ax.hist(col_null_importance, label='Null Importance', color='b')
        # ベースとなる特徴量の重要度の場所に縦線を引く
        col_base_importance = sorted_base_importance[index]
        ax.axvline(col_base_importance, label='Base Importance', color='r')
        # グラフの体裁
        ax.set_xlabel('Importance')
        ax.set_ylabel('Frequency')
        ax.set_title(f'Feature: {sorted_indices[index]}')
        ax.legend()
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python nullimphist.py

すると、次のようなヒストグラムが得られる。

f:id:momijiame:20200805185435p:plain
Null Importance のヒストグラムとの比較

先頭の 5 次元については Null Importance の分布から大きく離れている。 一方で、それ以降の特徴量の重要度は Null Importance の分布の中に入ってしまっていることがわかる。 ここから、先頭の 5 次元以降はノイズに近い特徴量であると判断する材料になる。

Null Importance を使って特徴量を選択してみる

それでは、実際に特徴量の選択までやってみよう。 以下のサンプルコードでは、Null Importance と元の特徴量の重要度の比率からスコアを計算している。 そして、スコアの大きさを元に上位 N% の特徴量を取り出して交差検証のスコアを確認している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import sys

from tqdm import tqdm
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate


LOGGER = logging.getLogger(__name__)


def _cross_validate(X, y):
    clf = RandomForestClassifier(n_estimators=100,
                                 random_state=42)

    folds = StratifiedKFold(n_splits=5,
                            shuffle=True,
                            random_state=42,
                            )

    cv_result = cross_validate(clf, X, y,
                               cv=folds,
                               return_estimator=True,
                               scoring='roc_auc',
                               n_jobs=-1,
                               )
    return cv_result


def cv_mean_feature_importance(X, y):
    """特徴量の重要度を交差検証で計算する"""
    cv_result = _cross_validate(X, y)
    clfs = cv_result['estimator']
    importances = [clf.feature_importances_ for clf in clfs]
    mean_importances = np.mean(importances, axis=0)

    return mean_importances


def cv_mean_test_score(X, y):
    """OOF Prediction のメトリックを交差検証で計算する"""
    cv_result = _cross_validate(X, y)
    mean_test_score = np.mean(cv_result['test_score'])
    return mean_test_score


def main():
    logging.basicConfig(level=logging.INFO, stream=sys.stderr)

    n_cols = 100
    args = {
        'n_samples': 1_000,
        'n_features': n_cols,
        'n_informative': 5,
        'n_redundant': 0,
        'n_repeated': 0,
        'class_sep': 0.65,
        'n_classes': 2,
        'random_state': 42,
        'shuffle': False,
    }
    X, y = make_classification(**args)
    columns = np.arange(n_cols)

    LOGGER.info('Starting base importance calculation')
    base_importance = cv_mean_feature_importance(X, y)

    np.random.seed(42)

    LOGGER.info('Starting null importance calculation')
    TRIALS_N = 20
    null_importances = []
    for _ in tqdm(range(TRIALS_N)):
        y_permuted = np.random.permutation(y)
        null_importance = cv_mean_feature_importance(X, y_permuted)
        null_importances.append(null_importance)
    null_importances = np.array(null_importances)

    # Null Importance を用いた特徴量選択の一例
    # Base Importance と Null Importance の値の比をスコアとして計算する
    criterion_percentile = 50  # 基準となるパーセンタイル (ここでは中央値)
    percentile_null_imp = np.percentile(null_importances,
                                        criterion_percentile,
                                        axis=0)
    null_imp_score = base_importance / (percentile_null_imp + 1e-6)
    # スコアの大きさで降順ソートする
    sorted_indices = np.argsort(null_imp_score)[::-1]

    # 上位 N% の特徴量を使って性能を比較してみる
    use_feature_importance_top_percentages = [100, 75, 50, 25, 10, 5, 1]

    mean_test_scores = []
    selected_cols_len = []
    for percentage in use_feature_importance_top_percentages:
        # スコアが上位のカラムを取り出す
        sorted_columns = columns[sorted_indices]
        num_of_features = int(n_cols * percentage / 100)
        selected_cols = sorted_columns[:num_of_features]
        X_selected = X[:, selected_cols]
        LOGGER.info(f'Null Importance score TOP {percentage}%')
        LOGGER.info(f'selected features: {selected_cols}')
        LOGGER.info(f'selected feature len: {len(selected_cols)}')
        selected_cols_len.append(len(selected_cols))

        mean_test_score = cv_mean_test_score(X_selected, y)
        LOGGER.info(f'mean test score: {mean_test_score}')
        mean_test_scores.append(mean_test_score)

    # 結果を可視化する
    _, ax1 = plt.subplots(figsize=(8, 4))
    ax1.plot(mean_test_scores, color='b', label='mean test score')
    ax1.set_xlabel('percentile')
    ax1.set_ylabel('mean test score')
    ax1.legend()
    ax2 = ax1.twinx()
    ax2.plot(selected_cols_len, color='r', label='selected features len')
    ax2.set_ylabel('selected features len')
    ax2.legend()
    plt.xticks(range(len(use_feature_importance_top_percentages)),
               use_feature_importance_top_percentages)
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python nullimpselect.py
INFO:__main__:Starting base importance calculation
INFO:__main__:Starting null importance calculation
100%|███████████████████████████████████████████| 20/20 [00:56<00:00,  2.82s/it]
INFO:__main__:Null Importance score TOP 100%
INFO:__main__:selected features: [ 0  3  1  4  2 55 13 52 34 87 67 18 29 56 82 94 23 75 58 89 54 47 97 61
 85 93 64 44 43 45 63  7 37 38 27 41 26 12 72 10 51 49 96 50 22 53 81 66
 78 69 60 79 16 90 19 80 24 17 36 71 15  8 39 20 98 40 11 86 84  6 74 73
  9 65 31 35 92 91 68 48 25 99 88 76 42 28  5 70 30 95 33 83 14 59 57 32
 21 46 77 62]
INFO:__main__:selected feature len: 100
INFO:__main__:mean test score: 0.8658391949639143
INFO:__main__:Null Importance score TOP 75%
INFO:__main__:selected features: [ 0  3  1  4  2 55 13 52 34 87 67 18 29 56 82 94 23 75 58 89 54 47 97 61
 85 93 64 44 43 45 63  7 37 38 27 41 26 12 72 10 51 49 96 50 22 53 81 66
 78 69 60 79 16 90 19 80 24 17 36 71 15  8 39 20 98 40 11 86 84  6 74 73
  9 65 31]
INFO:__main__:selected feature len: 75
INFO:__main__:mean test score: 0.8774203740902301
INFO:__main__:Null Importance score TOP 50%
INFO:__main__:selected features: [ 0  3  1  4  2 55 13 52 34 87 67 18 29 56 82 94 23 75 58 89 54 47 97 61
 85 93 64 44 43 45 63  7 37 38 27 41 26 12 72 10 51 49 96 50 22 53 81 66
 78 69]
INFO:__main__:selected feature len: 50
INFO:__main__:mean test score: 0.909876116663287
INFO:__main__:Null Importance score TOP 25%
INFO:__main__:selected features: [ 0  3  1  4  2 55 13 52 34 87 67 18 29 56 82 94 23 75 58 89 54 47 97 61
 85]
INFO:__main__:selected feature len: 25
INFO:__main__:mean test score: 0.932729764573096
INFO:__main__:Null Importance score TOP 10%
INFO:__main__:selected features: [ 0  3  1  4  2 55 13 52 34 87]
INFO:__main__:selected feature len: 10
INFO:__main__:mean test score: 0.9513031915436441
INFO:__main__:Null Importance score TOP 5%
INFO:__main__:selected features: [0 3 1 4 2]
INFO:__main__:selected feature len: 5
INFO:__main__:mean test score: 0.9629145987828075
INFO:__main__:Null Importance score TOP 1%
INFO:__main__:selected features: [0]
INFO:__main__:selected feature len: 1
INFO:__main__:mean test score: 0.6537074505769904

すると、以下のようなグラフが得られる。

f:id:momijiame:20200805185850p:plain
Null Importance を用いて特徴量選択の可視化

今回は先頭 5 次元が重要とわかっている。 上記のグラフをみても上位 5% (= 5 次元) が選択されるまではスコアが順調に伸びている様子が確認できる。 しかし上位 1% (= 1 次元) までいくと推論に使える特徴量まで削ってしまいスコアが落ちていることがわかる。 ちなみに、実際には上記のように比較するのではなくスコアに一定の閾値を設けたり上位 N 件を決め打ちで選択してしまう場合も多いようだ。

LightGBM + Pandas を使った場合のサンプル

続いては、よくある例としてモデルに LightGBM を使って、データが Pandas のデータフレームの場合も確認しておこう。 サンプルコードは次のとおり。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import logging
import sys

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedKFold
from matplotlib import pyplot as plt


LOGGER = logging.getLogger(__name__)


class ModelExtractionCallback(object):
    """see: https://blog.amedama.jp/entry/lightgbm-cv-model"""

    def __init__(self):
        self._model = None

    def __call__(self, env):
        self._model = env.model

    def _assert_called_cb(self):
        if self._model is None:
            raise RuntimeError('callback has not called yet')

    @property
    def cvbooster(self):
        self._assert_called_cb()
        return self._model


def _cross_validate(train_x, train_y, folds):
    """LightGBM で交差検証する関数"""
    lgb_train = lgb.Dataset(train_x, train_y)

    model_extraction_cb = ModelExtractionCallback()
    callbacks = [
        model_extraction_cb,
    ]

    lgbm_params = {
        'objective': 'binary',
        'metric': 'auc',
        'first_metric_only': True,
        'verbose': -1,
    }
    lgb.cv(lgbm_params,
           lgb_train,
           folds=folds,
           num_boost_round=1_000,
           early_stopping_rounds=100,
           callbacks=callbacks,
           verbose_eval=20,
           )
    return model_extraction_cb.cvbooster


def _predict_oof(cv_booster, train_x, train_y, folds):
    """学習済みモデルから Out-of-Fold Prediction を求める"""
    split = folds.split(train_x, train_y)
    fold_mappings = zip(split, cv_booster.boosters)
    oof_y_pred = np.zeros_like(train_y, dtype=float)
    for (_, val_index), booster in fold_mappings:
        val_train_x = train_x.iloc[val_index]
        y_pred = booster.predict(val_train_x,
                                 num_iteration=cv_booster.best_iteration)
        oof_y_pred[val_index] = y_pred
    return oof_y_pred


def cv_mean_feature_importance(train_x, train_y, folds):
    """交差検証したモデルを使って特徴量の重要度を計算する"""
    cv_booster = _cross_validate(train_x, train_y, folds)
    importances = cv_booster.feature_importance(importance_type='gain')
    mean_importance = np.mean(importances, axis=0)
    return mean_importance


def cv_mean_test_score(train_x, train_y, folds):
    """交差検証で OOF Prediction の平均スコアを求める"""
    cv_booster = _cross_validate(train_x, train_y, folds)
    # OOF Pred を取得する
    oof_y_pred = _predict_oof(cv_booster, train_x, train_y, folds)
    test_score = roc_auc_score(train_y, oof_y_pred)
    return test_score


def main():
    logging.basicConfig(level=logging.INFO, stream=sys.stderr)

    n_cols = 100
    args = {
        'n_samples': 1_000,
        'n_features': n_cols,
        'n_informative': 5,
        'n_redundant': 0,
        'n_repeated': 0,
        'class_sep': 0.65,
        'n_classes': 2,
        'random_state': 42,
        'shuffle': False,
    }
    X, y = make_classification(**args)

    # Pandas のデータフレームにする
    col_names = [f'col{i}' for i in range(n_cols)]
    train_x = pd.DataFrame(X, columns=col_names)
    # インデックスが 0 から始まる連番とは限らないのでこういうチェックを入れた方が良い
    train_x.index = train_x.index * 10 + 10
    train_y = pd.Series(y, name='target')

    folds = StratifiedKFold(n_splits=5,
                            shuffle=True,
                            random_state=42,
                            )

    LOGGER.info('Starting base importance calculation')
    base_importance = cv_mean_feature_importance(train_x, train_y, folds)

    LOGGER.info('Starting null importance calculation')
    TRIALS_N = 20
    null_importances = []
    for _ in tqdm(range(TRIALS_N)):
        train_y_permuted = np.random.permutation(train_y)
        null_importance = cv_mean_feature_importance(train_x,
                                                     train_y_permuted,
                                                     folds)
        null_importances.append(null_importance)
    null_importances = np.array(null_importances)

    criterion_percentile = 50
    percentile_null_imp = np.percentile(null_importances,
                                        criterion_percentile,
                                        axis=0)
    null_imp_score = base_importance / (percentile_null_imp + 1e-6)
    sorted_indices = np.argsort(null_imp_score)[::-1]

    # 上位 N% の特徴量を使って性能を比較してみる
    use_feature_importance_top_percentages = [100, 75, 50, 25, 10, 5, 1]

    mean_test_scores = []
    percentile_selected_cols = []
    for percentage in use_feature_importance_top_percentages:
        sorted_columns = train_x.columns[sorted_indices]
        num_of_features = int(n_cols * percentage / 100)
        selected_cols = sorted_columns[:num_of_features]
        selected_train_x = train_x[selected_cols]
        LOGGER.info(f'Null Importance score TOP {percentage}%')
        LOGGER.info(f'selected features: {list(selected_cols)}')
        LOGGER.info(f'selected feature len: {len(selected_cols)}')
        percentile_selected_cols.append(selected_cols)

        mean_test_score = cv_mean_test_score(selected_train_x, train_y, folds)
        LOGGER.info(f'mean test_score: {mean_test_score}')
        mean_test_scores.append(mean_test_score)

    _, ax1 = plt.subplots(figsize=(8, 4))
    ax1.plot(mean_test_scores, color='b', label='mean test score')
    ax1.set_xlabel('Importance TOP n%')
    ax1.set_ylabel('mean test score')
    ax1.legend()
    ax2 = ax1.twinx()
    ax2.plot([len(cols) for cols in percentile_selected_cols],
             color='r', label='selected features len')
    ax2.set_ylabel('selected features len')
    ax2.legend()
    plt.xticks(range(len(use_feature_importance_top_percentages)),
               use_feature_importance_top_percentages)
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python lgbmnullimp.py

すると、次のようなヒストグラムが得られる。 こちらも、上位 5% を選択するまで徐々にスコアが上がって、上位 1% までいくとスコアが落ちていることがわかる。

f:id:momijiame:20200805190732p:plain
データに Pandas モデルに LightGBM を使った場合

めでたしめでたし。

Python: NumPy 配列の操作でメモリのコピーが生じているか調べる

パフォーマンスの観点からいえば、データをコピーする機会は少ないほど望ましい。 コンピュータのバスの帯域幅は有限なので、データをコピーするには時間がかかる。

NumPy の配列 (ndarray) には、メモリを実際に確保している配列と、それをただ参照しているだけのビュー (view) がある。 そして、配列への操作によって、メモリが確保されて新しい配列が作られるか、それとも単なるビューになるかは異なる。 今回は NumPy の配列を操作するときにメモリのコピーが生じているか調べる方法について。

使った環境は次のとおり。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.15.5
BuildVersion:   19F101
$ python -V                             
Python 3.7.7
$ pip list | grep -i numpy
numpy              1.19.0

下準備

あらかじめ NumPy をインストールしておく。

$ pip install numpy

Python のインタプリタを起動する。

$ python

サンプルとなる配列を用意する。

>>> import numpy as np
>>> a = np.arange(10)

flags を使った調べ方

はじめに ndarray#flags を使った調べ方から。 NumPy の配列には flags というアトリビュートがあって、ここから配列の情報がいくらか得られる。 この中に owndata という情報があって、これはオブジェクトのメモリが自身のものか、それとも別のオブジェクトを参照しているかを表す。

numpy.org

最初に作った配列については、このフラグが True にセットされている。

>>> a.flags.owndata
True

では、上記に対してスライスを使って配列の一部を取り出した場合にはどうだろうか。

>>> b = a[1:]

スライスで取り出した配列の場合、フラグは False にセットされている。

>>> b.flags.owndata
False

つまり、自身でメモリを確保しているのではなく、別のオブジェクトを参照しているだけ。

ndarray#base を使った調べ方

同じように ndarray#base を使って調べることもできそうだ。 このアトリビュートは、オブジェクトが別のオブジェクトのメモリに由来している場合に、そのオブジェクトへの参照が入る。

numpy.org

先ほどの例では、最初に作った配列は None になっている。

>>> a.base is None
True

一方で、スライスを使って取り出した配列は、元になった配列への参照が入っている。

>>> b.base is a
True

インプレース演算

ところで、インプレース演算の場合は ndarray#flagsndarray#base を使った判定ができないのかな、と思った。

たとえば配列を使った通常の加算 (__add__()) では、新しく配列が作られてメモリのコピーが生じる。

>>> c = a + 1
>>> c.flags.owndata
True
>>> c.base is None
True

一方で、インプレースの加算 (__iadd__()) を使ったときも、これまで紹介してきたアトリビュートは同じ見え方になる。

>>> a += 1
>>> a.flags.owndata
True
>>> a.base is None
True

では、メモリのコピーは生じているかというと生じていない。

NumPy の配列には __array_interface__ というアトリビュートがある。 その中にある data というキーからは、配列の最初の要素が格納されているメモリのアドレス情報が得られる。

>>> a.__array_interface__['data']
(140489713031584, False)

以下のように、インプレース演算をしてもアドレス情報に変化はない。

>>> a += 1
>>> a.__array_interface__['data']
(140489713031584, False)

つまり、新たにメモリは確保されていない。

numpy.org

なお、それ以外にも大きな配列を用意してベンチマークしたり、ソースコードを読んで調べることも考えられる。

ltrace(1) で共有ライブラリの呼び出しを追いかける

Linux システムでは ltrace(1) を使うことで共有ライブラリの呼び出しを調べることができる。 今回は、いくつかの例を用いて使い方についてざっくりと見ていく。

使った環境は次のとおり。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.4 LTS"
$ uname -r
4.15.0-111-generic
$ ltrace -V
ltrace version 0.7.3.
Copyright (C) 1997-2009 Juan Cespedes <cespedes@debian.org>.
This is free software; see the GNU General Public Licence
version 2 or later for copying conditions.  There is NO warranty.

下準備

あらかじめ、利用するパッケージをインストールしておく。 尚、ltrace 以外は使い方のサンプルとして取り上げているだけ。

$ sudo apt-get -y install ltrace libc-bin gcc wget python3-requests python3-numpy

自作のプログラムを使った例

まずは最も単純な例として、ハローワールドするだけのプログラムを試してみよう。

次のような C 言語のソースコードを用意する。

$ cat << 'EOF' > greet.c
#include <stdio.h>
#include <stdlib.h>


int main(void) {
    printf("Hello, World!\n");
    return EXIT_SUCCESS;
}
EOF

上記をビルドする。

$ gcc -Wall -o greet greet.c

これで ELF フォーマットの実行可能ファイルがえきる。

$ file greet
greet: ELF 64-bit LSB shared object, x86-64, version 1 (SYSV), dynamically linked, interpreter /lib64/ld-linux-x86-64.so.2, for GNU/Linux 3.2.0, BuildID[sha1]=1771a2c7c6ded094f15254590870089080689968, not stripped

次のとおり、実行するとただメッセージを出力して終了する。

$ ./greet 
Hello, World!

このように、ただハローワールドするだけのプログラムでも標準 C ライブラリ (libc) はダイナミックリンクされている。

$ ldd greet
    linux-vdso.so.1 (0x00007ffd31d9e000)
    libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f7081218000)
    /lib64/ld-linux-x86-64.so.2 (0x00007f708180b000)

このプログラムを ltrace(1) 経由で実行してみよう。 すると、共有ライブラリの呼び出しに関する情報が出力される。

$ ltrace ./greet > /dev/null
puts("Hello, World!")                                                                                          = 14
+++ exited (status 0) +++

puts(3) は標準 C ライブラリの提供している API のひとつで、詳細は man を参照のこと。

$ man 3 puts

たとえば -l オプションを指定すると特定の共有ライブラリの呼び出しだけを追跡できる。 ただし、上記のサンプルでは libc の API しか呼び出していないので関係ない。

$ ltrace -l libc.so.6 ./greet > /dev/null
greet->puts("Hello, World!")                                                                                   = 14
+++ exited (status 0) +++

echo(1) を使った例

次は echo(1) を使ってみよう。 こちらも標準 C ライブラリがダイナミックリンクされている。

$ echo "Hello, World"
Hello, World
$ ldd $(which echo)
    linux-vdso.so.1 (0x00007ffe8dddf000)
    libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f87d2478000)
    /lib64/ld-linux-x86-64.so.2 (0x00007f87d2a72000)

ltrace(1) 経由で実行すると、先ほどよりも色々な API が呼び出されていることがわかる。

$ ltrace echo "Hello, World"
getenv("POSIXLY_CORRECT")                                                                                      = nil
strrchr("echo", '/')                                                                                           = nil
setlocale(LC_ALL, "")                                                                                          = "C.UTF-8"

...

__freading(0x7fc11d6c3680, 0, 4, 2880)                                                                         = 0
fflush(0x7fc11d6c3680)                                                                                         = 0
fclose(0x7fc11d6c3680)                                                                                         = 0
+++ exited (status 0) +++

wget(1) を使った例 (LibSSL)

もうちょっと複雑な例として wget(1) を試してみよう。

ただし、次は libc ではなく libssl の呼び出しを追跡する。 次のように wget(1) は libssl をダイナミックリンクしている。

$ ldd $(which wget) | grep ssl
    libssl.so.1.1 => /usr/lib/x86_64-linux-gnu/libssl.so.1.1 (0x00007f85b9051000)

libssl の呼び出しに限定して wget(1)ltrace(1) 経由で実行してみよう。 たとえば Google のウェブサイトを取得させてみる。

$ ltrace -l libssl.so.1.1 wget -qO /dev/null https://google.co.jp
wget->OPENSSL_init_ssl(0, 0, 0x55af441e1ec0, 0)                                                                = 1
wget->OPENSSL_init_ssl(0x200002, 0, 0x7fffffff, 0)                                                             = 1
wget->TLS_client_method(0x7f9b550c891c, 129, 0, 0x55af42583904)                                                = 0x7f9b550c28e0

...

wget->SSL_pending(0x55af441efec0, 0x55af44201a30, 127, 0x55af441f5270)                                         = 2
wget->SSL_peek(0x55af441efec0, 0x55af44201a30, 127, 1)                                                         = 2
wget->SSL_read(0x55af441efec0, 0x55af44201a30, 2, 0x7f9b54447eb7)                                              = 2
+++ exited (status 0) +++

ちゃんと呼び出しがトレースできていることがわかる。

Python を使った例 (LibSSL)

次は Python のインタプリタについて、同じように libssl の呼び出しを追いかけてみる。 注意すべき点として、このような例では Python のバイナリは直接は LibSSL をリンクしていない。

$ ldd $(which python3)
    linux-vdso.so.1 (0x00007ffca4dd4000)
    libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f78ac216000)
    libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f78abff7000)
    libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f78abdf3000)
    libutil.so.1 => /lib/x86_64-linux-gnu/libutil.so.1 (0x00007f78abbf0000)
    libexpat.so.1 => /lib/x86_64-linux-gnu/libexpat.so.1 (0x00007f78ab9be000)
    libz.so.1 => /lib/x86_64-linux-gnu/libz.so.1 (0x00007f78ab7a1000)
    libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f78ab403000)
    /lib64/ld-linux-x86-64.so.2 (0x00007f78ac607000)

代わりに、同梱されている共有ライブラリが間接的にリンクしている。

$ python3 -c "import _ssl; print(_ssl.__file__)"
/usr/lib/python3.6/lib-dynload/_ssl.cpython-36m-x86_64-linux-gnu.so
$ ldd /usr/lib/python3.6/lib-dynload/_ssl.cpython-36m-x86_64-linux-gnu.so | grep ssl
    libssl.so.1.1 => /usr/lib/x86_64-linux-gnu/libssl.so.1.1 (0x00007f181f27f000)

このような状況では ltrace(1) がブレークポイントを仕掛けるべき共有ライブラリを自動では認識できない。 そのため -l オプションで明示的に対象の共有ライブラリを指定しないと呼び出しが表示されないようだ。

$ ltrace -l libssl.so.1.1 python3 -c "import requests; requests.get('https://google.co.jp')"
--- SIGCHLD (Child exited) ---
--- SIGCHLD (Child exited) ---
--- SIGCHLD (Child exited) ---
_ssl.cpython-36m-x86_64-linux-gnu.so->TLS_method(0xa74560, 0, 0, 0)                                            = 0x7f2b498dd1a0
_ssl.cpython-36m-x86_64-linux-gnu.so->SSL_CTX_new(0x7f2b498dd1a0, 0, 0, 0)                                     = 0x25de450
_ssl.cpython-36m-x86_64-linux-gnu.so->SSL_CTX_get_verify_callback(0x25de450, 0x7f2b475ad5bc, 0, 0x7f2b498f8718) = 0

...

_ssl.cpython-36m-x86_64-linux-gnu.so->SSL_CTX_free(0x25de450, 0x9d2800, -3, 0x27f0070)                         = 0
_ssl.cpython-36m-x86_64-linux-gnu.so->SSL_free(0x27dbaf0, 0x7f2b475d15a8, -2, 2)                               = 0
_ssl.cpython-36m-x86_64-linux-gnu.so->SSL_CTX_free(0x22e8060, 0x9d2800, -3, 0x27f0070)                         = 0
+++ exited (status 0) +++

Python / NumPy を使った例 (LibBLAS)

最後に、おまけとして Python の NumPy が依存している libblas の呼び出しを追いかけてみる。 libblas の実装は Reference BLAS や ATLAS、OpenBLAS、Intel MKL など色々とある。 今回は apt-get(8) を使って NumPy をインストールして Reference BLAS が実装の例になる。

先ほどと同じように、libblas は NumPy に同梱されている共有ライブラリが間接的にリンクしている。 NumPy には Python/C API で書かれた、次のような共有ライブラリがある。

$ dpkg -L python3-numpy | grep so$ | grep -v test
/usr/lib/python3/dist-packages/numpy/core/_dummy.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/core/multiarray.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/core/umath.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/fft/fftpack_lite.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/linalg/_umath_linalg.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/linalg/lapack_lite.cpython-36m-x86_64-linux-gnu.so
/usr/lib/python3/dist-packages/numpy/random/mtrand.cpython-36m-x86_64-linux-gnu.so

この中で ndarray を実装しているのが multiarray で、次のように libblas をリンクしている。

$ ldd /usr/lib/python3/dist-packages/numpy/core/multiarray.cpython-36m-x86_64-linux-gnu.so | grep blas
    libblas.so.3 => /usr/lib/x86_64-linux-gnu/libblas.so.3 (0x00007faf3e37a000)

試しに NumPy を使って配列のドット積を計算してみよう。 すると、内部的に LibBLAS の API を呼び出していることが確認できる。

$ ltrace -l libblas.so.3 python3 -c "import numpy as np; x = np.random.randn(100, 100); np.dot(x, x)"
multiarray.cpython-36m-x86_64-linux-gnu.so->cblas_dgemm(101, 111, 111, 100 <unfinished ...>
libblas.so.3->dgemm_(0x7ffc75477dd6, 0x7ffc75477dd7, 0x7ffc75477dc8, 0x7ffc75477dcc <unfinished ...>
libblas.so.3->lsame_(0x7ffc75477dd6, 0x7fd337a21140, 1, 1)                                                                  = 1
libblas.so.3->lsame_(0x7ffc75477dd7, 0x7fd337a21140, 1, 1)                                                                  = 1
<... dgemm_ resumed> )                                                                                                      = 0
<... cblas_dgemm resumed> )                                                                                                 = 0
+++ exited (status 0) +++

いじょう。

Python: LightGBM 開発環境メモ

最近 LightGBM にコントリビューションする機会を得たので、その際に調べたことの備忘録を残しておく。 現時点では、Python 周りの開発環境についてドキュメントは特に見当たらないようだった。 以下は CI 環境のスクリプトやエラーメッセージを読みながら雰囲気で作ったもの。 そのため、あまり信用せずオフィシャルなところは本家のリポジトリで確認してほしい。

使った環境は次のとおり。

$ sw_vers                            
ProductName:    Mac OS X
ProductVersion: 10.15.5
BuildVersion:   19F101
$ python -V                      
Python 3.7.7

下準備

拡張モジュールやドキュメントのビルドに必要なパッケージをあらかじめインストールしておく。

$ brew install cmake doxygen

そして、ソースコードのリポジトリをチェックアウトする。

$ git clone https://github.com/microsoft/LightGBM.git
$ cd LightGBM

Python のテストを通す

今回は手を入れるのが Python binding 部分だったので、ひとまず Python のユニットテストを通す必要がある。

まずはテストに必要なパッケージをインストールする。

$ pip install pandas psutil scipy

続いて、Python binding のあるディレクトリに移動する。

$ cd LightGBM/python-package

LightGBM の本体といえる拡張モジュール (lib_lightgbm.so) をビルドした上で、開発モードで LightGBM をインストールする。

$ python setup.py bdist develop

ユニットテストのあるディレクトリに移動する。

$ cd ../tests/python_package_test 

テストランナーを実行して、ユニットテストがパスすることを確認する。

$ python -m unittest discover -v
...
test_stacking_regressor (test_sklearn.TestSklearn) ... ok
test_xendcg (test_sklearn.TestSklearn) ... ok

----------------------------------------------------------------------
Ran 107 tests in 146.355s

OK (skipped=5)

ドキュメントをビルドする

続いて、API に修正があった場合にはドキュメントにも手を入れる必要がある。 LightGBM では Sphinx を使ってドキュメントを書いている。

まずはドキュメントのディレクトリに移動する。

$ cd ../../docs/

ドキュメントをビルドするのに必要なパッケージをインストールする。

$ pip install -r requirements.txt

あとは通常の Sphinx の手順でターゲットを指定してドキュメントをビルドする。

$ make html

うまくいけば _build ディレクトリ以下に成果物ができるので修正した内容を確認する。

$ ls _build/html   
Advanced-Topics.html        Python-Intro.html
C-API.html          Quick-Start.html
Development-Guide.html      README.html
Experiments.html        _images
FAQ.html            _modules
Features.html           _sources
GPU-Performance.html        _static
GPU-Targets.html        gcc-Tips.html
GPU-Tutorial.html       genindex.html
GPU-Windows.html        index.html
Installation-Guide.html     objects.inv
Parallel-Learning-Guide.html    pythonapi
Parameters-Tuning.html      search.html
Parameters.html         searchindex.js
Python-API.html

いじょう。

xargs(1) でシェル関数を使いたい

コマンドラインの処理を並列実行したいときなどに使う xargs(1) だけど、引数にシェル関数を使おうとすると少し工夫する必要がある。 工夫しない場合に失敗する理由から説明しているので、うまくいくやり方だけ知りたいときは下までスクロールしてもらえると。

使った環境は次のとおり。 シェルとしては bash を想定する。 なお、xargs(1) は GNU 版か BSD 版かは問わない。

$ sw_vers       
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G5033
$ bash --version 
GNU bash, version 3.2.57(1)-release (x86_64-apple-darwin18)
Copyright (C) 2007 Free Software Foundation, Inc.

とりあえず失敗させてみる

たとえば、次のように greet() という名前でシェル関数を定義しておく。

$ greet() { echo "Hello, $1"; }
$ greet "World"
Hello, World

そして、特に何も考えず上記で定義した関数を xargs(1) 経由で実行しようとすると、次のようにエラーとなる。

$ echo "World" | xargs -I {} greet {}
xargs: greet {}: No such file or directory

コマンドが失敗する理由について

上記でコマンドが失敗する理由は複数ある。 まず、xargs(1) の基本的な動作原理は、プロセスを fork(2) して exec*(2) することに相当する 1 。 この動作原理から、少なくとも xargs(1) の引数は実行可能ファイルが先頭に来る必要があることがわかる。 シェル関数はあくまでシェルの中で利用できるものなので、まずはここが問題となる。

では、次のように bash をインラインで実行すればどうだろうか。 だいぶ確信には迫っているものの、まだ足りないのでエラーになる。

$ echo "World" | xargs -I {} bash -c "greet {}"
bash: greet: command not found

原因の 2 つ目は、xargs(1) 経由で実行したシェルの中でシェル関数が有効ではない点。 そのため、次のように export することでサブプロセスのシェルでもシェル関数が有効であるようにしなければいけない。

$ export -f greet

ようするに、xargs(1) は関係なくインラインで bash を実行したときに、ちゃんとシェル関数が使えるようになっている必要がある。

$ bash -c "greet 'World'"
Hello, World

うまくいくやり方

ということで、うまくいくやり方は次のとおり。

$ greet() { echo "Hello, $1"; }
$ export -f greet
$ echo "World" | xargs -I {} bash -c "greet {}"
Hello, World

いじょう。


  1. 少なくとも GNU 版の実装はそうなっているようだった

Python: 画像データをフーリエ変換して周波数領域で扱ってみる

フーリエ変換は音声データに対して用いられることが多い手法だけど、画像データにも応用が効く。 音声データの場合、フーリエ変換を使うことで時間領域の情報を周波数領域の情報に直せる。 それに対し、画像データでは空間領域の情報を周波数領域の情報に直すことになる。 つまり、画像データの濃淡が複数の波形の合成によって作られていると見なす。 今回は、画像データをフーリエ変換して、周波数領域の情報にフィルタをかけたり元に戻したりして遊んでみる。

使った環境は次のとおり。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G5033
$ python -V        
Python 3.7.7

下準備

あらかじめ、必要なパッケージをインストールしておく。

$ pip install pillow matplotlib

周波数領域の情報を可視化してみる

最初に、画像データをフーリエ変換して周波数領域で可視化してみる。 画像データは、二次元の NumPy 配列として読み込んだ上で np.fft.fft2() に渡すことでフーリエ変換できる。 また、周波数領域のデータは np.fft.ifft2() を使うことで空間領域のデータに戻せる。

以下のサンプルコードでは、元の画像と、周波数領域でのパワースペクトル、そして逆変換することで元に戻した画像を可視化している。 なお、読み込む画像は適当に用意しよう。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from PIL import Image
import numpy as np
from matplotlib import pyplot as plt


def main():
    # 画像を読み込む
    img = Image.open('lena.png')
    # グレイスケールに変換する
    gray_img = img.convert('L')
    # NumPy 配列にする
    f_xy = np.asarray(gray_img)

    # 2 次元高速フーリエ変換で周波数領域の情報を取り出す
    f_uv = np.fft.fft2(f_xy)
    # 画像の中心に低周波数の成分がくるように並べかえる
    shifted_f_uv = np.fft.fftshift(f_uv)

    # パワースペクトルに変換する
    magnitude_spectrum2d = 20 * np.log(np.absolute(shifted_f_uv))

    # 元の並びに直す
    unshifted_f_uv = np.fft.fftshift(shifted_f_uv)
    # 2 次元逆高速フーリエ変換で空間領域の情報に戻す
    i_f_xy = np.fft.ifft2(unshifted_f_uv).real  # 実数部だけ使う

    # 上記を画像として可視化する
    fig, axes = plt.subplots(1, 3, figsize=(8, 4))
    # 枠線と目盛りを消す
    for ax in axes:
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])
    # 元画像
    axes[0].imshow(f_xy, cmap='gray')
    axes[0].set_title('Input Image')
    # 周波数領域のパワースペクトル
    axes[1].imshow(magnitude_spectrum2d, cmap='gray')
    axes[1].set_title('Magnitude Spectrum')
    # FFT -> IFFT した画像
    axes[2].imshow(i_f_xy, cmap='gray')
    axes[2].set_title('Reversed Image')
    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記をファイルに保存して実行する。

$ python imgfft2d.py

すると、次のようなグラフが得られる。

f:id:momijiame:20200616180009p:plain
FFT -> IFFT

真ん中のプロットが、元となった画像データの周波数領域での表現になる。 この表現では、中心に近いほど低い周波数・遠いほど高い周波数の成分を含む。 白い部分ほど成分が多いことを示しているため、この画像は低周波の成分が比較的多いように見える。 また、右の画像を見ると、ちゃんと周波数領域のデータから空間領域のデータに復元できていることがわかる。

ローパスフィルタ (Low Pass Filter) をかけて復元してみる

先ほど述べた通り、周波数領域の表現では中心に近いほど低い周波数・遠いほど高い周波数の成分を含んでいる。 つまり、中心部分だけを取り出すような演算をすると、画像を構成する波形にローパスフィルタをかけることができる。

以下のサンプルコードでは、周波数領域のデータに対して中心部分を取り出すフィルタをかけることで低周波成分だけを抽出した。 そして、フィルタしたデータを逆フーリエ変換することで画像に戻している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from PIL import Image
from PIL import ImageDraw
import numpy as np
from matplotlib import pyplot as plt


def main():
    # 画像を読み込む
    img = Image.open('lena.png')
    # グレイスケールに変換する
    gray_img = img.convert('L')
    # NumPy 配列にする
    f_xy = np.asarray(gray_img)

    # 2 次元高速フーリエ変換で周波数領域の情報を取り出す
    f_uv = np.fft.fft2(f_xy)
    # 画像の中心に低周波数の成分がくるように並べかえる
    shifted_f_uv = np.fft.fftshift(f_uv)

    # フィルタ (ローパス) を用意する
    x_pass_filter = Image.new(mode='L',  # 8-bit pixels, black and white
                              size=(shifted_f_uv.shape[0],
                                    shifted_f_uv.shape[1]),
                              color=0,  # default black
                              )
    # 中心に円を描く
    draw = ImageDraw.Draw(x_pass_filter)
    # 円の半径
    ellipse_r = 50
    # 画像の中心
    center = (shifted_f_uv.shape[0] // 2,
              shifted_f_uv.shape[1] // 2)
    # 円の座標
    ellipse_pos = (center[0] - ellipse_r,
                   center[1] - ellipse_r,
                   center[0] + ellipse_r,
                   center[1] + ellipse_r)
    draw.ellipse(ellipse_pos, fill=255)
    # フィルタ
    filter_array = np.asarray(x_pass_filter)

    # フィルタを適用する
    filtered_f_uv = np.multiply(shifted_f_uv, filter_array)

    # パワースペクトルに変換する
    magnitude_spectrum2d = 20 * np.log(np.absolute(filtered_f_uv))

    # 元の並びに直す
    unshifted_f_uv = np.fft.fftshift(filtered_f_uv)
    # 2 次元逆高速フーリエ変換で空間領域の情報に戻す
    i_f_xy = np.fft.ifft2(unshifted_f_uv).real  # 実数部だけ使う

    # 上記を画像として可視化する
    fig, axes = plt.subplots(1, 4, figsize=(12, 4))
    # 枠線と目盛りを消す
    for ax in axes:
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])
    # 元画像
    axes[0].imshow(f_xy, cmap='gray')
    axes[0].set_title('Input Image')
    # フィルタ画像
    axes[1].imshow(filter_array, cmap='gray')
    axes[1].set_title('Filter Image')
    # フィルタされた周波数領域のパワースペクトル
    axes[2].imshow(magnitude_spectrum2d, cmap='gray')
    axes[2].set_title('Filtered Magnitude Spectrum')
    # FFT -> Band-pass Filter -> IFFT した画像
    axes[3].imshow(i_f_xy, cmap='gray')
    axes[3].set_title('Reversed Image')
    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python imglpf.py

すると、次のようなグラフが得られる。 分かりにくいかもしれないけど、復元した画像はちょっとぼやけている。 また、ぼやけ方も縞模様になっていて波を感じるものとなった。 これは、画像を構成するすべての周波数の中から、高周波の成分がフィルタによって取り除かれたことで生じている。

f:id:momijiame:20200616180053p:plain
FFT -> LPF -> IFFT

中心のごく一部だけを取り出している (つまり、多くの情報が失われている) のに、それっぽい画像になるのはなんとも面白い。 なお、フィルタに使う円の半径を小さくすれば、それだけ高周波の成分が少なくなって、より分かりやすくぼやける。

ハイパスフィルタ (High Pass Filter) をかけて復元してみる

次は、同様に高周波成分だけを取り出すハイパスフィルタをかけてみよう。 つまり、中心だけを取り除くことになる。 先ほどとのコードの違いは、適用するフィルタの周波数特性が違うだけ。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from PIL import Image
from PIL import ImageDraw
import numpy as np
from matplotlib import pyplot as plt


def main():
    # 画像を読み込む
    img = Image.open('lena.png')
    # グレイスケールに変換する
    gray_img = img.convert('L')
    # NumPy 配列にする
    f_xy = np.asarray(gray_img)

    # 2 次元高速フーリエ変換で周波数領域の情報を取り出す
    f_uv = np.fft.fft2(f_xy)
    # 画像の中心に低周波数の成分がくるように並べかえる
    shifted_f_uv = np.fft.fftshift(f_uv)

    # フィルタ (ハイパス) を用意する
    x_pass_filter = Image.new(mode='L',  # 8-bit pixels, black and white
                              size=(shifted_f_uv.shape[0],
                                    shifted_f_uv.shape[1]),
                              color=255,  # default white
                              )
    # 中心に円を描く
    draw = ImageDraw.Draw(x_pass_filter)
    # 円の半径
    ellipse_r = 50
    # 画像の中心
    center = (shifted_f_uv.shape[0] // 2,
              shifted_f_uv.shape[1] // 2)
    # 円の座標
    ellipse_pos = (center[0] - ellipse_r,
                   center[1] - ellipse_r,
                   center[0] + ellipse_r,
                   center[1] + ellipse_r)
    draw.ellipse(ellipse_pos, fill=0)
    # フィルタ
    filter_array = np.asarray(x_pass_filter)

    # フィルタを適用する
    filtered_f_uv = np.multiply(shifted_f_uv, filter_array)

    # パワースペクトルに変換する
    magnitude_spectrum2d = 20 * np.log(np.absolute(filtered_f_uv))

    # 元の並びに直す
    unshifted_f_uv = np.fft.fftshift(filtered_f_uv)
    # 2 次元逆高速フーリエ変換で空間領域の情報に戻す
    i_f_xy = np.fft.ifft2(unshifted_f_uv).real  # 実数部だけ使う

    # 上記を画像として可視化する
    fig, axes = plt.subplots(1, 4, figsize=(12, 4))
    # 枠線と目盛りを消す
    for ax in axes:
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])
    # 元画像
    axes[0].imshow(f_xy, cmap='gray')
    axes[0].set_title('Input Image')
    # フィルタ画像
    axes[1].imshow(filter_array, cmap='gray')
    axes[1].set_title('Filter Image')
    # フィルタされた周波数領域のパワースペクトル
    axes[2].imshow(magnitude_spectrum2d, cmap='gray')
    axes[2].set_title('Filtered Magnitude Spectrum')
    # FFT -> Band-pass Filter -> IFFT した画像
    axes[3].imshow(i_f_xy, cmap='gray')
    axes[3].set_title('Reversed Image')
    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を保存して実行する。

$ python imghpf.py

すると、今度は次のようなグラフが得られる。 またもや分かりにくいけど、復元した画像は元の画像の輪郭だけがぼんやりと浮かび上がったものとなっている。 これが画像を構成する高周波の成分で、先ほどの復元画像から取り除かれたものと解釈できる。

f:id:momijiame:20200616180116p:plain
FFT -> HPF -> IFFT

カラー画像を扱う場合

なお、カラー画像を扱う場合、それぞれのチャネルごとに処理する必要がある。 以下のサンプルコードでは、最初のサンプルコードを RGB のチャネルでそれぞれ適用している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from itertools import chain

from PIL import Image
import numpy as np
from matplotlib import pyplot as plt


def main():
    # 画像を読み込む
    img = Image.open('lena.png')
    # NumPy 配列にする
    f_xy_rgb = np.asarray(img)

    # 画像として可視化する
    fig, axes = plt.subplots(3, 5, figsize=(12, 6))
    # 枠線と目盛りを消す
    for ax in chain.from_iterable(axes):
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])

    # 元画像
    axes[1][0].imshow(f_xy_rgb)
    axes[1][0].set_title('Input Image (RGB)')

    # 逆変換した画像を入れる領域を用意しておく
    i_f_xy_rgb = np.empty_like(f_xy_rgb)
    # RGB ごとに処理する
    channel_names = ('R', 'G', 'B')
    channel_indices = range(f_xy_rgb.shape[2])
    zipped_channel_info = zip(channel_names, channel_indices)
    for channel_name, channel_index in zipped_channel_info:
        # 各チャネルごとに配列として取り出す
        f_xy = f_xy_rgb[:, :, channel_index]

        # 各チャネルの元画像
        axes[channel_index][1].imshow(f_xy, cmap='gray')
        axes[channel_index][1].set_title(f'Input Image ({channel_name})')

        # 2 次元高速フーリエ変換で周波数領域の情報を取り出す
        f_uv = np.fft.fft2(f_xy)
        # 画像の中心に低周波数の成分がくるように並べかえる
        shifted_f_uv = np.fft.fftshift(f_uv)
        # パワースペクトルに変換する
        magnitude_spectrum2d = 20 * np.log(np.absolute(shifted_f_uv))
        # 元の並びに直す
        unshifted_f_uv = np.fft.fftshift(shifted_f_uv)
        # 2 次元逆高速フーリエ変換で空間領域の情報に戻す
        i_f_xy = np.fft.ifft2(unshifted_f_uv).real  # 実数部だけ使う

        # 周波数領域のパワースペクトル
        axes[channel_index][2].imshow(magnitude_spectrum2d, cmap='gray')
        axes[channel_index][2].set_title(f'Magnitude Spectrum ({channel_name})')

        # FFT -> Band-pass Filter -> IFFT した画像
        axes[channel_index][3].imshow(i_f_xy, cmap='gray')
        axes[channel_index][3].set_title(f'Reversed Image ({channel_name})')

        # 逆変換したチャネルを保存しておく
        i_f_xy_rgb[:, :, channel_index] = i_f_xy

    # 逆変換した RGB 画像
    axes[1][4].imshow(i_f_xy_rgb)
    axes[1][4].set_title('Reversed Image (RGB)')

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を保存して実行する。

$ python imgfftc.py

すると、次のようなグラフが得られる。

f:id:momijiame:20200616180139p:plain
FFT (RGB) -> IFFT (RGB)

ちゃんと RGB それぞれのチャネルの情報から、カラー画像が復元できていることがわかる。

めでたしめでたし。