CUBE SUGAR CONTAINER

技術系のこと書きます。

S3 互換オブジェクトストレージの OSS - MinIO を試す

MinIO は Amazon S3 互換のオブジェクトストレージを提供する OSS のひとつ。 たとえばオンプレ環境でオブジェクトストレージを構築したいときや、手元で S3 を扱うアプリケーションの動作確認をするときなんかに使える。 今回はそんな MinIO を AWS CLI と Python クライアントの boto3 から使ってみる。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.4
BuildVersion:   20F71
$ minio -v    
minio version RELEASE.2021-05-26T00-22-46Z
$ python -V
Python 3.9.5
$ aws --version
aws-cli/2.2.6 Python/3.9.5 Darwin/20.5.0 source/x86_64 prompt/off
$ pip list | grep -i boto3  
boto3                     1.17.83

もくじ

下準備

今回は Homebrew から MinIO をインストールして使う。 クライアントとして awscli と boto3 も入れておく。

$ brew install minio awscli
$ pip install boto3

インストールできたら作業用のディレクトリを指定して minio server コマンドを実行する。 これで MinIO のサーバが立ち上がる。

$ mkdir -p /tmp/minio
$ minio server /tmp/minio

立ち上がると 9000 番ポートを Listen し始める。

$ lsof -i:9000
COMMAND   PID    USER   FD   TYPE             DEVICE SIZE/OFF NODE NAME
minio   13739 amedama   14u  IPv6 0x62631bfc177de1b3      0t0  TCP *:cslistener (LISTEN)

ブラウザでローカルホストの 9000 番ポートにアクセスすると管理用の Web UI が見える。

$ open http://localhost:9000/

f:id:momijiame:20210529145812p:plain
MinIO の管理用 Web UI

アカウントはデフォルトだと Access Key と Secret Key がどちらも minioadmin でログインできる。 デフォルトのアカウントを変更したいときはサーバを立ち上げるときに以下の環境変数で指定する。

  • Access Key

    • MINIO_ROOT_USER または MINIO_ACCESS_KEY
  • Secret Key

    • MINIO_ROOT_PASSWORD または MINIO_SECRET_KEY

AWS CLI から操作する

はじめに AWS CLI から操作してみよう。 まずは認証情報を環境変数で設定しておく。

$ export AWS_ACCESS_KEY_ID=minioadmin
$ export AWS_SECRET_ACCESS_KEY=minioadmin

あとは aws コマンドのオプションとして --endpoint-url に MinIO が動作してる http://localhost:9000 を指定するだけ。

$ aws --endpoint-url http://localhost:9000 s3 ls

特にエラーにならず上記が実行できれば大丈夫。

サンプルとなるバケットを example-bucket という名前で作成してみる。

$ aws s3 --endpoint-url http://localhost:9000 mb s3://example-bucket
make_bucket: example-bucket

作成すると、ちゃんと ls でバケットが見えるようになった。

$ aws --endpoint-url http://localhost:9000 s3 ls
2021-05-29 15:25:34 example-bucket

続いてはファイルをバケットにコピーしてみる。

$ echo "Hello, World" > /tmp/greet.txt
$ aws --endpoint-url http://localhost:9000 s3 cp /tmp/greet.txt s3://example-bucket
upload: ../../tmp/greet.txt to s3://example-bucket/greet.txt

ちゃんとアップロードできた。

$ aws --endpoint-url http://localhost:9000 s3 ls s3://example-bucket
2021-05-29 15:26:29         13 greet.txt

ファイルに深いプレフィックスをつけてコピーしたいときは ls--recursive オプションをつけると再帰的に確認できる。

$ aws --endpoint-url http://localhost:9000 s3 cp /tmp/greet.txt s3://example-bucket/folder/subfolder/ 
upload: ../../tmp/greet.txt to s3://example-bucket/folder/subfolder/greet.txt
$ aws --endpoint-url http://localhost:9000 s3 ls s3://example-bucket --recursive
2021-05-29 15:29:37         13 folder/subfolder/greet.txt
2021-05-29 15:26:29         13 greet.txt

上記は / を区切りにした階層構造があるように見えるけど、これはあくまでファイル名に / 区切りのプレフィックスがついているに過ぎない。 つまり、インタフェース的に階層構造があるように見せているだけ、という点には留意する必要がある。 階層構造のように見えたとしても、バケット以下の構造はあくまでもフラットな名前空間になっている。

標準入出力経由でファイルをコピーすることもできる。

$ echo "Hello, World" | aws --endpoint-url http://localhost:9000 s3 cp - s3://example-bucket/stdin/greet.txt
$ aws --endpoint-url http://localhost:9000 s3 cp s3://example-bucket/stdin/greet.txt -
Hello, World

ファイルを削除するときは rm コマンドを使う。

$ aws --endpoint-url http://localhost:9000 s3 rm s3://example-bucket/stdin/greet.txt
delete: s3://example-bucket/stdin/greet.txt

バケットの削除は、入っているファイルをすべて削除すれば rb コマンドからできる。 ただし、今回は後段の boto3 が残っているので省略する。

boto3 から操作する

続いては Python クライアントの boto3 からアクセスしてみる。

まずは Python のインタプリタを起動する。

$ python

boto3 パッケージをインポートする。

>>> import boto3

エンドポイントや認証情報を与えてクライアントを作る。

>>> s3_client = boto3.client('s3',
...                          use_ssl=False,
...                          endpoint_url='http://localhost:9000',
...                          aws_access_key_id='minioadmin',
...                          aws_secret_access_key='minioadmin')

バケットのリストを確認すると、先ほど AWS CLI で作成したものが確認できる。

>>> response = s3_client.list_buckets()
>>> response['Buckets']
[{'Name': 'example-bucket', 'CreationDate': datetime.datetime(2021, 5, 29, 6, 25, 34, 96000, tzinfo=tzutc())}]

試しに新しくバケットを作ってみよう。

>>> s3_client.create_bucket(Bucket='boto3-bucket')
{'ResponseMetadata': {'RequestId': '168376A910A8A588', 'HostId': '', 'HTTPStatusCode': 200, 'HTTPHeaders': {'accept-ranges': 'bytes', 'content-length': '0', 'content-security-policy': 'block-all-mixed-content', 'location': '/boto3-bucket', 'server': 'MinIO', 'vary': 'Origin', 'x-amz-request-id': '168376A910A8A588', 'x-xss-protection': '1; mode=block', 'date': 'Sat, 29 May 2021 06:45:59 GMT'}, 'RetryAttempts': 0}, 'Location': '/boto3-bucket'}

確認すると、新しくバケットができている。

>>> response = s3_client.list_buckets()
>>> from pprint import pprint
>>> pprint(response['Buckets'])
[{'CreationDate': datetime.datetime(2021, 5, 29, 6, 45, 59, 285000, tzinfo=tzutc()),
  'Name': 'boto3-bucket'},
 {'CreationDate': datetime.datetime(2021, 5, 29, 6, 25, 34, 96000, tzinfo=tzutc()),
  'Name': 'example-bucket'}]

いくつかやり方はあるけど、ここでは upload_fileobj() 関数を使ってファイルをアップロードしてみる。

>>> import io
>>> f = io.BytesIO(b'Hello, World')
>>> s3_client.upload_fileobj(f, 'boto3-bucket', 'greet.txt')

ちゃんとアップロードできた。

>>> response = s3_client.list_objects(Bucket='boto3-bucket')
>>> pprint(response['Contents'])
[{'ETag': '"82bb413746aee42f89dea2b59614f9ef"',
  'Key': 'greet.txt',
  'LastModified': datetime.datetime(2021, 5, 29, 6, 47, 55, 783000, tzinfo=tzutc()),
  'Owner': {'DisplayName': 'minio',
            'ID': '02d6176db174dc93cb1b899f7c6078f08654445fe8cf1b6ce98d8855f66bdbf4'},
  'Size': 12,
  'StorageClass': 'STANDARD'}]

今度は download_fileobj() 関数を使ってファイルをダウンロードしてみよう。

>>> f = io.BytesIO()
>>> s3_client.download_fileobj(Bucket='example-bucket', Key='greet.txt', Fileobj=f)

ちゃんと中身が確認できた。

>>> f.seek(0)
0
>>> f.read()
b'Hello, World\n'

いじょう。

iproute2 の ip-netns(8) を使わずに Network Namespace を操作する

今回は、iproute2 の ip-netns(8) を使わずに、Linux の Network Namespace を操作する方法について書いてみる。 目的は、namespaces(7) について、より深い理解を得ること。

使った環境は次のとおり。

$ cat /etc/lsb-release 
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.2 LTS"
$ uname -r
5.4.0-1043-gcp

もくじ

下準備

あらかじめ、必要なパッケージをインストールしておく。

$ sudo apt-get update
$ sudo apt-get -y install iproute2 util-linux gcc

前提知識

Linux の namespaces(7) は、プロセスが利用するリソースを分離するための仕組み。 典型的には、Linux のコンテナ仮想化を実現するために用いられている。 今回はタイトルに Network Namespace と入れたものの、分離できるのは何も Network に限らない。

プロセスが利用している Namespace の情報は procfs から /proc/<pid>/ns で確認できる。 現在のプロセスであれば、自身の pid を確認するまでもなく /proc/self/ns を見れば良い。

$ ls -alF /proc/self/ns
total 0
dr-x--x--x 2 amedama amedama 0 May 21 12:41 ./
dr-xr-xr-x 9 amedama amedama 0 May 21 12:41 ../
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 cgroup -> 'cgroup:[4026531835]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 ipc -> 'ipc:[4026531839]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 mnt -> 'mnt:[4026531840]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 net -> 'net:[4026531992]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 pid -> 'pid:[4026531836]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 pid_for_children -> 'pid:[4026531836]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 user -> 'user:[4026531837]'
lrwxrwxrwx 1 amedama amedama 0 May 21 12:41 uts -> 'uts:[4026531838]'

これらのファイルの実体はシンボリックリンクで、参照先として表示されている謎の数字は inode 番号を示している。 つまり、Namespace は inode 番号が識別子になっている。 上記であれば、/proc/self/ns/net がプロセスが利用している Network Namespace の識別子を表している。

$ file /proc/self/ns/net
/proc/self/ns/net: symbolic link to net:[4026531992]
$ stat -L /proc/self/ns/net
  File: /proc/self/ns/net
  Size: 0          Blocks: 0          IO Block: 4096   regular empty file
Device: 4h/4d   Inode: 4026531992  Links: 1
Access: (0444/-r--r--r--)  Uid: (    0/    root)   Gid: (    0/    root)
Access: 2021-05-21 12:42:07.565311760 +0000
Modify: 2021-05-21 12:42:07.565311760 +0000
Change: 2021-05-21 12:42:07.565311760 +0000
 Birth: -

unshare(1) / nsenter(1) / mount(8) を使って操作する

さて、前提知識の確認が終わったところで、実際に ip-netns(8) を使わずに Network Namespace を操作してみよう。 まずは、ip-netns(8) 以外のコマンドラインツールで操作する方法を試す。

新しく namespaces(7) を作るコマンドとしては unshare(1) が使える。 --net オプションを指定すると、コマンドで新たに起動するプロセスが利用する Network Namespace を確保できる。 以下では新しい Network Namespace を使って bash(1) を起動している。

$ sudo unshare --net bash

起動したシェルで確認すると、たしかに /proc/<pid>/ns 以下のファイルの inode 番号が変わっていることが分かる。

# file /proc/self/ns/net
/proc/self/ns/net: symbolic link to net:[4026532254]

ip-link(8) を使ってみるデバイスの状況を確認すると、DOWN したループバックデバイスしか無いことが分かる。 どうやら、ちゃんと Network Namespace が新しく作られたようだ。

# ip link show
1: lo: <LOOPBACK> mtu 65536 qdisc noop state DOWN mode DEFAULT group default qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00

ただ、この状況で ip-netns(8) を使ってみても何も表示されない。 新しく Network Namespace ができたというのに、どうしてだろう。

# ip netns list

というのも、実は ip-netns(8) の list サブコマンドは、/var/run/netns 以下にあるファイルを見ているだけに過ぎない。 上記で何も表示されないということは、ここに何もファイルがないということ。

# ls /var/run/netns

たしかに何も表示されない。 そもそも、ip-netns(8) を使ったことがない環境であれば、ディレクトリすらできていないことだろう。

ここでおもむろに /var/run/netns 以下にファイルを作って、/proc/self/ns/net--bind オプションつきでマウントしてみよう。

# touch /var/run/netns/example
# mount --bind /proc/self/ns/net /var/run/netns/example

すると、ip-netns(8) の list サブコマンドに、作ったファイルと同じ内容が見られる。

# ip netns list
example

上記は、ちゃんと ip-netns(8) から使うことができる。 一旦、unshare(1) で作ったシェルのプロセスから抜けて、ip-netns(8) の exec サブコマンドを実行してみよう。

# exit
$ sudo ip netns exec example bash -c 'ls -alF /proc/self/ns/net'
lrwxrwxrwx 1 root root 0 May 21 12:49 /proc/self/ns/net -> 'net:[4026532254]'

上記から、ちゃんと使えることがわかる。 というのも、これは実のところ ip-netns(8) が内部的にやっているのとほぼ同じことをやっているため。

先ほど /var/run/netns 以下に作ったファイルは nsenter(1) から利用することもできる。 このコマンドは既存の Namespace に切り替えるために用いる。 --net オプションにファイルを指定して、シェルを起動してみよう。

$ sudo nsenter --net=/var/run/netns/example bash

起動したシェルから確認すると、ちゃんと Namespace が切り替わっていることがわかる。

# ls -alF /proc/self/ns/net
lrwxrwxrwx 1 root root 0 May 21 12:53 /proc/self/ns/net -> 'net:[4026532254]'

ちなみに、ip-netns(8) から利用するときには mount(8) を使わなくてもシンボリックリンクを張るだけで代用できる。 次のように、$$ を使って自身の pid を置換しつつ、Namespace を表したファイルからシンボリックリンクを張ってみよう。

# ln -s /proc/$$/ns/net /var/run/netns/symlink

起動したシェルから抜けた上で確認すると、ちゃんと ip-netns(8) のリストに表示されると共に、使えることがわかる。

# exit
$ ip netns list
symlink
example
$ sudo ip netns exec example bash -c 'ls -alF /proc/self/ns/net'
lrwxrwxrwx 1 root root 0 May 21 12:58 /proc/self/ns/net -> 'net:[4026532254]'

このテクニックは Docker や Mininet などが作る Network Namespace を ip-netns(8) から操作したいときにも有効。

unshare(2) / setns(2) / mount(2) を使って操作する

さて、ip-netns(8) 以外のコマンドラインツールから操作できることがわかったところで、続いてはシステムコールを使ってみる。 というか、先ほど使った一連のコマンドラインツールも、内部的にはこれらの API を叩いていた。

早速だけど、以下にサンプルコードを示す。 このサンプルコードでは、次のような処理をしている。

  • unshare(2) で Network Namespace を新しく作る
  • mount(2) で /proc/self/ns/net/var/run/netns 以下に syscall-example という名前でマウントする
  • /proc/self/ns/net の中身を表示する
#define _GNU_SOURCE
#include <sched.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/mount.h>

int main(int argc, char *argv[]) {
    if (unshare(CLONE_NEWNET) < 0) {
        fprintf(stderr, "Failed to create a new network namespace: %s\n", strerror(errno));
        exit(-1);
    }

    const char *netns_path = "/var/run/netns/syscall-example";
    const int fd = open(netns_path, O_RDONLY | O_CREAT | O_EXCL, 0);
    if (fd < 0) {
        fprintf(stderr, "Cannot create namespace file \"%s\": %s\n",
            netns_path, strerror(errno));
        return EXIT_FAILURE;
    }
    close(fd);

    const char *proc_path = "/proc/self/ns/net";
    if (mount(proc_path, netns_path, "none", MS_BIND, NULL) < 0) {
        fprintf(stderr, "Failed to bind %s -> %s: %s\n",
            proc_path, netns_path, strerror(errno));
    }

    const char *cmd = "file";
    char* const args[] = {"file", "/proc/self/ns/net", NULL};
    if (execvp(cmd, args) < 0) {
        fprintf(stderr, "Failed to exec \"%s\": %s\n", cmd, strerror(errno));
        exit(-1);
    }
    return EXIT_SUCCESS;
}

上記に nsadd.c という名前をつけてビルドする。

$ gcc -o nsadd.o nsadd.c

実行すると、/proc/self/ns/net が新しい識別子になっていることがわかる。

$ sudo ./nsadd.o 
/proc/self/ns/net: symbolic link to net:[4026532315]

ip-netns(8) からも、ちゃんと使える。

$ ip netns list
syscall-example
symlink
example
$ sudo ip netns exec syscall-example bash -c 'ls -alF /proc/self/ns/net'
lrwxrwxrwx 1 root root 0 May 21 13:10 /proc/self/ns/net -> 'net:[4026532315]'

続いては、上記で作った Network Namespace を示すファイルを利用するサンプルコード。 次のような処理をしている。

  • /var/run/netns 以下のファイルを open(2) で開く
  • 上記で得られたファイルディスクリプタを setns(2) に渡して Namespace を切り替える
  • /proc/self/ns/net の中身を表示する
#define _GNU_SOURCE
#include <sched.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>

int main(int argc, char *argv[]) {
    const char *mounted_path = "/var/run/netns/syscall-example";
    const int fd = open(mounted_path, O_RDONLY | O_CLOEXEC);
    if (fd < 0) {
        fprintf(stderr, "Cannot open mounted path\"%s\": %s\n",
            mounted_path, strerror(errno));
        return EXIT_FAILURE;
    }

    if (setns(fd, CLONE_NEWNET) < 0) {
        fprintf(stderr, "failed to setup the network namespace \"%s\": %s\n",
            mounted_path, strerror(errno));
        close(fd);
        return EXIT_FAILURE;
    }

    const char *cmd = "file";
    char* const args[] = {"file", "/proc/self/ns/net", NULL};
    if (execvp(cmd, args) < 0) {
        fprintf(stderr, "Failed to exec \"%s\": %s\n", cmd, strerror(errno));
        exit(-1);
    }
    return EXIT_SUCCESS;
}

上記に nsexec.c という名前をつけてビルドする。

$ gcc -o nsexec.o nsexec.c 

実行すると、ちゃんと Network Namespace が切り替わっていることがわかる。

$ sudo ./nsexec.o 
/proc/self/ns/net: symbolic link to net:[4026532315]

いじょう。

参考

git.kernel.org

Python: Streamlit を使って手早く WebUI 付きのプロトタイプを作る

Streamlit は、ざっくり言うと主にデータサイエンス領域において WebUI 付きのアプリケーションを手早く作るためのソフトウェア。 使い所としては、ひとまず動くものを見せたかったり、少人数で試しに使うレベルのプロトタイプを作るフェーズに適していると思う。 たとえば、Jupyter で提供すると複数人で使うのに難があるし、かといって Flask や Django を使って真面目に作るほどではない、くらいのとき。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.3.1
BuildVersion:   20E241
$ python -V
Python 3.8.9

もくじ

下準備

まずは必要なパッケージをインストールしておく。 本当に必要なのは streamlit のみ。 watchdog はパフォーマンスのために入れる。 matplotlib についてはグラフを可視化するときに使うため入れておく。 click はスクリプトに引数を渡すサンプルのため。

$ pip install streamlit watchdog matplotlib click

インストールすると streamlit コマンドが使えるようになる。

$ streamlit version
Streamlit, version 0.81.0

必要に応じて Streamlit の設定ファイルを用意する。 以下は、初回の実行時に確認される e-mail アドレスのスキップと、利用に関する統計情報を送信しない場合の設定。 なお、これは別にやらなくても初回の実行時に案内が出る。

$ mkdir -p ~/.streamlit
$ cat << 'EOF' > ~/.streamlit/credentials.toml 
[general]
email = ""
EOF
$ cat << 'EOF' > ~/.streamlit/config.toml
[browser]
    gatherUsageStats = false
EOF

基本的な使い方

まずはもっとも基本的な使い方から見ていく。 以下は streamlit.write() 関数を使って任意のオブジェクトを WebUI に表示するサンプルコード。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # Streamlit が対応している任意のオブジェクトを可視化する (ここでは文字列)
    st.write('Hello, World!')


if __name__ == '__main__':
    main()

上記を適当な名前で保存したら streamlit run サブコマンドで指定して実行する。

$ streamlit run example.py

すると、デフォルトでは 8501 ポートで Streamlit のアプリケーションサーバが起動する。 ブラウザで開いて結果を確認しよう。

$ open http://localhost:8501

すると、次のように「Hello, World!」という表示のある Web ページが確認できる。

f:id:momijiame:20210505001417p:plain

やっていることは静的な文字列を表示しているだけとはいえ、Pure Python なスクリプトをちょっと書くだけで Web ページが表示できた。

なお、Streamlit はデフォルトだと実行するホストの全 IP アドレスを Listen するので注意しよう。 ループバックアドレスだけに絞りたいときは以下のようにする。

$ streamlit run --server.address localhost example.py

ちなみに先ほど使った streamlit.write() 関数は色々なオブジェクトを可視化するのに使うことができる。 現時点で対応しているものをざっと書き出してみると次のとおり。

  • サードパーティー製パッケージ関連
    • Pandas の DataFrame オブジェクト
    • Keras の Model オブジェクト
    • SymPy の表現式 (LaTeX)
    • グラフ描画系
      • Matplotlib
      • Altair
      • Vega Lite
      • Plotly
      • Bokeh
      • PyDeck
      • Graphviz
  • 標準的な Python のオブジェクト
    • 例外オブジェクト
    • 関数オブジェクト
    • モジュールオブジェクト
    • 辞書オブジェクト

その他、任意のオブジェクトは str() 関数に渡したのと等価な結果が得られる。

基本的な書式

続いて、Streamlit に備わっている基本的な書式をいくつか試してみる。 アプリケーションのタイトルやヘッダ、マークダウンテキストや数式など。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # タイトル
    st.title('Application title')
    # ヘッダ
    st.header('Header')
    # 純粋なテキスト
    st.text('Some text')
    # サブレベルヘッダ
    st.subheader('Sub header')
    # マークダウンテキスト
    st.markdown('**Markdown is available **')
    # LaTeX テキスト
    st.latex(r'\bar{X} = \frac{1}{N} \sum_{n=1}^{N} x_i')
    # コードスニペット
    st.code('print(\'Hello, World!\')')
    # エラーメッセージ
    st.error('Error message')
    # 警告メッセージ
    st.warning('Warning message')
    # 情報メッセージ
    st.info('Information message')
    # 成功メッセージ
    st.success('Success message')
    # 例外の出力
    st.exception(Exception('Oops!'))
    # 辞書の出力
    d = {
        'foo': 'bar',
        'users': [
            'alice',
            'bob',
        ],
    }
    st.json(d)


if __name__ == '__main__':
    main()

先ほどの Python ファイルに上書きすると、Streamlit はファイルの変更を検知して自動的に読み込み直してくれる。 アプリケーションを表示しているブラウザはリロードするか、変更が生じた際に自動で読み込むか問うボタンが右上に出てくる。

f:id:momijiame:20210506011250p:plain

プレースホルダー

続いて扱うのはプレースホルダーという機能。 かなり地味なので、この時点で紹介する点に違和感があるかもしれない。 とはいえ、地味なりに多用する機能なので先に説明しておく。

プレースホルダーは、任意のオブジェクトを表示するための入れ物みたいなオブジェクト。 言葉よりも実際に使った方が分かりやすいと思うので以下にサンプルを示す。 プレースホルダーを用意して、後からそこにオブジェクトを書き込む、みたいな使い方をする。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # プレースホルダーを用意する
    placeholder1 = st.empty()
    # プレースホルダーに文字列を書き込む
    placeholder1.write('Hello, World')

    placeholder2 = st.empty()
    # コンテキストマネージャとして使えば出力先をプレースホルダーにできる
    with placeholder2:
        # 複数回書き込むと上書きされる
        st.write(1)
        st.write(2)
        st.write(3)  # この場合は最後に書き込んだものだけ見える


if __name__ == '__main__':
    main()

上記を実行した結果は次のとおり。 プレースホルダーの内容は上書きされるので、特に何もしなければ最後に書きこまれた内容が見える。

f:id:momijiame:20210506013310p:plain

プレースホルダーを応用するとアニメーション的なこともできる。 以下のサンプルコードではスリープを挟みながらプレースホルダーの内容を書きかえることで動きのあるページを作っている。

# -*- coding: utf-8 -*-

import time

import streamlit as st


def main():
    status_area = st.empty()

    # カウントダウン
    count_down_sec = 5
    for i in range(count_down_sec):
        # プレースホルダーに残り秒数を書き込む
        status_area.write(f'{count_down_sec - i} sec left')
        # スリープ処理を入れる
        time.sleep(1)

    # 完了したときの表示
    status_area.write('Done!')
    # 風船飛ばす
    st.balloons()


if __name__ == '__main__':
    main()

上記を実行すると秒数のカウントダウンが確認できる。

f:id:momijiame:20210506013737p:plain

プログレスバーを使った処理の進捗の可視化

ちなみに先ほどのようなカウントダウンをするような処理だとプログレスバーを使うこともできる。 以下のサンプルコードでは 0.1 秒ごとにプログレスバーの数値を増やしていくページができる。

# -*- coding: utf-8 -*-

import time

import streamlit as st


def main():
    status_text = st.empty()
    # プログレスバー
    progress_bar = st.progress(0)

    for i in range(100):
        status_text.text(f'Progress: {i}%')
        # for ループ内でプログレスバーの状態を更新する
        progress_bar.progress(i + 1)
        time.sleep(0.1)

    status_text.text('Done!')
    st.balloons()


if __name__ == '__main__':
    main()

上記を実行すると、以下のようにプログレスバーが表示される。

f:id:momijiame:20210506014133p:plain

基本的な可視化

ここまでの内容だと、面白いけど何が便利なのかイマイチよく分からないと思う。 そこで、ここからはもう少し実用的な話に入っていく。 具体的には、いくつかグラフなどを可視化する方法について見ていこう。

組み込みのグラフ描画機能

Streamlit には組み込みのグラフ描画機能がある。 この機能を使うと NumPy の配列や Pandas のデータフレームなどをサクッとグラフにできる。 以下のサンプルコードでは折れ線グラフ、エリアチャート、バーチャートの 3 種類を試している。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # ランダムな値でデータフレームを初期化する
    data = {
        'x': np.random.random(20),
        'y': np.random.random(20) - 0.5,
        'z': np.random.random(20) - 1.0,
    }
    df = pd.DataFrame(data)
    # 折れ線グラフ
    st.subheader('Line Chart')
    st.line_chart(df)
    # エリアチャート
    st.subheader('Area Chart')
    st.area_chart(df)
    # バーチャート
    st.subheader('Bar Chart')
    st.bar_chart(df)


if __name__ == '__main__':
    main()

上記からは次のようなグラフが得られる。

f:id:momijiame:20210506012655p:plain

グラフにデータを動的に追加することもできる。 これにはグラフを描画する関数を実行して得られるオブジェクトに add_rows() メソッドを使えば良い。 以下のサンプルコードでは、折れ線グラフに 0.5 秒間隔で 10 回までデータを追加している。

# -*- coding: utf-8 -*-

import time

import streamlit as st
import numpy as np


def main():
    # 折れ線グラフ (初期状態)
    x = np.random.random(size=(10, 2))
    line_chart = st.line_chart(x)

    for i in range(10):
        # 折れ線グラフに 0.5 秒間隔で 10 回データを追加する
        additional_data = np.random.random(size=(5, 2))
        line_chart.add_rows(additional_data)
        time.sleep(0.5)


if __name__ == '__main__':
    main()

上記を確認すると、0.5 秒間隔でグラフにデータが追加されていく様子が確認できる。 こういったアニメーション効果を手軽に導入できるのは Streamlit の強みだと思う。

f:id:momijiame:20210506012934p:plain

ちなみに気づいたかもしれないけどブラウザをリロードするごとにプロットされる結果は変わる。 これは Streamlit がページを表示するときに、スクリプトを上から順に実行するように処理しているため。 つまり、ブラウザをリロードする毎にスクリプトのコードが評価され直しているように考えれば良い。

Matplotlib

続いては Matplotlib のグラフを描画してみよう。 Streamlit では Matplotlib の Figure オブジェクトを書き出すことでグラフを描画できる。 以下のサンプルコードではランダムに生成した値をヒストグラムにプロットしている。

# -*- coding: utf-8 -*-

import streamlit as st
import numpy as np
from matplotlib import pyplot as plt


def main():
    # 描画領域を用意する
    fig = plt.figure()
    ax = fig.add_subplot()
    # ランダムな値をヒストグラムとしてプロットする
    x = np.random.normal(loc=.0, scale=1., size=(100,))
    ax.hist(x, bins=20)
    # Matplotlib の Figure を指定して可視化する
    st.pyplot(fig)


if __name__ == '__main__':
    main()

上記からは次のような画面が得られる。

f:id:momijiame:20210509171628p:plain

先ほどと同じように、データを更新しながらグラフを描画し直すサンプルも書いてみる。 以下のサンプルコードでは、プレースホルダを使って描画されるグラフの内容を更新している。

# -*- coding: utf-8 -*-

import time

import streamlit as st
import numpy as np
from matplotlib import pyplot as plt


def main():
    # グラフを書き出すためのプレースホルダを用意する
    plot_area = st.empty()
    fig = plt.figure()
    ax = fig.add_subplot()
    x = np.random.normal(loc=.0, scale=1., size=(100,))
    ax.plot(x)
    # プレースホルダにグラフを書き込む
    plot_area.pyplot(fig)

    # 折れ線グラフに 0.5 秒間隔で 10 回データを追加する
    for i in range(10):
        # グラフを消去する
        ax.clear()
        # データを追加する
        additional_data = np.random.normal(loc=.0, scale=1., size=(10,))
        x = np.concatenate([x, additional_data])
        # グラフを描画し直す
        ax.plot(x)
        # プレースホルダに書き出す
        plot_area.pyplot(fig)
        time.sleep(0.5)


if __name__ == '__main__':
    main()

上記を実行すると、一定間隔でデータが追加されながらグラフの描画も更新される。

f:id:momijiame:20210509171822p:plain

Pandas

グラフではないけど Pandas のデータフレームを Jupyter で可視化するときと同じように表示できる。 データフレームを出力するときは streamlit.dataframe()streamlit.table() という 2 種類の関数がある。 前者は行や列の要素が多いときにスクロールバーを使って表示する一方で、後者はすべてをいっぺんに表示する。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # Pandas のデータフレームを可視化してみる
    data = {
        # ランダムな値で初期化する
        'x': np.random.random(20),
        'y': np.random.random(20),
    }
    df = pd.DataFrame(data)
    # データフレームを書き出す
    st.dataframe(df)
    # st.write(df)  でも良い
    # スクロールバーを使わず一度に表示したいとき
    st.table(df)


if __name__ == '__main__':
    main()

上記からは以下のような表示が得られる。

f:id:momijiame:20210509172138p:plain

画像

画像を表示するときは streamlit.image() 関数を使う。 以下のサンプルコードではランダムに生成した NumPy 配列を、カラー画像として可視化している。

# -*- coding: utf-8 -*-

import streamlit as st
import numpy as np


def main():
    x = np.random.random(size=(400, 400, 3))
    # NumPy 配列をカラー画像として可視化する
    st.image(x)


if __name__ == '__main__':
    main()

上記からは以下のような表示が得られる。

f:id:momijiame:20210509172356p:plain

地図

地図上にプロットすることもできる。 地図に散布図を描きたいときは streamlit.map() 関数を使えば良い。 以下のサンプルコードでは、東京を中心とした地図にランダムな点をプロットしている。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # 東京のランダムな経度・緯度を生成する
    data = {
        'lat': np.random.randn(100) / 100 + 35.68,
        'lon': np.random.randn(100) / 100 + 139.75,
    }
    map_data = pd.DataFrame(data)
    # 地図に散布図を描く
    st.map(map_data)


if __name__ == '__main__':
    main()

上記からは以下のような表示が得られる。

f:id:momijiame:20210509172513p:plain

Streamlit がサポートしている可視化の機能は他にも色々とあるけど、とりあえず一旦はここまでで切り上げる。

キャッシュ機構

ここまでのサンプルコードは、ブラウザをリロードすると表示される内容が変わるものが多かった。 それはスクリプトの内容が毎回、評価し直されているのと同じ状態のため。 ただ、それだと困る場面も多い。 たとえば、時間のかかる処理が毎回評価され直すと、パフォーマンスに深刻な影響がある。 そんなときは Streamlit のキャッシュ機構を使うと良い。

キャッシュ機構を使うには streamlit.cache デコレータを使えば良い。 以下のサンプルコードでは、cached_data() 関数をデコレータで修飾している。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


# 関数の出力をキャッシュする
@st.cache
def cached_data():
    data = {
        'x': np.random.random(20),
        'y': np.random.random(20),
    }
    df = pd.DataFrame(data)
    return df


def main():
    # リロードしても同じ結果が得られる
    df = cached_data()
    st.dataframe(df)


if __name__ == '__main__':
    main()

上記はオンメモリで結果がキャッシュされるため、ブラウザをリロードしても表示が変わることがない。 その他、キャッシュ機構の詳しい解説は以下のドキュメントに記載されている。

docs.streamlit.io

ウィジェット

ここまでのサンプルには、ユーザからの入力を受け付けるものがなかった。 ここからは、ウィジェットを使ってインタラクティブなページを作る方法について書く。

ボタン

まずは最も基本的なウィジェットとしてボタンを扱う。 このボタン、Streamlit のウィジェットの考え方が、他の UI フレームワークと違うことがよく分かって面白い。

ボタンは streamlit.button() 関数を使って配置できる。 以下のサンプルコードは、ボタンを押すことで表示される内容が変わるものとなっている。 興味深いのは、ボタンにイベントハンドラなどの類が一切設定されていないこと。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # データフレームを書き出す
    data = np.random.randn(20, 3)
    df = pd.DataFrame(data, columns=['x', 'y', 'z'])
    st.dataframe(df)
    # リロードボタン
    st.button('Reload')


if __name__ == '__main__':
    main()

上記を実行すると以下のような表示が得られる。 実際、ボタンを押すと表示内容が変わるはず。

f:id:momijiame:20210509173746p:plain

ポイントは、Streamlit は毎回スクリプトを評価し直すように動作するところ。 つまり、ウィジェットで何らかのイベントが起こったら、Streamlit はページの内容を丸ごと評価し直すと考えれば良い。 上記のサンプルコードは、ボタンが押されるイベントによって、表示が丸ごと変わったわけだ。

ウィジェットは、一番最後の試行 (評価) のときに、ウィジェットがどのような状態になったかを返す場合がある。 ボタンも同様で、最後の試行でボタンが押されたか・押されていないかを真偽値 (bool) で返す。

ウィジェットの特性を利用すると、ウィジェットを設置する関数から返ってくる値を使ってインタラクティブな画面が作れる。 以下のサンプルコードでは、2 つのボタンを設置して、押されたボタンに対応するメッセージを表示している。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    if st.button('Top button'):
        # 最後の試行で上のボタンがクリックされた
        st.write('Clicked')
    else:
        # クリックされなかった
        st.write('Not clicked')

    if st.button('Bottom button'):
        # 最後の試行で下のボタンがクリックされた
        st.write('Clicked')
    else:
        # クリックされなかった
        st.write('Not clicked')


if __name__ == '__main__':
    main()

上記を実行すると、以下のような表示が得られる。 ボタンを押すと、表示が更新されて、押されたボタンに対応するメッセージが表示されるはず。

f:id:momijiame:20210509174659p:plain

チェックボックス

チェックボックスは、最後の試行でチェックされたか・されなかったかを元に処理を分岐できる。 以下のサンプルコードでは、チェックされたときだけデータフレームを表示している。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # チェックボックスにチェックが入っているかで処理を分岐する
    if st.checkbox('Show'):
        # チェックが入っているときはデータフレームを書き出す
        data = np.random.randn(20, 3)
        df = pd.DataFrame(data, columns=['x', 'y', 'z'])
        st.dataframe(df)


if __name__ == '__main__':
    main()

上記を実行すると、以下のような表示が得られる。 チェックボックスをチェックしたときだけデータフレームが表示される。

f:id:momijiame:20210509175353p:plain

ラジオボタン

同様に、最後の試行でチェックされたアイテムを元に処理をできるラジオボタン。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    selected_item = st.radio('Which do you like?',
                             ['Dog', 'Cat'])
    if selected_item == 'Dog':
        st.write('Wan wan')
    else:
        st.write('Nya- nya-')


if __name__ == '__main__':
    main()

上記を実行して得られる表示は以下のとおり。

f:id:momijiame:20210511182308p:plain

セレクトボックス

できることは基本的にラジオボタンと変わらないセレクトボックス。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    selected_item = st.selectbox('Which do you like?',
                                 ['Dog', 'Cat'])
    st.write(f'Selected: {selected_item}')


if __name__ == '__main__':
    main()

上記を実行して得られる表示は以下のとおり。

f:id:momijiame:20210511182423p:plain

単一のアイテムを選択するセレクトボックスの他に、複数のアイテムを選択できるマルチセレクトもある。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    selected_items = st.multiselect('What are your favorite characters?',
                                    ['Miho Nishizumi',
                                     'Saori Takebe',
                                     'Hana Isuzu',
                                     'Yukari Akiyama',
                                     'Mako Reizen',
                                     ])
    st.write(f'Selected: {selected_items}')


if __name__ == '__main__':
    main()

上記から得られる表示は以下のとおり。

f:id:momijiame:20210511182536p:plain

スライダー

スライダーは特定の範囲の中から値を選択するのに使える。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    age = st.slider(label='Your age',
                    min_value=0,
                    max_value=130,
                    value=30,
                    )
    st.write(f'Selected: {age}')


if __name__ == '__main__':
    main()

f:id:momijiame:20210511182710p:plain

デフォルトの値にタプルなどで 2 つの要素を指定すると、レンジを入力できるようになる。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    min_value, max_value = st.slider(label='Range selected',
                                     min_value=0,
                                     max_value=100,
                                     value=(40, 60),
                                     )
    st.write(f'Selected: {min_value} ~ {max_value}')


if __name__ == '__main__':
    main()

f:id:momijiame:20210511182804p:plain

ちなみに整数以外にも日付とかを指定するのにも使える。 ただ、そんなに使いやすいとは思えない。 日付とか時間は後述する専用のウィジェットを使った方が良いと思う。

# -*- coding: utf-8 -*-

from datetime import date

import streamlit as st


def main():
    birthday = st.slider('When is your birthday?',
                         min_value=date(1900, 1, 1),
                         max_value=date.today(),
                         value=date(2000, 1, 1),
                         format='YYYY-MM-DD',
                         )
    st.write('Birthday: ', birthday)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511182923p:plain

Date / Time インプット

日付や時間を扱う専用のウィジェットが続いて紹介する Date / Time インプット。

まずは Date インプットから。

# -*- coding: utf-8 -*-

from datetime import date

import streamlit as st


def main():
    birthday = st.date_input('When is your birthday?',
                             min_value=date(1900, 1, 1),
                             max_value=date.today(),
                             value=date(2000, 1, 1),
                             )
    st.write('Birthday: ', birthday)


if __name__ == '__main__':
    main()

ウィジェットをクリックするとカレンダーで日付を指定できるので使いやすい。

f:id:momijiame:20210511183139p:plain

Time インプットは一日の中の時間を指定できる。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    time = st.time_input(label='Your input:')
    st.write('input: ', time)


if __name__ == '__main__':
    main()

こちらもウィジェットをクリックすると時間のセレクタが表示されて使いやすい。

f:id:momijiame:20210511183244p:plain

文字列入力

一行の文字列の入力にはテキストインプットが使える。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    text = st.text_input(label='Message', value='Hello, World!')
    st.write('input: ', text)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511183447p:plain

同様に、複数行に渡る文字列を入力するときはテキストエリアを用いる。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    text = st.text_area(label='Multi-line message', value='Hello, World!')
    st.write('input: ', text)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511183536p:plain

数字入力

数字を入力するときはナンバーインプットを使う。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    n = st.number_input(label='What is your favorite number?',
                        value=42,
                        )
    st.write('input: ', n)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511183634p:plain

デフォルト値を浮動小数点型にすれば、小数を入力できる。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    n = st.number_input(label='What is your favorite number?',
                        value=3.14,
                        )
    st.write('input: ', n)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511183726p:plain

ファイルアップローダ

ファイルアップローダを使うと、クライアントのファイルをアプリケーションに渡すことができる。 以下のサンプルコードでは、渡されたファイルに含まれるテキストを UTF-8 として表示している。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    f = st.file_uploader(label='Upload file:')
    st.write('input: ', f)

    if f is not None:
        # XXX: 信頼できないファイルは安易に評価しないこと
        data = f.getvalue()
        text = data.decode('utf-8')
        st.write('contents: ', text)


if __name__ == '__main__':
    main()

適当なテキストファイルを使って動作確認してみよう。

$ echo "Hello, World" > ~/Downloads/greet.txt

ウィジェットをクリックしてファイルを選択すると、以下のように中身が表示される。

f:id:momijiame:20210511183950p:plain

受け取れるオブジェクトは streamlit.UploadedFile という、オープン済みのファイルライクオブジェクトになる。

カラーピッカー

ちょっと変わり種だけどカラーピッカーも用意されている。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    c = st.color_picker(label='Select color:')
    st.write('input: ', c)


if __name__ == '__main__':
    main()

f:id:momijiame:20210511184147p:plain

フロー制御

ウィジェットが色々とあると、ユーザの入力のバリデーションも考えることになる。 ここではフロー制御をするための機能を紹介する。

特定の条件に満たないときに処理を停止するサンプルコードを以下に示す。 このサンプルではテキストインプットに何か文字列が入っていないときに警告メッセージを出して処理を停止している。 処理の停止には streamlit.stop() 関数を使う。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    name = st.text_input(label='your name:')

    # バリデーション処理
    if len(name) < 1:
        st.warning('Please input your name')
        # 条件を満たないときは処理を停止する
        st.stop()

    st.write('Hello,', name, '!')


if __name__ == '__main__':
    main()

テキストインプットに何も入力されていない状態では、以下のように警告メッセージだけが表示されることになる。

f:id:momijiame:20210511184433p:plain

テキストインプットに文字列を入力すると、警告メッセージが消えて正常系の表示に切り替わる。

f:id:momijiame:20210511184601p:plain

レイアウトを調整する

ここからは画面のレイアウトを調整するための機能を見ていく。

カラム

はじめに紹介するのはカラム。 これは、ようするに画面を縦方向に分割して異なる内容を表示できるもの。

