CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PySpark で UDAF が作れない場合の回避策について

PySpark では、ごく最近まで UDAF (User Defined Aggregate Function: ユーザ定義集計関数) がサポートされていなかった。 Apache Spark 2.3 以降では Pandas UDF を使うことで UDAF に相当する処理を書くことができるようになっている。 今回は、それ以前のバージョンを使っているときに、同等の処理を書くための回避策について書いてみる。

使った環境は次のとおり。

$ cat /etc/redhat-release 
CentOS Linux release 7.7.1908 (Core)
$ uname -r
3.10.0-1062.18.1.el7.x86_64
$ hadoop version
Hadoop 2.9.2
Subversion https://git-wip-us.apache.org/repos/asf/hadoop.git -r 826afbeae31ca687bc2f8471dc841b66ed2c6704
Compiled by ajisaka on 2018-11-13T12:42Z
Compiled with protoc 2.5.0
From source with checksum 3a9939967262218aa556c684d107985
This command was run using /home/vagrant/hadoop-2.9.2/share/hadoop/common/hadoop-common-2.9.2.jar
$ pyspark --version
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.4.5
      /_/
                        
Using Scala version 2.11.12, OpenJDK 64-Bit Server VM, 1.8.0_242
Branch HEAD
Compiled by user centos on 2020-02-02T19:38:06Z
Revision cee4ecbb16917fa85f02c635925e2687400aa56b
Url https://gitbox.apache.org/repos/asf/spark.git
Type --help for more information.

下準備

はじめに、PySpark の REPL を起動しておく、

$ pyspark --master yarn

次のようにして、サンプルとなる DataFrame を用意する。 このデータは `'category`` カラムを使ってグループ化できる。

>>> data = [
...   ('A', 'Alice', 10),
...   ('A', 'Bob', 15),
...   ('A', 'Carol', 20),
...   ('B', 'Daniel', 25),
...   ('B', 'Ellie', 30),
...   ('C', 'Frank', 35),
... ]
>>> df = spark.createDataFrame(data, ('category', 'name', 'age'))

組み込みの集計関数について

はじめに、PySpark に組み込みで用意されている集計関数はどのように使うのかおさらいしておく。 たとえば、カテゴリーごとの平均値を計算してみよう。

たとえば年齢の平均を計算するときは、次のようにする。 DataFrame#groupBy() からは GroupedData というクラスのインスタンスが返る。 さらに、そのインスタンスに対して GroupedData#agg() を使って集計関数を適用する。

>>> df.groupBy('category').agg({'age': 'mean'}).show()
+--------+--------+                                                             
|category|avg(age)|
+--------+--------+
|       B|    27.5|
|       C|    35.0|
|       A|    15.0|
+--------+--------+

集計する処理を辞書と文字列で表す以外に pyspark.sql.functions を使う方法もある。 たとえば、上記と同じ処理を書いてみよう。

>>> from pyspark.sql import functions as F
>>> df.groupBy('category').agg(F.mean('age').alias('mean-age')).show()
+--------+--------+                                                             
|category|mean-age|
+--------+--------+
|       B|    27.5|
|       C|    35.0|
|       A|    15.0|
+--------+--------+

以上が組み込みの集計関数を使う方法になる。 基礎的な統計量などを計算するだけなら、これでも問題ないはず。 しかし、集計の処理は組み込みの関数だけで完結しない場合も多い。 そこで、ユーザ定義の関数で集計の処理をしたくなる。

UDAF 代わりの処理の書き方

前述したとおり Apache Spark 2.3 以降では Pandas UDF を使って UDAF に相当する処理が書ける。 ただし、ここではそれについての具体的な紹介はしない。 紹介するのは、Pandas UDF が使えない環境での回避策となる。

回避策のキモは pyspark.sql.functions.collect_list を使うところ。 この関数は GroupedData の特定カラムを、グループ単位でリストに入れて返すことができる。 言葉で説明するよりも、次のサンプルコードを見てもらった方が早いかもしれない。

>>> agg_df = df.groupBy('category').agg(F.collect_list('age').alias('grouped-age'))
>>> agg_df.show()
+--------+------------+                                                         
|category| grouped-age|
+--------+------------+
|       B|    [25, 30]|
|       C|        [35]|
|       A|[10, 15, 20]|
+--------+------------+

この状態までくれば、あとは単なる UDF (User Defined Function) を使って集計できる。 例として平均を計算する UDF を書いてみよう。

>>> def mean(values):
...     return sum(values) / len(values)
... 
>>> mean_udf = F.udf(mean)

上記の UDF をリストのカラムに対して適用する。

>>> agg_df.withColumn('mean-age', mean_udf('grouped-age')).show()
+--------+------------+--------+                                                
|category| grouped-age|mean-age|
+--------+------------+--------+
|       B|    [25, 30]|    27.5|
|       C|        [35]|    35.0|
|       A|[10, 15, 20]|    15.0|
+--------+------------+--------+

ばっちり。

参考

stackoverflow.com

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム