PySpark では、ごく最近まで UDAF (User Defined Aggregate Function: ユーザ定義集計関数) がサポートされていなかった。 Apache Spark 2.3 以降では Pandas UDF を使うことで UDAF に相当する処理を書くことができるようになっている。 今回は、それ以前のバージョンを使っているときに、同等の処理を書くための回避策について書いてみる。
使った環境は次のとおり。
$ cat /etc/redhat-release CentOS Linux release 7.7.1908 (Core) $ uname -r 3.10.0-1062.18.1.el7.x86_64 $ hadoop version Hadoop 2.9.2 Subversion https://git-wip-us.apache.org/repos/asf/hadoop.git -r 826afbeae31ca687bc2f8471dc841b66ed2c6704 Compiled by ajisaka on 2018-11-13T12:42Z Compiled with protoc 2.5.0 From source with checksum 3a9939967262218aa556c684d107985 This command was run using /home/vagrant/hadoop-2.9.2/share/hadoop/common/hadoop-common-2.9.2.jar $ pyspark --version Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.5 /_/ Using Scala version 2.11.12, OpenJDK 64-Bit Server VM, 1.8.0_242 Branch HEAD Compiled by user centos on 2020-02-02T19:38:06Z Revision cee4ecbb16917fa85f02c635925e2687400aa56b Url https://gitbox.apache.org/repos/asf/spark.git Type --help for more information.
下準備
はじめに、PySpark の REPL を起動しておく、
$ pyspark --master yarn
次のようにして、サンプルとなる DataFrame を用意する。 このデータは `'category`` カラムを使ってグループ化できる。
>>> data = [ ... ('A', 'Alice', 10), ... ('A', 'Bob', 15), ... ('A', 'Carol', 20), ... ('B', 'Daniel', 25), ... ('B', 'Ellie', 30), ... ('C', 'Frank', 35), ... ] >>> df = spark.createDataFrame(data, ('category', 'name', 'age'))
組み込みの集計関数について
はじめに、PySpark に組み込みで用意されている集計関数はどのように使うのかおさらいしておく。 たとえば、カテゴリーごとの平均値を計算してみよう。
たとえば年齢の平均を計算するときは、次のようにする。
DataFrame#groupBy()
からは GroupedData
というクラスのインスタンスが返る。
さらに、そのインスタンスに対して GroupedData#agg()
を使って集計関数を適用する。
>>> df.groupBy('category').agg({'age': 'mean'}).show() +--------+--------+ |category|avg(age)| +--------+--------+ | B| 27.5| | C| 35.0| | A| 15.0| +--------+--------+
集計する処理を辞書と文字列で表す以外に pyspark.sql.functions
を使う方法もある。
たとえば、上記と同じ処理を書いてみよう。
>>> from pyspark.sql import functions as F >>> df.groupBy('category').agg(F.mean('age').alias('mean-age')).show() +--------+--------+ |category|mean-age| +--------+--------+ | B| 27.5| | C| 35.0| | A| 15.0| +--------+--------+
以上が組み込みの集計関数を使う方法になる。 基礎的な統計量などを計算するだけなら、これでも問題ないはず。 しかし、集計の処理は組み込みの関数だけで完結しない場合も多い。 そこで、ユーザ定義の関数で集計の処理をしたくなる。
UDAF 代わりの処理の書き方
前述したとおり Apache Spark 2.3 以降では Pandas UDF を使って UDAF に相当する処理が書ける。 ただし、ここではそれについての具体的な紹介はしない。 紹介するのは、Pandas UDF が使えない環境での回避策となる。
回避策のキモは pyspark.sql.functions.collect_list
を使うところ。
この関数は GroupedData
の特定カラムを、グループ単位でリストに入れて返すことができる。
言葉で説明するよりも、次のサンプルコードを見てもらった方が早いかもしれない。
>>> agg_df = df.groupBy('category').agg(F.collect_list('age').alias('grouped-age')) >>> agg_df.show() +--------+------------+ |category| grouped-age| +--------+------------+ | B| [25, 30]| | C| [35]| | A|[10, 15, 20]| +--------+------------+
この状態までくれば、あとは単なる UDF (User Defined Function) を使って集計できる。 例として平均を計算する UDF を書いてみよう。
>>> def mean(values): ... return sum(values) / len(values) ... >>> mean_udf = F.udf(mean)
上記の UDF をリストのカラムに対して適用する。
>>> agg_df.withColumn('mean-age', mean_udf('grouped-age')).show() +--------+------------+--------+ |category| grouped-age|mean-age| +--------+------------+--------+ | B| [25, 30]| 27.5| | C| [35]| 35.0| | A|[10, 15, 20]| 15.0| +--------+------------+--------+
ばっちり。
参考
入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム
- 作者:Tomasz Drabas,Denny Lee
- 発売日: 2017/11/22
- メディア: 単行本(ソフトカバー)