カラムを作るには streamlit.beta_columns() 関数を使う。 以下のサンプルコードでは画面を 3 列に分割している。 関数の返り値をコンテキストマネージャとして使うとデフォルトの出力先として使うこともできるし、オブジェクトに直接書き込むこともできる。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # カラムを追加する
    col1, col2, col3 = st.beta_columns(3)

    # コンテキストマネージャとして使う
    with col1:
        st.header('col1')

    with col2:
        st.header('col2')

    with col3:
        st.header('col3')

    # カラムに直接書き込むこともできる
    col1.write('This is column 1')
    col2.write('This is column 2')
    col3.write('This is column 3')


if __name__ == '__main__':
    main()

上記を実行して得られる表示は以下のとおり。

f:id:momijiame:20210511184758p:plain

コンテナ

続いて扱うのはコンテナ。 これは、不可視な仕切りみたいなもの。

以下のサンプルコードではコンテナの内と外にオブジェクトを書き込んで、結果を確認している。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    # コンテナを追加する
    container = st.beta_container()

    # コンテキストマネージャとして使うことで出力先になる
    with container:
        st.write('This is inside the container')
    # これはコンテナの外への書き込み
    st.write('This is outside the container')

    # コンテナに直接書き込むこともできる
    container = st.beta_container()
    container.write('1')
    st.write('2')
    # 出力順は後だがレイアウト的にはこちらが先に現れる
    container.write('3')


if __name__ == '__main__':
    main()

f:id:momijiame:20210511185155p:plain

入れ子にすることもできて、たとえば以下のサンプルコードではプレースホルダにコンテナを追加して、さらにそこにカラムを追加している。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    placeholder = st.empty()
    # プレースホルダにコンテナを追加する
    container = placeholder.beta_container()
    # コンテナにカラムを追加する
    col1, col2 = container.beta_columns(2)
    # それぞれのカラムに書き込む
    with col1:
        st.write('Hello, World')
    with col2:
        st.write('Konnichiwa, Sekai')


if __name__ == '__main__':
    main()

f:id:momijiame:20210511185307p:plain

エキスパンダ

デフォルトでは折りたたまれて非表示な領域を作るのにエキスパンダが使える。

# -*- coding: utf-8 -*-

import streamlit as st


def main():
    with st.beta_expander('See details'):
        st.write('Hidden item')


if __name__ == '__main__':
    main()

上記を実行して、以下はエキスパンダを展開した状態。

f:id:momijiame:20210511185400p:plain

サイドバー

ウィジェットやオブジェクトの表示をサイドバーに配置することもできる。 使い方は単純で、サイドバーに置きたいなと思ったら sidebar をつけて API を呼び出す。

以下のサンプルコードでは、サイドバーにボタンを配置している。 前述したとおり、streamlit.button()streamlit.sidebar.button() に変えるだけ。 同様に、streamlit.sidebar.dataframe() のように間に sidebar をはさむことで大体の要素はサイドバーに置ける。

# -*- coding: utf-8 -*-

import streamlit as st
import pandas as pd
import numpy as np


def main():
    # サイドバーにリロードボタンをつける
    st.sidebar.button('Reload')
    # サイドバーにデータフレームを書き込む
    data = np.random.randn(20, 3)
    df = pd.DataFrame(data, columns=['x', 'y', 'z'])
    st.sidebar.dataframe(df)


if __name__ == '__main__':
    main()

上記を実行すると、以下のようにサイドバーに要素が設置されることが確認できる。

f:id:momijiame:20210512222116p:plain

オブジェクトの docstring を表示する

Streamlit はスクリプトの変更を検出して自動でリロードしてくれるため、基本的には WebUI を見ながら開発していくことになる。 そんなとき、この関数またはメソッドの使い方なんだっけ?みたいな場面では streamlit.help() を使うと良い。 オブジェクトの docstring を表示してくれる。

# -*- coding: utf-8 -*-

import pandas as pd

import streamlit as st


def main():
    st.help(pd.DataFrame)


if __name__ == '__main__':
    main()

まあ自動補完とかドキュメント表示をサポートしてる IDE なんかで開発するときは、あんまり使わないかもしれないけど。

f:id:momijiame:20210512222544p:plain

単一のスクリプトで複数のアプリケーションを扱う

Streamlit は基本的に複数のページから成るアプリケーションを作ることができない。 では、複数のアプリケーションを単一のスクリプトで扱うことができないか、というとそうではない。 これは、ウィジェットの状態に応じて表示するアプリケーションを切り替えてやることで実現できる。

以下のサンプルコードでは、セレクトボックスの状態に応じて実行する関数を切り替えている。 それぞれの関数が、それぞれのアプリケーションになっていると考えてもらえれば良い。

# -*- coding: utf-8 -*-

import streamlit as st


def render_gup():
    """GuP のアプリケーションを処理する関数"""
    character_and_quotes = {
        'Miho Nishizumi': 'パンツァーフォー',
        'Saori Takebe': 'やだもー',
        'Hana Isuzu': '私この試合絶対勝ちたいです',
        'Yukari Akiyama': '最高だぜ!',
        'Mako Reizen': '以上だ',
    }
    selected_items = st.multiselect('What are your favorite characters?',
                                    list(character_and_quotes.keys()))
    for selected_item in selected_items:
        st.write(character_and_quotes[selected_item])


def render_aim_for_the_top():
    """トップ!のアプリケーションを処理する関数"""
    selected_item = st.selectbox('Which do you like more in the series?',
                                 [1, 2])
    if selected_item == 1:
        st.write('me too!')
    else:
        st.write('2 mo ii yo ne =)')


def main():
    # アプリケーション名と対応する関数のマッピング
    apps = {
        '-': None,
        'GIRLS und PANZER': render_gup,
        'Aim for the Top! GunBuster': render_aim_for_the_top,
    }
    selected_app_name = st.sidebar.selectbox(label='apps',
                                             options=list(apps.keys()))

    if selected_app_name == '-':
        st.info('Please select the app')
        st.stop()

    # 選択されたアプリケーションを処理する関数を呼び出す
    render_func = apps[selected_app_name]
    render_func()


if __name__ == '__main__':
    main()

上記を実行して得られる表示を以下に示す。

f:id:momijiame:20210512223221p:plain

f:id:momijiame:20210512223230p:plain

f:id:momijiame:20210512223239p:plain

ちなみに、呼び出す関数も 1 つのスクリプトにまとまっている必要はない。 別のモジュールに切り出して、スクリプトではそれをインポートして使うこともできる。 それならコードの見通しもさほど悪くはならないはず。

スクリプトでコマンドライン引数を受け取る

Streamlit のスクリプトにコマンドライン引数を渡したいときもある。 ここでは、そのやり方を紹介する。

Argparse

まずは Python の標準ライブラリにある Argparse を使う場合。 スクリプトを書く時点では特に Streamlit かどうかを意識する必要はない。 一般的な使い方と同じように引数を設定してパースして使うだけ。

# -*- coding: utf-8 -*-

import argparse

import streamlit as st


def main():
    parser = argparse.ArgumentParser(description='parse argument example')
    # --message または -m オプションで文字列を受け取る
    parser.add_argument('--message', '-m', type=str, default='World')
    # 引数をパースする
    args = parser.parse_args()
    # パースした引数を表示する
    st.write(f'Hello, {args.message}!')


if __name__ == '__main__':
    main()

ただ、使う時点ではちょっと注意点がある。 スクリプトの後ろにオプションをつけると Streamlit の引数として認識されてしまう。

$ streamlit run example.py -m Sekai
Usage: streamlit run [OPTIONS] TARGET [ARGS]...
Try 'streamlit run --help' for help.

Error: no such option: -m

そこで -- を使って区切って、スクリプトに対する引数であることを明示的に示す。

$ streamlit run example.py -- -m Sekai

Click

続いてサードパーティ製のパッケージである Click を使う場合。 Click は純粋なコマンドラインパーサ以外の機能もあることから、スクリプトを記述する時点から注意点がある。 具体的には、デコレータで修飾したオブジェクトを呼び出すときに standalone_modeFalse に指定する。 こうすると、デフォルトでは実行が完了したときに exit() してしまう振る舞いを抑制できる。

# -*- coding: utf-8 -*-

import streamlit as st
import click


@click.command()
@click.option('--message', '-m', type=str, default='World')
def main(message):
    # パースした引数を表示する
    st.write(f'Hello, {message}!')


if __name__ == '__main__':
    # click.BaseCommand.main() メソッドが呼ばれる
    # デフォルトの動作では返り値を戻さずに exit してしまう
    # スタンドアロンモードを無効にすることで純粋なコマンドラインパーサとして動作する
    main(standalone_mode=False)

実行するときに Streamlit のオプションとの間に -- で区切りが必要なのは Argparse のときと同じ。

$ streamlit run example.py -- -m Sekai

参考

docs.streamlit.io

click.palletsprojects.com

Python: LightGBM の学習に使うデータ量と最適なイテレーション数の関係性について

XGBoost は同じデータセットとパラメータを用いた場合、学習に使うデータの量 (行数) と最適なイテレーション数が線形な関係にあることが経験的に知られている 1。 今回は、それが同じ GBDT (Gradient Boosting Decision Tree) の一手法である LightGBM にも適用できる経験則なのかを実験で確認する。

使った環境は次のとおり。

$ sw_vers          
ProductName:    macOS
ProductVersion: 11.2.3
BuildVersion:   20D91
$ python -V           
Python 3.9.2
$ pip list | grep -i lightgbm
lightgbm        3.2.0

もくじ

下準備

あらかじめ、必要なパッケージをインストールしておく。

$ pip install lightgbm scikit-learn seaborn

実験

以下に、実験用のサンプルコードを示す。 サンプルコードでは、sklearn.datasets.make_classification() を使って生成した擬似的な二値分類用のデータセットを使っている。 生成したデータセットから、一定の割合で学習用のデータを無作為抽出して、LightGBM のモデルを学習したときの特性を確認している。 なお、性能の評価は念のため Nested Validation (outer: stratified hold-out, inner: stratified 5-fold cv) にしている。 outer の予測には inner で学習させたモデルで averaging している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import time

import numpy as np
import pandas as pd
import lightgbm as lgb
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import log_loss


def main():
    # 疑似的な教師信号を作るためのパラメータ
    dist_args = {
        # データ点数
        'n_samples': 100_000,
        # 次元数
        'n_features': 100,
        # その中で意味のあるもの
        'n_informative': 20,
        # 重複や繰り返しはなし
        'n_redundant': 0,
        'n_repeated': 0,
        # タスクの難易度
        'class_sep': 0.65,
        # 二値分類問題
        'n_classes': 2,
        # 生成に用いる乱数
        'random_state': 42,
        # 特徴の順序をシャッフルしない (先頭の次元が informative になる)
        'shuffle': False,
    }
    # 教師データを作る
    x, y = make_classification(**dist_args)
    # Nested Validation (stratified hold-out -> stratified 5 fold cv)
    train_x, test_x, train_y, test_y = train_test_split(x, y,
                                                        test_size=0.3,
                                                        stratify=y,
                                                        shuffle=True,
                                                        random_state=42,
                                                        )
    folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    # 学習用のパラメータ
    lgb_params = {
        # タスク設定
        'objective': 'binary',
        # メトリック
        'metric': 'binary_logloss',
        # 乱数シード
        'seed': 42,
    }

    # 乱数シードを設定する
    np.random.seed(42)

    sampled_rows = []
    best_iterations = []
    test_metrics = []
    learning_times = []
    sampling_rates = np.arange(0.1, 1.0 + 1e-2, 0.1)
    for sampling_rate in sampling_rates:
        train_len = len(train_x)
        sampled_len = int(train_len * sampling_rate)
        sampled_rows.append(sampled_len)

        # 重複なしで無作為抽出する (本当はここも Stratified にした方が良い)
        sampled_indices = np.random.choice(np.arange(train_len),
                                           size=sampled_len,
                                           replace=False)
        sampled_train_x = train_x[sampled_indices]
        sampled_train_y = train_y[sampled_indices]
        train_dataset = lgb.Dataset(sampled_train_x, sampled_train_y)

        # 交差検証
        start_time = time.time()
        cv_result = lgb.cv(params=lgb_params,
                           train_set=train_dataset,
                           num_boost_round=10_000,
                           early_stopping_rounds=100,
                           verbose_eval=100,
                           folds=folds,
                           return_cvbooster=True,
                           )
        end_time = time.time()
        learning_time = end_time - start_time
        learning_times.append(learning_time)

        cvbooster = cv_result['cvbooster']
        best_iterations.append(cvbooster.best_iteration)

        # Fold Averaging でテストデータのメトリックを計算する
        pred_y_folds = cvbooster.predict(test_x)
        pred_y_avg = np.array(pred_y_folds).mean(axis=0)
        test_metric = log_loss(test_y, pred_y_avg)
        test_metrics.append(test_metric)

    # 生の値
    data = {
        'sampling_rates': sampling_rates,
        'sampled_rows': sampled_rows,
        'best_iterations': best_iterations,
        'learning_times': learning_times,
        'test_metrics': test_metrics,
    }
    df = pd.DataFrame(data)
    print(df)

    # グラフにプロットする
    fig = plt.figure(figsize=(8, 12))
    ax1 = fig.add_subplot(3, 1, 1)
    sns.lineplot(data=df,
                 x='sampling_rates',
                 y='best_iterations',
                 label='best iteration',
                 ax=ax1,
                 )
    ax1.grid()
    ax1.legend()
    ax2 = fig.add_subplot(3, 1, 2)
    sns.lineplot(data=df,
                 x='sampling_rates',
                 y='learning_times',
                 label='learning time (sec)',
                 ax=ax2,
                 )
    ax2.grid()
    ax2.legend()
    ax3 = fig.add_subplot(3, 1, 3)
    sns.lineplot(data=df,
                 x='sampling_rates',
                 y='test_metrics',
                 label='test metric (logloss)',
                 ax=ax3,
                 )
    ax3.grid()
    ax3.legend()

    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。 計算リソースにもよるけど、それなりに時間がかかるはず。

$ python lgbiter.py 
[LightGBM] [Info] Number of positive: 2764, number of negative: 2836
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.003497 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 25500

...

[1500] cv_agg's binary_logloss: 0.118727 + 0.00302704
[1600]    cv_agg's binary_logloss: 0.118301 + 0.00281247
[1700] cv_agg's binary_logloss: 0.117938 + 0.00278925
   sampling_rates  sampled_rows  best_iterations  learning_times  test_metrics
