CUBE SUGAR CONTAINER

技術系のこと書きます。

ClusterShell を使って複数のマシンを SSH で並列に操作する

複数のマシンを使って動作検証をしていると、ログインやコマンド入力の操作が煩雑になる。 また、複数のマシンに共通で必要な操作があったりすると手数もかさむ。 今回は、そういった問題を緩和できる ClusterShell について扱う。 ClusterShell を使うと、マシンをグループ化して SSH で並列に操作できる。

今回は、次のようなマシンの構成を扱う。 client には ClusterShell をインストールして、他のマシンを操作する。 masterworker[01] は名前通り異なる役割のマシンを想定して用意した。

  • client
    • 192.168.56.10
  • master
    • 192.168.56.20
  • worker1
    • 192.168.56.31
  • worker2
    • 192.168.56.32

上記のマシンは、あらかじめ Ubuntu 20.04 LTS を使って構築してある。 一応、末尾にはおまけとして Vagrant + VirtualBoxを使って仮想マシンを構築するための設定ファイルを用意した。

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.3 LTS
Release:    20.04
Codename:   focal
$ uname -srm
Linux 5.4.0-91-generic x86_64

ClusterShell のバージョンは次のとおり。

$ clush --version
clush 1.8.3

もくじ

下準備

client には、まず ClusterShell をインストールする。 sshpass と openssh-client は SSH のために入れておく。

$ sudo apt-get update
$ sudo apt-get -y install sshpass openssh-client clustershell

これで clush コマンドが使えるようになる。

$ clush --version
clush 1.8.3

次に、ホスト名を使って操作したいので /etc/hosts に IP アドレスとの対応関係を書き込んでおく。

$ cat << 'EOF' | sudo tee -a /etc/hosts >/dev/null
192.168.56.10 client
192.168.56.20 master
192.168.56.31 worker1
192.168.56.32 worker2
EOF

次に SSH でログインするための公開鍵を作成する。

$ ssh-keygen -t rsa -P '' -f $HOME/.ssh/id_rsa

client から、その他のホストに公開鍵を使ってログインできるように登録する。 ここの作業は環境構築に使ったツールやイメージなどによって少し変わる。 たとえばパスワード認証が無効になっているイメージだと、この操作では登録できない。 また、Vagrant で作った環境なのでパスワードが vagrant になっている。

$ sshpass -p "vagrant" \
    ssh-copy-id -i $HOME/.ssh/id_rsa.pub -o "StrictHostKeyChecking no" master
$ sshpass -p "vagrant" \
    ssh-copy-id -i $HOME/.ssh/id_rsa.pub -o "StrictHostKeyChecking no" worker1
$ sshpass -p "vagrant" \
    ssh-copy-id -i $HOME/.ssh/id_rsa.pub -o "StrictHostKeyChecking no" worker2

次に、ClusterShell にホスト名とグループの対応関係を登録する。 対応関係は /etc/clustershell/ 以下の設定ファイルで指定する。 設定ファイルは <group-name>: <hostname>,... というフォーマットになっている。 以下では all というグループに、操作対象となるすべてのホストを登録している。 そして、グループ mmaster を、グループ wworker1worker2 を登録している。

$ sudo cp /etc/clustershell/groups.d/local.cfg{,.orig}
$ cat << 'EOF' | sudo tee /etc/clustershell/groups.d/local.cfg >/dev/null
all: master,worker1,worker2
m: master
w: worker1,worker2
EOF

これで ClusterShell を使い始める準備ができた。

個別のホストを指定して操作する

特定のホストを指定してコマンドを実行したいときは -w オプションを使う。 ここでは、それぞれのホストにホスト名を設定した。

$ clush -w master "sudo hostnamectl set-hostname master"
$ clush -w worker1 "sudo hostnamectl set-hostname worker1"
$ clush -w worker2 "sudo hostnamectl set-hostname worker2"

グループを指定して操作する

先ほどの例であれば、別に ssh(1) を直接使って操作しても変わらなかった。 ClusterShell の本領はグループを指定して操作できることにある。 グループを指定するには -g オプションでグループ名を指定すれば良い。 また、-L オプションを指定すると、結果をホスト名のアルファベット順でソートできる。

試しに all グループに対して hostname コマンドを実行してみよう。

$ clush -g all -L hostname
master: master
worker1: worker1
worker2: worker2

上記から、操作対象のすべてのホストに hostname コマンドが実行されたことがわかる。

また、-g all はよく使うので -a オプションがエイリアスとして用意されている。

$ clush -a -L hostname
master: master
worker1: worker1
worker2: worker2

同じように、特定のグループを指定してコマンドを実行してみよう。 以下ではグループ mw を、それぞれ指定している。

$ clush -g m -L hostname
master: master
$ clush -g w -L hostname
worker1: worker1
worker2: worker2

ちゃんとグループに所属しているホストに対してコマンドが実行されていることがわかる。

グループはカンマ区切りで複数指定することもできる。 以下ではグループ mw に対して実行している。

$ clush -g m,w -L hostname
master: master
worker1: worker1
worker2: worker2

複数のホストにファイルをコピーする

ClusterShell では、複数のホストにファイルを scp(1) できる。 ファイルをコピーするには -c オプションでコピーしたいファイルを指定して、コピー先のディレクトリを --dest オプションで指定する。

以下では greet.txt というファイルを、すべてのホストに対して /tmp 以下にコピーしている。

$ echo "Hello, World" > greet.txt
$ clush -g all -c greet.txt --dest /tmp

コピーされたはずのパスを cat(1) すると、ちゃんとファイルがコピーされていることがわかる。

$ clush -g all -L "cat /tmp/greet.txt"
master: Hello, World
worker1: Hello, World
worker2: Hello, World

書き込みに特権が必要なファイルをコピーするときは、少し工夫が必要になる。 具体的には、一度特権が不要なディレクトリにコピーした上で、あらためて特権ユーザでファイルを移動するというもの。 たとえば /etc/hosts をコピーしてみよう。

$ clush -g all -c /etc/hosts --dest /var/tmp
$ clush -g all -L "sudo cp /var/tmp/hosts /etc/hosts"

たとえば master の内容を確認すると、ちゃんとコピーされたことがわかる。

$ clush -w master "cat /etc/hosts" 
master: 127.0.0.1 localhost
master: 127.0.1.1 vagrant
master: 
master: # The following lines are desirable for IPv6 capable hosts
master: ::1     ip6-localhost ip6-loopback
master: fe00::0 ip6-localnet
master: ff00::0 ip6-mcastprefix
master: ff02::1 ip6-allnodes
master: ff02::2 ip6-allrouters
master: 192.168.56.10 client
master: 192.168.56.20 master
master: 192.168.56.31 worker1
master: 192.168.56.32 worker2

まとめ

今回は ClusterShell を使うことで、複数のマシンを SSH で並列に操作する方法を扱った。

おまけ: 環境構築に使った Vagrantfile

今回の環境を作るのに使った Vagrantfile を以下に示す。

# -*- mode: ruby -*-
# vi: set ft=ruby :

# Vagrantfile API/syntax version. Don't touch unless you know what you're doing!
VAGRANTFILE_API_VERSION = "2"

Vagrant.configure(VAGRANTFILE_API_VERSION) do |config|

  machines = {
    "client" => "192.168.56.10",
    "master" => "192.168.56.20",
    "worker1" => "192.168.56.31",
    "worker2" => "192.168.56.32",
  }

  machines.each do |key, value|
    config.vm.define key do |machine|
      machine.vm.box = "bento/ubuntu-20.04"
      machine.vm.network "private_network", ip: value
      machine.vm.provider "virtualbox" do |vb|
        vb.cpus = "2"
        vb.memory = "1024"
      end
    end
  end

end

あとは以下で環境が用意できる。

$ vagrant up
$ vagrant ssh client

いじょう。

chroot について

今回は、Unix の古典的な機能のひとつである chroot について扱う。 chroot を使うと、特定のプロセスにおけるルートディレクトリを、ルートディレクトリ以下にある別のディレクトリに変更できる。 今回扱うのはコマンドラインツールとしての chroot(8) と、システムコールとしての chroot(2) になる。

使った環境は次のとおり。

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.4 LTS
Release:    20.04
Codename:   focal
$ uname -srm
Linux 5.4.0-104-generic aarch64
$ chroot --version
chroot (GNU coreutils) 8.30
Copyright (C) 2018 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <https://gnu.org/licenses/gpl.html>.
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.

Written by Roland McGrath.
$ gcc --version
gcc (Ubuntu 9.4.0-1ubuntu1~20.04) 9.4.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
$ ldd --version
ldd (Ubuntu GLIBC 2.31-0ubuntu9.7) 2.31
Copyright (C) 2020 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
Written by Roland McGrath and Ulrich Drepper.

もくじ

下準備

chroot(8) は coreutils パッケージに含まれているのでインストールしておく。 また、chroot(2) を呼び出すコードをビルドするために build-essential をインストールする。

$ sudo apt-get -y install coreutils build-essential

chroot(8) の動作を試す

まずはコマンドラインツールとしての chroot(8) から動作を確認していく。

はじめに、chroot(8) したプロセスでルートディレクトリになるディレクトリを用意する。 ディレクトリは mktemp(1) を使ってテンポラリディレクトリとして作る。

$ ROOTFS=$(mktemp -d)
$ echo ${ROOTFS}
/tmp/tmp.GuMwStXLLO

chroot(8) した上で起動するプログラムとして bash(1) をコピーしておく。 このとき --parents オプションを使ってディレクトリ構造ごとコピーしてやる。

$ cp -avL --parents $(which bash) ${ROOTFS}
/usr -> /tmp/tmp.GuMwStXLLO/usr
/usr/bin -> /tmp/tmp.GuMwStXLLO/usr/bin
'/usr/bin/bash' -> '/tmp/tmp.GuMwStXLLO/usr/bin/bash'

さらに、bash(1) の動作に必要な共有ライブラリをコピーする。 動作に必要な共有ライブラリは ldd(1) の出力から得られる。

$ ldd $(which bash) | grep -o "/lib.*\.[0-9]\+" | xargs -I {} cp -avL --parents {} ${ROOTFS}
/lib -> /tmp/tmp.GuMwStXLLO/lib
/lib/aarch64-linux-gnu -> /tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu
'/lib/aarch64-linux-gnu/libtinfo.so.6' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libtinfo.so.6'
'/lib/aarch64-linux-gnu/libdl.so.2' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libdl.so.2'
'/lib/aarch64-linux-gnu/libc.so.6' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libc.so.6'
'/lib/ld-linux-aarch64.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/ld-linux-aarch64.so.1'

上記の作業を、必要なプログラムそれぞれについてやっていく。 手作業でひとつひとつやると大変なのでループを回して処理する。 ここでは例として ls, mkdir, mount をコピーした。

$ CMDS=ls,mkdir,mount
$ IFS=","
$ for CMD in ${CMDS}
> do
>   cp -avL --parents $(which ${CMD}) ${ROOTFS}
>   ldd $(which ${CMD}) | grep -o "/lib.*\.[0-9]\+" | xargs -I {} cp -avL --parents {} ${ROOTFS}
> done
'/usr/bin/ls' -> '/tmp/tmp.GuMwStXLLO/usr/bin/ls'
'/lib/aarch64-linux-gnu/libselinux.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libselinux.so.1'
'/lib/aarch64-linux-gnu/libc.so.6' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libc.so.6'
'/lib/ld-linux-aarch64.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/ld-linux-aarch64.so.1'
'/lib/aarch64-linux-gnu/libpcre2-8.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpcre2-8.so.0'
'/lib/aarch64-linux-gnu/libdl.so.2' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libdl.so.2'
'/lib/aarch64-linux-gnu/libpthread.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpthread.so.0'
'/usr/bin/mkdir' -> '/tmp/tmp.GuMwStXLLO/usr/bin/mkdir'
'/lib/aarch64-linux-gnu/libselinux.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libselinux.so.1'
'/lib/aarch64-linux-gnu/libc.so.6' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libc.so.6'
'/lib/ld-linux-aarch64.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/ld-linux-aarch64.so.1'
'/lib/aarch64-linux-gnu/libpcre2-8.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpcre2-8.so.0'
'/lib/aarch64-linux-gnu/libdl.so.2' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libdl.so.2'
'/lib/aarch64-linux-gnu/libpthread.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpthread.so.0'
'/usr/bin/mount' -> '/tmp/tmp.GuMwStXLLO/usr/bin/mount'
'/lib/aarch64-linux-gnu/libmount.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libmount.so.1'
'/lib/aarch64-linux-gnu/libc.so.6' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libc.so.6'
'/lib/ld-linux-aarch64.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/ld-linux-aarch64.so.1'
'/lib/aarch64-linux-gnu/libblkid.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libblkid.so.1'
'/lib/aarch64-linux-gnu/libselinux.so.1' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libselinux.so.1'
'/lib/aarch64-linux-gnu/libpcre2-8.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpcre2-8.so.0'
'/lib/aarch64-linux-gnu/libdl.so.2' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libdl.so.2'
'/lib/aarch64-linux-gnu/libpthread.so.0' -> '/tmp/tmp.GuMwStXLLO/lib/aarch64-linux-gnu/libpthread.so.0'

準備が終わったところで、満を持して chroot(8) する。 第 1 引数は chroot(8) したプロセスでルートディレクトリになるディレクトリ。 第 2 引数は chroot(8) した上で起動するコマンドの場所。

$ sudo chroot ${ROOTFS} $(which bash)

実行すると、先ほどとは異なるシェルとして bash(1) が立ち上がる。 試しにルートディレクトリを ls(1) してみよう。 あきらかに、普段とは表示されるディレクトリの数が異なる。 lib と usr しかない。 とはいえ、これは先ほどコピーしたファイルのあったディレクトリなので心当たりは十分にあるはず。

# ls /
lib  usr

ここで、試しに proc ファイルシステムをマウントしてみよう。 ディレクトリを用意してマウントする。

# mkdir -p /proc
# mount -t proc proc /proc

すると、次のようにちゃんとマウントできる。 なお、chroot(8) ではルートディレクトリを切り替えるだけなので、PID (プロセス識別子) の名前空間はシステムと共有している。

# ls /proc
1     1448  178  240  474  501    623  70  80   98     diskstats    kallsyms     mdstat    schedstat  thread-self
10    1449  179  287  475  522    625  71  81   acpi    driver       kcore    meminfo   scsi       timer_list
104   1458  18    288  476  574   627  72  83   buddyinfo  execdomains  key-users    misc       self       tty
11    15    19    3    477  576   629  73  84   bus     fb       keys     modules   slabinfo   uptime
12    16    2     361  486  6 632  74  841  cgroups     filesystems  kmsg     mounts    softirqs   version
1359  1685  20    374  488  612   646  75  842  cmdline     fs       kpagecgroup  net       stat       version_signature
1376  1686  21    380  489  615   648  76  85   consoles    interrupts   kpagecount   pagetypeinfo  swaps      vmallocinfo
1377  1691  22    392  496  616   666  77  86   cpuinfo     iomem        kpageflags   partitions    sys        vmstat
14    17    23    4    499  621   673  78  9    crypto  ioports      loadavg      pressure  sysrq-trigger  zoneinfo
143   177   24    473  500  622   686  8     95   devices   irq          locks    sched_debug   sysvipc

確認が終わったらシェルを終了しよう。 これで chroot(8) を呼び出した元のプロセスに戻れる。

# exit
exit

Ubuntu 21.10 のルートファイルシステムに chroot(8) してみる

次は、試しに他の GNU/Linux ディストリビューションのルートファイルシステムを展開して chroot(8) してみよう。 今、システムとして使っているのが Ubuntu 20.04 LTS なので、Ubuntu 21.10 を使うことにした。

まずは Ubuntu 21.10 の、ルートファイルシステムをアーカイブしたファイルをダウンロードして展開する。 CPU の命令セットが違うとダウンロードするファイルが異なる点に注意する。

$ ISA=$(uname -m | sed -e "s/x86_64/amd64/" -e "s/aarch64/arm64/")
$ mkdir -p /tmp/ubuntu-impish-${ISA}
$ wget -O - https://cdimage.ubuntu.com/ubuntu-base/releases/21.10/release/ubuntu-base-21.10-base-${ISA}.tar.gz | tar zxvf - -C /tmp/ubuntu-impish-${ISA}

次のように /tmp 以下にファイルが展開された。

$ ls /tmp/ubuntu-impish-${ISA}/
bin  boot  dev  etc  home  lib  media  mnt  opt  proc  root  run  sbin  srv  sys  tmp  usr  var

ここで、展開されたディレクトリに対して chroot(8) してみよう。

$ sudo chroot /tmp/ubuntu-impish-${ISA} /usr/bin/bash

これで Ubuntu 21.10 のルートファイルシステムが、プロセスのルートディレクトリになった。 例えば /etc 以下にある lsb-release ファイルを表示すると Ubuntu 21.10 のものになっている。 bash のバージョンも Ubuntu 21.04 LTS の 5.0 系ではなく 5.1 系になっている。

# cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=21.10
DISTRIB_CODENAME=impish
DISTRIB_DESCRIPTION="Ubuntu 21.10"
# bash --version
GNU bash, version 5.1.8(1)-release (aarch64-unknown-linux-gnu)
Copyright (C) 2020 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>

This is free software; you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.

システムは Ubuntu 20.04 LTS なのに、なんだか Ubuntu 21.10 を使っているような気分になる。 一方で、uname(1) から得られるカーネルのバージョンは Ubuntu 20.04 LTS のまま。

# uname -r
5.4.0-104-generic

これは、単に chroot(8) でルートディレクトリを入れ替えているだけなので当たり前。 Linux コンテナ技術は基本的にカーネルを共有するので Docker などを使っていても、この点は変わらない 1

chroot(2) の動作を試す

続いてはシステムコールとしての chroot(2) の動作を試してみる。

以下のサンプルコードでは、第 1 引数で指定されたパスに chroot(2) した上で bash(1) を起動している。

#define _XOPEN_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>


int main(int argc, char *argv[]) {
    // 引数の長さをチェックする
    if (argc < 2) {
        fprintf(stderr, "Please specify the path to change root\n");
        exit(EXIT_FAILURE);
    }

    // chroot(2) したいディレクトリにカレントワーキングディレクトリを変更する
    if (chdir(argv[1]) != 0) {
        fprintf(stderr, "Failed to change directory: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // カレントワーキングディレクトリに chroot(2) する
    if (chroot(".") != 0) {
        fprintf(stderr, "Failed to change root: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // シェルを起動し直す
     char* const args[] = {"bash", NULL};
    if (execvp(args[0], args) != 0) {
        fprintf(stderr, "Failed to exec \"%s\": %s\n", args[0], strerror(errno));
        exit(EXIT_FAILURE);
    }

    return EXIT_SUCCESS;
}

上記をコンパイルする。

$ gcc --std=c11 --static -Wall chroot.c 
$ file a.out 
a.out: ELF 64-bit LSB executable, ARM aarch64, version 1 (GNU/Linux), statically linked, BuildID[sha1]=362e58fceadfa88e4ef8f7becdb06350922b9930, for GNU/Linux 3.7.0, not stripped

できたバイナリに第 1 引数として Ubuntu 21.10 のディレクトリを指定して実行する。

$ sudo ./a.out /tmp/ubuntu-impish-${ISA}/

すると、次のようにちゃんと Ubuntu 21.10 のルートファイルシステムがルートディレクトリになっている。 つまり chroot(8) を使ったときと同じ結果になった。

# ls /
bin  boot  dev  etc  home  lib  media  mnt  opt  proc  root  run  sbin  srv  sys  tmp  usr  var
# cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=21.10
DISTRIB_CODENAME=impish
DISTRIB_DESCRIPTION="Ubuntu 21.10"

ひとしきり確認したら環境から抜ける。

# exit

chroot した環境から脱獄 (jail break) してみる

実は chroot で隔離したファイルシステムは、プロセスに CAP_SYS_CHROOT のケーパビリティがあると脱獄できることが知られている。 これはあくまで chroot(2) の仕様であって、不具合や脆弱性ではないらしい。 では、実際に脱獄できるのか確かめてみよう。

以下にサンプルコードを示す。 このコードをビルドしたバイナリを chroot した環境で実行することで脱獄する。 コードでは "foo" という名前でディレクトリを作って、そこに chroot(2) している。 その上で chdir(2) を何度も呼び出して、その後でまたカレントワーキングディレクトリに対して chroot(2) している。 そして、最後に bash(1) を呼び出している。 やっていることは実にシンプル。

#define _XOPEN_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>


int main(int argc, char *argv[]) {
    // chroot した環境で実行されることを想定している

    // 適当にサブディレクトリを作る
    if (mkdir("foo", 755) != 0) {
        // すでに同名のパスがあるときはエラーを無視する
        if (errno != EEXIST) {
            fprintf(stderr, "Failed to create a new directory: %s\n", strerror(errno));
            exit(EXIT_FAILURE);
    }
    }

    // 作成したサブディレクトリに chroot(2) する
    // chroot(2) は pwd を変更しない
    // rootfs が pwd よりも下のディレクトリになる
    if (chroot("foo") != 0) {
        fprintf(stderr, "Failed to change root: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // chdir(2) は rootfs に到達するまで pwd から早退パスで移動できる
    // ただし、現状 rootfs は pwd よりも下にあるので決して到達しない
    // 元々のルートディレクトリまでさかのぼってしまう
    for (int i = 0; i < 1024; i++) {
        if (chdir("..") != 0) {
            fprintf(stderr, "Failed to change directory: %s\n", strerror(errno));
            exit(EXIT_FAILURE);
        }
    }

    // ルートディレクトリまでいってから chroot(2) すると脱獄できる
    if (chroot(".") != 0) {
        fprintf(stderr, "Failed to change root: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // 脱獄した上でシェルを起動し直す
    char* const args[] = {"bash", NULL};
    if (execvp(args[0], args) != 0) {
        fprintf(stderr, "Failed to exec \"%s\": %s\n", args[0], strerror(errno));
        exit(EXIT_FAILURE);
    }

    return EXIT_SUCCESS;
}

上記の概念的な説明は man 2 chroot に書かれているし、巷にもいくつか解説が見つかる。

man7.org

ざっくり説明すると、どうやら chroot(2) がプロセスのカレントワーキングディレクトリを変更しないところがポイントらしい。 サブディレクトリに chroot(2) すると、プロセスのルートディレクトリはサブディレクトリになるが、カレントワーキングディレクトリは元のまま変更されない。 つまり、カレントワーキングディレクトリよりもルートディレクトリの方が下位のディレクトリにあるという、なんだか変な状況になる。 そして、カレントワーキングディレクトリから相対パスで chdir(2) する場合、ルートディレクトリに至るまで上位のディレクトリに移動できるらしい。 しかし、ルートディレクトリはカレントワーキングディレクトリよりも下位にあるため、決してそこに至ることはなく本来の隔離される前のルートディレクトリまで到達してしまう。 そこで改めて chroot(2) すると、晴れてプロセスのルートディレクトリが変更されて脱獄成功、ということらしい。

理屈は分かったので、実際に試してみよう。 上記をコンパイルする。

$ gcc --std=c11 --static -Wall jailbreak.c

できたバイナリを、先ほど展開した Ubuntu 21.10 のルートファイルシステムに放り込む。

$ cp a.out /tmp/ubuntu-impish-${ISA}/

上記のディレクトリを指定して chroot(8) する。

$ sudo chroot /tmp/ubuntu-impish-${ISA} /usr/bin/bash

Ubuntu 21.10 のファイルシステムに隔離されたことを確認する。

# cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=21.10
DISTRIB_CODENAME=impish
DISTRIB_DESCRIPTION="Ubuntu 21.10"

ここでおもむろに先ほどコピーしたバイナリを実行してみる。

# /a.out

一見すると変化はないが /etc/lsb-release を確認すると隔離前のファイルシステムに参照できている。 つまり、脱獄できた。

# cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.4 LTS"

このような脱獄を防ぐには、根本的には chroot(2) の代わりに pivot_root(2) を使う必要があるようだ。

まとめ

今回は chroot をコマンドラインツールとシステムコールを使って試してみた。


  1. 念の為に補足しておくと、一般的な Linux コンテナ仮想化の実装ではデフォルトで chroot(2) ではなく pivot_root(2) が使われる

Python: Prophet で単変量の時系列予測を試す

Prophet は Meta (旧 Facebook) が中心となって開発している OSS の時系列予測フレームワーク。 目的変数のトレンド、季節性、イベントや外部説明変数を加味した時系列予測を簡単にできることが特徴として挙げられる。 使い所としては、精度はさほど追求しない代わりにとにかく手軽に予測がしたい、といった場面が考えられる。 また、扱うデータセットについても単変量に近いシンプルなものが得意そう。 なお、今回は扱うデータセットの都合からイベントや外部説明変数の追加に関しては扱わない。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 12.2.1
BuildVersion:   21D62
$ uname -srm
Darwin 21.3.0 arm64
$ python -V
Python 3.9.10
$ pip list | grep -i prophet                       
prophet                  1.0.1

もくじ

下準備

あらかじめ Prophet をインストールしておく。 その他に、データセットの読み込みなどに必要なパッケージもインストールしておく。

$ pip install prophet scikit-learn seaborn pmdarima

flights データセットで試してみる

まずは、航空機の旅客数を扱った有名な flights データセットで試してみる。

その前に、どういったデータかをグラフにプロットして確認しておく。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt


def main():
    # flights データセットを読み込む
    df = sns.load_dataset('flights')

    # カラムが年と月で分かれているのでマージする
    df['year-month'] = pd.to_datetime(df['year'].astype(str) + '-' + df['month'].astype(str),
                                      format='%Y-%b')

    # プロットする
    plt.plot(df['year-month'], df['passengers'])
    plt.show()


if __name__ == '__main__':
    main()

上記に適当な名前をつけて保存したら実行する。

$ python plotflights.py 

すると、次のような折れ線グラフが得られる。

f:id:momijiame:20220309181506p:plain
flightsデータセット

上記からトレンドや季節成分の存在が確認できる。

それでは、次に Prophet を使って予測してみよう。 以下のサンプルコードでは、データを時系列でホールドアウトして、末尾を Prophet で予測している。 Prophet のデフォルトでは、時系列のカラムを ds という名前で、目的変数のカラムを y という名前にすることになっている。 また、実際の値と予測をプロットしたものと、データをトレンドと季節成分に分離したものをグラフとしてプロットしている。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from prophet import Prophet


def main():
    # flights データセットを読み込む
    df = sns.load_dataset('flights')

    # カラムが年と月で分かれているのでマージする
    df['year-month'] = pd.to_datetime(df['year'].astype(str) + '-' + df['month'].astype(str),
                                      format='%Y-%b')

    # Prophet が仮定するカラム名に変更する
    # タイムスタンプ: ds
    # 目的変数: y
    rename_mappings = {
        'year-month': 'ds',
        'passengers': 'y',
    }
    df.rename(columns=rename_mappings,
              inplace=True)

    # 不要なカラムを落とす
    df.drop(['year', 'month'],
            axis=1,
            inplace=True)

    # 時系列の順序で学習・検証用データをホールアウトする
    train_df, eval_df = train_test_split(df,
                                         shuffle=False,
                                         random_state=42,
                                         test_size=0.3)

    # 学習用データを使って学習する
    m = Prophet()
    m.fit(train_df)

    # 検証用データを予測する
    forecast = m.predict(eval_df.drop(['y'],
                                      axis=1))

    # 真の値との誤差を MAE で求める
    mae = mean_absolute_error(forecast['yhat'],
                              eval_df['y'])
    print(f'MAE: {mae:.05f}')

    # 実際のデータと予測をプロットする
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(df['ds'], df['y'], color='y')
    m.plot(forecast, ax=ax)
    # トレンドと季節成分をプロットする
    m.plot_components(forecast)

    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。 実際の値と予測値の、ホールドアウトデータでの乖離を MAE で出力している。

$ python predflights.py

...

MAE: 35.21388

実際の値と予測をプロットしたグラフは次のとおり。 青い実線が予測値、薄い青色で示された範囲は 95% 信頼区間らしい。

f:id:momijiame:20220309181947p:plain
flightsデータセットの実測値と予測結果

トレンドはつかめているものの、実際の値よりも振幅は小さくなっていることがわかる。

トレンドと季節成分は次のように分離された。 Prophet は特に指定しない限り、季節成分を自動で検出してくれる。 以下では年次でのトレンドが自動的に検出されたことが確認できる。

f:id:momijiame:20220309182038p:plain
flightsデータセットのトレンドと季節成分

先ほどのモデルでは、振幅が小さいことで実際の値とのズレが大きくなってしまっていた。 これは、季節成分の計算がデフォルトで加算モードになっていたことが理由として考えられる。 つまり、時間が進むごとに目的変数が大きくなると共に振幅も大きくなることが上手く表現できていなかった。 そこで、次は季節成分の計算を乗法モードに変更してみる。 トレンド成分にかけ算で季節成分をのせてやれば、振幅がだんだんと大きくなっていく様子が表現できるはず。

以下のサンプルコードではモデルに seasonality_mode='multiplicative' を指定することで季節成分の計算を乗法モードにしている。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from prophet import Prophet


def main():
    df = sns.load_dataset('flights')

    df['year-month'] = pd.to_datetime(df['year'].astype(str) + '-' + df['month'].astype(str),
                                      format='%Y-%b')

    rename_mappings = {
        'year-month': 'ds',
        'passengers': 'y',
    }
    df.rename(columns=rename_mappings,
              inplace=True)

    df.drop(['year', 'month'],
            axis=1,
            inplace=True)

    train_df, eval_df = train_test_split(df,
                                         shuffle=False,
                                         random_state=42,
                                         test_size=0.3)

    # 季節成分の計算を加算モードから乗法モードに変更する
    m = Prophet(seasonality_mode='multiplicative')
    m.fit(train_df)

    forecast = m.predict(eval_df.drop(['y'],
                                      axis=1))

    mae = mean_absolute_error(forecast['yhat'],
                              eval_df['y'])
    print(f'MAE: {mae:.05f}')

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(df['ds'], df['y'], color='y')
    m.plot(forecast, ax=ax)
    m.plot_components(forecast)

    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。 かなり MAE が改善したことが確認できる。

$ python multiflights.py

...

MAE: 22.31301

実際の値と予測値をグラフで確認しても、次のように当てはまりが良くなっている。

f:id:momijiame:20220309182757p:plain
flightsデータセットの実測値と予測結果 (乗法モード)

wineind データセットで試してみる

もうひとつ、ワインの生産量を示すデータセット (wineind) で試してみよう。

先ほどと同じように、まずはデータセットを可視化する。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
from pmdarima import datasets
from matplotlib import pyplot as plt


def main():
    # wineind データセットを読み込む
    series = datasets.load_wineind(as_series=True)
    df = series.to_frame(name='bottles')

    # プロットする
    df.plot()
    plt.show()


if __name__ == '__main__':
    main()

上記を実行する。

$ python plotwineind.py

得られるグラフは次のとおり。 季節成分は確認できるものの、単調な増加トレンドがあるわけではないようだ。

f:id:momijiame:20220309183422p:plain
wineindデータセット

先ほどと同じように、データをホールドアウトして予測してみよう。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
from pmdarima import datasets
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from prophet import Prophet


def main():
    series = datasets.load_wineind(as_series=True)
    df = series.to_frame(name='bottles')

    df.reset_index(inplace=True)
    df['index'] = pd.to_datetime(df['index'],
                                 format='%b %Y')

    rename_mappings = {
        'index': 'ds',
        'bottles': 'y',
    }
    df.rename(columns=rename_mappings,
              inplace=True)

    train_df, eval_df = train_test_split(df,
                                         shuffle=False,
                                         random_state=42,
                                         test_size=0.3)

    m = Prophet(seasonality_mode='multiplicative')
    m.fit(train_df)

    forecast = m.predict(eval_df.drop(['y'],
                                      axis=1))

    mae = mean_absolute_error(forecast['yhat'],
                              eval_df['y'])
    print(f'MAE: {mae:.05f}')

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(df['ds'], df['y'], color='y')
    m.plot(forecast, ax=ax)
    m.plot_components(forecast)

    plt.show()


if __name__ == '__main__':
    main()

上記を実行する。

$ python predwineind.py

...

MAE: 2443.18595

得られた予測は次のとおり。 今回は、先ほどよりも実際の値と予測が一致していない。 中には実際の値が 95% 信頼区間の外に出てしまっているものもある。

f:id:momijiame:20220309183655p:plain
wineindデータセットの実測値と予測結果

トレンドと季節成分は次のとおり。 今度は単調な下降トレンドと認識されているようだ。 もちろん、これらの結果は学習させる範囲にも大きく依存する。

f:id:momijiame:20220309183724p:plain
wineindデータセットのトレンドと季節成分

まとめ

今回は Prophet を使って時系列の予測を試してみた。 ごくシンプルな時系列データで、なるべく簡単にトレンドや季節成分を加味した予測をしたいときには選択肢の一つとして考えられるかもしれない。

Lima を使って Apple Silicon 版の Mac で x86-64 (Intel on ARM) な仮想マシンを扱う

Apple Silicon 版の Mac を使っていても、依然として成果物をデプロイする先は ISA が x86-64 (amd64) のマシンであることが多い。 となると、どうしても x86-64 の環境を使って作業をしたい場面が出てくる。 もちろん、IaaS を利用してリモートにマシンを立ち上げれば良いんだけど、簡単な検証なら手元で手軽に済ませたい。 今回は、そんなニーズを埋めてくれるかもしれない Lima を使ってみる。

Lima を使うと、Apple Silicon 版の Mac 上で ISA が x86-64 の Linux 仮想マシンを手軽に立ち上げることができる 1。 ただし、バックエンドは QEMU のソフトウェアエミュレーション (qemu-system-x86_64) なので、ネイティブな環境に比べるとパフォーマンスは大きく劣る。

使った環境は次のとおり。

$ sw_vers   
ProductName:    macOS
ProductVersion: 12.2.1
BuildVersion:   21D62
$ uname -srm              
Darwin 21.3.0 arm64
$ lima --version          
limactl version 0.8.3

もくじ

下準備

あらかじめ、Homebrew で Lima をインストールする。

$ brew install lima

インストールすると limactl コマンドが使えるようになる。

$ limactl --version        
limactl version 0.8.3

仮想マシンを立ち上げる

Lima は YAML 形式の設定ファイルを元に仮想マシンを作成する。 以下のサンプルでは ISA が x86-64 で、ディストリビューションが Ubuntu 20.04 LTS の仮想マシンを定義している。

$ cat << 'EOF' > focal-amd64.yaml
arch: "x86_64"
images:
- location: "https://cloud-images.ubuntu.com/focal/current/focal-server-cloudimg-amd64.img"
  arch: "x86_64"
EOF

設定ファイルができたら limactl validate コマンドで形式が正しいかチェックしておこう。

$ limactl validate focal-amd64.yaml
INFO[0000] "focal-amd64.yaml": OK

上記の設定ファイルを元に limactl start コマンドで仮想マシンを起動する。 --tty=false オプションは、つけない場合に設定ファイルをエディタで編集した上で起動するようになる。

$ limactl start --tty=false focal-amd64.yaml

上記を実行すると、初回はイメージファイルのダウンロードや仮想マシンの作成とセットアップが走る。 環境にもよるけど、この作業には数分かかるので気長に待つ。

ちなみに設定ファイルの項目やデフォルトの値は以下を参照すると良い。

github.com

また、limactl start コマンドで YAML ファイルではなく単純に仮想マシンの名前を指定した場合にも、上記のデフォルトの設定ファイルをベースに仮想マシンが作られる。 ここでも、--tty=false をつけなければ、デフォルトの設定ファイルをインタラクティブに編集しながら仮想マシンが定義できる。 現在 (2022-02) のデフォルトのディストリビューションは Ubuntu 21.10 のようだ。

$ limactl start impish

仮想マシンが作成できると limactl list コマンドに確認できるようになる。

$ limactl list                              
NAME           STATUS     SSH                ARCH      CPUS    MEMORY    DISK      DIR
focal-amd64    Running    127.0.0.1:50191    x86_64    4       4GiB      100GiB    /Users/amedama/.lima/focal-amd64

仮想マシンを操作する

仮想マシンが起動したら limactl shell コマンドで仮想マシンにログインしてシェルが取れる。

$ limactl shell focal-amd64

ログインできたら uname -r コマンドで仮想マシンの ISA を確認してみよう。 ちゃんと x86_64 と表示されるはず。

$ uname -m
x86_64

そして、次のとおり仮想マシンが Ubuntu 20.04 LTS であることがわかる。

$ uname -sr
Linux 5.4.0-99-generic
$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.3 LTS
Release:    20.04
Codename:   focal

Lima には他にも色々と機能があるけど、とりあえず今回はそんな感じで。


  1. 反対に、Intel 版の Mac 上で ISA が ARM64 の Linux 仮想マシンを立ち上げることもできる

Python: xfeat を使った特徴量エンジニアリング

今回は PFN が公開している OSS の xfeat を使った特徴量エンジニアリングについて見ていく。 xfeat には次のような特徴がある。

  • 多くの機能が scikit-learn の Transformer 互換の API で提供されている
  • 多くの機能が CuPy / CuDF に対応しているため CUDA 環境で高いパフォーマンスが得られる
  • 多くの機能がデータフレームを入力としてデータフレームを出力とした API になっている

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 12.2.1
BuildVersion:   21D62
$ uname -srm
Darwin 21.3.0 arm64
$ conda -V                         
conda 4.10.3
$ python -V        
Python 3.9.10
$ pip list | grep xfeat
xfeat           0.1.1

もくじ

下準備

AppleSilicon 版の Mac を使う場合、Python の実行環境に Miniforge を使う。 これは xfeat が依存しているいくつかのパッケージが、まだ pip からインストールできないため。

blog.amedama.jp

あらかじめ Miniforge を使って仮想環境を作る。

$ conda create -y -n venv python=3.9
$ conda activate venv

pip からインストールできない依存パッケージの LightGBM と PyArrow をインストールする。 ついでに、今回のサンプルコードで使用する scikit-learn と seaborn も入れておく。

$ conda install -y lightgbm pyarrow scikit-learn seaborn

最後に xfeat をインストールする。 現時点 (2022-02-20)では xfeat が依存しているパッケージの ml_metrics が setuptools v58 以降の環境でインストールできない。 そこで setuptools を v58 未満にダウングレードする必要がある。

$ pip install -U "setuptools<58"
$ pip install xfeat

ちなみに、Intel 版の Mac であれば以下だけでインストールできる。

$ pip install -U "setuptools<58"
$ pip install xfeat scikit-learn seaborn

インストールが終わったら Python のインタプリタを起動する。

$ python

今回、データセットには seaborn に同梱されている diamonds をサンプルとして用いる。 このデータセットにはカテゴリ変数と連続変数の両方が含まれているので、特徴量エンジニアリングの説明に都合が良い。

>>> import seaborn as sns
>>> df = sns.load_dataset('diamonds')
>>> df.head()
   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
>>> df.dtypes
carat       float64
cut        category
color      category
clarity    category
depth       float64
table       float64
price         int64
x           float64
y           float64
z           float64
dtype: object

最後に xfeat をインポートしたら準備は終わり。

>>> import xfeat

特定の種類の変数を取り出す

まずはデータフレームから特定の種類の変数だけを取り出す機能から見ていく。 この機能を使うとデータの型に基づいてカテゴリ変数や連続変数のカラムだけを選択できる。 この機能は、後述する Pipeline の機能と組み合わせると便利に使える。

SelectCategorical

まず、SelectCategorical を使うとカテゴリ変数だけ取り出すことができる。

>>> select_cat = xfeat.SelectCategorical()
>>> select_cat.fit_transform(df)
             cut color clarity
0          Ideal     E     SI2
1        Premium     E     SI1
2           Good     E     VS1
3        Premium     I     VS2
4           Good     J     SI2
...          ...   ...     ...
53935      Ideal     D     SI1
53936       Good     D     SI1
53937  Very Good     D     SI1
53938    Premium     H     SI2
53939      Ideal     D     SI2

[53940 rows x 3 columns]

SelectNumerical

同様に SelectNumerical を使うと連続変数が取り出せる。

>>> select_num = xfeat.SelectNumerical()
>>> select_num.fit_transform(df)
       carat  depth  table  price     x     y     z
0       0.23   61.5   55.0    326  3.95  3.98  2.43
1       0.21   59.8   61.0    326  3.89  3.84  2.31
2       0.23   56.9   65.0    327  4.05  4.07  2.31
3       0.29   62.4   58.0    334  4.20  4.23  2.63
4       0.31   63.3   58.0    335  4.34  4.35  2.75
...      ...    ...    ...    ...   ...   ...   ...
53935   0.72   60.8   57.0   2757  5.75  5.76  3.50
53936   0.72   63.1   55.0   2757  5.69  5.75  3.61
53937   0.70   62.8   60.0   2757  5.66  5.68  3.56
53938   0.86   61.0   58.0   2757  6.15  6.12  3.74
53939   0.75   62.2   55.0   2757  5.83  5.87  3.64

[53940 rows x 7 columns]

カテゴリ変数のエンコード

次にカテゴリ変数のエンコードに使える機能を見ていこう。

LabelEncoder

LabelEncoder を使うとラベルエンコードができる。 scikit-learn の sklearn.preprocessing.LabelEncoder と比べるとデータフレームをそのまま入れられるメリットがある。 デフォルトでは、入力したデータフレームに _le というサフィックスがついたカラムが追加される 1

>>> label_encoder = xfeat.LabelEncoder()
>>> label_encoder.fit_transform(df[['cut']])
             cut  cut_le
0          Ideal       0
1        Premium       1
2           Good       2
3        Premium       1
4           Good       2
...          ...     ...
53935      Ideal       0
53936       Good       2
53937  Very Good       3
53938    Premium       1
53939      Ideal       0

[53940 rows x 2 columns]

元のカラムがいらないときは、output_suffix オプションに空文字を入れると上書きできる。

>>> label_encoder = xfeat.LabelEncoder(output_suffix='')
>>> label_encoder.fit_transform(df[['cut']])
       cut
0        0
1        1
2        2
3        1
4        2
...    ...
53935    0
53936    2
53937    3
53938    1
53939    0

[53940 rows x 1 columns]

また、複数のカラムが含まれるデータフレームを渡せば、複数のカテゴリ変数を一度にエンコードできる。

>>> label_encoder.fit_transform(df[['cut', 'color', 'clarity']])
       cut  color  clarity
0        0      0        0
1        1      0        1
2        2      0        2
3        1      1        3
4        2      2        0
...    ...    ...      ...
53935    0      6        1
53936    2      6        1
53937    3      6        1
53938    1      3        0
53939    0      6        0

[53940 rows x 3 columns]

この点は、前述した SelectCategoricalPipeline の機能を組み合わせると上手く動作してくれる。

>>> pipe = xfeat.Pipeline([
...     xfeat.SelectCategorical(),
...     xfeat.LabelEncoder(output_suffix=''),
... ])
>>> pipe.fit_transform(df)
       cut  color  clarity
0        0      0        0
1        1      0        1
2        2      0        2
3        1      1        3
4        2      2        0
...    ...    ...      ...
53935    0      6        1
53936    2      6        1
53937    3      6        1
53938    1      3        0
53939    0      6        0

[53940 rows x 3 columns]

未知のデータが含まれていた際の振る舞いは、デフォルトでは -1 が入る。

>>> label_encoder = xfeat.LabelEncoder()
>>> label_encoder.fit(df[['cut']])
>>> data = {
...     'cut': ['Ideal', 'Premium', 'Very Good', 'Good', 'Fair', 'Unknown', 'Unseen'],
... }
>>> import pandas as pd
>>> new_df = pd.DataFrame(data)
>>> label_encoder.transform(new_df)
         cut  cut_le
0      Ideal       0
1    Premium       1
2  Very Good       3
3       Good       2
4       Fair       4
5    Unknown      -1
6     Unseen      -1

この振る舞いは unseen オプションを使って変更できる。 たとえば n_unique を指定すると、これまでに出現したラベルの値から連続した値が割り当てられる。

>>> label_encoder = xfeat.LabelEncoder(unseen='n_unique')
>>> label_encoder.fit(df[['cut']])
>>> label_encoder.transform(new_df)
         cut  cut_le
0      Ideal       0
1    Premium       1
2  Very Good       3
3       Good       2
4       Fair       4
5    Unknown       5
6     Unseen       5

CountEncoder

同様に CountEncoder を使うとカウントエンコードができる。 これは同じ値がデータの中にいくつ含まれるかを特徴量として生成する。

>>> count_encoder = xfeat.CountEncoder()
>>> count_encoder.fit_transform(df[['cut']])
             cut  cut_ce
0          Ideal   21551
1        Premium   13791
2           Good    4906
3        Premium   13791
4           Good    4906
...          ...     ...
53935      Ideal   21551
53936       Good    4906
53937  Very Good   12082
53938    Premium   13791
53939      Ideal   21551

[53940 rows x 2 columns]

ConcatCombination

ConcatCombination では、変数の値を連結することで新しいカテゴリ変数を作り出せる。 drop_origin オプションを True に指定すると、元の特徴量を落としたデータフレームになる。 また、組み合わせる数は r オプションで指定する。

>>> concat_combi = xfeat.ConcatCombination(drop_origin=True, r=2)
>>> concat_combi.fit_transform(df[['cut', 'color', 'clarity']].astype(str))
      cutcolor_combi cutclarity_combi colorclarity_combi
0             IdealE         IdealSI2               ESI2
1           PremiumE       PremiumSI1               ESI1
2              GoodE          GoodVS1               EVS1
3           PremiumI       PremiumVS2               IVS2
4              GoodJ          GoodSI2               JSI2
...              ...              ...                ...
53935         IdealD         IdealSI1               DSI1
53936          GoodD          GoodSI1               DSI1
53937     Very GoodD     Very GoodSI1               DSI1
53938       PremiumH       PremiumSI2               HSI2
53939         IdealD         IdealSI2               DSI2

[53940 rows x 3 columns]

上記でカラムの型を str にキャストしているのは category 型のままだと例外になるため。

>>> concat_combi.fit_transform(df[['cut', 'color', 'clarity']])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
...
    raise TypeError(
TypeError: Cannot setitem on a Categorical with a new category (_NaN_), set the categories first

TargetEncoder

TargetEncoder を使うとターゲットエンコードができる。 データの分割方法は fold というオプションで指定できる。

>>> from sklearn.model_selection import KFold
>>> folds = KFold(n_splits=2, shuffle=False)
>>> target_encoder = xfeat.TargetEncoder(fold=folds, target_col='price')
>>> target_encoder.fit_transform(df[['cut', 'price']])
             cut  price       cut_te
0          Ideal    326  1561.627930
1        Premium    326  1949.366089
2           Good    327  1773.722046
3        Premium    334  1949.366089
4           Good    335  1773.722046
...          ...    ...          ...
53935      Ideal   2757  6113.638184
53936       Good   2757  5552.538574
53937  Very Good   2757  5893.533203
53938    Premium   2757  6663.661621
53939      Ideal   2757  6113.638184

[53940 rows x 3 columns]

なお、Target Encoding の詳細は下記のエントリに書いたことがある。

blog.amedama.jp

連続変数のエンコード

続いて連続変数のエンコードについて見ていく。

ArithmeticCombinations

ArithmeticCombinations を使うと、複数のカラムを四則演算するといった特徴量が計算できる。 たとえば、以下では 2 つのカラムを足し合わせた特徴量を生成している。

>>> add_combi = xfeat.ArithmeticCombinations(drop_origin=True, operator='+', r=2)
>>> add_combi.fit_transform(df[['x', 'y', 'z']])
       xy_combi  xz_combi  yz_combi
0          7.93      6.38      6.41
1          7.73      6.20      6.15
2          8.12      6.36      6.38
3          8.43      6.83      6.86
4          8.69      7.09      7.10
...         ...       ...       ...
53935     11.51      9.25      9.26
53936     11.44      9.30      9.36
53937     11.34      9.22      9.24
53938     12.27      9.89      9.86
53939     11.70      9.47      9.51

[53940 rows x 3 columns]

独自の加工をするエンコーダを作る

その他、LambdaEncoder を使うことで、自分で加工方法を定義したエンコーダを指定することもできる。 以下では例として値を 2 倍するエンコーダを作っている。

>>> double_encoder = xfeat.LambdaEncoder(lambda x: x * 2, drop_origin=False, output_suffix='_double')
>>> double_encoder.fit_transform(df[['x', 'y']])
          x     y  x_double  y_double
0      3.95  3.98      7.90      7.96
1      3.89  3.84      7.78      7.68
2      4.05  4.07      8.10      8.14
3      4.20  4.23      8.40      8.46
4      4.34  4.35      8.68      8.70
...     ...   ...       ...       ...
53935  5.75  5.76     11.50     11.52
53936  5.69  5.75     11.38     11.50
53937  5.66  5.68     11.32     11.36
53938  6.15  6.12     12.30     12.24
53939  5.83  5.87     11.66     11.74

[53940 rows x 4 columns]

集約特徴量

特定のカラムの値を Group By のキーにして要約統計量を計算するような特徴量は aggregation() 関数を使って計算できる。 この API は scikit-learn の Transformer 互換になっていない点に注意が必要。 以下では cut をキーにして、x, y, z の値にいくつかの統計量を計算している。

>>> df_agg, agg_cols = xfeat.aggregation(df,
...                                      group_key='cut',
...                                      group_values=['x', 'y', 'z'],
...                                      agg_methods=['sum', 'min', 'max', 'mean', 'median'],
...                                      )

結果はタプルで得られる。 最初の要素にはデータフレームが入っている。

>>> df_agg
       carat        cut color clarity  depth  table  ...  agg_mean_x_grpby_cut  agg_mean_y_grpby_cut  agg_mean_z_grpby_cut  agg_median_x_grpby_cut  agg_median_y_grpby_cut  agg_median_z_grpby_cut
0       0.23      Ideal     E     SI2   61.5   55.0  ...              5.507451              5.520080              3.401448                    5.25                    5.26                    3.23
1       0.21    Premium     E     SI1   59.8   61.0  ...              5.973887              5.944879              3.647124                    6.11                    6.06                    3.72
2       0.23       Good     E     VS1   56.9   65.0  ...              5.838785              5.850744              3.639507                    5.98                    5.99                    3.70
3       0.29    Premium     I     VS2   62.4   58.0  ...              5.973887              5.944879              3.647124                    6.11                    6.06                    3.72
4       0.31       Good     J     SI2   63.3   58.0  ...              5.838785              5.850744              3.639507                    5.98                    5.99                    3.70
...      ...        ...   ...     ...    ...    ...  ...                   ...                   ...                   ...                     ...                     ...                     ...
53935   0.72      Ideal     D     SI1   60.8   57.0  ...              5.507451              5.520080              3.401448                    5.25                    5.26                    3.23
53936   0.72       Good     D     SI1   63.1   55.0  ...              5.838785              5.850744              3.639507                    5.98                    5.99                    3.70
53937   0.70  Very Good     D     SI1   62.8   60.0  ...              5.740696              5.770026              3.559801                    5.74                    5.77                    3.56
53938   0.86    Premium     H     SI2   61.0   58.0  ...              5.973887              5.944879              3.647124                    6.11                    6.06                    3.72
53939   0.75      Ideal     D     SI2   62.2   55.0  ...              5.507451              5.520080              3.401448                    5.25                    5.26                    3.23

[53940 rows x 25 columns]

タプルの二番目の要素には、生成されたカラム名の入ったリストが入っている。

>>> from pprint import pprint
>>> pprint(agg_cols)
['agg_sum_x_grpby_cut',
 'agg_sum_y_grpby_cut',
 'agg_sum_z_grpby_cut',
 'agg_min_x_grpby_cut',
 'agg_min_y_grpby_cut',
 'agg_min_z_grpby_cut',
 'agg_max_x_grpby_cut',
 'agg_max_y_grpby_cut',
 'agg_max_z_grpby_cut',
 'agg_mean_x_grpby_cut',
 'agg_mean_y_grpby_cut',
 'agg_mean_z_grpby_cut',
 'agg_median_x_grpby_cut',
 'agg_median_y_grpby_cut',
 'agg_median_z_grpby_cut']

特徴量選択

ここまでは、主に特徴量エンジニアリングの中でも特徴量抽出 (Feature Extraction) の機能を見てきた。 ここからは特徴量選択 (Feature Selection) の機能を見ていく。

DuplicatedFeatureEliminator

DuplicatedFeatureEliminator を使うと、重複した特徴量を削除できる。 たとえば、次のようにまったく同じ値の入ったカラムが xx2 として含まれるデータフレームがあるとする。

>>> new_df = df[['x']].copy()
>>> new_df['x2'] = df['x']
>>> new_df
          x    x2
0      3.95  3.95
1      3.89  3.89
2      4.05  4.05
3      4.20  4.20
4      4.34  4.34
...     ...   ...
53935  5.75  5.75
53936  5.69  5.69
53937  5.66  5.66
53938  6.15  6.15
53939  5.83  5.83

[53940 rows x 2 columns]

重複した特徴量は、どちらかさえあれば予測には十分なはず。 DuplicatedFeatureEliminator を使うと、片方だけ残して特徴量を削除できる。

>>> dup_eliminator = xfeat.DuplicatedFeatureEliminator()
>>> dup_eliminator.fit_transform(new_df)
          x
0      3.95
1      3.89
2      4.05
3      4.20
4      4.34
...     ...
53935  5.75
53936  5.69
53937  5.66
53938  6.15
53939  5.83

[53940 rows x 1 columns]

ConstantFeatureEliminator

同様に ConstantFeatureEliminator を使うと定数になっている特徴量を削除できる。 たとえば、すべての値が 1 になっている a というカラムの入ったデータフレームを用意する。

>>> new_df = df[['x']].copy()
>>> new_df['a'] = 1
>>> new_df
          x  a
0      3.95  1
1      3.89  1
2      4.05  1
3      4.20  1
4      4.34  1
...     ... ..
53935  5.75  1
53936  5.69  1
53937  5.66  1
53938  6.15  1
53939  5.83  1

[53940 rows x 2 columns]

分散のない特徴量は予測に寄与しないはず。 DuplicatedFeatureEliminator を使うと、そのような特徴量を削除できる。

>>> const_eliminator = xfeat.ConstantFeatureEliminator()
>>> const_eliminator.fit_transform(new_df)
          x
0      3.95
1      3.89
2      4.05
3      4.20
4      4.34
...     ...
53935  5.75
53936  5.69
53937  5.66
53938  6.15
53939  5.83

[53940 rows x 1 columns]

SpearmanCorrelationEliminator

SpearmanCorrelationEliminator を使うと、高い相関を持った特徴量を削除できる。 たとえば、あるカラムに定数を加えただけのカラムを含んだデータフレームを用意する。

>>> new_df = df[['x']].copy()
>>> new_df['x2'] = df['x'] + 0.1
>>> new_df
          x    x2
0      3.95  4.05
1      3.89  3.99
2      4.05  4.15
3      4.20  4.30
4      4.34  4.44
...     ...   ...
53935  5.75  5.85
53936  5.69  5.79
53937  5.66  5.76
53938  6.15  6.25
53939  5.83  5.93

[53940 rows x 2 columns]

上記の特徴量は相関係数が 1.0 になっている。

>>> new_df.corr()
      x   x2
x   1.0  1.0
x2  1.0  1.0

極端に相関係数の高い特徴量も、予測においては片方があれば十分と考えられる。 SpearmanCorrelationEliminator を使うと片方だけ残して削除できる。

>>> corr_eliminator = xfeat.SpearmanCorrelationEliminator()
>>> corr_eliminator.fit_transform(new_df)
          x
0      3.95
1      3.89
2      4.05
3      4.20
4      4.34
...     ...
53935  5.75
53936  5.69
53937  5.66
53938  6.15
53939  5.83

[53940 rows x 1 columns]

GBDTFeatureSelector

GBDTFeatureSelector を使うと GBDT (Gradient Boosting Decision Tree) を用いて、特徴量の重要度に基づいた特徴量選択ができる。 なお、ここでいう GBDT としては LightGBM が使われている。

まず、LightGBM はカテゴリ変数をそのままだと受け付けないので、一旦ラベルエンコードしておく。

>>> pipe = xfeat.Pipeline([
...     xfeat.LabelEncoder(input_cols=['cut', 'color', 'clarity'],
...                        output_suffix=''),
... ])
>>> df = pipe.fit_transform(df)

ここでは threshold オプションに 0.5 を指定することで、重要と考えられる特徴量を 50% 残してみよう。 この値はハイパーパラメータなので、実際にはいくつかの値を試して予測精度や計算量のバランスを取っていく必要がある

>>> lgbm_params = {
...     'objective': 'regression',
...     'metric': 'rmse',
...     'verbosity': -1,
... }
>>> lgbm_fit_params = {
...     'num_boost_round': 1_000,
... }
>>> gbdt_selector = xfeat.GBDTFeatureSelector(target_col='price',
...                                           threshold=0.5,
...                                           lgbm_params=lgbm_params,
...                                           lgbm_fit_kwargs=lgbm_fit_params,
...                                           )
>>> selected_df = gbdt_selector.fit_transform(df)

上記を実行すると carat, y, depth, z という 4 つのカラムが選ばれた。 これらが、GBDT で予測する場合には重要となる特徴量の上位 50% ということ。

>>> selected_df
       carat     y  depth     z
0       0.23  3.98   61.5  2.43
1       0.21  3.84   59.8  2.31
2       0.23  4.07   56.9  2.31
3       0.29  4.23   62.4  2.63
4       0.31  4.35   63.3  2.75
...      ...   ...    ...   ...
53935   0.72  5.76   60.8  3.50
53936   0.72  5.75   63.1  3.61
53937   0.70  5.68   62.8  3.56
53938   0.86  6.12   61.0  3.74
53939   0.75  5.87   62.2  3.64

[53940 rows x 4 columns]

GBDTFeatureExplorer

GBDTFeatureSelector を使う場合、残す特徴量の割合を threshold オプションとして自分で指定する必要があった。 予測精度が最も高いものが欲しい場合であれば、GBDTFeatureExplorer を使うことで自動で探索させることもできる。

この機能は少し複雑なのでスクリプトにした。 以下にサンプルコードを示す。 以下では diamonds データセットを使って、price カラムを目的変数に RMSE のメトリックで回帰問題として解いている。 最初に xfeat を使って特徴量抽出をしており、機械的に 265 次元まで増やしている。 そして、予測精度が最も良くなるように GBDTFeatureExplorer を使って特徴量選択している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from functools import partial

import pandas as pd
import numpy as np
import seaborn as sns
import xfeat
import lightgbm as lgb
import optuna
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split


def _lgbm_cv(train_x, train_y):
    """LightGBM を使った交差検証のヘルパー関数"""
    lgbm_params = {
        'objective': 'regression',
        'metric': 'rmse',
        'verbosity': -1,
    }
    train_dataset = lgb.Dataset(data=train_x,
                                label=train_y)
    folds = KFold(n_splits=5,
                  shuffle=True,
                  random_state=42)
    cv_result = lgb.cv(lgbm_params,
                       train_dataset,
                       num_boost_round=1_000,
                       folds=folds,
                       return_cvbooster=True,
                       )
    return cv_result


def _rmse(y_true, y_pred):
    """RMSE を計算するヘルパー関数"""
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    return rmse


def _evaluate(train_x, train_y, test_x, test_y):
    """学習用データの CV とテストデータの誤差を確認するヘルパー関数"""
    cv_result = _lgbm_cv(train_x, train_y)

    cv_rmse_mean = cv_result['rmse-mean'][-1]
    print(f'CV RMSE: {cv_rmse_mean}')

    cvbooster = cv_result['cvbooster']
    y_preds = cvbooster.predict(test_x)
    y_pred = np.mean(y_preds, axis=0)
    test_rmse = _rmse(y_pred, test_y)
    print(f'Test RMSE: {test_rmse}')


def objective(df, selector, trial):
    """Optuna の目的関数"""
    # 次に試行する特徴量の組み合わせを得る
    selector.set_trial(trial)
    selector.fit(df)
    input_cols = selector.get_selected_cols()

    # 選択された特徴量から得られる Local CV のスコアを計算する
    train_x = df[input_cols].drop(['price'], axis=1)
    train_y = df['price']
    cv_result = _lgbm_cv(train_x, train_y)
    # スコアの平均を返す
    mean_score = cv_result['rmse-mean'][-1]
    return mean_score


def main():
    # データセットを読み込む
    df = sns.load_dataset('diamonds')

    # ベースとなるカラム毎の変数の種類
    categorical_cols = ['cut', 'color', 'clarity']
    numerical_cols = ['carat', 'depth', 'table', 'x', 'y', 'z']
    target_col = 'price'

    # fillna の問題があるので str にキャストする
    df = df.astype({
        cat_col: str for cat_col in categorical_cols
    })

    # カテゴリ変数の前処理
    pipe = xfeat.Pipeline([
        # カテゴリ同士の組み合わせ
        xfeat.ConcatCombination(r=2),
        # ラベルエンコード
        xfeat.LabelEncoder(output_suffix=''),
    ])
    cat_df = pipe.fit_transform(df[categorical_cols])
    # カテゴリ変数のリストを更新する
    categorical_cols = cat_df.columns.tolist()
    # 元のデータフレームと結合する
    df = pd.concat([cat_df, df[numerical_cols + [target_col]]], axis=1)
    print(f'add combination features: {len(df.columns)}')

    # カテゴリ変数を中心にした集約特徴量
    for cat_col in categorical_cols:
        df, _ = xfeat.aggregation(df,
                                  group_key=cat_col,
                                  group_values=numerical_cols,
                                  agg_methods=[
                                      'sum',
                                      'min',
                                      'max',
                                      'mean',
                                      'median',
                                  ],
                                  )
    print(f'add aggregation features: {len(df.columns)}')

    # 最終的な評価をするためにデータをホールドアウトしておく
    train_df, test_df = train_test_split(df,
                                         test_size=0.35,
                                         shuffle=True,
                                         random_state=42)

    folds = KFold(n_splits=5,
                  shuffle=True,
                  random_state=42)
    pipe = xfeat.Pipeline([
        # カウントエンコード
        xfeat.CountEncoder(input_cols=categorical_cols),
        # ターゲットエンコード
        xfeat.TargetEncoder(input_cols=categorical_cols,
                            target_col=target_col,
                            fold=folds),
        # 組み合わせ特徴量
        xfeat.ArithmeticCombinations(input_cols=numerical_cols,
                                     operator='+',
                                     r=2,
                                     output_suffix='_plus'),
        xfeat.ArithmeticCombinations(input_cols=numerical_cols,
                                     operator='*',
                                     r=2,
                                     output_suffix='_mul'),
        xfeat.ArithmeticCombinations(input_cols=numerical_cols,
                                     operator='-',
                                     r=2,
                                     output_suffix='_minus'),
        xfeat.ArithmeticCombinations(input_cols=numerical_cols,
                                     operator='/',
                                     r=2,
                                     output_suffix='_div'),
    ])
    train_df = pipe.fit_transform(train_df)
    test_df = pipe.transform(test_df)
    print(f'add some features: {len(train_df.columns)}')

    # 選択前のスコアを計算しておく
    train_x, train_y = train_df.drop(target_col, axis=1), train_df[target_col]
    test_x, test_y = test_df.drop(target_col, axis=1), test_df[target_col]
    _evaluate(train_x, train_y, test_x, test_y)

    # 学習用データセットを使って特徴量選択をする
    lgbm_params = {
        'objective': 'regression',
        'metric': 'rmse',
        'verbosity': -1,
    }
    fit_params = {
        'num_boost_round': 1_000,
    }
    selector = xfeat.GBDTFeatureExplorer(input_cols=train_df.columns.tolist(),
                                         target_col=target_col,
                                         fit_once=True,
                                         threshold_range=(0.1, 1.0),
                                         lgbm_params=lgbm_params,
                                         lgbm_fit_kwargs=fit_params,
                                         )

    # メトリックのスコアが良くなる特徴量の組み合わせを探索する
    study = optuna.create_study(direction='minimize')
    # 最適化する
    study.optimize(partial(objective, train_df, selector),
                   n_trials=10,
                   )

    # 探索で見つかった最善の組み合わせを取り出す
    selector.from_trial(study.best_trial)
    selected_cols = selector.get_selected_cols()

    # 特徴量の数をどれだけ減らせたか
    print(f'selected features: {len(selected_cols)}')

    # 選択後のスコアを計算する
    train_x = train_df[selected_cols].drop(target_col, axis=1)
    test_x = test_df[selected_cols].drop(target_col, axis=1)
    _evaluate(train_x, train_y, test_x, test_y)


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python example.py 
add combination features: 13
add aggregation features: 193
add some features: 265
CV RMSE: 544.5696268027966
Test RMSE: 515.6280330475462
[I 2022-02-20 18:44:37,774] A new study created in memory with name: no-name-5d80802a-6e8c-48b1-bf1e-f2843ca3eadd
[I 2022-02-20 18:45:11,947] Trial 0 finished with value: 545.1447126832893 and parameters: {'GBDTFeatureSelector.threshold': 0.8839849802341044}. Best is trial 0 with value: 545.1447126832893.
[I 2022-02-20 18:45:39,267] Trial 1 finished with value: 546.669714703891 and parameters: {'GBDTFeatureSelector.threshold': 0.7619871719011998}. Best is trial 0 with value: 545.1447126832893.
[I 2022-02-20 18:45:56,204] Trial 2 finished with value: 539.8527700066101 and parameters: {'GBDTFeatureSelector.threshold': 0.4415951458583348}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:46:22,874] Trial 3 finished with value: 544.7853629837771 and parameters: {'GBDTFeatureSelector.threshold': 0.7281870389232026}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:46:33,338] Trial 4 finished with value: 614.75354170118 and parameters: {'GBDTFeatureSelector.threshold': 0.16941573919931688}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:46:47,241] Trial 5 finished with value: 558.923413959607 and parameters: {'GBDTFeatureSelector.threshold': 0.26183305989448025}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:47:09,066] Trial 6 finished with value: 541.8006025367876 and parameters: {'GBDTFeatureSelector.threshold': 0.5902889999794974}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:47:22,243] Trial 7 finished with value: 545.7942530448582 and parameters: {'GBDTFeatureSelector.threshold': 0.30621508734500646}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:47:35,899] Trial 8 finished with value: 543.8972261792727 and parameters: {'GBDTFeatureSelector.threshold': 0.344652063624768}. Best is trial 2 with value: 539.8527700066101.
[I 2022-02-20 18:48:02,356] Trial 9 finished with value: 546.669714703891 and parameters: {'GBDTFeatureSelector.threshold': 0.8269956186739581}. Best is trial 2 with value: 539.8527700066101.
selected features: 219
CV RMSE: 546.669714703891
Test RMSE: 515.1565138588921

上記では、元の 265 次元から 219 次元まで特徴量が削減されている。 特徴量が削減されると、計算量が削減できることから一回の実験にかかる時間も減らすことができる。 一方で、予測精度についてはホールドアウトしたテストデータに対してはほとんど変化していない。 この結果は、GBDT の場合は予測にあまり寄与しない特徴量が含まれていても、さほど予測性能に悪影響を及ぼさないという経験則と一致している。

まとめ

今回は xfeat を使った特徴量エンジニアリングのやり方について見てきた。


  1. 内部的にデータフレームはコピーされるため元のデータフレームが変更されるわけではない

dbt (data build tool) を使ってデータをテストする

ソフトウェアエンジニアリングの世界では、自動化されたテストを使ってコードの振る舞いを検証するのが当たり前になっている。 同じように、データエンジニアリングの世界でも、自動化されたテストを使ってデータの振る舞いを検証するのが望ましい。

データをテストするのに使える OSS のフレームワークも、いくつか存在する。 今回は、その中でも dbt (data build tool) を使ってデータをテストする方法について見ていく。 dbt 自体はデータのテストを主目的としたツールではないものの、テストに関する機能も備えている。

また、dbt には WebUI を備えたマネージドサービスとしての dbt Cloud と、CLI で操作するスタンドアロン版の dbt Core がある。 今回扱うのは後者の dbt Core になる。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 12.2
BuildVersion:   21D49
$ uname -srm                       
Darwin 21.3.0 arm64
$ python -V          
Python 3.9.10
$ pip list | grep dbt
dbt-core                 1.0.1
dbt-extractor            0.4.0
dbt-postgres             1.0.1

もくじ

下準備

今回は、公開されているデータセットをローカルの PostgreSQL に取り込んで、それを dbt でテストしていく。

PostgreSQL をセットアップする

まずは Homebrew を使って PostgreSQL をインストールしよう。 また、データセットをダウンロードするために wget も入れておく。

$ brew install postgresql wget

PostgreSQL のサービスを起動する。

$ brew services start postgresql
 brew services list            
Name       Status  User    File
postgresql started amedama ~/Library/LaunchAgents/homebrew.mxcl.postgresql.plist

データセットとしては seaborn が利用している taxis データセットを使う。 ダウンロードして /tmp に保存する。

$ wget https://raw.githubusercontent.com/mwaskom/seaborn-data/master/taxis.csv -P /tmp

上記のデータセットに合う形でテーブルの定義を作る。

$ cat << 'EOF' | psql -d postgres
CREATE TABLE IF NOT EXISTS public.taxis (
  id SERIAL NOT NULL,
  pickup TIMESTAMP NOT NULL,
  dropoff TIMESTAMP NOT NULL,
  passengers INT NOT NULL,
  distance FLOAT NOT NULL,
  fare FLOAT NOT NULL,
  tip FLOAT NOT NULL,
  tolls FLOAT NOT NULL,
  total FLOAT NOT NULL,
  color TEXT NOT NULL,
  payment TEXT,
  pickup_zone TEXT,
  dropoff_zone TEXT,
  pickup_borough TEXT,
  dropoff_borough TEXT
);
EOF

先ほどダウンロードした CSV ファイルの内容を上記のテーブルに取り込む。

$ cat << 'EOF' | psql -d postgres
COPY public.taxis (
  pickup,
  dropoff,
  passengers,
  distance,
  fare,
  tip,
  tolls,
  total,
  color,
  payment,
  pickup_zone,
  dropoff_zone,
  pickup_borough,
  dropoff_borough
)
FROM '/tmp/taxis.csv'
WITH (
  FORMAT csv,
  HEADER true
)
EOF

また、データベースを操作するためのユーザ (ROLE) を alice という名前で追加しておく。

$ cat << 'EOF' | psql -d postgres
CREATE ROLE
  alice
WITH
  LOGIN
  PASSWORD 'wonderland'
EOF

ちなみに、ユーザにパスワードは設定してるけど、実は無くても問題はない。 Homebrew でインストールした場合、デフォルトでローカルからの接続が trust になっているため。 パスワードなしでも接続できる。

$ cat /opt/homebrew/var/postgres/pg_hba.conf | sed -e "/^#/d" -e "/^$/d"
local   all             all                                     trust
host    all             all             127.0.0.1/32            trust
host    all             all             ::1/128                 trust
local   replication     all                                     trust
host    replication     all             127.0.0.1/32            trust
host    replication     all             ::1/128                 trust

追加したユーザにテーブルを操作する権限をつけておく。

$ cat << 'EOF' | psql -d postgres
GRANT
  ALL PRIVILEGES
ON
  TABLE public.taxis
TO
  alice
EOF

次のように、ユーザとパスワードを指定してデータを見られることを確認しておく。

$ echo "SELECT * FROM public.taxis LIMIT 5" | psql -d postgres --user alice --password             
Password: 
 id |       pickup        |       dropoff       | passengers | distance | fare | tip  | tolls | total | color  |   payment   |      pickup_zone      |     dropoff_zone      | pickup_borough | dropoff_borough 
----+---------------------+---------------------+------------+----------+------+------+-------+-------+--------+-------------+-----------------------+-----------------------+----------------+-----------------
  1 | 2019-03-23 20:21:09 | 2019-03-23 20:27:24 |          1 |      1.6 |    7 | 2.15 |     0 | 12.95 | yellow | credit card | Lenox Hill West       | UN/Turtle Bay South   | Manhattan      | Manhattan
  2 | 2019-03-04 16:11:55 | 2019-03-04 16:19:00 |          1 |     0.79 |    5 |    0 |     0 |   9.3 | yellow | cash        | Upper West Side South | Upper West Side South | Manhattan      | Manhattan
  3 | 2019-03-27 17:53:01 | 2019-03-27 18:00:25 |          1 |     1.37 |  7.5 | 2.36 |     0 | 14.16 | yellow | credit card | Alphabet City         | West Village          | Manhattan      | Manhattan
  4 | 2019-03-10 01:23:59 | 2019-03-10 01:49:51 |          1 |      7.7 |   27 | 6.15 |     0 | 36.95 | yellow | credit card | Hudson Sq             | Yorkville West        | Manhattan      | Manhattan
  5 | 2019-03-30 13:27:42 | 2019-03-30 13:37:14 |          3 |     2.16 |    9 |  1.1 |     0 |  13.4 | yellow | credit card | Midtown East          | Yorkville West        | Manhattan      | Manhattan
(5 rows)

これでデータベースの準備は整った。

dbt をインストールする

続いては肝心の dbt をインストールする。 dbt は Python で開発されているので、pip を使ってインストールできる。

dbt では、接続するデータベースごとにアダプタと呼ばれるパッケージを切り替えて対応する。 たとえば PostgreSQL なら dbt-postgres というアダプタを使えば良い。 これが、もしバックエンドに BigQuery を使うなら dbt-bigquery を使うことになる。 いずれも、依存関係に本体の dbt-core が入っているので一緒にインストールできる。

$ pip install dbt-postgres

インストールすると dbt コマンドが使えるようになる。

$ dbt --version                                  
installed version: 1.0.1
   latest version: 1.0.1

Up to date!

Plugins:
  - postgres: 1.0.1

これで必要な下準備がすべて整った。

dbt からデータベースに接続する

まずは dbt からデータベースに接続する部分を確認していく。

dbt では dbt_project.yml という設定ファイルが必要になるので、まずは作る。 name はプロジェクト名、version はプロジェクトのバージョン番号、config-version は YAML のコンフィグ形式のバージョン番号を表している。 profile というのは、dbt がデータベースに接続するやり方のことをプロファイルと呼んでいて、それにつけた名前のこと。

$ cat << 'EOF' >> dbt_project.yml
name: taxis
version: 0.0.1
config-version: 2
profile: postgres_taxis
EOF

ちなみに dbt init コマンドを使えば一通りの設定ファイルとディレクトリ構造の入ったボイラープレートを展開することもできる。 ここらへんは好みで。

$ dbt init

次は、上記の設定ファイルで指定した postgres_taxis という名前のプロファイルを用意しよう。 データベースへの接続方法は、パスワードなど秘匿しておきたい情報も多い。 そのため、デフォルトではプロジェクトのディレクトリとは分離して、ホームディレクトリ以下を読むことになっている 1

~/.dbt というディレクトリに profiles.yml という名前で YAML ファイルを作る。 そして、postgres_taxis というプロファイルを定義する。 その下にある target では、デフォルトで使う接続先の環境を指定している。 これは、同じ用途のデータベースであっても、一般的には役割によって複数の環境を用意することになるため。 たとえば、本番 (prod)、ステージング (stg)、開発 (dev) といったように。 その中で、デフォルトで使用するものを指定している。 なお、接続先はコマンドラインオプションの --target で切り替えることができる。 target に指定した名前は、outputs 以下にある設定と対応する。 ここでは local という名前で設定した。 typepostgres を指定することで、アダプタとして dbt-postgres の実装が使われることになる。

$ mkdir -p ~/.dbt
$ cat << 'EOF' > ~/.dbt/profiles.yml                                             
postgres_taxis:
  target: local
  outputs:
    local:
      type: postgres
      threads: 1
      host: localhost
      port: 5432
      user: alice
      pass: wonderland
      dbname: postgres
      schema: dbt_alice
EOF

ちなみに、PostgreSQL ではテーブルの階層構造が <database>.<schema>.<table> という 3 層構造になっている。 上記で dbnameschema を指定しているのは、このため。 ただし、上記で指定している schema は、このプロファイルで使う作業用のスキーマの名前になっている。 つまり、アクセスできるスキーマがこれに限られるわけではない。

さて、プロファイルができたら、まずはデータベースへの接続が上手くいくことを確認しよう。 これには dbt debug コマンドを使う。

$ dbt debug
15:14:54  Running with dbt=1.0.1
dbt version: 1.0.1
python version: 3.9.10
python path: /Users/amedama/.virtualenvs/py39/bin/python
os info: macOS-12.2-arm64-arm-64bit
Using profiles.yml file at /Users/amedama/.dbt/profiles.yml
Using dbt_project.yml file at /Users/amedama/Documents/temporary/dbt-example/dbt_project.yml

Configuration:
  profiles.yml file [OK found and valid]
  dbt_project.yml file [OK found and valid]

Required dependencies:
 - git [OK found]

Connection:
  host: localhost
  port: 5432
  user: alice
  database: postgres
  schema: dbt_alice
  search_path: None
  keepalives_idle: 0
  sslmode: None
  Connection test: [OK connection ok]

All checks passed!

どうやら、無事に接続できたようだ。 データベースに接続できることを確認するのも、ある意味で健全性のテストと言えるかもしれない。

source freshness をテストする

さて、データベースに接続できることが分かった。 次の一手としては、すでにデータベースに取り込まれているテーブルをソース (source) として定義する。 ソースを定義しておくと、別の場所から参照できたり、それに対してテストが書けたりする。

テストの観点は色々とあるけど、まずは source freshness を確認してみよう。 これは、ソースの特定のカラムに含まれる最新のタイムスタンプが、現在時刻からどれくらい離れているかを検証するもの。 たとえば DWH へのデータの取り込みが何らかの理由で遅延したり、あるいは停止しているのを見つけるのに利用できる。

ソースを定義するには models というディレクトリを作って、そこに YAML の設定ファイルを追加する。 さっきも似たような作業があったと思うけど、基本的に dbt はユーザから見える部分のほとんどが YAML と SQL で成り立っている。 sources 以下に name でスキーマを指定して、その下に tables でテーブルを指定する。 以下であれば postgres.public.taxis という階層構造のテーブルを定義していることになる。 そして、その下にある freshnessloaded_at_field という項目で source freshness の設定をする。

$ mkdir -p models
$ cat << 'EOF' > models/taxis.yml
version: 2

sources:
  - name: public
    tables:
      - name: taxis
        freshness:
          warn_after:
            count: 1
            period: hour
          error_after:
            count: 1
            period: day
        loaded_at_field: pickup::timestamp
EOF

上記の設定では、pickup というカラムの最新の時刻が現在時刻から 1h 以上離れると警告、1d 以上離れるとエラーになる。 カラムの時刻は UTC を基準にする点に注意が必要。 つまり、JST を使っている場合には、UTC に変換する必要がある 2

設定できたら dbt source freshness コマンドを実行しよう。 これで、ソースに含まれる最新のタイムスタンプと現在時刻が比較される。

$ dbt source freshness
15:16:57  Running with dbt=1.0.1
15:16:57  Partial parse save file not found. Starting full parse.
15:16:57  Found 0 models, 0 tests, 0 snapshots, 0 analyses, 165 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
15:16:57  
15:16:57  Concurrency: 1 threads (target='local')
15:16:57  
15:16:57  1 of 1 START freshness of public.taxis.......................................... [RUN]
15:16:57  1 of 1 ERROR STALE freshness of public.taxis.................................... [ERROR STALE in 0.02s]
15:16:57  
15:16:57  Done.

当たり前だけど、実行は失敗する。 サンプルデータのタイムスタンプは、最新のレコードでも 2019 年になっている。 1h とか 1d なんて単位ではない離れ方をしている。

$ echo "SELECT MAX(pickup), NOW() FROM public.taxis" | psql -d postgres

         max         |              now              
---------------------+-------------------------------
 2019-03-31 23:43:45 | 2022-02-04 00:26:35.521179+09
(1 row)

試しに現在時刻との乖離が小さなデータを追加してみよう。 時刻の計算に GNU date を使いたいので Homebrew を使って coreutils をインストールする。

$ brew install coreutils

pickup が現在時刻から 2h 前のレコードを追加する。 これなら実行したときエラーではなく警告になるはずだ。 他のカラムについては適当な値で埋めた。

$ PICKUP=$(TZ=UTC gdate "+%Y-%m-%d %H:%M:%S" --date '2 hour ago')
$ DROPOFF=$(TZ=UTC gdate "+%Y-%m-%d %H:%M:%S" --date '1 hour ago')
$ cat << EOF | psql -d postgres
INSERT INTO taxis (
  pickup,
  dropoff,
  passengers,
  distance,
  fare,
  tip,
  tolls,
  total,
  color,
  payment,
  pickup_zone,
  dropoff_zone,
  pickup_borough,
  dropoff_borough
) VALUES (
  '${PICKUP}',
  '${DROPOFF}',
  1,
  1.5,
  7,
  2.0,
  0,
  12.5,
  'yellow',
  'credit card',
  'Lenox Hill West',
  'UN/Turtle Bay South',
  'Manhattan',
  'Manhattan'
);
EOF

先ほどと同じコマンドを実行すると、今度はたしかに警告 (WARN) に変わっている。

$ dbt source freshness
15:31:35  Running with dbt=1.0.1
15:31:35  Found 0 models, 0 tests, 0 snapshots, 0 analyses, 165 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
15:31:35  
15:31:35  Concurrency: 1 threads (target='local')
15:31:35  
15:31:35  1 of 1 START freshness of public.taxis.......................................... [RUN]
15:31:35  1 of 1 WARN freshness of public.taxis........................................... [WARN in 0.01s]
15:31:35  Done.

同じ要領で pickup と現在時刻の差が 1h 未満のデータを入れてみよう。

$ PICKUP=$(TZ=UTC gdate "+%Y-%m-%d %H:%M:%S" --date '15 min ago')
$ DROPOFF=$(TZ=UTC gdate "+%Y-%m-%d %H:%M:%S" --date '10 min ago')
$ cat << EOF | psql -d postgres
INSERT INTO taxis (
  pickup,
  dropoff,
  passengers,
  distance,
  fare,
  tip,
  tolls,
  total,
  color,
  payment,
  pickup_zone,
  dropoff_zone,
  pickup_borough,
  dropoff_borough
) VALUES (
  '${PICKUP}',
  '${DROPOFF}',
  3,
  2.16,
  9,
  1.1,
  0,
  13.4,
  'yellow',
  'cash',
  'Midtown East',
  'Yorkville West',
  'Manhattan',
  'Manhattan'
);
EOF

今度は実行が成功 (PASS) した。

$ dbt source freshness
15:33:36  Running with dbt=1.0.1
15:33:36  Found 0 models, 0 tests, 0 snapshots, 0 analyses, 165 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
15:33:36  
15:33:36  Concurrency: 1 threads (target='local')
15:33:36  
15:33:36  1 of 1 START freshness of public.taxis.......................................... [RUN]
15:33:36  1 of 1 PASS freshness of public.taxis........................................... [PASS in 0.01s]
15:33:36  Done.

これで source freshness の確認ができるようになった。

generic test を使ってテストを書く

さて、続いてはデータの中身を見るテストを書いていこう。 dbt を使ってテストを書くやり方には generic test と singular test の 2 つがある。 まずは、より汎用性の高い generic test から見ていこう。

generic test を使うには YAML の設定を追加するだけで良い。 次の設定では、ソースの taxis テーブルに含まれるいくつかのカラムに対してテストを用意している。 それぞれの名前から内容はなんとなく分かるはずだけど、念の為に書いておくと次のようなルールになっている。

  • id カラムは一意で NULL の値がないこと
  • pickup カラムは NULL の値がないこと
  • color カラムは yellowgreen の値だけあること
$ cat << 'EOF' > models/taxis.yml
version: 2

sources:
  - name: public
    tables:
      - name: taxis
        columns:
          - name: id
            tests:
              - unique
              - not_null
          - name: pickup
            tests:
              - not_null
          - name: color
            tests:
              - accepted_values:
                  values: ['yellow', 'green']
EOF

なお、先ほど確認した source freshness の設定は、簡単のために上記の設定ファイルからは省いた。

設定ファイルを作ったら dbt test コマンドでテストを実行する。

$ dbt test
09:38:01  Running with dbt=1.0.1
09:38:01  Found 0 models, 4 tests, 0 snapshots, 0 analyses, 165 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
09:38:01  
09:38:01  Concurrency: 1 threads (target='local')
09:38:01  
09:38:01  1 of 4 START test source_accepted_values_public_taxis_color__yellow__green...... [RUN]
09:38:01  1 of 4 PASS source_accepted_values_public_taxis_color__yellow__green............ [PASS in 0.03s]
09:38:01  2 of 4 START test source_not_null_public_taxis_id............................... [RUN]
09:38:01  2 of 4 PASS source_not_null_public_taxis_id..................................... [PASS in 0.01s]
09:38:01  3 of 4 START test source_not_null_public_taxis_pickup........................... [RUN]
09:38:01  3 of 4 PASS source_not_null_public_taxis_pickup................................. [PASS in 0.03s]
09:38:01  4 of 4 START test source_unique_public_taxis_id................................. [RUN]
09:38:01  4 of 4 PASS source_unique_public_taxis_id....................................... [PASS in 0.02s]
09:38:01  
09:38:01  Finished running 4 tests in 0.19s.
09:38:01  
09:38:01  Completed successfully
09:38:01  
09:38:01  Done. PASS=4 WARN=0 ERROR=0 SKIP=0 TOTAL=4

テストが実行されて成功 (PASS) した。

ちなみに、それぞれのテストケースは、いずれも SQL を使って実現されている。 先ほどの実行で、どのような SQL が発行されているかは、デフォルトで logs ディレクトリに生成されるログを読むと分かる。 テストは原則として「失敗するときに一致するレコードが出る SQL」になっている。 つまり、先ほどテストが成功したということは「発行した SQL に一致するレコードがなかった」ことを意味する。

また、上記ではソースに対してテストを書いたけど、モデルなど dbt に登場するその他のオブジェクトに対しても同じ要領でテストが書ける。 ただし、今回は簡単のためにソースに対してだけテストを書いていく。

テスト用マクロの入ったパッケージを利用する

先ほどはカラムの値が一意か NULL がないかといったテストを書いた。 ただ、dbt Core が組み込みで用意しているテスト用のマクロは多くない。 早々に「こういうテストが書きたいのに!」という場面が出てくる。 そんなときは、お目当てのテストに使えるマクロが入ったパッケージがないか探してみるのが良い。 dbt には dbt Hub というリポジトリに登録されているパッケージをインストールする仕組みがある。

たとえば dbt-utils というパッケージをインストールしてみよう。 このパッケージはテスト専用ではないものの、テストに使えるマクロがいくつか用意されている。 インストールしたいパッケージは packages.yml という名前の設定ファイルに書く。

$ cat << 'EOF' > packages.yml                
packages:
  - package: dbt-labs/dbt_utils
    version: 0.8.0
EOF

バージョンの指定は必須なので dbt Hub を見て記述する。

hub.getdbt.com

設定ファイルの用意ができたら dbt deps コマンドを実行する。

$ dbt deps

すると、デフォルトで dbt_packages というディレクトリにパッケージがインストールされる。

$ ls dbt_packages 
dbt_utils

このディレクトリには、パッケージに対応するリポジトリの内容がそのままダウンロードされている。 実にシンプルな仕組み。

$ ls dbt_packages/dbt_utils 
CHANGELOG.md        README.md       dbt_project.yml     etc         macros
LICENSE         RELEASE.md      docker-compose.yml  integration_tests   run_test.sh

パッケージをダウンロードできたら、実際に dbt-utils に含まれるテスト用のマクロを使ってみよう。 たとえば accepted_range というマクロを使うとカラムが取りうる値の範囲を指定できる。 以下では tip カラムの値が 0.0 ~ 33.2 になることを確認している。

$ cat << 'EOF' > models/taxis.yml
version: 2

sources:
  - name: public
    tables:
      - name: taxis
        columns:
          - name: tip
            tests:
              - dbt_utils.accepted_range:
                  min_value: 0.0
                  max_value: 33.2
                  inclusive: true
EOF

次のとおり tip カラムの値は 0.0 ~ 33.2 になっている。

$ echo "SELECT MAX(tip) AS max_tip, MIN(tip) AS min_tip FROM public.taxis" | psql -d postgres
 max_tip | min_tip 
---------+---------
    33.2 |       0
(1 row)

テストを実行してみよう。

$ dbt test
14:56:08  Running with dbt=1.0.1
14:56:08  Unable to do partial parsing because a project dependency has been added
14:56:09  Found 0 models, 1 test, 0 snapshots, 0 analyses, 352 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
14:56:09  
14:56:09  Concurrency: 1 threads (target='local')
14:56:09  
14:56:09  1 of 1 START test dbt_utils_source_accepted_range_public_taxis_tip__True__33_2__0_0 [RUN]
14:56:09  1 of 1 PASS dbt_utils_source_accepted_range_public_taxis_tip__True__33_2__0_0... [PASS in 0.03s]
14:56:09  
14:56:09  Finished running 1 test in 0.08s.
14:56:09  
14:56:09  Completed successfully
14:56:09  
14:56:09  Done. PASS=1 WARN=0 ERROR=0 SKIP=0 TOTAL=1

ちゃんと成功した。

ちなみに、テスト用のマクロが色々と入っているパッケージとしては dbt-expectations というのがある。

hub.getdbt.com

custom generic test を定義する (引数なし)

探しても使えそうなマクロが見つからないときは、独自の generic test を自分で書くこともできる。 公式のドキュメントでは custom generic test と呼んでいる。

試しに、カラムの値が正の値かをテストするマクロを定義してみよう。 前述したとおり、書くのは「失敗するときに一致するレコードが出る SQL」となる。 ただし、純粋な SQL ではなくて Jinja2 というテンプレートエンジンの構文を使って書いていく。

custom generic test を定義するには、tests/generic ディレクトリ以下に SQL を記述する。 以下では is_positive という名前で定義しており、カラムの値が 0 未満のレコードを抽出する。

$ mkdir -p tests/generic 
$ cat << 'EOF' > tests/generic/is_positive.sql 
{% test is_positive(model, column_name) %}

select
  *
from
  {{ model }}
where
  {{ column_name }} < 0

{% endtest %}
EOF

上記のテストを使ってみよう。 passengers カラムの内容をチェックする。

$ cat << 'EOF' > models/taxis.yml
version: 2

sources:
  - name: public
    tables:
      - name: taxis
        columns:
          - name: passengers
            tests:
              - is_positive
EOF

$ dbt test

custom generic test を定義する (引数あり)

先ほどの custom generic test には、追加の引数がなかった。 次は追加の引数があるものを定義してみよう。

以下では、特定の値よりも大きな値がカラムに含まれないことを確認する custom generic test を定義している。 引数にはしきい値を表す value と、境界値を含むかを表したフラグの inclusive をつけている。

$ cat << 'EOF' > tests/generic/max.sql 
{% test max(model, column_name, value, inclusive=true) %}

select
  *
from
  {{ model }}
where
  {{ column_name }} > {{- "=" if inclusive }} {{ value }}

{% endtest %}
EOF

上記を使ってみよう。 以下では passengers の最大値が 6 であることを確認している。 なお、inclusive のデフォルト値は false で上書きしている。

$ cat << 'EOF' > models/taxis.yml
version: 2

sources:
  - name: public
    tables:
      - name: taxis
        columns:
          - name: passengers
            tests:
              - max:
                  value: 6
                  inclusive: false
EOF

テストを実行すると成功する。

$ dbt test
15:27:22  Running with dbt=1.0.1
15:27:22  Found 0 models, 1 test, 0 snapshots, 0 analyses, 354 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
15:27:22  
15:27:22  Concurrency: 1 threads (target='local')
15:27:22  
15:27:22  1 of 1 START test source_max_public_taxis_passengers__False__6.................. [RUN]
15:27:22  1 of 1 PASS source_max_public_taxis_passengers__False__6........................ [PASS in 0.02s]
15:27:22  
15:27:22  Finished running 1 test in 0.06s.
15:27:22  
15:27:22  Completed successfully
15:27:22  
15:27:22  Done. PASS=1 WARN=0 ERROR=0 SKIP=0 TOTAL=1

ちなみに inclusive オプションを true にしたり、あるいは削ってデフォルト値にするとテストは失敗する。 これは passengers カラムの最大値が 6 のため。

$ echo "SELECT MAX(passengers) FROM public.taxis" | psql -d postgres 
 max 
-----
   6
(1 row)

singular test を使ってテストを書く

generic test は汎用的なテストをマクロとして定義した上で、それを色々な場所から YAML の設定で使うものだった。 一方で、特定の用途に特化したテストを書きたいときもあるはず。 そんなときは単発の SQL を実行するだけの singular test を使うと良い。

singular test を書くときは tests ディレクトリ以下に、直接 SQL のファイルを記述する。 以下では postgres.public.taxispassengers カラムに正の値しかないことを確認するテストを書いている。 ようするに、先ほど引数なしの custom generic test で検証した内容をベタ書きしているだけ。

$ cat << 'EOF' > tests/taxis_passengers_positive.sql 
select
  *
from
  {# sources に定義してあるテーブルの名前を取得できる #}
  {{ source('public', 'taxis') }}
where
  passengers < 0
EOF

純粋に singular test の結果だけ見たいので、先ほどの custom generic test は設定ファイルから削る。

$ cat << 'EOF' > models/taxis.yml                               
version: 2

sources:
  - name: public
    tables:
      - name: taxis
EOF

テストを実行してみよう。

$ dbt test
15:36:50  Running with dbt=1.0.1
15:36:50  Found 0 models, 1 test, 0 snapshots, 0 analyses, 354 macros, 0 operations, 0 seed files, 1 source, 0 exposures, 0 metrics
15:36:50  
15:36:50  Concurrency: 1 threads (target='local')
15:36:50  
15:36:50  1 of 1 START test taxis_passengers_positive..................................... [RUN]
15:36:50  1 of 1 PASS taxis_passengers_positive........................................... [PASS in 0.02s]
15:36:50  
15:36:50  Finished running 1 test in 0.06s.
15:36:50  
15:36:50  Completed successfully
15:36:50  
15:36:50  Done. PASS=1 WARN=0 ERROR=0 SKIP=0 TOTAL=1

ちゃんと実行されて成功したことがわかる。

まとめ

今回は dbt を使ってデータをテストする方法について書いた。

参考

docs.getdbt.com

docs.getdbt.com


  1. 変更したいときはコマンドラインオプションを使って場所を指定できる

  2. タイムゾーンを UTC に変換するやり方は公式ドキュメントでいくつか紹介されている

Linux の PID Namespace について

Linux のコンテナ仮想化を構成する機能の一つに Namespace (名前空間) がある。 Namespace は、カーネルのリソースを隔離して扱うための仕組みで、リソース毎に色々とある。 今回は、その中でも PID (Process Identifier) を隔離する PID Namespace を扱ってみる。

使った環境は次のとおり。

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.3 LTS
Release:    20.04
Codename:   focal
$ uname -rm
5.4.0-96-generic aarch64
$ unshare --version
unshare from util-linux 2.34
$ gcc --version
gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
$ ps --version
ps from procps-ng 3.3.16

PID と PID Namespace

まず、PID の前提について確認しておく。 PID は 1 から始まって、新しいプロセスができるたびにインクリメントされた整数がプロセスの識別子として付与される。 PID の数が割り当てられる上限に達したときは、また 1 に戻って空いている数字が割り当てられる。 たとえば Ubuntu 20.04 LTS であれば、上限は以下に設定されているようだ。

$ cat /proc/sys/kernel/pid_max 
4194304

プロセスが一意に識別できなくなるため、同じ PID を持ったプロセスが複数できることはない。 しかし、コンテナ仮想化においては、コンテナの中では独立した PID を見せたい。 そこで、PID Namespace によって、PID のリソースをシステムから隔離して扱うことができる。 言いかえると、コンテナの中で PID が 1 から始まるのは PID Namespace によって実現されている。

下準備

下準備として、あらかじめ unshare(1) のために util-linux をインストールしておく。 また、unshare(2) を呼び出すコードをビルドするために build-essential をインストールする。

$ sudo apt-get update
$ sudo apt-get install -y util-linux build-essential

unshare(1) を使って PID Namespace を使ってみる

まずはコマンドラインツールの unshare(1) から PID Namespace を使ってみよう。

unshare(1) を使って PID Namespace を新たに作るには --pid オプションを使う。 また、同時に --fork オプションと --mount-proc オプションもつけた方が良い。 この理由は後ほど説明する。 起動するプログラムとしては bash を指定しておこう。

$ sudo unshare --pid --fork --mount-proc bash

起動した bash では、PID が 1 になっていることがわかる。 また、ps(1) でも PID が 1 から振り直されていることが確認できる。 ちゃんとシステムから独立した PID が利用できているようだ。

# echo $$
1
# ps
    PID TTY          TIME CMD
      1 pts/0    00:00:00 bash
      8 pts/0    00:00:00 ps

さて、それでは先ほど --pid とは別で追加で指定したオプションについて見ていこう。 まずは --fork オプションから。 このオプションをつけないと何が起こるだろうか。 一旦、先ほど起動した bash は終了した上で、改めて unshare(1) を使おう。 今度は --pid オプションだけつける。

$ sudo unshare --pid bash
bash: fork: Cannot allocate memory

すると fork: Cannot allocate memory というエラーになってしまう。 とはいえ、一応エラーにはなりつつも bash は起動しているようだ。 しかし、何をするにしても同じ fork: Cannot allocate memory というエラーになってしまう。

# ls
bash: fork: Cannot allocate memory

このエラーの原因は、次の stackoverflow の質問に詳しい解説がある。

stackoverflow.com

かいつまんで説明すると、こういうことらしい。 まず、unshare(1) から起動した bash は、新たに作成した PID Namespace には所属しない。 代わりに、bash が最初に (fork(2) によって) 生成するサブプロセスが、新たに作成した PID Namespace に所属することになる。 新たに作成した PID Namespace では、PID が 1 から始まるため、bash のサブプロセスが PID 1 になる。 しかし、bash のサブプロセスは直後に終了するため、PID 1 のプロセスがいなくなる。 Linux において PID 1 のプロセスは特別な意味を持つことから、それがいなくなることで上記のエラーが生じているらしい。

一応、サブプロセスを起動しないタイプのシェルを利用すればエラーは出ない。 たとえば Ubuntu 20.04 LTS の sh は dash を使っているらしいので、指定してみよう。 しかし、その場合はそもそも起動したシェルが新しい PID Namespace に所属していないので何も意味がない。

$ sudo unshare --pid sh
# echo $$
1136

ということで --fork オプションが必要な理由がわかった。 続いては --mount-proc オプションについて。 今度はこのオプションを付けないで実行してみよう。

$ sudo unshare --pid --fork bash

一見すると何も問題なさそうに見えるけど、ps(1) なんかを呼び出すと随分と大きな数字が見える。 そもそも、隔離して見えないはずの unshare(1) の PID が見えているのはどうしたことか。

# ps
    PID TTY          TIME CMD
    959 pts/0    00:00:00 sudo
    960 pts/0    00:00:00 unshare
    961 pts/0    00:00:00 bash
    968 pts/0    00:00:00 ps

これは、/proc ファイルシステムが、PID Namespace を隔離する前の状態のままであることが原因。 要するにシステムの状態が見えたままということ。 Ubuntu 20.04 LTS の ps (=procps-ng) は /proc ファイルシステムを見ているので、さもありなん。

# ls /proc | egrep ^[0-9] | sort -n | tail -n 5
974
989
990
991
992

この状態は /proc ファイルシステムをマウントし直せば解消できる。

# mount -t proc proc /proc
# ls /proc | egrep ^[0-9] | sort -n | tail -n 5
1
20
21
22
23

つまり、これこそが --mount-proc オプションがやっていたこと、というわけだ。

ちなみに、上記のように新しいプロセスでマウントし直すやり方を取ると、そのままではマウントのプロパゲーションが起こってしまう。 この振る舞いについては以下のエントリで説明している。

blog.amedama.jp

要するに、上記のようなことをしたければ --mount オプションをつける必要がある。 あるいは、次のコマンドを使って事前にプロパゲーションを無効にしても良い。

$ sudo mount --make-private /

unshare(2) を使って PID Namespace を使ってみる

さて、続いては unshare(2) から PID Namespace を扱ってみよう。 ソースコードは、以下のコマンドと等価なものにする。

$ sudo unshare --pid --fork --mount-proc bash

早速だけどサンプルコードを以下に示す。 まず、unshare(2) で Mount Namespace と PID Namespace を新たに作成している。 その上で fork(2) で子プロセスを作っている。 この子プロセスが新しく作った PID Namespace に所属することになる。 その上で、マウントのプロパゲーションを無効にした上で /proc ファイルシステムをマウントし直している。 そして、最後にシェルを起動している。

#define _GNU_SOURCE

#include <sched.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <unistd.h>
#include <sys/wait.h>
#include <sys/types.h>
#include <sys/mount.h>

int main(int argc, char *argv[]) {
    // Mount & PID Namespace を作成する
    if (unshare(CLONE_NEWPID | CLONE_NEWNS) != 0) {
        fprintf(stderr, "Failed to create a new PID namespace: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // fork(2) で子プロセスを作る
    pid_t pid = fork();
    if (pid < 0) {
        fprintf(stderr, "Failed to fork a new process: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (pid != 0) {
        // 親プロセスは wait(2) で子プロセスの完了を待つ
        wait(NULL);
        exit(EXIT_SUCCESS);
    }

    // 以降は子プロセスの処理

    // ルート以下のマウントプロパゲーションを再帰的に無効にする
    if (mount("none", "/", NULL, MS_REC | MS_PRIVATE, NULL) != 0) {
        fprintf(stderr, "cannot change root filesystem propagation: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // mount(2) で /proc ファイルシステムをマウントする
    if (mount("proc", "/proc", "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, NULL) != 0) {
        fprintf(stderr, "Failed to mount /proc: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    // execvp(3) でシェルを起動する
    char* const args[] = {"bash", NULL};
    if (execvp(args[0], args) != 0) {
        fprintf(stderr, "Failed to exec \"%s\": %s\n", args[0], strerror(errno));
        exit(EXIT_FAILURE);
    }

    return EXIT_SUCCESS;
}

上記に適当な名前をつけて保存したらビルドして実行する。

$ gcc -std=c11 -Wall example.c
$ sudo ./a.out

シェルの PID を確認すると、ちゃんと 1 になっている。

# echo $$
1

ps(1) を実行しても、PID がリセットされていることがわかる。

# ps
    PID TTY          TIME CMD
      1 pts/0    00:00:00 bash
      8 pts/0    00:00:00 ps

PID Namespace が異なっているのは /proc ファイルシステムの以下を見ても確認できる。

# ls -l /proc/self/ns/pid
lrwxrwxrwx 1 root root 0 Jan 23 01:29 /proc/self/ns/pid -> 'pid:[4026532130]'
# exit
$ ls -l /proc/self/ns/pid
lrwxrwxrwx 1 ubuntu ubuntu 0 Jan 23 01:29 /proc/self/ns/pid -> 'pid:[4026531836]'

いじょう。

参考

man7.org

man7.org

man7.org