0             0.1          7000              342        6.734761      0.189618
1             0.2         14000              634       12.412657      0.157727
2             0.3         21000              849       18.421927      0.134406
3             0.4         28000             1018       22.645187      0.129939
4             0.5         35000             1162       27.784236      0.122941
5             0.6         42000             1327       33.731716      0.115750
6             0.7         49000             1567       42.821615      0.113003
7             0.8         56000             1614       48.171218      0.109459
8             0.9         63000             1650       60.064258      0.107337
9             1.0         70000             1681       63.199017      0.104814

完了すると、以下のようなグラフが得られる。

f:id:momijiame:20210403001951p:plain
学習に使うデータ量と最適なイテレーション数の関係性

グラフから、LightGBM においても学習に使うデータ量と最適なイテレーション数は概ね線形な関係にあることが確認できた。 また、学習に使うデータ量と学習にかかる時間に関しても概ね線形な関係にあることが見て取れる。 一方で、学習に使うデータが増えても予測精度は非線形な改善にとどまっており、この点も直感には反していない。

いじょう。

Kaggleコンペティション チャレンジブック

Kaggleコンペティション チャレンジブック

Python: LightGBM の決定木を可視化して分岐を追ってみる

今回は、LightGBM が構築するブースターに含まれる決定木を可視化した上で、その分岐を追いかけてみよう。 その過程を通して、LightGBM の最終的な出力がどのように得られているのかを確認してみよう。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.2.3
BuildVersion:   20D91
$ python -V
Python 3.9.2

もくじ

下準備

まずは動作に必要なパッケージをインストールする。

決定木の可視化のために graphviz を、並列計算のために OpenMP を入れておく。

$ brew install graphviz libomp

そして、Python のパッケージを入れる。

$ pip install lightgbm scikit-learn graphviz matplotlib

二値分類問題 (乳がんデータセット)

まずは乳がんデータセットを使って二値分類問題を扱ってみよう。

LightGBM には、lightgbm.Booster オブジェクトに含まれる決定木を可視化する API として lightgbm.plot_tree() という関数が用意されている。 使うときは、tree_index オプションにイテレーション番号を指定することで、そのイテレーションで作成された決定木がグラフとして得られる。 以下のサンプルコードでは、学習させた lightgbm.Booster の先頭にある決定木をグラフにプロットした。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

from pprint import pprint

import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt


def main():
    # 乳がんデータセットを使う
    dataset = datasets.load_breast_cancer()
    x, y = dataset.data, dataset.target
    # ホールドアウト
    train_x, eval_x, train_y, eval_y = train_test_split(x, y,
                                                        stratify=y,
                                                        shuffle=True,
                                                        random_state=42)
    # LightGBM のデータセット表現にする
    lgb_train = lgb.Dataset(train_x, train_y,
                            feature_name=list(dataset.feature_names))
    lgb_eval = lgb.Dataset(eval_x, eval_y, reference=lgb_train)
    # 学習パラメータ
    lgb_params = {
        'objective': 'binary',
        'metric': "binary_logloss",
        'verbose': -1,
        'seed': 42,
        'deterministic': True,
    }
    # 学習する
    booster = lgb.train(params=lgb_params,
                        train_set=lgb_train,
                        valid_sets=[lgb_train, lgb_eval],
                        num_boost_round=1_000,
                        early_stopping_rounds=50,
                        verbose_eval=10,
                        )

    # 検証用データの先頭の情報を出力する
    head_row = eval_x[0]
    pprint(dict(zip(dataset.feature_names, head_row)))
    # 1 本目の決定木だけを使って予測してみる
    single_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=1)
    print(f'single tree pred: {single_tree_pred}')
    # 2 本目までの決定木を使って予測してみる
    double_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=2)
    print(f'double tree pred: {double_tree_pred}')

    # 先頭の決定木を可視化してみる
    rows = 2
    cols = 1
    # 表示する領域を準備する
    fig = plt.figure(figsize=(12, 6))
    # 一本ずつプロットしていく
    for i in range(rows * cols):
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.set_title(f'Booster index: {i}')
        lgb.plot_tree(booster=booster,
                      tree_index=i,
                      show_info='internal_value',
                      ax=ax,
                      )
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。 実行すると、検証用データの先頭行の説明変数と、それを学習済みモデルで予測させたときのスコアが出力される。 なお、以下は一例であり環境が変わると出力は異なる可能性がある。 (LightGBM の seeddeterministic オプションは、学習に CPU を使った上で完全に同一の環境でのみ結果が同じになることを仮定できる)

$ python binary.py 
Training until validation scores don't improve for 50 rounds
[10]  training's binary_logloss: 0.247243 valid_1's binary_logloss: 0.266197
[20]  training's binary_logloss: 0.116632 valid_1's binary_logloss: 0.158874
[30]  training's binary_logloss: 0.0581821    valid_1's binary_logloss: 0.113181
[40]  training's binary_logloss: 0.0286961    valid_1's binary_logloss: 0.0965949
[50]  training's binary_logloss: 0.0140411    valid_1's binary_logloss: 0.0985209
[60]  training's binary_logloss: 0.00667688   valid_1's binary_logloss: 0.10083
[70]  training's binary_logloss: 0.00317889   valid_1's binary_logloss: 0.104945
[80]  training's binary_logloss: 0.00160051   valid_1's binary_logloss: 0.115742
[90]  training's binary_logloss: 0.00082228   valid_1's binary_logloss: 0.129502
Early stopping, best iteration is:
[45]  training's binary_logloss: 0.0201096    valid_1's binary_logloss: 0.0943608
{'area error': 28.62,
 'compactness error': 0.01561,
 'concave points error': 0.009199,
 'concavity error': 0.01977,
 'fractal dimension error': 0.003629,
 'mean area': 493.8,
 'mean compactness': 0.1117,
 'mean concave points': 0.02995,
 'mean concavity': 0.0388,
 'mean fractal dimension': 0.06623,
 'mean perimeter': 82.51,
 'mean radius': 12.75,
 'mean smoothness': 0.1125,
 'mean symmetry': 0.212,
 'mean texture': 16.7,
 'perimeter error': 2.495,
 'radius error': 0.3834,
 'smoothness error': 0.007509,
 'symmetry error': 0.01805,
 'texture error': 1.003,
 'worst area': 624.1,
 'worst compactness': 0.1979,
 'worst concave points': 0.08045,
 'worst concavity': 0.1423,
 'worst fractal dimension': 0.08557,
 'worst perimeter': 93.63,
 'worst radius': 14.45,
 'worst smoothness': 0.1475,
 'worst symmetry': 0.3071,
 'worst texture': 21.74}
single tree pred: [0.66326872]
double tree pred: [0.69607225]

今回は、以下のようなグラフが得られた。

f:id:momijiame:20210319004235p:plain
乳がんデータセットを学習したモデルの先頭にある決定木

検証用データの先頭行が、どのリーフに落ちるのかを決定木から確認してみよう。 まずは、Booster index: 0 から。

最初の条件は worst perimeter <= 112.800 になっていて、データは 93.63 なので yes に分岐する。 次は worst concave points <= 0.146 で、0.08045なので yes に分岐する。 以下、同様に area error <= 34.75028.62 なので yesworst texture <= 30.04521.74 なので yesmean radius <= 13.87512.75 なので yesmean radius <= 12.31012.75 なので no。 最終的に leaf 7 に落ちて、内部的なスコアは 0.678 になった。

さて、この 0.678 という値は、先頭の決定木だけを使って予測した 0.66326872 というスコアとは少し乖離がある。 これは当然のことで、実際には内部的なスコアにシグモイド関数がかかるため。

Python のインタプリタを別で起動して確認してみよう。

$ python

次のようにシグモイド関数を定義する。

>>> import numpy as np
>>> def sigmoid(x):
...     return 1. / (1. + np.exp(-x))
... 

leaf 7 の内部的なスコアをシグモイド関数にかけると、最終的な予測とほぼ同じ値が得られる。 微妙にズレているのはグラフに出力するときの値がデフォルトだと小数点 3 桁で丸められているから。

>>> sigmoid(0.678)
0.6632921720482895

同じように 2 本 (イテレーション) 目の決定木も確認してみよう。 2 本目は分岐の詳細は省略するけど、最終的に leaf 0 に落ちて 0.151 というスコアになる。 2 本目までの決定木を使った予測は、各決定木から得られる内部的なスコアを足してシグモイド関数にかければ良い。

>>> sigmoid(0.678 + 0.151)
0.6961434435735563

サンプルコードから得られた 0.69607225 というスコアと、ほぼ同じ結果が得られることがわかる。 以下、同様にすべてのイテレーションで作られた決定木のスコアを足し合わせていくことで最終的な結果が得られる。

回帰問題 (ボストンデータセット)

続いてはボストンデータセットを使って回帰問題を扱ってみる。 以下のサンプルコードは先ほどとやっていることはほとんど同じ。 問題が二値分類から回帰になっているだけ。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

from pprint import pprint

import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt


def main():
    # ボストンデータセットを使う
    dataset = datasets.load_boston()
    x, y = dataset.data, dataset.target
    # ホールドアウト
    train_x, eval_x, train_y, eval_y = train_test_split(x, y,
                                                        shuffle=True,
                                                        random_state=42)
    # LightGBM のデータセット表現にする
    lgb_train = lgb.Dataset(train_x, train_y,
                            feature_name=list(dataset.feature_names))
    lgb_eval = lgb.Dataset(eval_x, eval_y, reference=lgb_train)
    # 学習パラメータ
    lgb_params = {
        'objective': 'regression',
        'metric': "rmse",
        'verbose': -1,
        'seed': 42,
        'deterministic': True,
    }
    # 学習する
    booster = lgb.train(params=lgb_params,
                        train_set=lgb_train,
                        valid_sets=[lgb_train, lgb_eval],
                        num_boost_round=1_000,
                        early_stopping_rounds=50,
                        verbose_eval=10,
                        )

    # 検証用データの先頭の情報を出力する
    head_row = eval_x[0]
    pprint(dict(zip(dataset.feature_names, head_row)))
    # 1 本目の決定木だけを使って予測してみる
    single_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=1)
    print(f'single tree pred: {single_tree_pred}')
    # 2 本目までの決定木を使って予測してみる
    double_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=2)
    print(f'double tree pred: {double_tree_pred}')

    # 先頭の決定木を可視化してみる
    rows = 2
    cols = 1
    # 表示する領域を準備する
    fig = plt.figure(figsize=(12, 6))
    # 一本ずつプロットしていく
    for i in range(rows * cols):
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.set_title(f'Booster index: {i}')
        lgb.plot_tree(booster=booster,
                      tree_index=i,
                      show_info='internal_value',
                      ax=ax,
                      )
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python regression.py 
Training until validation scores don't improve for 50 rounds
[10]  training's rmse: 4.71003    valid_1's rmse: 4.87178
[20]  training's rmse: 3.18599    valid_1's rmse: 3.9085
[30]  training's rmse: 2.68799    valid_1's rmse: 3.66692
[40]  training's rmse: 2.3669 valid_1's rmse: 3.55354
[50]  training's rmse: 2.13835    valid_1's rmse: 3.41701
[60]  training's rmse: 1.96456    valid_1's rmse: 3.38303
[70]  training's rmse: 1.81511    valid_1's rmse: 3.35055
[80]  training's rmse: 1.68986    valid_1's rmse: 3.34631
[90]  training's rmse: 1.59394    valid_1's rmse: 3.34022
[100] training's rmse: 1.49454    valid_1's rmse: 3.30722
[110] training's rmse: 1.41423    valid_1's rmse: 3.30035
[120] training's rmse: 1.33056    valid_1's rmse: 3.28779
[130] training's rmse: 1.25246    valid_1's rmse: 3.26555
[140] training's rmse: 1.19406    valid_1's rmse: 3.25197
[150] training's rmse: 1.13264    valid_1's rmse: 3.24115
[160] training's rmse: 1.07332    valid_1's rmse: 3.23656
[170] training's rmse: 1.02584    valid_1's rmse: 3.22715
[180] training's rmse: 0.983137   valid_1's rmse: 3.2189
[190] training's rmse: 0.940608   valid_1's rmse: 3.22034
[200] training's rmse: 0.898673   valid_1's rmse: 3.22073
[210] training's rmse: 0.862312   valid_1's rmse: 3.22325
[220] training's rmse: 0.827644   valid_1's rmse: 3.22175
[230] training's rmse: 0.795422   valid_1's rmse: 3.22017
Early stopping, best iteration is:
[184] training's rmse: 0.966589   valid_1's rmse: 3.21349
{'AGE': 84.1,
 'B': 395.5,
 'CHAS': 0.0,
 'CRIM': 0.09178,
 'DIS': 2.6463,
 'INDUS': 4.05,
 'LSTAT': 9.04,
 'NOX': 0.51,
 'PTRATIO': 16.6,
 'RAD': 5.0,
 'RM': 6.416,
 'TAX': 296.0,
 'ZN': 0.0}
single tree pred: [23.08757859]
double tree pred: [23.24927528]

以下のようなグラフが得られる。

f:id:momijiame:20210319010222p:plain
ボストンデータセットを学習したモデルの先頭にある決定木

先ほどと同じように決定木の分岐を追いかけてみる。 分岐を辿ると、1 本目の決定木は leaf 10 に落ちて 23.088 というスコアが得られる。 これは先頭 1 本だけを使った予測と同じ値になっており、回帰では内部的なスコアがそのまま最終的な出力となることがわかる。

同様に 2 本目の分岐を辿ると leaf 11 に落ちて 0.162 というスコアになる。 2 本目までを使った予測は、両方の決定木のスコアを足し合わせることで得られる。

>>> 23.088 + 0.162
23.25

多値分類問題 (あやめデータセット)

続いてはあやめデータセットを使って多値分類問題を扱う。 基本的にはこれまでと変わらないけど、多値分類問題は内部的に作られる決定木の数が多い。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

from pprint import pprint

import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt


def main():
    # あやめデータセットを使う
    dataset = datasets.load_iris()
    x, y = dataset.data, dataset.target
    # ホールドアウト
    train_x, eval_x, train_y, eval_y = train_test_split(x, y,
                                                        stratify=y,
                                                        shuffle=True,
                                                        random_state=42)
    # LightGBM のデータセット表現にする
    lgb_train = lgb.Dataset(train_x, train_y,
                            feature_name=list(dataset.feature_names))
    lgb_eval = lgb.Dataset(eval_x, eval_y, reference=lgb_train)
    # 学習パラメータ
    lgb_params = {
        'objective': 'multiclass',
        'metric': 'softmax',
        'num_class': 3,
        'verbose': -1,
        'seed': 42,
        'deterministic': True,
    }
    # 学習する
    booster = lgb.train(params=lgb_params,
                        train_set=lgb_train,
                        valid_sets=[lgb_train, lgb_eval],
                        num_boost_round=1_000,
                        early_stopping_rounds=50,
                        verbose_eval=10,
                        )

    # 検証用データの先頭の情報を出力する
    head_row = eval_x[0]
    pprint(dict(zip(dataset.feature_names, head_row)))
    # 1 本目の決定木だけを使って予測してみる
    single_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=1)
    print(f'single tree pred: {single_tree_pred}')
    # 2 本目までの決定木を使って予測してみる
    double_tree_pred = booster.predict(data=[head_row],
                                       num_iteration=2)
    print(f'double tree pred: {double_tree_pred}')

    # 先頭の決定木を可視化してみる
    rows = 2
    cols = 3
    # 表示する領域を準備する
    fig = plt.figure(figsize=(14, 6))
    # 一本ずつプロットしていく
    for i in range(rows * cols):
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.set_title(f'Booster index: {i}')
        lgb.plot_tree(booster=booster,
                      tree_index=i,
                      show_info='internal_value',
                      ax=ax,
                      )
    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python multiclass.py
Training until validation scores don't improve for 50 rounds
[10]  training's multi_logloss: 0.287552  valid_1's multi_logloss: 0.366785
[20]  training's multi_logloss: 0.119475  valid_1's multi_logloss: 0.232667
[30]  training's multi_logloss: 0.0678466 valid_1's multi_logloss: 0.234263
[40]  training's multi_logloss: 0.0346539 valid_1's multi_logloss: 0.270086
[50]  training's multi_logloss: 0.016588  valid_1's multi_logloss: 0.350929
[60]  training's multi_logloss: 0.00939384    valid_1's multi_logloss: 0.384578
[70]  training's multi_logloss: 0.00567975    valid_1's multi_logloss: 0.414652
Early stopping, best iteration is:
[26]  training's multi_logloss: 0.0854995 valid_1's multi_logloss: 0.216513
{'petal length (cm)': 1.3,
 'petal width (cm)': 0.2,
 'sepal length (cm)': 4.4,
 'sepal width (cm)': 3.2}
single tree pred: [[0.4084366 0.2957817 0.2957817]]
double tree pred: [[0.47189453 0.26405274 0.26405274]]

以下のようなグラフが得られる。

f:id:momijiame:20210319010334p:plain
あやめデータセットを学習したモデルの先頭にある決定木

LightGBM では、多値分類問題を扱う際に「クラス数 x イテレーション数」本の決定木が作られる。 先頭にある「クラス数」本の決定木が、各クラスの出力を得るのに使われる。 今回の例でいえばあやめの品種は 3 種類なので、先頭の 3 本が 1 イテレーション目のそれぞれの品種に対応することになる。 同様に、4 ~ 6 本目が 2 イテレーション目のそれぞれの品種に対応する。 ようするに、上記のグラフでいうと縦に並んでいる決定木がそれぞれの品種 (クラス) に対応しているということ。

今回も決定木の分岐を追いかけてみよう。 1 本目の決定木は leaf 0 に落ちて -0.884 というスコアになる。 同様に、2 本目は leaf 0 に落ちて -1.207 になる。 3 本目は leaf 2 に落ちて -1.207 になる。

さて、1 イテレーション目の内部的なスコアは [-0.884, -1.207, -1.207] になった。 これを 1 イテレーション目の最終的な出力である [0.4084366 0.2957817 0.2957817] にするにはソフトマックス関数にかける。

以下のように、ソフトマックス関数を定義する。

>>> def softmax(x):
...     return np.exp(x) / np.sum(np.exp(x))
...

内部的なスコアをソフトマックス関数にかけてみよう。

>>> softmax([-0.884, -1.207, -1.207])
array([0.40850546, 0.29574727, 0.29574727])

すると、最終的な出力とほぼ同じ値になった。 2 イテレーション目以降の処理は、これまでと同じなので省略する。 足すだけ。

まとめ

LightGBM の決定木を可視化して分岐を追いかけることで最終的な予測がどのように得られるのかを確認できた。

Python: ipywidgets で Jupyter に簡単な UI を作る

Jupyter を使ってデータを可視化していると、似たようなグラフを何度も描くことがある。 そんなとき、変数の値を変更しながらグラフを描画するセルを実行しまくるのは効率があまりよくない。 そこで、今回は ipywidgets を使って簡単な UI を作ることで、Jupyter でインタラクティブな操作ができるようにしてみる。 グラフの描画には、今回は主に Matplotlib を使うことを想定している。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.2.2
BuildVersion:   20D80
$ python -V
Python 3.9.2
$ pip list | grep widgets   
ipywidgets          7.6.3
jupyterlab-widgets  1.0.0
widgetsnbextension  3.5.1

下準備

下準備として、必要なパッケージをインストールしておく。

$ pip install ipywidgets jupyterlab matplotlib

そして、JupyterLab を起動する。

$ jupyter-lab

ウィジェットを作る

ここからは、コードを JupyterLab のセル内で実行することを想定する。

まずは、ipywidgets パッケージを widgets という名前でインポートしておこう。 名前を変更しているのは、公式のサンプルコードがこのやり方をしているため。

import ipywidgets as widgets

たとえば、もっとも単純なサンプルとしてボタンを作ってみよう。 はじめに、widgets.Button クラスをインスタンス化する。

button = widgets.Button(description='Click me')

上記でインスタンス化したボタンを表示するには、たとえばセルの最後で評価させる方法がある。

button

f:id:momijiame:20210302194334p:plain
widgets.Button

一方で、上記のやり方はウィジェットが複数あったり、複雑なパターンを扱いにくい。 そのため IPython.display.display() 関数の引数にウィジェットを渡していくやり方がおそらく分かりやすいと思う。 JupyterLab だとインポートしないで使えるっぽいけど、念のためインポートした上でボタンを可視化するサンプルコードを以下に示す。

from IPython.display import display
display(button)

もちろん、見え方は先ほどと同じ。

ちなみに、ボタン以外にもウィジェットはたくさんある。 一覧はとても紹介しきれないので、以下の公式ページを見てもらいたい。

ipywidgets.readthedocs.io

イベントハンドラと標準出力

さて、先ほど作ったボタンは、押しても何も起こらない。 何も起こらない UI を作っても意味がないので、次はウィジェットにイベントハンドラを登録しよう。

たとえば、widgets.Button なら widgets.Button#on_click() というメソッドで、クリックされた時に発火するイベントハンドラを登録できる。 試しに、イベントハンドラの中で print() 関数を呼んでみよう。

button = widgets.Button(description='Click me')

def on_click_callback(clicked_button: widgets.Button) -> None:
    """ボタンが押されたときに発火するイベントハンドラ"""
    print('Clicked')  # イベントハンドラ内で print() 関数を呼んでみる

# ボタンにイベントハンドラを登録する
button.on_click(on_click_callback)
display(button)

さて、上記を実行して表示されたボタンをクリックしてみても、実は何も表示されない。 おや?と思いながらブラウザの下方に目を移すと、"Log" というペインに通知が出てくるはず。 このペインを開くと、print() 関数で出力した内容がそこに表示されていることが分かる。

f:id:momijiame:20210302195653p:plain
イベントハンドラ内で print() するとデフォルトではログに残る

これはこれで悪くはないけど、毎回ログのペインを確認しながら作業するのも微妙な感じ。 できればセルの出力に表示させたいので `widgets.Output`` というウィジェットを使う。 このウィジェットは ipywidgets を使っていると、かなり登場機会が多い。 後ほどグラフを描画するときにもお世話になる。

widgets.Output はいくつかの使い方がある。 以下では、コンテキストマネージャとして使っている。 コンテキストマネージャのスコープ内で print() 関数を呼ぶと、出力先が widgets.Output の描画エリアに向く。

button = widgets.Button(description='Click me')
# 標準出力を表示するエリアを用意する
output = widgets.Output(layour={'border': '1px solid black'})

def on_click_callback(clicked_button: widgets.Button) -> None:
    # コンテキストマネージャとして使う
    with output:
        # スコープ内の標準出力は Output に書き出される
        print('Clicked')

button.on_click(on_click_callback)
# Output も表示対象に入れる
display(button, output)

f:id:momijiame:20210302200533p:plain
widgets.Output

もう一つの使い方は、デコレータとしてコールバック関数をラップするやり方。 これだと、そのコールバック関数内でのデフォルトの標準出力が widgets.Output に向く。

button = widgets.Button(description='Click me')
output = widgets.Output(layout={'border': '1px solid black'})

# デコレータとして使うとデフォルトの向け先になる
@output.capture()
def on_click_callback(b: widgets.Button) -> None:
    print('Clicked')

button.on_click(on_click_callback)
display(button, output)

得られる結果は同じ。

ちなみに widgets.Output は、そのままだと内容が追記されていく。 もし内容を消去したいときは widgets.Output#clear_output() メソッドを呼び出せば良い。 以下のサンプルコードでは、ボタンをクリックしたタイミングで前回の内容を消去しつつ、時刻を表示している。

from datetime import datetime

button = widgets.Button(description='Click me')
output = widgets.Output(layour={'border': '1px solid black'})

def on_click_callback(clicked_button: widgets.Button) -> None:
    with output:
        # 表示エリアの内容を消去する
        output.clear_output()
        print(f'Clicked at {datetime.now()}')

button.on_click(on_click_callback)
display(button, output)

# 手動でイベントを発生させる
button.click()

f:id:momijiame:20210302200635p:plain
widgets.Output#clear_output()

デコレータの使い方のときは Output#capture() の引数で clear_output オプションを有効にすると良い。 これで、コールバック関数が呼ばれる毎に出力内容がクリアされる。

from datetime import datetime

button = widgets.Button(description='Click me')
output = widgets.Output(layout={'border': '1px solid black'})

# 関数が呼ばれる度に出力をクリアする
@output.capture(clear_output=True)
def on_click_callback(b: widgets.Button) -> None:
    print(f'Clicked at {datetime.now()}')

button.on_click(on_click_callback)
display(button, output)

button.click()

複数のウィジェットを連携させる

さて、実際に UI を書いていくと、複数のウィジェットを連携させることが多い。 多くのウィジェットは、選択されている値を value というアトリビュートで読み出すことができる。

以下のサンプルコードでは、widgets.Button を押したタイミングで widgets.Select で選択されているアイテムを widgets.Output に表示させている。

button = widgets.Button(description='Click me')
# 値を選択するセレクタ
select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
output = widgets.Output(layour={'border': '1px solid black'})

@output.capture()
def on_click_callback(clicked_button: widgets.Button) -> None:
    # セレクタで選択されているアイテムを使う
    print(f'Selected item: {select.value}')

button.on_click(on_click_callback)
display(select, button, output)

f:id:momijiame:20210302201325p:plain
別のウィジェットの内容を読み取る

ウィジェットをグローバルスコープに置かない

複数のウィジェットを扱うようになると、それらがグローバルスコープにあるとコードがどんどんスパゲッティになっていく。 複数のセルに複数のウィジェットを置くと特にやばい。 そのため、以下のように一連のウィジェットは関数スコープの中で作るようにした方が良いと思う。

def show_widgets():
    """ウィジェットを設定する関数"""
    button = widgets.Button(description='Click me')
    select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
    output = widgets.Output(layour={'border': '1px solid black'})

    @output.capture()
    def on_click_callback(clicked_button: widgets.Button) -> None:
        print(f'Selected item: {select.value}')

    button.on_click(on_click_callback)
    display(select, button, output)

# ウィジェットを表示する
show_widgets()

ただ、これだと関数スコープ内にあるウィジェットが GC に拾われないか心配だったけど、とりあえず大丈夫そう。 どこに参照が生き残るのかはちょっと気になるね。

もしどうしても気になるようなら、以下のように widgets.VBox という複数のウィジェットをまとめるウィジェットを使うのはどうだろう。 これなら、少なくともグローバルスコープにウィジェットの参照が残るので、GC に拾われないことは担保できるはず。

def show_widgets() -> widgets.VBox:
    """ウィジェットを設定する関数"""
    button = widgets.Button(description='Click me')
    select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
    output = widgets.Output(layour={'border': '1px solid black'})

    @output.capture()
    def on_click_callback(clicked_button: widgets.Button) -> None:
        print(f'Selected item: {select.value}')

    button.on_click(on_click_callback)
    # 一連のウィジェットを VBox にまとめて返す
    return widgets.VBox([button, select, output])

# ウィジェットを表示する
box = show_widgets()
display(box)

値の変更を監視する

先ほど使ったセレクタのように、値を選択したり入力する系のウィジェットは、入力値が変更されたタイミングでイベントを発火させたいことが多い。 そのような場合、ウィジェットによっては observe() というメソッドでイベントハンドラを登録できる。 以下のサンプルコードでは、widgets.Select で選んだ内容を widgets.Output に表示している。

from traitlets.utils.bunch import Bunch

def show_widgets():
    select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
    output = widgets.Output(layour={'border': '1px solid black'})

    @output.capture()
    def on_value_change(change: Bunch) -> None:
        # 値が変更されたイベントを扱う
        if change['name'] == 'value':
            output.clear_output()
            # 変更前と変更後の値を出力する
            old_value = change['old']
            new_value = change['new']
            print(f'value changed: {old_value} -> {new_value}')

    # 値の変更を監視する
    select.observe(on_value_change)
    display(select, output)

show_widgets()

f:id:momijiame:20210302204328p:plain
値の変更を契機にイベントを発火する

ただ、実際に UI を作ってみると複数のウィジェットが連携することも多い。 その場合は、個別に observe() すると煩雑になりがち。 そういったときは ipywidgets.interactive() を使った方がコードの見通しが良くなると思う。 以下では、IntSliderSelect の両方の値の変更を監視している。

def show_widgets():
    slider = widgets.IntSlider(value=50, min=1, max=100, description='slider:')
    select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
    output = widgets.Output(layour={'border': '1px solid black'})

    @output.capture(clear_output=True)
    def on_value_change(select_value: str, slider_value: int) -> None:
        print(f'value changed: {select_value=}, {slider_value=}')

    # 複数のウィジェットの変更を一度に監視できる
    widgets.interactive(on_value_change, select_value=select, slider_value=slider)
    display(select, slider, output)

show_widgets()

f:id:momijiame:20210302205749p:plain
ipywidgets.interactive() で複数のウィジェットを監視する

ウィジェットの配置を工夫する

デフォルトでは display() 関数に渡された順序で、垂直にウィジェットが配置されていく。 しかし、それだと操作が分かりにくいこともあるので配置を工夫する方法について書く。

たとえば、ウィジェットを横に並べたいときは widgets.Box または widgets.HBox でウィジェットをまとめると良い。 以下のサンプルコードではスライダーとセレクタを widgets.Box を使って横に並べている。 なお、既に登場しているとおり widgets.VBox を使うと縦に並べることができる。

def show_widgets():
    slider = widgets.IntSlider(value=50, min=1, max=100, description='slider:')
    select = widgets.Select(options=['Apple', 'Banana', 'Cherry'])
    output = widgets.Output(layour={'border': '1px solid black'})

    @output.capture(clear_output=True)
    def on_value_change(select_value: str, slider_value: int) -> None:
        print(f'value changed: {select_value=}, {slider_value=}')

    widgets.interactive(on_value_change, select_value=select, slider_value=slider)

    # 横に並べるときはウィジェットを Box や HBox にまとめる
    box = widgets.Box([slider, select])
    display(box, output)

show_widgets()

f:id:momijiame:20210302210339p:plain
widgets.Box でウィジェットを横に並べる

他にも widgets.GridBoxwidgets.Layout を組み合わせてグリッドレイアウトを作ったり。

def show_widgets():
    labels = [widgets.Label(str(i)) for i in range(8)]
    # グリッドレイアウト
    grid_box = widgets.GridBox(labels,
                               layout=widgets.Layout(grid_template_columns="repeat(3, 100px)"))
    display(grid_box)

show_widgets()

f:id:momijiame:20210302210505p:plain
グリッドレイアウト

widgets.Tab を使えばタブを使った UI も作れる。

def show_widgets(num_of_tabs: int = 5):
    # タブ毎のウィジェット
    contents = [widgets.Label(f'This is tab {i}') for i in range(num_of_tabs)] 
    tab = widgets.Tab(children=contents)
    # タブのタイトルを設定する
    for i in range(num_of_tabs):
        tab.set_title(i, f'tab {i}')
    display(tab)

show_widgets()

f:id:momijiame:20210302210541p:plain
widgets.Tab

Matplotlib と連携させる

さて、やっとかって感じだけど Matplotlib との連携について書いていく。 基本的にはこれまでの延長線上にある。 ポイントは、display() 関数で Matplotlib の Figure オブジェクトを描画するところ。 このとき、描画先が widgets.Output オブジェクトになるようにスコープ内で呼んでやれば良い。

以下のサンプルコードではボタンを押す度にグラフの描画を更新している。

from matplotlib import pyplot as plt
import numpy as np


def show_widgets():
    button = widgets.Button(description='Refresh')
    # グラフの描画領域としての Output を用意する
    output = widgets.Output()
    # アクティブな Axes オブジェクトを取得する
    ax = plt.gca()

    # NOTE: デコレータを使っても問題はない
    # @output.capture(clear_output=True, wait=True)
    def on_click(b: widgets.Button) -> None:
        # 前回の描画内容をクリアする
        ax.clear()
        # 描画し直す
        rand_x = np.random.randn(100)
        rand_y = np.random.randn(100)
        ax.plot(rand_x, rand_y, '+')
        #  Output に書き出す
        with output:
            output.clear_output(wait=True)
            display(ax.figure)

    button.on_click(on_click)
    display(button, output)

    # セルの出力に描画されるのを抑制するために一旦アクティブな Figure を破棄する
    plt.close()

    # 最初の 1 回目の描画を手動でトリガーする
    button.click()
    
show_widgets()

f:id:momijiame:20210304222745p:plain
Matplotlib の描画をウィジェットのイベントで更新する

もうひとつサンプルコードを示す。 こちらでは、widgets.IntRangeSlider の値が変更されたタイミングでグラフを描画し直している。 また、先ほどとの違いとしてプロットされるグラフのサイズを大きくしている。

from __future__ import annotations

def show_widgets():
    MIN, MAX, STEPS = 1, 11, 1000
    range_slider = widgets.IntRangeSlider(value=[2, 4], min=MIN, max=MAX, step=1, description='plot range:')
    output = widgets.Output()
    # サイズを指定する場合
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot()

    @output.capture(clear_output=True, wait=True)
    def on_value_change(selected_range: tuple(int, int)) -> None:
        # サイン波を作る
        x = np.linspace(MIN, MAX, num=STEPS)
        y = np.sin(x)
        # 選択範囲を取り出す
        selected_lower, selected_upper = selected_range
        lower = (selected_lower - MIN) * (STEPS // (MAX - MIN))
        upper = (selected_upper - MIN) * (STEPS // (MAX - MIN))
        # 前回の描画内容をクリアする
        ax.clear()
        # 描画する
        ax.plot(x[lower:upper], y[lower:upper])
        # Output に書き出す
        display(ax.figure)

    widgets.interactive(on_value_change, selected_range=range_slider)
    display(range_slider, output)

    plt.close()

show_widgets()

f:id:momijiame:20210304224020p:plain
ウィジェットの変更を元にグラフを描画する

別のウィジェットのイベントでウィジェットの値を更新する

あとはもはや蛇足っぽいけど、あるウィジェットのイベントを契機に別のウィジェットの値を変更するのもよくあるよねってことで。 以下のサンプルコードでは widgets.Button をクリックすると widgets.Text に入力された内容が widgets.Select に追加されていく。

def show_widgets():
    text = widgets.Text()
    select = widgets.Select(options=[])
    output = widgets.Output(layour={'border': '1px solid black'})
    button = widgets.Button(description='Add')

    def on_click_callback(b: widgets.Button) -> None:
        # テキストの入力を選択肢として追加する
        select.options = list(select.options) + [text.value]

    button.on_click(on_click_callback)
    display(text, button, select)

show_widgets()

f:id:momijiame:20210302211057p:plain
別のウィジェットの変更を元にウィジェットを変更する

複数のウィジェットで値を同期する

あとはあるウィジェットと別のウィジェットの値を同期させる、みたいなことも widgets.jslink() 関数を使ってできる。 以下は時系列の情報を表示させるのに便利な widgets.Playwidgets.IntSlider の値を同期させている。

def show_widgets():
    # アニメーション制御
    play = widgets.Play(
        value=50,
        min=1,
        max=100,
        step=1,
        interval=500,  # 更新間隔 (ミリ秒)
        description="play:",
    )
    slider = widgets.IntSlider(value=50, min=1, max=100, description='slider:')
    output = widgets.Output(layour={'border': '1px solid black'})

    # ウィジェットの値を連動させる
    widgets.jslink((play, 'value'), (slider, 'value'))

    @output.capture(clear_output=True)
    def on_value_change(slider_value: int) -> None:
        print(f'value changed: {slider_value=}')

    widgets.interactive(on_value_change, slider_value=slider)
    display(play, slider, output)

show_widgets()

f:id:momijiame:20210302211230p:plain
ウィジェット同士の値を同期させる

とりあえず、そんな感じで。

Python: TensorFlow/Keras で Word2Vec の SGNS を実装してみる

以前のエントリで、Word2Vec の CBOW (ContinuousBagOfWords) モデルを TensorFlow/Keras で実装した。 CBOW は、コンテキスト (周辺語) からターゲット (入力語) を推定する多値分類のタスクが考え方のベースになっている。

blog.amedama.jp

今回扱うのは、CBOW と対を成すモデルの Skip Gram をベースにした SGNS (Skip Gram with Negative Sampling) になる。 Skip Gram では、CBOW とは反対にターゲット (入力語) からコンテキスト (周辺語) を推定する多値分類のタスクを扱う。 ただし、with Negative Sampling と付くことで、タスクを多値分類から二値分類にして計算量を削減している。 SGNS では、ターゲットとコンテキストを入力にして、それらが共起 (Co-occurrence) するか否かを推定することになる。 コーパスを処理して実際に共起する単語ペアを正例、出現頻度を元にランダムにサンプルした単語ペアを共起していない負例としてモデルに与える。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.2.1
BuildVersion:   20D74
$ python -V  
Python 3.8.7

下準備

まずは、必要なパッケージをインストールする。

$ pip install tensorflow gensim scipy tqdm

そして、コーパスとして PTB (Penn Treebank) データセットをダウンロードしておく。

$ wget https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt

サンプルコード

早速だけど、サンプルコードを以下に示す。 いくらかマジックナンバーがコードに残ってしまっていて、あんまりキレイではないかも。 各エポックの終了時には、WordSim353 データセットを使って単語間類似度で単語埋め込みを評価している。 また、学習が終わった後には、いくつかの単語で類似する単語や類推語の結果を確認している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import re
from itertools import count
from typing import Iterable
from typing import Iterator
from functools import reduce
from functools import partial
from collections import Counter

import numpy as np  # type: ignore
import tensorflow as tf  # type: ignore
from tensorflow.keras import Model  # type: ignore
from tensorflow.keras.layers import Embedding  # type: ignore
from tensorflow import Tensor  # type: ignore
from tensorflow.data import Dataset  # type: ignore
from tensorflow.keras.optimizers import Adam  # type: ignore
from tensorflow.keras.layers import Dense  # type: ignore
from tensorflow.keras.models import Sequential  # type: ignore
from tensorflow.keras.callbacks import Callback  # type: ignore
from tensorflow.keras.layers import Dot  # type: ignore
from tensorflow.keras.layers import Flatten  # type: ignore
from tensorflow.keras.losses import BinaryCrossentropy  # type: ignore
from gensim.test.utils import datapath  # type: ignore
from scipy.stats import pearsonr  # type: ignore
from tqdm import tqdm  # type: ignore


class SkipGramWithNegativeSampling(Model):
    """Word2Vec の SGNS モデルを実装したクラス"""

    def __init__(self, vocab_size: int, embedding_size: int):
        super().__init__()

        # ターゲット (入力語) の埋め込み
        self.target_embedding = Embedding(input_dim=vocab_size,
                                          input_shape=(1, ),
                                          output_dim=embedding_size,
                                          name='word_embedding',
                                          )
        # コンテキスト (周辺語) の埋め込み
        self.context_embedding = Embedding(input_dim=vocab_size,
                                           input_shape=(1, ),
                                           output_dim=embedding_size,
                                           name='context_embedding',
                                           )

        self.dot = Dot(axes=1)
        self.output_layer = Sequential([
            Flatten(),
            Dense(1, activation='sigmoid'),
        ])

    def call(self, inputs: Tensor) -> Tensor:
        # ターゲットのベクトルを取り出す
        target_label = inputs[:, 0]
        target_vector = self.target_embedding(target_label)
        # コンテキストのベクトルを取り出す
        context_label = inputs[:, 1]
        context_vector = self.context_embedding(context_label)
        # ターゲットとコンテキストの内積を計算する
        x = self.dot([target_vector, context_vector])
        # 共起したか・していないかを二値の確率にする
        prediction = self.output_layer(x)
        return prediction


def cosine_similarity_one_to_one(x, y):
    """1:1 のコサイン類似度"""
    nx = x / np.sqrt(np.sum(x ** 2))
    ny = y / np.sqrt(np.sum(y ** 2))
    return np.dot(nx, ny)


class WordSimilarity353Callback(Callback):
    """WordSim353 データセットを使って単語間の類似度を評価するコールバック"""

    def __init__(self, word_id_table: dict[str, int]):
        super().__init__()

        self.word_id_table = word_id_table
        self.model = None

        # 評価用データを読み込む
        self.eval_data = []
        wordsim_filepath = datapath('wordsim353.tsv')
        with open(wordsim_filepath, mode='r') as fp:
            # 最初の 2 行はヘッダなので読み飛ばす
            fp.readline()
            fp.readline()
            for line in fp:
                word1, word2, sim_score = line.strip().split('\t')
                self.eval_data.append((word1, word2, float(sim_score)))

    def set_model(self, model):
        self.model = model

    def on_epoch_end(self, epoch, logs=None):
        # モデルから学習させたレイヤーの重みを取り出す
        model_layers = {layer.name: layer for layer in self.model.layers}
        embedding_layer = model_layers['word_embedding']
        word_vectors = embedding_layer.weights[0].numpy()

        # 評価用データセットに含まれる単語間の類似度を計算する
        labels = []
        preds = []
        for word1, word2, sim_score in self.eval_data:
            # Out-of-Vocabulary な単語はスキップ
            if word1 not in self.word_id_table or word2 not in self.word_id_table:
                continue

            # コサイン類似度を計算する
            word1_vec = word_vectors[self.word_id_table[word1]]
            word2_vec = word_vectors[self.word_id_table[word2]]
            pred = cosine_similarity_one_to_one(word1_vec, word2_vec)
            preds.append(pred)
            # 正解ラベル
            labels.append(sim_score)

        # ピアソンの相関係数を求める
        r_score = pearsonr(labels, preds)[0]
        print(f'Pearson\'s r score with WordSim353: {r_score}')


def load_corpus(filepath: str) -> Iterator[str]:
    """テキストファイルからコーパスを読み出す"""
    with open(filepath, mode='r') as fp:
        for line in fp:
            # 改行コードは取り除く
            yield line.rstrip()


def sentences_to_words(sentences: Iterable[str], lower: bool = True) -> Iterator[list[str]]:
    """文章を単語に分割する"""
    for sentence in sentences:
        if lower:
            sentence = sentence.lower()
        words = re.split('\\W+', sentence)
        yield [word for word in words if len(word) > 0]  # 空文字は取り除く


def word_id_mappings(sentences: Iterable[Iterable[str]]) -> dict[str, int]:
    """単語を ID に変換する対応テーブルを作る"""
    counter = count(start=0)

    word_to_id = {}
    for sentence in sentences:
        for word in sentence:

            if word in word_to_id:
                # 登録済みの単語はスキップする
                continue

            # 単語の識別子を採番する
            word_id = next(counter)
            word_to_id[word] = word_id

    return word_to_id


def words_to_ids(sentences: Iterable[list[str]], word_to_id: dict[str, int]) -> Iterator[list[int]]:
    # 単語を対応するインデックスに変換する
    for words in sentences:
        # NOTE: Out-of-Vocabulary への対応がない
        yield [word_to_id[word] for word in words]


def extract_contexts(word_ids: Tensor, window_size: int) -> Tensor:
    """コンテキストの単語をラベル形式で得る"""
    target_ids = word_ids[:-window_size]
    context_ids = word_ids[window_size:]
    # ウィンドウサイズ分ずらした Tensor 同士をくっつける
    co_occurrences = tf.transpose([target_ids, context_ids])
    # 逆順でも共起したのは同じ
    reversed_co_occurrences = tf.transpose([context_ids, target_ids])
    concat_co_occurrences = tf.concat([co_occurrences,
                                       reversed_co_occurrences],
                                      axis=0)
    # ラベル (正例なので 1)
    labels = tf.ones_like(concat_co_occurrences[:, 0],
                          dtype=tf.int8)
    return concat_co_occurrences, labels


def positive_pipeline(ds: Dataset, window_size: int) -> Dataset:
    """正例を供給するパイプライン"""

    ctx_ds_list = []
    for window in range(1, window_size + 1):
        partialed = partial(extract_contexts, window_size=window)
        # ウィンドウサイズごとに共起した単語を抽出する
        mapped_ds = ds.map(partialed,
                           num_parallel_calls=tf.data.AUTOTUNE,
                           deterministic=False)
        ctx_ds_list.append(mapped_ds)

    # すべての Dataset をつなげる
    context_ds = reduce(lambda l_ds, r_ds: l_ds.concatenate(r_ds),
                        ctx_ds_list)
    return context_ds


def word_frequency(sentences: Iterable[Iterable[str]], word_id_table: dict[str, int]) -> dict[str, int]:
    """単語の出現頻度を調べる"""
    counter = Counter(word for words in sentences for word in words)
    id_count = {word_id_table[word]: count for word, count in counter.items()}
    # ID 順でソートされた出現頻度
    sorted_freq = np.array([count for _, count in sorted(id_count.items(), key=lambda x: x[0])],
                           dtype=np.int32)
    return sorted_freq


def noisy_word_pairs(word_proba: list[float], eps: float = 1e-6) -> Iterator[Tensor]:
    """単語の出現頻度を元にネガティブサンプルの単語ペアを生成するジェネレータ関数"""
    p = tf.constant(word_proba) + eps
    logits = tf.math.log([p, p])
    while True:
        word_pair = tf.random.categorical(logits, num_samples=2**12)
        word_pair_t = tf.transpose(word_pair)
        # ラベル (負例なので 0)
        labels = tf.zeros_like(word_pair_t[:, 0], dtype=tf.int8)
        yield word_pair_t, labels


def negative_pipeline(sentences: Iterable[Iterable[str]], word_id_table: dict[str, int]) -> Dataset:
    """負例を供給するパイプライン"""
    # 単語の出現頻度からサンプリングテーブルを求める
    word_freq = word_frequency(sentences, word_id_table)
    word_proba = word_freq / np.sum(word_freq)
    # 0.75 乗することで、出現頻度の低い単語をちょっとだけ選ばれやすくする
    ADJUST_FACTOR = 0.75
    adjusted_word_proba = np.power(word_proba, ADJUST_FACTOR)
    adjusted_word_proba /= np.sum(adjusted_word_proba)
    # 単語の出現頻度を元にノイジーワードペアを生成する
    negative_ds = Dataset.from_generator(lambda: noisy_word_pairs(adjusted_word_proba),
                                         (tf.int32, tf.int8))

    return negative_ds


def batched_concat(pos_tensor: Tensor, neg_tensor: Tensor) -> Tensor:
    """正例と負例を直列に結合する関数"""
    pos_word_pairs, pos_labels = pos_tensor
    neg_word_pairs, neg_labels = neg_tensor
    word_pairs = tf.concat((pos_word_pairs, neg_word_pairs), axis=0)
    labels = tf.concat((pos_labels, neg_labels), axis=0)
    return word_pairs, labels


def skip_grams_with_negative_sampling_dataset(positive_ds: Dataset,
                                              negative_ds: Dataset,
                                              negative_sampling_ratio: int):
    """データセットで共起した単語ペアを正例、出現頻度を元にランダムに選んだ単語ペアを負例として供給するパイプライン"""
    positive_batch_size = 1024  # 正例の供給単位
    batched_pos_ds = positive_ds.unbatch().batch(positive_batch_size)
    batched_neg_ds = negative_ds.unbatch().batch(positive_batch_size * negative_sampling_ratio)
    zipped_ds = tf.data.Dataset.zip((batched_pos_ds, batched_neg_ds))
    concat_ds = zipped_ds.map(batched_concat,
                              num_parallel_calls=tf.data.AUTOTUNE,
                              deterministic=False).unbatch()
    # バッチサイズ単位でシャッフルする
    shuffle_buffer_size = positive_batch_size * (negative_sampling_ratio + 1)
    shuffled_ds = concat_ds.shuffle(buffer_size=shuffle_buffer_size)
    return shuffled_ds


def cosine_similarity_matrix(word_vectors: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    """N:N のコサイン類似度を計算する"""
    word_norm = np.sqrt(np.sum(word_vectors ** 2, axis=1)).reshape(word_vectors.shape[0], -1)
    normalized_word_vectors = word_vectors / (word_norm + eps)
    cs_matrix = np.dot(normalized_word_vectors, normalized_word_vectors.T)
    return cs_matrix


def most_similar_words(similarities: np.ndarray, top_n: int = 5):
    """コサイン類似度が最も高い単語の ID を得る"""
    similar_word_ids = np.argsort(similarities)[::-1]
    top_n_word_ids = similar_word_ids[:top_n]
    top_n_word_sims = similarities[similar_word_ids][:top_n]
    return zip(top_n_word_ids, top_n_word_sims)


def cosine_similarity_one_to_many(word_vector: np.ndarray,
                                  word_vectors: np.ndarray,
                                  eps: float = 1e-8):
    """1:N のコサイン類似度"""
    normalized_word_vector = word_vector / np.sqrt(np.sum(word_vector ** 2))
    word_norm = np.sqrt(np.sum(word_vectors ** 2, axis=1)).reshape(word_vectors.shape[0], -1)
    normalized_word_vectors = word_vectors / (word_norm + eps)
    return np.dot(normalized_word_vector, normalized_word_vectors.T)


def main():
    # Penn Treebank コーパスを読み込む
    train_sentences = load_corpus('ptb.train.txt')

    # コーパスを単語に分割する
    train_corpus_words = list(sentences_to_words(train_sentences))

    # 単語に ID を振る
    word_to_id = word_id_mappings(train_corpus_words)

    # コーパスの語彙数
    vocab_size = len(word_to_id.keys())

    # データセットを準備する
    # ID に変換したコーパスを行ごとに読み出せるデータセット
    train_word_ids_ds = Dataset.from_generator(lambda: words_to_ids(train_corpus_words, word_to_id),
                                               tf.int32,
                                               output_shapes=[None])

    # 共起したと判断する単語の距離
    CONTEXT_WINDOW_SIZE = 5
    positive_ds = positive_pipeline(train_word_ids_ds, window_size=CONTEXT_WINDOW_SIZE)
    negative_ds = negative_pipeline(train_corpus_words, word_to_id)

    # 正例に対する負例の比率 (一般的に 5 ~ 10)
    NEGATIVE_SAMPLING_RATIO = 5
    train_ds = skip_grams_with_negative_sampling_dataset(positive_ds,
                                                         negative_ds,
                                                         NEGATIVE_SAMPLING_RATIO)

    # モデルとタスクを定義する
    EMBEDDING_SIZE = 100  # 埋め込み次元数
    criterion = BinaryCrossentropy()
    optimizer = Adam(learning_rate=1e-2)
    model = SkipGramWithNegativeSampling(vocab_size, EMBEDDING_SIZE)
    model.compile(optimizer=optimizer,
                  loss=criterion,
                  )

    # データセットを準備する
    TRAIN_BATCH_SIZE = 2 ** 14
    train_ds = train_ds.batch(TRAIN_BATCH_SIZE)
    train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    train_ds = train_ds.cache()

    print('caching train data...')
    num_of_steps_per_epoch = sum(1 for _ in tqdm(train_ds))
    print(f'{num_of_steps_per_epoch=}')
    train_ds = train_ds.repeat()

    callbacks = [
        # WordSim353 データセットを使って単語間の類似度を相関係数で確認する
        WordSimilarity353Callback(word_to_id),
    ]
    # 学習する
    model.fit(train_ds,
              steps_per_epoch=num_of_steps_per_epoch,
              epochs=5,
              callbacks=callbacks,
              verbose=1,
              )

    # モデルから学習させたレイヤーの重みを取り出す
    model_layers = {layer.name: layer for layer in model.layers}
    embedding_layer = model_layers['word_embedding']
    word_vectors = embedding_layer.weights[0].numpy()

    # 単語を表すベクトル間のコサイン類似度を計算する
    cs_matrix = cosine_similarity_matrix(word_vectors)
    # ID -> 単語
    id_to_word = {value: key for key, value in word_to_id.items()}

    # いくつか似ているベクトルを持った単語を確認してみる
    example_words = ['you', 'year', 'car', 'toyota']
    for target_word in example_words:
        # ID に変換した上で最も似ている単語とそのベクトルを取り出す
        print(f'The most similar words of "{target_word}"')
        target_word_id = word_to_id[target_word]
        similarities = cs_matrix[target_word_id, :]
        top_n_most_similars = most_similar_words(similarities, top_n=6)
        # 先頭は自分自身になるので取り除く
        next(top_n_most_similars)
        # 単語と類似度を表示する
        for rank, (similar_word_id, similarity) in enumerate(top_n_most_similars, start=1):
            similar_word = id_to_word[similar_word_id]
            print(f'TOP {rank}: {similar_word} = {similarity}')
        print('-' * 50)

    # いくつか類推語を確認してみる
    analogies = [
        ('king', 'man', 'woman'),
        ('took', 'take', 'go'),
        ('cars', 'car', 'child'),
        ('better', 'good', 'bad'),
    ]
    for word1, word2, word3 in analogies:
        print(f'The most similar words of "{word1}" - "{word2}" + "{word3}"')
        word1_vec = word_vectors[word_to_id[word1]]
        word2_vec = word_vectors[word_to_id[word2]]
        word3_vec = word_vectors[word_to_id[word3]]
        new_vec = word1_vec - word2_vec + word3_vec
        similarities = cosine_similarity_one_to_many(new_vec, word_vectors)
        top_n_most_similars = most_similar_words(similarities)
        # 単語と類似度を表示する
        for rank, (similar_word_id, similarity) in enumerate(top_n_most_similars, start=1):
            similar_word = id_to_word[similar_word_id]
            print(f'TOP {rank}: {similar_word} = {similarity}')
        print('-' * 50)


if __name__ == '__main__':
    main()

上記を実行してみよう。 コーパスの前処理に結構時間がかかる。 今回使った環境では 1 分 47 秒かかった。 学習に関しては、WordSim353 データセットを使った内省的評価で 4 エポック目にはサチる感じ。

$ python sgns.py
caching train data...
2802it [01:47, 26.01it/s]
num_of_steps_per_epoch=2802
Epoch 1/5
2802/2802 [==============================] - 50s 18ms/step - loss: 0.3845
Pearson's r score with WordSim353: 0.2879142572631919
Epoch 2/5
2802/2802 [==============================] - 47s 17ms/step - loss: 0.3412
Pearson's r score with WordSim353: 0.367370159567898
Epoch 3/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3307
Pearson's r score with WordSim353: 0.3898624474454972
Epoch 4/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3248
Pearson's r score with WordSim353: 0.39416965929977094
Epoch 5/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3211
Pearson's r score with WordSim353: 0.39503500234447125
The most similar words of "you"
TOP 1: your = 0.6509251594543457
TOP 2: i = 0.6414255499839783
TOP 3: we = 0.569475531578064
TOP 4: re = 0.5692735314369202
TOP 5: someone = 0.5565952658653259
--------------------------------------------------
The most similar words of "year"
TOP 1: earlier = 0.5841510891914368
TOP 2: month = 0.5817509293556213
TOP 3: period = 0.572060763835907
TOP 4: ago = 0.5633273720741272
TOP 5: last = 0.5298276543617249
--------------------------------------------------
The most similar words of "car"
TOP 1: cars = 0.6105561852455139
TOP 2: luxury = 0.5986034870147705
TOP 3: truck = 0.563898503780365
TOP 4: ford = 0.5133273005485535
TOP 5: auto = 0.5039612054824829
--------------------------------------------------
The most similar words of "toyota"
TOP 1: infiniti = 0.6949869394302368
TOP 2: honda = 0.6433103084564209
TOP 3: mazda = 0.6296555995941162
TOP 4: lexus = 0.6275536417961121
TOP 5: motor = 0.6175971627235413
--------------------------------------------------
The most similar words of "king" - "man" + "woman"
TOP 1: king = 0.7013309001922607
TOP 2: woman = 0.6033017039299011
TOP 3: burger = 0.4819037914276123
TOP 4: md = 0.46715104579925537
TOP 5: egg = 0.45565301179885864
--------------------------------------------------
The most similar words of "took" - "take" + "go"
TOP 1: took = 0.6121359467506409
TOP 2: go = 0.6096295714378357
TOP 3: stands = 0.4905664920806885
TOP 4: hammack = 0.45641931891441345
TOP 5: refuge = 0.4478893578052521
--------------------------------------------------
The most similar words of "cars" - "car" + "child"
TOP 1: child = 0.8112377524375916
TOP 2: women = 0.5026379823684692
TOP 3: cars = 0.4889577627182007
TOP 4: patients = 0.4796864092350006
TOP 5: custody = 0.47176921367645264
--------------------------------------------------
The most similar words of "better" - "good" + "bad"
TOP 1: bad = 0.6880872249603271
TOP 2: better = 0.510699987411499
TOP 3: involved = 0.45395123958587646
TOP 4: serious = 0.4192639887332916
TOP 5: hardest = 0.41672468185424805
--------------------------------------------------

CBOW のときは WordSim353 の評価が 0.25 前後だったことを考えると、今回の 0.39 前後という結果はかなり良く見える。 ただし、CBOW の実験ではコンテキストウィンドウサイズが 1 だったのに対し、上記の SGNS では 5 を使っている。 同じコンテキストウィンドウサイズの 1 に揃えると、評価指標は 0.28 前後まで落ちる。

学習が終わってから確認している類似語は、CBOW のときと同じでなかなか良い感じに見える。 一方で、類推語の方はほとんど上手くいっていない。 類推語を解ける位の単語埋め込みを学習するには、もっと大きなコーパスが必要なのだろうか?

ゼロから作るDeep Learning ❷ ―自然言語処理編

ゼロから作るDeep Learning ❷ ―自然言語処理編

  • 作者:斎藤 康毅
  • 発売日: 2018/07/21
  • メディア: 単行本(ソフトカバー)