CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: SQLAlchemy の生成する SQL をテストするパッケージを作ってみた

SQLAlchemy は Python でよく使われている O/R マッパーの一つ。 今回は、そんな SQLAlchemy が生成する SQL 文を確認するためのパッケージを作ってみたよ、という話。

具体的には、以下の sqlalchemy-profile というパッケージを作ってみた。 このエントリでは、なんでこんなものを作ったのかみたいな話をしてみる。

github.com

使った環境は次の通り。 ただし sqlalchemy-profile 自体はプラットフォームに依存せず Python 2.7, 3.3 ~ 3.6 に対応している。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

O/R マッパーについて

O/R マッパーというのは、プログラミング言語からリレーショナルデータベース (RDB) を良い感じに使うための機能ないしライブラリの総称。 プログラミング言語から RDB を操作するための SQL 文を直に扱ってしまうと、両者のパラダイムの違いから色々な問題が起こる。 この問題は、一般にインピーダンスミスマッチと呼ばれている。 そこで登場するのが O/R マッパーで、これを使うとプログラミング言語のオブジェクトを操作する形で RDB を操作できるようになる。

論よりソースということで、まずは SQLAlchemy の基本的な使い方から見てみよう。 その前に SQLAlchemy 自体をインストールしておく。

$ pip install sqlalchemy

そして次に示すのがサンプルコード。 ユーザ情報を模したモデルクラスを用意して、それを SQLite のオンメモリデータベースで永続化している。 この中には SQL 文が全く登場していないところがポイントとなる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

Base = declarative.declarative_base()


class User(Base):
    """SQLAlchemy のモデルクラス

    このクラスが RDB のテーブルと対応し、インスタンスはテーブルの一レコードに対応する
    ここではユーザの情報を格納するテーブルを模している"""
    __tablename__ = 'users'

    # テーブルの主キー
    id = Column(Integer, primary_key=True)
    # 名前を入れるカラム
    name = Column(Text, nullable=False)


def main():
    # データベースとの接続に使う情報
    # ここでは SQLite のオンメモリデータベースを使う
    # echo=True とすることで生成される SQL 文を確認できる
    engine = create_engine('sqlite:///', echo=True)
    # モデルの情報を元にテーブルを生成する
    Base.metadata.create_all(engine)
    # データベースとのセッションを確立する
    session_maker = sessionmaker(bind=engine)
    session = session_maker()

    # データベースのトランザクションを作る
    with session.begin(subtransactions=True):
        # レコードに対応するモデルのインスタンスを作る
        user = User(name='Alice')
        # そのインスタンスをセッションに追加する
        session.add(user)

    # トランザクションがコミットされてオブジェクトが RDB で永続化される

if __name__ == '__main__':
    main()

上記のサンプルコードでは生成される SQL 文を標準出力に表示するようにしている。 なので、実行するとこんな感じの出力が得られる。

2017-04-20 04:48:30,976 INFO sqlalchemy.engine.base.Engine SELECT CAST('test plain returns' AS VARCHAR(60)) AS anon_1
2017-04-20 04:48:30,976 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,978 INFO sqlalchemy.engine.base.Engine SELECT CAST('test unicode returns' AS VARCHAR(60)) AS anon_1
2017-04-20 04:48:30,978 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,980 INFO sqlalchemy.engine.base.Engine PRAGMA table_info("users")
2017-04-20 04:48:30,980 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,982 INFO sqlalchemy.engine.base.Engine 
CREATE TABLE users (
    id INTEGER NOT NULL, 
    name TEXT NOT NULL, 
    PRIMARY KEY (id)
)


2017-04-20 04:48:30,983 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,984 INFO sqlalchemy.engine.base.Engine COMMIT
2017-04-20 04:48:30,987 INFO sqlalchemy.engine.base.Engine BEGIN (implicit)
2017-04-20 04:48:30,989 INFO sqlalchemy.engine.base.Engine INSERT INTO users (name) VALUES (?)
2017-04-20 04:48:30,989 INFO sqlalchemy.engine.base.Engine ('Alice',)

たしかに Python のオブジェクトを使うだけで RDB を操作できた。便利。 ただし、上記で使った生成した SQL 文を出力する機能はデバッグ用途なので普段は無効にされる場合が多い。

SQL が隠蔽されることのメリットとデメリット

先ほど見た通り O/R マッパーを使うと Python のオブジェクトを通して RDB を操作できるようになる。 これにはインピーダンスミスマッチの解消という多大なメリットがある反面、生成される SQL が隠蔽されるというデメリットもある。

例えば、直接 SQL を書くならそんな非効率なクエリは組まないよね・・・というような内容も、気をつけていないと生成されうる。 これは、典型的には N + 1 問題とか。 それを防ぐには、これまでだとコードから生成される SQL 文を推測したり、あるいは先ほどのようにして実際に目で見て確かめていた。 慣れてくるとどんな SQL 文が発行されるか分かってくるのと、実際に目で見て確かめるのは手間なので大体は前者になっている。

ただ、パフォーマンスチューニングの世界では、推測する前に測定せよという格言もある。 実際に生成される SQL 文を、ユニットテストで確認できるようになっているべきなのでは、という考えに至った。 それが、今回作ったパッケージ sqlalchemy-profile のモチベーションになっている。

ただし、どんな SQL 文が生成されるかは SQLAlchemy のアルゴリズム次第なので、気をつけないとテストのメンテナンス性が低下する恐れはあると思う。 これは、SQLAlchemy のバージョン変更とか、些細なモデルの構造変更でテストを修正する手間がかかるかも、ということ。 とはいえ、それはそれで生成される SQL が変更されたことにちゃんと気づけるのは大事じゃないかという感じでいる。

sqlalchemy-profile について

やっと本題に入るんだけど、前述した問題を解消すべく sqlalchemy-profile という Python のパッケージを作ってみた。 これを使うことで、SQLAlchemy が生成する SQL 文を確かめることができる。

Python のパッケージリポジトリである PyPI にも登録しておいた。 pypi.python.org

インストールは Python のパッケージマネージャの PIP からできる。

$ pip install sqlalchemy-profile

使い方

ここからは sqlalchemy-profile の具体的な使い方について見ていく。 シンプルなのでサンプルコードをいくつか見れば、すぐに分かってもらえると思う。 ちなみに、トラッキングしている SQL 文は今のところ INSERT, UPDATE, SELECT, DELETE の四つ。

以下のサンプルコードでは、最も基本的な使い方を示している。 まず、プロファイラとなる StatementProfiler には SQLAlchemy のデータベースとの接続情報を渡す。 そして、プロファイルしている期間中に実行された SQL 文を記録する、というもの。 ユニットテストで利用することを意図しているので、サンプルコードも Python の unittest モジュールを使うものにしてみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import StatementProfiler


class Test_RawExecute(unittest.TestCase):

    def test(self):
        # データベースとの接続を確立する
        engine = create_engine('sqlite:///')
        connection = engine.connect()

        # データベースとの接続情報を渡してプロファイラをインスタンス化する
        profiler = StatementProfiler(engine)
        # プロファイルを開始する
        profiler.start()

        # SQLAlchemy を使って RDB を操作する
        # ここでは、サンプルコードをシンプルにする目的で低レイヤーな API を使っている
        connection.execute('SELECT 1')
        connection.execute('SELECT 2')

        # プロファイルを停止する
        profiler.stop()

        # 実行された SQL 文の内容を確認する
        assert profiler.count == 2
        assert profiler.select == 2


if __name__ == '__main__':
    unittest.main()

上記では、分かりやすくするためにあえて SQLAlchemy の直接 SQL 文を扱う低レイヤーな API を使っている。

上記を実行するとテストがパスする。

$ python profile101.py 
.
----------------------------------------------------------------------
Ran 1 test in 0.020s

OK

このとき assert しているところの数値を変更すると、当然だけどテストは失敗するようになる。 想定していた SQL 文の数と、実際に発行された数が一致しないため。

$ python profile101.py
F
======================================================================
FAIL: test (__main__.Test_RawExecute)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "profile101.py", line 32, in test
    assert profiler.count == 1
AssertionError

----------------------------------------------------------------------
Ran 1 test in 0.017s

FAILED (failures=1)

O/R マッピングと共に使う

先ほどの例では、分かりやすさのためにあえて SQLAlchemy の直接 SQL 文を扱う低レイヤーな API を使っていた。 もちろん sqlalchemy-profile は O/R マッピングをしたコードでも動作するし、使い方については何も変わらない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    """ユーザ情報を模したモデルクラス"""
    __tablename__ = 'users'

    # 主キー
    id = Column(Integer, primary_key=True)
    # 名前を格納するカラム
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        """テストが実行される前の下準備"""
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        """テストが実行された後の後始末"""
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        # 以下のユーザを模したインスタンスを一通り CRUD していく
        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        # UPDATE
        with session.begin(subtransactions=True):
            user.name = 'Bob'

        # SELECT
        session.query(_User).all()

        # DELETE
        with session.begin(subtransactions=True):
            session.delete(user)

        profiler.stop()

        # SQL 文は各一回ずつ実行されているはず
        assert profiler.count == 4
        assert profiler.insert == 1
        assert profiler.update == 1
        assert profiler.select == 1
        assert profiler.delete == 1


if __name__ == '__main__':
    unittest.main()

with ステートメントと共に使う

これまでの例ではプロファイリング期間を start() メソッドと stop() メソッドで制御したけど、これは with でも代用できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import StatementProfiler


class Test_WithStatement(unittest.TestCase):

    def test(self):
        engine = create_engine('sqlite:///')
        connection = engine.connect()

        # with ステートメントのスコープで実行された SQL 文を記録する
        with StatementProfiler(engine) as profiler:
            connection.execute('SELECT 1')
            connection.execute('SELECT 2')

        assert profiler.count == 2
        assert profiler.select == 2


if __name__ == '__main__':
    unittest.main()

デコレータと共に使う

with を使うのもめんどくさいなー、というときはテストメソッド自体をデコレータで修飾しちゃうような使い方もできる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import sqlprofile

ENGINE = create_engine('sqlite:///')


class Test_Decorator(unittest.TestCase):

    # ユニットテストのメソッドをデコレータで修飾する
    # メソッド内で実行されることが想定される SQL 文の数を指定する
    @sqlprofile(ENGINE, count=2, select=2)
    def test(self):
        connection = ENGINE.connect()

        connection.execute('SELECT 1')
        connection.execute('SELECT 2')


if __name__ == '__main__':
    unittest.main()

SQL 文の種類と順序まで確認したい

いやいや回数だけのアサーションとかアバウトすぎるでしょ、っていうときは StatementProfiler#sequence を使う。 これで INSERT, UPDATE, SELECT, DELETE が、どんな順番で実行されたかを確認できる。 中身は文字列で、それぞれの操作の頭文字が入っている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True)
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        # UPDATE
        with session.begin(subtransactions=True):
            user.name = 'Bob'

        # SELECT
        session.query(_User).all()

        # DELETE
        with session.begin(subtransactions=True):
            session.delete(user)

        profiler.stop()

        # [I]NSERT -> [U]PDATE -> [S]ELECT -> [D]ELETE
        assert profiler.sequence == 'IUSD'


if __name__ == '__main__':
    unittest.main()

もちろんデコレータの API でも使える。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import sqlprofile

ENGINE = create_engine('sqlite:///')


class Test_Decorator(unittest.TestCase):

    # SELECT -> SELECT = SS
    @sqlprofile(ENGINE, seq='SS')
    def test(self):
        connection = ENGINE.connect()

        connection.execute('SELECT 1')
        connection.execute('SELECT 2')


if __name__ == '__main__':
    unittest.main()

もっと厳密にアサーションしたい

いやいや SQL 文の構造までもっと調べたいよ、というときは StatementProfiler#statementsStatementProfiler#statements_with_parameters を使う。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True)
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        profiler.stop()

        assert profiler.count == 1
        assert profiler.insert == 1

        # 生の SQL 文を取得する
        print(profiler.statements)
        print(profiler.statements_with_parameters)


if __name__ == '__main__':
    unittest.main()

こんな感じ。

['INSERT INTO users (name) VALUES (?)']
[('INSERT INTO users (name) VALUES (?)', ('Alice',))]
.
----------------------------------------------------------------------
Ran 1 test in 0.019s

OK

こちらは、今のところデコレータの API では使えない。

まとめ

  • SQLAlchemy の生成する SQL 文を確認するための sqlalchemy-profile というパッケージを作ってみた
  • O/R マッピングをすると、生成される SQL 文をプログラマが把握しにくくなる
  • 非効率なクエリをコードや実行結果から目で見て確認するのは手間がかかる
  • sqlalchemy-profile を使うことで実行される SQL 文をユニットテストで確認できるようになる

もしかすると似たようなことができるパッケージが既にあるかも。

Python: 相関行列を計算してヒートマップを描いてみる

以前、このブログで相関係数について解説した記事を書いたことがある。 相関係数というのは、データセットのある次元とある次元の関連性を示すものだった。

blog.amedama.jp

この相関係数を、データセットの各次元ごとに計算したものを相関行列と呼ぶ。 データ分析の世界では、それぞれの次元の関連性を見るときに、この相関行列を計算することがある。 また、それを見やすくするためにヒートマップというグラフを用いて図示することが多い。

今回は Python を使って相関行列を計算すると共にヒートマップを描いてみることにした。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.5.3

下準備

今回は、相関行列の計算には NumPy を、グラフの描画には seaborn を使うのでインストールしておく。 最後の scikit-learn については、相関行列の計算に使うデータセットを読み込むためだけに使っている。

$ pip install seaborn numpy scikit-learn

相関行列を計算してみる

まずはヒートマップの描く以前に相関行列の計算から。 データセットにはみんな大好きアイリス (あやめ) データセットを用いる。 これは 150 行 4 次元の構造になっている。

>>> from sklearn import datasets
>>> dataset = datasets.load_iris()
>>> features = dataset.data
>>> features.shape
(150, 4)

相関行列の計算には NumPycorrcoef() 関数が使える。 この関数には、相関行列を計算したい次元をリストの形で渡す。 すごくベタに書くとしたら、こんな感じ。

>>> import numpy as np
>>> np.corrcoef([features[:, 0], features[:, 1], features[:, 2], features[:, 3]])
array([[ 1.        , -0.10936925,  0.87175416,  0.81795363],
       [-0.10936925,  1.        , -0.4205161 , -0.35654409],
       [ 0.87175416, -0.4205161 ,  1.        ,  0.9627571 ],
       [ 0.81795363, -0.35654409,  0.9627571 ,  1.        ]])

上記には、それぞれの次元ごとの相関係数が格納されている。 対角要素が全て 1 になっているのは、全く同じデータ同士の相関係数は 1 になるため。

ただ、実際には上記のようなベタ書きをする必要はない。 次元ごとにリストで、というのはようするに行と列を入れ替えれば良いということ。 つまり M 行 N 列のデータなら N 行 M 列に直して渡すことになる。

これは NumPy 行列なら transpose() メソッドで実現できる。

>>> features.transpose().shape
(4, 150)

ようするに、こうなる。

>>> np.corrcoef(features.transpose())
array([[ 1.        , -0.10936925,  0.87175416,  0.81795363],
       [-0.10936925,  1.        , -0.4205161 , -0.35654409],
       [ 0.87175416, -0.4205161 ,  1.        ,  0.9627571 ],
       [ 0.81795363, -0.35654409,  0.9627571 ,  1.        ]])

相関行列はこれで計算できた。

ヒートマップを描いてみる

続いては、先ほどの相関行列をヒートマップにしてみよう。 Python のグラフ描画ライブラリの seaborn には、あらかじめヒートマップを描くための API が用意されている。

次のサンプルコードではアイリスデータセットの相関行列をヒートマップで図示している。 それぞれの行の説明についてはコメントで補足している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn import datasets


def main():
    # アイリスデータセットを読み込む
    dataset = datasets.load_iris()

    features = dataset.data
    feature_names = dataset.feature_names

    # N 行 M 列を M 行 N 列に変換して相関行列を計算する
    correlation_matrix = np.corrcoef(features.transpose())

    # 相関行列のヒートマップを描く
    sns.heatmap(correlation_matrix, annot=True,
                xticklabels=feature_names,
                yticklabels=feature_names)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を実行すると、次のようなグラフが得られる。

f:id:momijiame:20170418225256p:plain

このヒートマップでは、正の相関が強いものほど赤く、負の相関が強いものほど青く図示されている。 相関がないものについては色が薄いことから白に近づくことになる。 上記を見ると、アイリスデータセットには相関の強い次元が多いことが分かる。

主成分分析した結果の相関行列でヒートマップを描いてみる

ここで一つ気になったことを試してみることにした。 主成分分析した結果を相関行列にしてヒートマップで描いてみる、というものだ。 主成分分析とは何ぞや、ということは以下のブログエントリで書いた。

blog.amedama.jp

理屈の上では、主成分分析した結果は互いに直交した次元になるため相関しないはず。 これを相関行列とヒートマップで確かめてみよう、ということ。

次のサンプルコードでは、アイリスデータセットを主成分分析している。 そして、分析した内容に対して相関行列を計算してヒートマップを描いた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.decomposition import PCA


def main():
    # アイリスデータセットを読み込む
    dataset = datasets.load_iris()

    features = dataset.data

    # 特徴量を主成分分析する
    pca = PCA()
    pca.fit(features)

    # 分析にもとづいて特徴量を主成分に変換する
    transformed_features = pca.fit_transform(features)

    # 主成分の相関行列を計算する
    correlation_matrix = np.corrcoef(transformed_features.transpose())

    # 主成分の相関行列をヒートマップで描く
    feature_names = ['PCA{0}'.format(i)
                     for i in range(features.shape[1])]
    sns.heatmap(correlation_matrix, annot=True,
                xticklabels=feature_names,
                yticklabels=feature_names)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を実行すると、次のようなグラフが得られる。

f:id:momijiame:20170418225930p:plain

見事に真っ白で、互いに相関が全然ないことが分かる。 理屈通りの結果になった。

まとめ

  • 二つの次元の関連性を調べるには相関係数を用いる
  • データセットに含まれる全ての次元で相関係数を計算したものを相関行列と呼ぶ
  • 相関行列はヒートマップというグラフで図示することが多い
  • 主成分分析した結果の相関行列は対角要素を覗いてゼロになる

めでたしめでたし。

Python: scikit-learn で主成分分析 (PCA) してみる

主成分分析 (PCA) は、主にデータ分析や統計の世界で使われる道具の一つ。 データセットに含まれる次元が多いと、データ分析をするにせよ機械学習をするにせよ分かりにくさが増える。 そんなとき、主成分分析を使えば取り扱う必要のある次元を圧縮 (削減) できる。 ただし、ここでいう圧縮というのは非可逆なもので、いくらか失われる情報は出てくる。 今回は、そんな主成分分析を Python の scikit-learn というライブラリを使って試してみることにした。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

下準備

あらかじめ、今回使う Python のパッケージを pip でインストールしておく。

$ pip install matplotlib scipy scikit-learn

主成分分析の考え方

前述した通り、主成分分析はデータセットの次元を圧縮 (削減) するのに用いる。 ただし、実は元のデータセットと分析結果で次元数を変えないようにすることもできる。 それじゃあ圧縮できていないじゃないかという話になるんだけど、実は分析結果では次元ごとの性質が異なっている。 これは、例えるなら「すごく重要な次元・それなりに重要な次元・あんまり重要じゃない次元」と分かれているような感じ。 そして、その中から重要な次元をいくつかピックアップして使えば、次元の数が減るというわけ。 もちろん、そのとき選ばれなかった「あんまり重要じゃない次元」に含まれていた情報は失われてしまう。

では、主成分分析ではどのような基準で次元の重要さを決めるのだろうか。 これは、データの分散が大きな次元ほど、より多くの情報を含んでいると考える。 分散というのは、データのバラつきの大きさを表す統計量なので、ようするに値がバラけている方が価値が大きいと捉える。 分散が小さいというのは、ようするにどの値も似たり寄ったりで差異を見出すのが難しいということ。 それに対し、分散が大きければ値ごとの違いも見つけやすくなる。

例えば、次のような x 次元と y 次元から成る、二次元のデータを考えてみよう。 この中には (1, 2), (2, 4), (3, 6) という三点の要素が含まれる。

f:id:momijiame:20170402110001p:plain

ここで x 次元の標本分散は  \frac{2}{3} で、y 次元の標本分散は  \frac{8}{3} になる。 主成分分析の考え方でいくと y 次元の方が分散が大きいので、より重要といえる。 ただ、上記のデータは二つの次元が相関しているようだ。 相関しているということは、似たような情報を含む次元が二つある、とも捉えることができる。

では、上記で相関に沿って新しい次元を作ってみたら、どうなるだろうか?

f:id:momijiame:20170402121107p:plain

値の間隔はピタゴラスの定理から  \sqrt{5} となることが分かる。 これは x 次元の間隔である 1 や y 次元の間隔である 2 よりも大きい。

f:id:momijiame:20170402121423p:plain

間隔が大きいということは分散も大きくなることが分かる。

続いては、先ほどの相関に沿って作った次元とは直交する軸でさらに新しい次元を作ってみよう。

f:id:momijiame:20170402121737p:plain

今度は、新しい次元からそれぞれの要素を見てみよう。 このとき、全ての要素は同じ場所にいるので間隔は 0 になっている。 つまり、分散も 0 なので、この次元には全然情報が含まれていないことになる。

上記の作業によって、情報がたくさん含まれる次元と、全く含まれない次元に分けることができた。 あとは、最初に作った情報がたくさん含まれる次元だけを使えば、二次元を一次元に圧縮できたことになる。 実は、これこそ正に主成分分析でしている作業を表している。

実際に試してみる

やっていることの概要は分かったので、次は実際にその通りになるのか試してみよう。 データセットとしては、まずは先ほどの三点をそのまま使ってみる。

次のサンプルコードでは、相関した三点のデータを主成分分析している。 そして、元のデータと分析結果を散布図にした。 また、分析結果の各次元の寄与率というものも出力している。 scikit-learn では

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA


def main():
    # y = 2x
    features = np.array([[1, 2], [2, 4], [3, 6]])

    # グラフ描画サイズを設定する
    plt.figure(figsize=(12, 4))

    # 元データをプロットする
    plt.subplot(1, 2, 1)
    plt.scatter(features[:, 0], features[:, 1])
    plt.title('origin')
    plt.xlabel('x')
    plt.ylabel('y')

    # 主成分分析する
    pca = PCA()
    pca.fit(features)

    # 分析結果を元にデータセットを主成分に変換する
    transformed = pca.fit_transform(features)

    # 主成分をプロットする
    plt.subplot(1, 2, 2)
    plt.scatter(transformed[:, 0], transformed[:, 1])
    plt.title('principal component')
    plt.xlabel('pc1')
    plt.ylabel('pc2')

    # 主成分の次元ごとの寄与率を出力する
    print(pca.explained_variance_ratio_)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

それでは、上記のサンプルコードを実行してみよう。 出力されたリストは、分析結果の各次元の寄与率を表している。

$ python pca.py 
[ 1.  0.]

寄与率というのは、前述した「各次元の重要度」を表したもの。 その次元に元のデータからどれだけの割合で情報が含まれているかで、全てを足すと 1 になるように作られている。 つまり、主成分分析をした結果から全ての次元を使えば、元のデータセットから情報の損失は起こらない。 ただし、それだと次元も圧縮できないことになる。

先ほどの出力結果を見ると、最初の次元に寄与率が全て集中している。 つまり、最初の次元だけに全ての情報が含まれていることになる。 これは、先ほど主成分分析の概要を図示したときに得られた結論と一致している。

では、上記をグラフでも確認してみよう。 次のグラフは、主成分分析の前後を散布図で比べたもの。 左が元データで、右が分析結果となっている。

f:id:momijiame:20170402123229p:plain

見て分かる通り、先ほど図示した内容と一致している。 ちなみに、主成分分析では分析結果として得られた次元のことを第 n 主成分と呼ぶ。 例えば、最初に作った次元なら第一主成分、次に作った次元なら第二主成分という風になる。 今回の例では第一主成分に必要な情報が全て集中した。

アイリスデータセットを主成分分析してみる

次はもうちょっとだけそれっぽいデータを使ってみることにする。 みんな大好きアイリスデータセットは、あやめという花の特徴量と品種を含んでいる。 この特徴量は四次元なので、別々のグラフに分けたりしないと本来は可視化できない。 今回は、主成分分析を使って二次元に圧縮して可視化してみることにしよう。

次のサンプルコードではアイリスデータセットの次元を主成分分析している。 そして、分析結果から第二主成分までを取り出して散布図に可視化した。 また、同時に寄与率と累積寄与率を出力するようにした。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn import datasets


def main():
    dataset = datasets.load_iris()

    features = dataset.data
    targets = dataset.target

    # 主成分分析する
    pca = PCA(n_components=2)
    pca.fit(features)

    # 分析結果を元にデータセットを主成分に変換する
    transformed = pca.fit_transform(features)

    # 主成分をプロットする
    for label in np.unique(targets):
        plt.scatter(transformed[targets == label, 0],
                    transformed[targets == label, 1])
    plt.title('principal component')
    plt.xlabel('pc1')
    plt.ylabel('pc2')

    # 主成分の寄与率を出力する
    print('各次元の寄与率: {0}'.format(pca.explained_variance_ratio_))
    print('累積寄与率: {0}'.format(sum(pca.explained_variance_ratio_)))

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

それでは、上記を実行してみよう。 コンソールには寄与率と累積寄与率が表示される。

$ python pcairis.py 
各次元の寄与率: [ 0.92461621  0.05301557]
累積寄与率: 0.9776317750248034

寄与率は先ほど説明した通りで、ここでの累積寄与率は使うことにした次元の寄与率を足したもの。 ようするに、今回の場合なら第一主成分と第二主成分の寄与率を足したものになっている。 累積寄与率は約 0.97 で、ようするに第二主成分までで元のデータの 97% が表現できていることが分かる。

同時に、次のような散布図が表示される。 これは、第一主成分と第二主成分を x 軸と y 軸に取って散布図にしたもの。 点の色の違いは品種を表している。

f:id:momijiame:20170402125421p:plain

本来なら四次元の特徴量で複数の散布図になるところを、主成分分析を使うことで一つの散布図にできた。

まとめ

今回は Python の scikit-learn を使って主成分分析について学んだ。 データセットに含まれる次元が多いと、データ分析なら分かりにくいし、機械学習なら計算量が増えてしまう。 そんなとき主成分分析を使えば、重要さが異なる新たな次元を含んだデータが分析結果として得られる。 その中から、重要なものをいくつかピックアップして使えば、データの損失を最小限に抑えて次元を減らすことができる。

参考文献

実践 機械学習システム

実践 機械学習システム

Python: ソケットプログラミングのアーキテクチャパターン

今回はソケットプログラミングについて。 ソケットというのは Unix 系のシステムでネットワークを扱うとしたら、ほぼ必ずといっていいほど使われているもの。 ホスト間の通信やホスト内での IPC など、ネットワークを抽象化したインターフェースになっている。

そんな幅広く使われているソケットだけど、取り扱うときには色々なアーキテクチャパターンが考えられる。 また、比較的低レイヤーな部分なので、効率的に扱うためにはシステムコールなどの、割りと OS レベルに近い知識も必要になってくる。 ここらへんの話は、体系的に語られているドキュメントが少ないし、あっても鈍器のような本だったりする。 そこで、今回はそれらについてざっくりと見ていくことにした。

尚、今回はプログラミング言語として Python を使うけど、何もこれは特定の言語に限った話ではない。 どんな言語を使うにしても、あるいは表面上は抽象化されたインターフェースで隠蔽されていても、内部的にはソケットが使われている。 例えば Java サーブレットや Ruby on Rails で Web アプリケーションを書くにしても、それが動くサーバの通信部分はソケットで書かれていることだろう。

動作確認に使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

もくじ

ブロッキングとノンブロッキングについて

まず、ソケットを扱うには大きく分けて「ブロッキングで使うか・ノンブロッキングで使うか」を選ぶことになる。 その中でも基本となる使い方はブロッキングで、こちらの方が逐次的なプログラミングモデルとなりやすいので理解も早い。 ではノンブロッキングにはどんなメリットがあるかというと、こちらは通信相手が増えたときのパフォーマンス面 (スケーラビリティ) で優れている。

このエントリでは、ソケットの扱い方をブロッキング・ノンブロッキングと分けた上で、それぞれにどんなアーキテクチャパターンが考えられるか見ていく。 しかし、その前にまずは事前知識としてソケットにおけるブロッキング・ノンブロッキングという概念自体の説明から入ろう。

まず、ソケットというオブジェクトに対してはデータの読み込みや書き込みを指示できる。 読み込まれるデータは通信相手から送られてきたもので、書き込まれたデータは通信相手に送り届けられる。 しかし、それらのデータを読み書きする指示は即座に完了するわけではない。 具体的には、ソケットには読み書きができる状態とできない状態があるためだ。 読み書きができないというのは、わんこそばで例えると口の中でもぐもぐしている最中で、次のそばを口に入れられない状態を指す。

では、もしも読み込みや書き込みができない状態にあるソケットに対して、その指示を出したらどう振る舞うのだろうか。 ブロッキング・ノンブロッキングの違いというのは、正にこの「どう振る舞うか」の違いを指す。 ブロッキングというのは、読み書きができる状態になるまで、じっとそのまま待つことを意味している。 それに対して、ノンブロッキングは読み書きができない状態にあるときエラーを出してすぐに処理を終了する。

これで、ソケットのブロッキング・ノンブロッキングの違いについて説明できた。

ソケットをブロッキングで扱う場合

さて、前フリが長くなったけど、ここからは具体的なアーキテクチャパターンを見ていくことにしよう。 初めは、基本的な使い方であるソケットをブロッキングで扱う場合から。

今回、サンプルコードとして題材にするのはエコーサーバにした。 エコーサーバというのは、クライアントから送られてきたデータを、そのままオウム返しでクライアントに送り返すサーバのことをいう。

実装については IPv4 のループバックアドレスを使って TCP:37564 ポートでクライアントからの接続を待ち受けるようにした。 ループバックアドレスとは何か、みたいな TCP/IP 的な概念についての説明は省く。 これは、今回の主題として扱うアーキテクチャパターンという範疇からは、ちょっと外れるため。

あと、クライアントサイドについても自分で書いても良いんだけど、今回はありものを使うことにした。 ここでは netcat というツールを使うことにしよう。 netcatHomebrew を使ってインストールする。

$ brew install netcat

Homebrew が入っていないときは入れる感じで。

シングルスレッド

まずは、ソケットがブロッキングで、それをシングルスレッドで扱う場合から考えてみよう。 これが、最もシンプルなパターンといえるはず。

早速だけどサンプルコードを以下に示す。 それぞれの処理の内容はコメントで補足している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    # IPv4/TCP のソケットを用意する
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # 'Address already in use' の回避策 (必須ではない)
    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    # 待ち受けるアドレスとポートを指定する
    # もし任意のアドレスで Listen したいときは '' を使う
    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    # クライアントをいくつまでキューイングするか
    serversocket.listen(128)

    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            # クライアントソケットから指定したバッファバイト数だけデータを受け取る
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except OSError:
                break

            # 受信したデータの長さが 0 ならクライアントからの切断を表す
            if len(message) == 0:
                break

            # 受信したデータをそのまま送り返す (エコー)
            sent_message = message
            while True:
                # 送信できたバイト数が返ってくる
                sent_len = clientsocket.send(sent_message)
                # 全て送れたら完了
                if sent_len == len(sent_message):
                    break
                # 送れなかった分をもう一度送る
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        # 後始末
        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

サーバにおけるソケットプログラミングの基本的な流れは次の通り。

  • ソケットを作る (socket)
  • 待ち受けるアドレスとポートを指定する (bind)
  • 接続キューの長さを指定して接続を待ち受ける (listen)
  • 接続してきたクライアントからソケットを取得する (accept)
  • 取得したクライアントのソケットに対して読み書きする (send/recv)

このパターンでは、上記の一連の処理を一つのスレッドでこなしていく。

それではサンプルコードを実行してみよう。 これで、エコーサーバが起動する。 とはいえ、クライアントが接続しない限り特に何も表示されることはない。

$ python singlethread.py

続いて、別のターミナルを開いて netcat を実行しよう。 次のようにすると、先ほど起動したエコーサーバに接続できる。

$ nc localhost 37564

すると、エコーサーバを起動したターミナルに、クライアントからの接続を表す表示が出るはず。

$ python singlethread.py
New client: 127.0.0.1:63917

さらに netcat のターミナルで文字列を入力して Enter すると、同じ内容がまた表示される。 これは、送信した内容がエコーサーバからオウム返しで返ってきたことを意味する。

$ nc localhost 37564
hogehoge
hogehoge

エコーサーバのターミナルを見ると、送受信した内容が表示されている。

$ python singlethread.py
New client: 127.0.0.1:63917
Recv: b'hogehoge\n'
Send: b'hogehoge\n'

netcat は Ctrl キーと C キーを一緒に押すことで終了できる。 これでサーバとの接続も切断される。

$ nc localhost 37564
hogehoge
hogehoge
^C

サーバの方にもクライアントとの接続が切れた旨が表示された。

$ python singlethread.py
New client: 127.0.0.1:63917
Recv: b'hogehoge\n'
Send: b'hogehoge\n'
Recv: b''
Bye-Bye: 127.0.0.1:63917

ここまで見た限り、このパターンで何の問題も無いように見える。 しかし、クライアントを二つにすると問題点が分かってくる。

サーバを一旦終了して、もう一度起動し直そう。 ちなみにサーバについても Ctrl-C で終了できる。

$ python singlethread.py

そして、改めて別のターミナルから netcat で接続する。

$ nc localhost 37564

クライアントが一つなら、サーバは接続を正常に受け付ける。

$ python singlethread.py
New client: 127.0.0.1:49746

では、さらにもう一つターミナルを開いて netcat で接続してみると、どうだろうか?

$ nc localhost 37564

今度は、サーバ側に接続を受け付けたメッセージが表示されない。

$ python singlethread.py
New client: 127.0.0.1:49746

そう、ソケットをブロッキングかつシングルスレッドで扱う場合、二つ以上のクライアントを同時に上手くさばくことができない。 なぜなら、唯一のスレッドは最初のクライアントからデータを読み書きする仕事に従事しているからだ。

先ほどのサンプルコードでいえば以下、クライアントからの新たなデータの到来を待ち続けて (ブロックして) いることだろう。

message = clientsocket.recv(1024)

唯一のスレッドが一つのクライアントにかかりきりなので、以下の別のクライアントからの接続を受け付ける処理は実行されない。 クライアントからの接続は、ソケットの接続キューに積まれたまま放置プレイを食らう。

clientsocket, (client_address, client_port) = serversocket.accept()

今かかりきりになっている相手との通信が終わるまで、別のクライアントは受け付けることができないというわけ。

マルチスレッド

ソケットをブロッキングで扱うとき、シングルスレッドでは二つ以上のクライアントを上手くさばけないことが分かった。 そこで、次はクライアントを処理するスレッドを複数用意してマルチスレッドにしてみよう。

先ほどの内容に手を加えて、マルチスレッドにしたサンプルコードは次の通り。 先ほどと同じ処理についてはコメントを省いて、新たに追加したり変更したところにコメントを書いている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import threading


def client_handler(clientsocket, client_address, client_port):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        try:
            message = clientsocket.recv(1024)
            print('Recv: {0} from {1}:{2}'.format(message,
                                                  client_address,
                                                  client_port))
        except OSError:
            break

        if len(message) == 0:
            break

        sent_message = message
        while True:
            sent_len = clientsocket.send(sent_message)
            if sent_len == len(sent_message):
                break
            sent_message = sent_message[sent_len:]
        print('Send: {0} to {1}:{2}'.format(message,
                                            client_address,
                                            client_port))

    clientsocket.close()
    print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        # 接続してきたクライアントを処理するスレッドを用意する
        client_thread = threading.Thread(target=client_handler,
                                         args=(clientsocket,
                                               client_address,
                                               client_port))
        # 親 (メイン) スレッドが死んだら子も道連れにする
        client_thread.daemon = True
        # スレッドを起動する
        client_thread.start()


if __name__ == '__main__':
    main()

先ほどとの違いは、クライアントとの接続に対してスレッドが一対一で生成されるところだ。 プログラムが起動された直後に生成されるメインスレッドは、クライアントからの接続を受け付ける仕事だけに専念している。 実際に受け付けたクライアントとの接続の処理は、新たに生成した子スレッドに任せるわけだ。

では、上記サンプルコードの動作を確認してみよう。 まずはエコーサーバを起動する。

$ python multithread.py

そして、二つのターミナルからエコーサーバに接続してみよう。

$ nc localhost 37564

すると、今度は二つのクライアントから接続を受け付けた旨が表示された。

$ python multithread.py
New client: 127.0.0.1:51027
New client: 127.0.0.1:51028

それぞれのクライアントのターミナルで文字列を入力すると、ちゃんとエコーバックされるし上手く動いている。

$ python multithread.py
New client: 127.0.0.1:51027
New client: 127.0.0.1:51028
Recv: b'hogehoge\n' from 127.0.0.1:51027
Send: b'hogehoge\n' to 127.0.0.1:51027
Recv: b'hogehoge\n' from 127.0.0.1:51028
Send: b'hogehoge\n' to 127.0.0.1:51028

マルチスレッド (スレッドプール)

先ほどの例では、クライアントを処理する部分をマルチスレッド化することで、二つ以上のクライアントを同時にさばけるようになった。 しかし、実は先ほどのやり方ではクライアントの接続数がどんどん増えていくと問題になってくることがある。 それは、メモリの使用量とコンテキストスイッチにかかるコストの増加だ。

スレッドというのは、新たに作ろうとするとそれ用のコンテキストを必要とする。 この、コンテキストというのは各スレッドの状態を保持しておくために必要なメモリに他ならない。 スレッドあたりのコンテキストのサイズは状態や実装に依存するので、これくらいとはなかなか言いづらいものがある。 とはいえ、一つ一つが小さくてもクライアントの接続数が増えれば決してばかにできないサイズになってくる。

また、コンテキストスイッチというのは、CPU が処理しているスレッドを OS が途中で切り替える作業のことをいう。 まず、CPU というのは同時に処理できるスレッドの数が、あらかじめ製品ごとに決まっている。 例えば、今売られている Intel や AMD の x86-64 アーキテクチャの CPU を例に挙げてみよう。 この場合は、物理コアあたり 1 または 2 スレッドである場合が多い。 つまり、同時に処理できるスレッドには機械的な上限がある。 ちなみに、物理コアあたり同時 2 スレッドの製品については、OS からは論理コアが 2 つあるように扱われる。

にも関わらず、実のところ私たちは普段からそれよりも多くのスレッドを同時に起動して扱っている。 なぜそんなことができるかというと、CPU が実行するスレッドを、OS が途中で別のスレッドに入れ替えているためだ。 この入れ替えは、ごく短時間で行われているので、見かけ上はたくさんのスレッドが同時に実行できているかのように見える。

しかしながら、この入れ替え作業には短時間ながらもちろん時間がかかる。 そして、CPU で同時に処理できるスレッドの数に対して、OS が扱うスレッドの数が増えてくると、その頻度も上がる。 これによって、切り替え作業に要する時間が増えて、だんだんと CPU が非効率な使われ方をしてしまうことがある。

先ほどのサンプルコードでは、まさに上記の二つが問題となる。 なぜなら、生成するスレッドの数に上限を設けていないからだ。 上限がないと、クライアントの数に応じてどんどんスレッドが増え続ける。 結果として、メモリを消費すると共に CPU が非効率な使われ方をしてしまう。

スレッドが多すぎるとまずいという問題点が分かったところで、次はスレッドを生成する数に上限を設けてみよう。 具体的には、あらかじめスレッドを既定数だけ生成して、それらに仕事を割り振る形にする。 この手法は、一般にスレッドプールと呼ばれている。 スレッドプールの中にいる各スレッドは、ワーカースレッドと呼ばれる。

次のサンプルコードはスレッドプールを使った実装になっている。 生成されたワーカーがサーバソケットからの接続を奪い合う形になる。 これなら、あらかじめ決まった数を超えるスレッドは生成されないので、前述したような問題は発生しない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import threading


def worker_thread(serversocket):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカスレッド同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # サーバソケットを渡してワーカースレッドを起動する
    NUMBER_OF_THREADS = 10
    for _ in range(NUMBER_OF_THREADS):
        thread = threading.Thread(target=worker_thread, args=(serversocket, ))
        thread.daemon = True
        thread.start()

    while True:
        # メインスレッドは遊ばせておく (ハンドラを処理させても構わない)
        time.sleep(1)


if __name__ == '__main__':
    main()

ただし、上記にも注意点がある。 それは、あらかじめプールしたスレッド数を超えてクライアントをさばくことができない、という点だ。 プール数を超えた接続があったときは、他のクライアントとの接続が切れるまで、ソケットは処理されないままキューに積まれてしまう。

実行結果については、先ほどと変わらないので省略する。

ちなみに、蛇足だけど Mac OS X に関してはプロセスごとに生成できるスレッド数があらかじめ制限されているようだ。 例えば、次のようなサンプルコードを用意して、たくさんのスレッドを起動してみる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import threading
import time


def loop():
    """各スレッドは特に何もしない"""
    while True:
        time.sleep(1)


def main():
    # ネイティブスレッドをたくさん起動してみる
    for _ in range(10000):
        t = threading.Thread(target=loop)
        t.daemon = True
        t.start()
        # 動作中のスレッド数を出力する
        print(threading.active_count())


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると 2049 個目のスレッドを起動するところで例外になった。

$ python toomanythreads.py
...(省略)...
2046
2047
2048
Traceback (most recent call last):
  File "toomanythreads.py", line 25, in <module>
    main()
  File "toomanythreads.py", line 19, in main
    t.start()
  File "/Users/amedama/.pyenv/versions/3.6.1/lib/python3.6/threading.py", line 846, in start
    _start_new_thread(self._bootstrap, ())
RuntimeError: can't start new thread

このリミットは、どうやら次のカーネルパラメータでかかっているらしい。

$ sysctl -n kern.num_taskthreads
2048

Mac OS X においては、スレッドの生成数に上限を設けないと、メモリの枯渇などを待つことなくサーバが突然死することになる。

マルチプロセス (プロセスプール)

先ほどの例では、スレッドプールを使うことで同時に処理できるクライアントの数を増やしつつ、リソースの消費を抑えることができた。 しかしながら、実はここまでの例では、パフォーマンスを求める上で、まだ使い切れていないリソースが残っている。 それは、複数の CPU コアだ。

実は Python の標準的な処理系である CPython には、とある制限が存在している。 それは、一つのプロセスで同時に実行できるスレッドの数が一つだけ、というもの。 一般的に、これはグローバルインタプリタロック (Global Interpreter Lock, GIL) と呼ばれている。 この制限は、Python/C API で書かれた拡張モジュールを Python から扱いやすくするために存在する。

この GIL がある処理系では、CPU に複数の論理コアがあったとしても、同時に使われるのが一つだけに制限されてしまう。 つまり、先ほどの例では、マルチスレッドにしても実際に使われている CPU 論理コアは同時に一つだけだった。 ようするに、複数のスレッドを OS が一つの CPU 論理コアの上で切り替え (コンテキストスイッチ) ながら動作する。

ちなみに、コンピュータの処理には、大きく分けて入出力 (I/O) が主体になるものと計算 (CPU) が主体になるものがある。 CPU が主体となるのは、例えば科学計算のようなもの。 それに対して、今回の例であるエコーサーバのようなプログラムは、CPU の処理がほとんどない。 処理時間のほとんどを I/O の待ちに使っていることから、入出力が主体のプログラムといえる。

つまり、今回取り扱うエコーサーバは CPU の処理能力がボトルネックになりにくい。 ようするに、あえて CPU の能力を最大限引き出すようなコードにする必然性は薄い。 しかしながら、アーキテクチャパターンの紹介という意味では重要だと思う。 なので、その方法についても記述しておこう。

その方法というのは、具体的にはプログラムを複数のプロセスで動かす。 前述した通り GIL はプロセスあたりの同時実行スレッド数を一つに制限するというものだった。 なので、プロセスを複数立ち上げてしまえば、同時実行スレッド数をプログラム全体で見たときに増やすことができる。

次のサンプルコードでは、スレッドの代わりにプロセスを複数起動 (マルチプロセス) している。 Python でマルチプロセスを扱う方法としては、例えば標準ライブラリの multiprocessing モジュールがある。 起動するプロセスの数は CPU の論理コア数と同じにした。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import multiprocessing


def worker_process(serversocket):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカープロセス同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # プロセス数は CPU のコア数前後に合わせると良い
    NUMBER_OF_PROCESS = multiprocessing.cpu_count()
    # サーバソケットを渡してワーカープロセスを起動する
    for _ in range(NUMBER_OF_PROCESS):
        process = multiprocessing.Process(target=worker_process,
                                          args=(serversocket, ))
        # デーモンプロセスにする (親プロセスが死んだら子も道連れに死ぬ)
        process.daemon = True
        # プロセスを起動する
        process.start()

    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()

マルチプロセスを使うときの注意点についても見ていこう。 これは、マルチスレッドの場合とほとんど変わらない。 つまり、プロセスを作るにもコンテキストが必要であり、コンテキストスイッチが起こるということだ。 そのため、同時に起動するプロセス数は制限してやる必要がある。 しかも、必要なリソースの量はスレッドに比べるとずっと多い。 そのため、一般的には起動するプロセス数は CPU の論理コアの数前後が良いとされている。

また、マルチプロセス固有の問題としては、プロセス間での値の共有が挙げられる。 マルチスレッドであれば、同一プロセス内でメモリ空間を共有していた。 なので、例えばグローバル変数の値をスレッド間で情報を共有する手段にもできた。 それに対し、マルチプロセスではプロセス同士でメモリ空間は共有していない。 そのため、別の何らかの IPC を使って情報をやり取りしなければいけない。

尚、繰り返しになるけどマルチプロセスにする必要があるのは、あくまで GIL があることに由来している。 もし、これがない処理系やプログラミング言語を使うなら、単にマルチスレッドにするだけで大丈夫。 ちゃんと CPU のコアを使い切ってくれるはず。

マルチプロセス・マルチスレッド

先ほどの例では、プロセスを複数立ち上げることで CPU の能力を使い切れるようにした。 ただし、マルチプロセスではあるものの、それぞれのプロセスでは一つのスレッドしか動かしていなかった。 そこで、次は各プロセスの中をマルチスレッドにしてみよう。 これなら、マルチプロセスかつマルチスレッドになって CPU と I/O の両方を上手く使い切れるはず。

次のサンプルコードでは、各ワーカープロセスの中でさらにスレッドプールを動かしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import multiprocessing
import threading


def worker_thread(serversocket):
    """クライアントとの接続を処理するハンドラ (スレッド)"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカースレッド同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def worker_process(serversocket):
    """クライアントとの接続を受け付けるハンドラ (プロセス)"""

    # サーバソケットを渡してワーカースレッドを起動する
    NUMBER_OF_THREADS = 10
    for _ in range(NUMBER_OF_THREADS):
        thread = threading.Thread(target=worker_thread, args=(serversocket, ))
        thread.start()
        # ここではワーカーをデーモンスレッドにする必要はない (死ぬときはプロセスごと逝くので)

    while True:
        # ワーカープロセスのメインスレッドは遊ばせておく
        time.sleep(1)


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    NUMBER_OF_PROCESSES = multiprocessing.cpu_count()
    for _ in range(NUMBER_OF_PROCESSES):
        process = multiprocessing.Process(target=worker_process,
                                          args=(serversocket, ))
        process.daemon = True
        process.start()

    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()

実行結果については、これまで変わらないので省略する。

ひとまず、ソケットをブロッキングで扱う場合のアーキテクチャパターンについては、これでおわり。

ソケットをノンブロッキングで扱う場合

続いては、ソケットをノンブロッキングで扱う場合について見ていこう。 前述した通り、ソケットをノンブロッキングで扱うと、読み書きなどを指示してもブロックが起きない。 その代わり、もし読み書きの準備ができていないときはその旨がエラーで返ってくる。

とりあえずノンブロッキングにしてみよう

最初に、ノンブロッキングなソケットをブロッキングっぽく扱ったときの挙動を確認しておこう。 具体的に、どんなことが起こるのだろうか?

次のサンプルコードは、最初に示したシングルスレッドのサーバに一行だけ手を加えている。 それは、サーバソケットを setblocking() メソッドでノンブロッキングモードにしているところだ。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # ソケットをノンブロッキングモードにする
    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、すぐに例外が出て終了してしまう。

$ python nonblocking.py
Traceback (most recent call last):
  File "nonblocking.py", line 48, in <module>
    main()
  File "nonblocking.py", line 22, in main
    clientsocket, (client_address, client_port) = serversocket.accept()
  File "/Users/amedama/.pyenv/versions/3.6.1/lib/python3.6/socket.py", line 205, in accept
    fd, addr = self._accept()
BlockingIOError: [Errno 35] Resource temporarily unavailable

上記の BlockingIOError という例外は、まだ準備が整っていないにも関わらず指示が出されたときに上がる。 今回の場合だと、クライアントからの接続が到着していないのに accept() メソッドを呼び出している。 ブロッキングモードのソケットなら、そのまま到着するまで待ってくれていた。 それに対し、ノンブロッキングモードでは呼び出した時点で到着していないなら即座に例外となってしまう。 正に、これがブロッキングとノンブロッキングの挙動の違い。

準備が整うまで待つ (ビジーループ)

先ほどの例で分かるように、ソケットをノンブロッキングで使うとブロッキングとは使い勝手が異なっている。 具体的には、ソケットの準備が整うのを勝手に待ってくれるわけではないので、自分で意図的に待たなければいけない。

では、どのようにすれば待つことができるだろうか。 一つのやり方としては、エラーが出なくなるまで定期的に実行してみる方法が考えられる。 この、何度も自分から試しに行くやり方はポーリングと呼ばれる。 その中でも、それぞれの試行間隔を全く空けないものはビジーループという。

次のサンプルコードではノンブロッキングなソケットをビジーループで待ちながら処理している。 ただし、あらかじめ言っておくと、このやり方は間違っている。 ソケットをノンブロッキングで扱うとき、こんなソースコードは書いちゃいけない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        try:
            clientsocket, (client_address, client_port) = serversocket.accept()
        except (BlockingIOError, socket.error):
            # まだソケットの準備が整っていない
            continue

        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except (BlockingIOError,  socket.error):
                # まだソケットの準備が整っていない
                continue
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                try:
                    sent_len = clientsocket.send(sent_message)
                except (BlockingIOError,  socket.error):
                    # まだソケットの準備が整っていない
                    continue
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

上記のサンプルコードは一応動作するものの、複数のクライアントを処理することができない。 それに、ビジーループを使っているとプロセスの CPU 使用率が 100% になってしまう。 繰り返しになるけど、ソケットをノンブロッキングで扱うとき、こんなソースコードは書いちゃだめ。

準備が整うまで待つ (イベントループ)

ビジーループでは色々と難しいことが分かったところで、次は実用的に待つ方法を見ていこう。 これには、イベントループや I/O 多重化と呼ばれる手法というかシステムコールを用いる。 システムコールというのは OS のカーネルに備わっている API のことだ。 ユーザランドのプログラムは、このシステムコールを呼び出すことで OS の機能が利用できる。

システムコールの中には、ソケットの状態を監視して、変更されたときにそれを通知してくれるものがある。 より正確には、監視できるものはファイルやソケットに汎用的に割り当てられるファイルディスクリプタだ。

イベントループにはいくつかの種類があるものの、ここでは古典的な select(2) を使うやり方を見ていく。 次のサンプルコードは、エコーサーバを select(2) システムコールで実装したもの。 ただし、先に断っておくと、これは実装している機能の割にコード量が多いし、逐次的でないから読みにくいと思う。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import select


# 読み取りが可能になるまで待っているソケットと、可能になったときに呼び出されるハンドラ・引数の対応を持つ
read_waiters = {}
# 書き込みが可能になるまで待っているソケットと、可能になったときに呼び出されるハンドラ・引数の対応を持つ
write_waiters = {}
# 接続してきたクライアントとの接続情報を格納する
connections = {}


def accept_handler(serversocket):
    """サーバソケットが読み取り可能になったとき呼ばれるハンドラ"""
    # 準備ができているので、すぐに accept() できる
    clientsocket, (client_address, client_port) = serversocket.accept()

    # クライアントソケットもノンブロックモードにする
    clientsocket.setblocking(False)

    # 接続してきたクライアントの情報を出力する
    # ただし、厳密に言えば print() もブロッキング I/O なので避けるべき
    print('New client: {0}:{1}'.format(client_address, client_port))

    # ひとまずクライアントの一覧に入れておく
    connections[clientsocket.fileno()] = (clientsocket,
                                          client_address,
                                          client_port)

    # 次はクライアントのソケットが読み取り可能になるまで待つ
    read_waiters[clientsocket.fileno()] = (recv_handler,
                                           (clientsocket.fileno(), ))

    # 次のクライアントからの接続を待ち受ける
    read_waiters[serversocket.fileno()] = (accept_handler, (serversocket, ))


def recv_handler(fileno):
    """クライアントソケットが読み取り可能になったとき呼ばれるハンドラ"""
    def terminate():
        """クライアントとの接続が切れたときの後始末"""
        # クライアント一覧から取り除く
        del connections[clientsocket.fileno()]
        # ソケットを閉じる
        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))

    # クライアントとの接続情報を取り出す
    clientsocket, client_address, client_port = connections[fileno]

    try:
        # 準備ができているので、すぐに recv() できる
        message = clientsocket.recv(1024)
    except OSError:
        terminate()
        return

    if len(message) == 0:
        terminate()
        return

    print('Recv: {0} to {1}:{2}'.format(message,
                                        client_address,
                                        client_port))

    # 次はクライアントのソケットが書き込み可能になるまで待つ
    write_waiters[fileno] = (send_handler, (fileno, message))


def send_handler(fileno, message):
    """クライアントソケットが書き込み可能になったとき呼ばれるハンドラ"""
    # クライアントとの接続情報を取り出す
    clientsocket, client_address, client_port = connections[fileno]

    # 準備ができているので、すぐに send() できる
    sent_len = clientsocket.send(message)
    print('Send: {0} to {1}:{2}'.format(message[:sent_len],
                                        client_address,
                                        client_port))

    if sent_len == len(message):
        # 全て送ることができたら、次はまたソケットが読み取れるようになるのを待つ
        read_waiters[clientsocket.fileno()] = (recv_handler,
                                               (clientsocket.fileno(), ))
    else:
        # 送り残している内容があったら、再度ソケットが書き込み可能になるまで待つ
        write_waiters[fileno] = (send_handler,
                                 (fileno, message[sent_len:]))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # ソケットをノンブロックモードにする
    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # クライアントからの接続がくるまで待つ
    read_waiters[serversocket.fileno()] = (accept_handler, (serversocket, ))

    while True:
        # ソケットが読み取り・書き込み可能になるまで待つ
        rlist, wlist, _ = select.select(read_waiters.keys(),
                                        write_waiters.keys(),
                                        [],
                                        60)

        # 読み取り可能になったソケット (のファイル番号) の一覧
        for r_fileno in rlist:
            # 読み取り可能になったときに呼んでほしいハンドラを取り出す
            handler, args = read_waiters.pop(r_fileno)
            # ハンドラを実行する
            handler(*args)

        # 書き込み可能になったソケット (のファイル番号の一覧)
        for w_fileno in wlist:
            # 書き込み可能になったときに呼んでほしいハンドラを取り出す
            handler, args = write_waiters.pop(w_fileno)
            # ハンドラを実行する
            handler(*args)


if __name__ == '__main__':
    main()

Python では select(2) システムコールの薄いラッパとして select モジュールが使える。 このモジュールが提供する select() 関数には、ファイルディスクリプタの入ったリストを渡す。

ファイルディスクリプタというのは、名前だけ聞くと難しそうだけどただの整数に過ぎない。 これは、各ソケットやファイルなどを使うときに OS が割り当てた一意な整数を指している。 ようするに 10 とか 20 とかいう数字が、何らかのソケットやファイルなどを表す。 ソケットに割り当てられたファイルディスクリプタは fileno() メソッドで得られる。

select() 関数には、読み込みや書き込みの準備ができたら通知してほしいファイルディスクリプタを渡す。 そして select() 関数を呼び出すと、そこでブロックした後に、準備ができたファイルディスクリプタが返される。 返ってきたファイルディスクリプタは、既に読み書きができるようになっているので指示を出しても例外にはならない。

先ほどのサンプルコードでは、そのようにして準備ができたものに対して読み書きをしている。 ビジーループと比べると CPU を使い切ることもなく、複数のクライアントを処理できる。 また、大きなポイントとしてはシングルスレッドにも関わらず、複数のクライアントを処理できているところだ。 これはソケットをブロッキングで使っていたときとの大きな違いだろう。

ちなみに、今回使った select(2) システムコールにはパフォーマンス上の問題が知られている。 そのため、実用的な用途で使われることはそこまで多くない。 代わりに、BSD 系なら kqueue(2)、Linux であれば epoll(2) が用いられる。 ただし、select(2) なら大抵のプラットフォームで使えるので、それらに比べると移植性が高いというメリットはある。

また「ソケットやファイルなど」と前述した通り、実はブロッキング・ノンブロッキングという概念はソケットに限った話ではない。 ファイルやデバイスについてもノンブロッキングで扱うことはできる。 そして、これはノンブロッキングなソケットプログラミングをする上で重要な意味を持ってくる。 詳細は後述するものの、これはノンブロッキングとブロッキングを同じスレッドで混ぜて使うと問題が発生する、というもの。

尚、前述した通り先ほどのサンプルコードはシングルプロセス・シングルスレッドで動作している。 そのため、複数の CPU コアを使い切ることはできない。 使い切れるようにするときは、マルチプロセスにする必要がある。 もちろん、これは GIL の制約のためにプロセスを複数立ち上げる必要があるに過ぎない。 別の処理系やプログラミング言語であれば、単にマルチスレッドにするだけで良い。 いずれの場合でも、それぞれのスレッドごとにイベントループを用意する。

ノンブロッキング I/O をラップした API やライブラリを使う

先ほどの例ではイベントループのシステムコールを使ってノンブロッキングなソケットを処理してみた。 とはいえ、実際にシステムコールを直接使ってソケットプログラミングする機会は、あまりないと思う。 なぜなら、先ほどのサンプルコードを見て分かる通り、それらの API はそのままでは扱いにくい上にコード量も増えてしまうため。

実際には、イベントループをラップしたライブラリを使ってプログラミングすることになると思う。 どんなライブラリがあるかはプログラミング言語ごとに異なる。 例えば C 言語なら libev が有名だと思うし Python なら Twisted などがある。 また、Python に関しては 3.4 から標準ライブラリに asyncio というモジュールが追加された。 次は、この asyncio を使ってみることにしよう。

Python の asyncio には色んなレイヤーの API が用意されている。 それこそ、先ほどのシステムコールを直接使うのと大差ないようなコードも書ける。 しかし、それだとライブラリを使う意味がないので、もうちょっと抽象度の高いものを使ってみた。 次のサンプルコードでは asyncio を使ってエコーサーバを実装している。 コードを見て分かる通り、先ほどと比べるとだいぶコード量が減って読みやすくなっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio


class EchoServer(asyncio.Protocol):

    def connection_made(self, transport):
        """クライアントからの接続があったときに呼ばれるイベントハンドラ"""
        # 接続をインスタンス変数として保存する
        self.transport = transport

        # 接続元の情報を出力する
        client_address, client_port = self.transport.get_extra_info('peername')
        print('New client: {0}:{1}'.format(client_address, client_port))

    def data_received(self, data):
        """クライアントからデータを受信したときに呼ばれるイベントハンドラ"""
        # 受信した内容を出力する
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Recv: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

        # 受信したのと同じ内容を返信する
        self.transport.write(data)
        print('Send: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

    def connection_lost(self, exc):
        """クライアントとの接続が切れたときに呼ばれるイベントハンドラ"""
        # 接続が切れたら後始末をする
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))
        self.transport.close()


def main():
    host = 'localhost'
    port = 37564

    # イベントループを用意する
    ev_loop = asyncio.get_event_loop()

    # 指定したアドレスとポートでサーバを作る
    factory = ev_loop.create_server(EchoServer, host, port)
    # サーバを起動する
    server = ev_loop.run_until_complete(factory)

    try:
        # イベントループを起動する
        ev_loop.run_forever()
    finally:
        # 後始末
        server.close()
        ev_loop.run_until_complete(server.wait_closed())
        ev_loop.close()


if __name__ == '__main__':
    main()

注目すべきは、もはやソースコードの中に socket モジュールが登場していないところ。 それらは Protocol や Transport といった抽象的なオブジェクトに取って代わられている。

では、本当に内部でイベントループのシステムコールが使われているのかを調べてみよう。 まずは上記のサンプルコードを実行して、エコーサーバを起動する。

$ python asyncioserver.py

続いて、別のターミナルを開いたら上記エコーサーバが動いているプロセスの ID を調べる。

$ ps auxww | grep [a]syncioserver
amedama        31018   0.0  0.2  2430616  17344 s000  S+    7:58PM   0:00.16 python asyncioserver.py

そして、プロセスで発行されるシステムコールをトレースする dtruss コマンドを仕掛ける。

$ sudo dtruss -p 31018

準備ができたらクライアントを接続する。

$ nc localhost 37564

すると、次のように kevent(2) システムコールが発行されていることが分かる。 kevent(2) システムコールは kqueue(2) と共に用いるイベントループのためのシステムコール。

$ sudo dtruss -p 31018
SYSCALL(args)            = return
...
kevent(0x3, 0x0, 0x0)            = 0 0
getsockname(0xA, 0x7FFF50F55B00, 0x7FFF50F55AFC)                 = 0 0
setsockopt(0xA, 0x6, 0x1)                = 0 0
kevent(0x3, 0x0, 0x0)            = 0 0
write(0x1, "New client: 127.0.0.1:51822\n\0", 0x1C)              = 28 0
kevent(0x3, 0x10F8FB6F0, 0x1)            = 0 0
kevent(0x3, 0x0, 0x0)            = 0 0

どうやら、ちゃんと内部がノンブロッキングな世界になっていることが確認できた。 しかも、プラットフォームに応じたパフォーマンスに優れるイベントループをちゃんと使ってくれている。

ノンブロッキングとブロッキングは混ぜるな危険

ちなみに、ノンブロッキングなソケットプログラミングをする上では重要なポイントが一つある。 それは、ノンブロッキングなソケットを扱うスレッドで、ブロッキングな操作をしてはいけない、という点。 もちろん、前述した通りブロッキング・ノンブロッキングという概念はソケットに限った話ではない。 つまり、言い換えるとノンブロッキングな I/O とブロッキングな I/O は同じスレッドで混ぜてはいけない。

二つ前のセクションで登場した select システムコールを使ったサンプルコードを思い出してほしい。 あのサンプルコードでは、シングルスレッドで複数のクライアントをさばいていた。 では、もしその一つしかないスレッドが何処かでブロックしたら、何が起こるだろうか? これは、そのスレッドでさばいている全ての処理が、そこで停止してしまうことを意味する。 これは、ノンブロッキングな I/O を扱う上で登場する代表的な問題の一つ。

どのようなことが起こるかを実際に確かめてみよう。 次のサンプルコードでは、データを受信した際に time.sleep() 関数を使っている。 これには、実行したスレッドを指定した時間だけブロックさせる効果がある。 正に、ノンブロッキングなスレッドへのブロッキングな操作の混入といえる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio
import time


class EchoServer(asyncio.Protocol):

    def connection_made(self, transport):
        self.transport = transport

        client_address, client_port = self.transport.get_extra_info('peername')
        print('New client: {0}:{1}'.format(client_address, client_port))

    def data_received(self, data):
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Recv: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

        # 何らかの処理で、イベントループのスレッドがブロックしてしまった!
        print('Go to sleep...')
        time.sleep(20)

        self.transport.write(data)
        print('Send: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

    def connection_lost(self, exc):
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))
        self.transport.close()


def main():
    host = 'localhost'
    port = 37564

    ev_loop = asyncio.get_event_loop()

    factory = ev_loop.create_server(EchoServer, host, port)
    server = ev_loop.run_until_complete(factory)

    try:
        ev_loop.run_forever()
    finally:
        server.close()
        ev_loop.run_until_complete(server.wait_closed())
        ev_loop.close()


if __name__ == '__main__':
    main()

上記のサンプルコードを実行してエコーサーバを起動しよう。

$ python asyncblock.py

続いて別のターミナルから nc コマンドでサーバに接続したら適当な文字列を入力する。

$ nc localhost 37564
hogehoge

これでイベントループを回しているスレッドはブロックを起こした。

$ python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...

すかさず別のターミナルから nc でクライアントを追加してみよう。

$ nc localhost 37564

すると、今度はサーバに新しいクライアントが追加された旨は表示されない。 エコーサーバ全体の処理が、一箇所で停止してしまっているからだ。

$ python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...

もうしばらく待つと、スレッドのブロックが解除されて新しいクライアントの接続が受理される。

python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...
Send: b'hogehoge\n' to 127.0.0.1:51883
New client: 127.0.0.1:51884

このように、ノンブロッキング I/O を扱うスレッドにブロッキング I/O のコードが混入すると、全てがそこで停止してしまう。

そして、真にこの問題が恐ろしいのは、混入に気づきにくい点かもしれない。 先ほどのサンプルコードは極端な例なので、使ってみるだけでも明確に変化を知覚できた。 しかしながら、実際にはブロッキング I/O の処理は人間にとって一瞬なので気づくことは難しいかもしれない。 にも関わらず、そのタイミングで一連の処理が全て停止していることに間違いはない。 結果として、パフォーマンスの低下をもたらす。

また、世の中のほとんどのライブラリはブロッキング I/O を使って実装されている。 例えば、外部の WebAPI を叩こうとそのまま requests でも使おうものなら、それだけでアウトだ。 それに、HTTP のような分かりやすい I/O 以外にもキューのような基本的な部品であっても操作をブロックしたりする。

つまり、新たに何かを使おうとしたら、それにブロックする操作が混入していないかをあらかじめ調べる必要がある。 さらに、ブロックする操作が含まれると分かったら、それをブロックしないようにする方法を模索しなきゃいけない。 以上のように、イベントループを中心に据えた非同期なフレームワークというのは、一般的な認識よりもずっと扱いが難しいと思う。

ブロッキング I/O が混入する問題へのアプローチについて

ノンブロッキング I/O を扱うスレッドにブロッキング I/O が混ざり込む問題に対するアプローチはいくつかある。 もちろん、混入しないように人間が頑張ってコードを見張る、というのは最も基本的なやり方の一つ。

それ以外には、プログラミング言語のレベルでブロッキング I/O を排除してしまうという選択肢もある。 これは例えば JavaScript (Node.js) が採用している。 Golang も、ネットワーク部分に関してはノンブロッキング I/O しか用意していないらしい。 初めからブロッキング I/O の操作が存在していないなら、そもそも混入することはない。

それ以外には、モンキーパッチを当てるというアプローチもある。 つまり、ブロッキング I/O を使うコードを、全てノンブロッキング I/O を使うように書き換えてしまう。 Python であれば、例えば EventletGevent といったサードパーティー製ライブラリがこれにあたる。

試しに Eventlet を使った例を見てみよう。 まずは Python のパッケージマネージャである pip を使って Eventlet をインストールしておく。

$ pip install eventlet

それでは Eventlet の魔法をお見せしよう。 次のサンプルコードは、最初に示したマルチスレッドの例に、たった二行だけコードを追加している。 その冒頭に追加した二行こそ、まさにモンキーパッチを当てるためのコードになっている。 たったこれだけで、ブロッキングだった世界がノンブロッキングな世界に書き換わってしまう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# 標準ライブラリにモンキーパッチを当てる
# ブロッキング I/O を使った操作が裏側で全てノンブロッキング I/O を使うように書き換えられる
import eventlet
eventlet.monkey_patch()

import socket
import threading


def client_handler(clientsocket, client_address, client_port):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        try:
            message = clientsocket.recv(1024)
            print('Recv: {0} from {1}:{2}'.format(message,
                                                  client_address,
                                                  client_port))
        except OSError:
            break

        if len(message) == 0:
            break

        sent_message = message
        while True:
            sent_len = clientsocket.send(sent_message)
            if sent_len == len(sent_message):
                break
            sent_message = sent_message[sent_len:]
        print('Send: {0} to {1}:{2}'.format(message,
                                            client_address,
                                            client_port))

    clientsocket.close()
    print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        client_thread = threading.Thread(target=client_handler,
                                         args=(clientsocket,
                                               client_address,
                                               client_port))
        client_thread.daemon = True
        client_thread.start()


if __name__ == '__main__':
    main()

本当にイベントループが使われているのか確かめてみることにしよう。 まずは、上記のサンプルコードを実行する。

$ python eventletserver.py

続いて別のターミナルを開いて、上記で実行しているプロセス ID を調べる。

$ ps auxww | grep [e]ventletserver
amedama         7796   0.0  0.1  2426888  19488 s000  S+    8:44PM   0:00.17 python eventletserver.p

dtruss コマンドでプロセス内で発行されるシステムコールをトレースする。

$ sudo dtruss -p 7796

クライアントからサーバに接続してみよう。

$ nc localhost 37564

すると、次のように dtruss の実行結果に kevent システムコールが登場している。 本当に、モンキーパッチを当てるだけでノンブロッキング I/O を使うようになった。

$ sudo dtruss -p 7796
SYSCALL(args)            = return
kevent(0x4, 0x101256710, 0x1)            = 0 0
accept(0x3, 0x7FFF5F8A7750, 0x7FFF5F8A774C)              = 7 0
ioctl(0x7, 0x20006601, 0x0)              = 0 0
ioctl(0x7, 0x8004667E, 0x7FFF5F8A7A04)           = 0 0
ioctl(0x7, 0x8004667E, 0x7FFF5F8A74E4)           = 0 0
write(0x1, "New client: 127.0.0.1:54132\n\0", 0x1C)              = 28 0
recvfrom(0x7, 0x7FE86782BC20, 0x400)             = -1 Err#35
kevent(0x4, 0x101256710, 0x1)            = 0 0
kevent(0x4, 0x7FE866D11020, 0x0)                 = 0 0
accept(0x3, 0x7FFF5F8A7750, 0x7FFF5F8A774C)              = -1 Err#35
kevent(0x4, 0x101256710, 0x1)            = 0 0

注目すべきは、逐次的なプログラミングモデルを保ったまま、それが実現できているところだろう。 asyncio の例でも、データの読み書きなどは逐次的に書くことができたものの、基本はイベントドリブンだった。 しかし Eventlet のコードでは、完全にブロッキング I/O を使っているときと同じように書くことができている。

これが一体どのようにして実現されているかというと、主にグリーンスレッドの寄与が大きい。 Eventlet では、カーネルで実装されたネイティブスレッドの代わりにユーザランドで実装されたグリーンスレッドを用いる。 グリーンスレッドには、実装によってコルーチン、軽量プロセス、協調スレッドなど色々な呼び方がある。

カーネルで実装されたネイティブスレッドとの大きな違いは、コンテキストスイッチのタイミングがプログラムで制御できるところにある。 ネイティブスレッドのコンテキストスイッチはカーネルのスケジューラ次第なので、基本的にプログラムからは制御できない。 それに対し、グリーンスレッドでは実行中のスレッドが自発的に処理を手放さない限りコンテキストスイッチが起こらない。 つまり、I/O などの外的な要因がない限りグリーンスレッドは決定論的に動作することを意味している。

Eventlet では、モンキーパッチを使うと既存のスレッドやソケットがインターフェースはそのままに書き換えられる。 そして、本来ならブロックするコードに処理が到達したタイミングでコンテキストスイッチが起こるように変化する。 コンテキストスイッチする先は、読み書きの準備が整った I/O を処理しているグリーンスレッドだ。 これは、先ほどシステムコールをトレースした通り、イベントループを使って判断している。 そして、コンテキストスイッチした元のグリーンスレッドは、イベントループを使って処理中の I/O が読み書きができるようになるまで待たされる。 ちなみに Golang はプログラミング言語のレベルで上記を実現していて、それは goroutine と呼ばれている。

このアーキテクチャでは、逐次的なプログラミングモデルを保ったままノンブロッキング I/O を使った恩恵が受けられる。 また、グリーンスレッドは一般的にネイティブスレッドよりもコンテキストに必要なメモリのサイズが小さい。 つまり、同時に多くのクライアントをさばきやすい。

ただし、Eventlet のようなモンキーパッチを使ったアプローチには抵抗がある人も多いかもしれない。 実際のところ Eventlet にはクセが全くないとは言えないし、よく分からずに使うのはやめた方が良いと思う。 ただし、名誉のために言っておくと Eventlet は OpenStack のような巨大なプロジェクトでも使われている実績のあるライブラリだ。

ちなみに、モンキーパッチでは一つだけブロッキング I/O の混入を防げないところがある。 それは Python/C API を使って書かれた拡張モジュールだ。 コンパイル済みの拡張モジュールに対しては、個別に対応しない限り自動でモンキーパッチが効くことはない。 これは、典型的には Python/C API で書かれたデータベースドライバで問題になることが多い。

まとめ

今回はソケットプログラミングにおいて、どういったアーキテクチャが考えられるかについて見てきた。 まず、ソケットは大きく分けてブロッキングで使うかノンブロッキングで使うかという選択肢がある。

ブロッキングは、逐次的なプログラミングモデルで扱いやすいことから理解もしやすい。 ただし、複数のクライアントをさばくにはマルチスレッドやマルチプロセスにする必要がある。 それらは必要なコンテキストのサイズやスイッチのコストも大きいことから、スケーラビリティの面で問題となりやすい。

それに対し、ノンブロッキングはイベントドリブンなプログラミングモデルとなりやすいことから理解が難しい。 しかしながら、イベントループを使うことでシングルスレッドでも複数のクライアントを効率的にさばける。

また、ブロッキング・ノンブロッキングというのはソケットに限った概念ではない。 ファイルなども同じようにノンブロッキングで扱うことができる。

ノンブロッキングで I/O を扱うときの注意点としては、同じスレッドをブロックさせてはいけない、というところ。 言い換えると、イベントループを回しているスレッドにブロッキングな I/O のコードを混入させてはいけない。 もし混入するとパフォーマンス低下をもたらす。

ブロッキング I/O が混入する問題に対するアプローチは、言語や処理系、ライブラリによっていくつかある。 例えば JavaScript (Node.js) では、プログラミング言語自体にブロッキング I/O を扱う API がない。 それ以外だと、スクリプト言語ならモンキーパッチで動的に実装を書き換えてしまうというものもある。

参考文献

詳解UNIXプログラミング 第3版

詳解UNIXプログラミング 第3版

Python: KMeans 法を実装してみる

KMeans 法は、機械学習における教師なし学習のクラスタリングという問題を解くためのアルゴリズム。 教師なし学習というのは、事前に教師データというヒントが与えられないことを指している。 その上で、クラスタリングというのは未知のデータに対していくつかのまとまりを作る問題をいう。

今回取り扱う KMeans 法は、比較的単純なアルゴリズムにも関わらず広く使われているものらしい。 実際に書いてみても、基本的な実装であればたしかにとてもシンプルだった。 ただし、データの初期化をするところで一点考慮すべき内容があることにも気づいたので、それについても書く。 KMeans 法の具体的なアルゴリズムについてはサンプルコードと共に後述する。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.3
BuildVersion:   16D32
$ python --version
Python 3.5.3

依存パッケージをインストールする

あらかじめ、今回ソースコードで使う依存パッケージをインストールしておく。 グラフ描画ライブラリの matplotlib を Mac OS X で動かすには、pip でのインストール以外にもちょっとした設定が必要になる。

$ pip install scipy scikit-learn matplotlib
$ mkdir -p ~/.matplotlib
$ cat << 'EOF' > ~/.matplotlib/matplotlibrc
backend: TkAgg
EOF

まずは scikit-learn のお手本から

本来はお手本の例については後回しにしたいんだけど、今回は構成の都合で先に示しておく。 Python の機械学習系ライブラリである scikit-learn には KMeans 法の実装も用意されている。 自分で書いた実装をお披露目する前に、どんなものなのか見てもらいたい。

次のサンプルコードでは、ダミーのデータを生成した上でそれをクラスタリングしている。 scikit-learn にはデータセットを生成する機能も備わっている。 その中でも make_blobs() という関数は、いくつかのまとまりを持ったダミーデータを作ってくれる。 あとは、そのダミーデータを KMeans 法の実装である sklearn.cluster.KMeans で処理させている。 クラスタリングした処理結果は matplotlib を使って可視化した。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.cluster import KMeans


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    # クラスタのセントロイド (重心) を描く
    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

上記のサンプルコードでは、クラスタ数として適当に 5 を選んだ。 ここで「クラスタ数を選ぶ」というのはポイントとなる。 KMeans 法では、あらかじめクラスタ数をハイパーパラメータとして指定するためだ。 つまり、データをいくつのまとまりとして捉えるかを人間が教えてやる必要がある。 クラスタリングの手法によっては、これをアルゴリズムがやってくれる場合もある。

それでは、上記のサンプルコードを実行してみよう。

$ python kmeans_scikit.py

すると、こんな感じのグラフが出力される。 それぞれの特徴ベクトルが所属するクラスタが、別々の色で表現されている。 また、それぞれのクラスタの真ん中らへんにある黒い円は、後述するセントロイドという概念を表している。 f:id:momijiame:20170319081628p:plain ちなみに、生成されるダミーデータは毎回異なるので、出力されるグラフも異なる。

自分で実装してみる

scikit-learn がクラスタリングしたお手本を見たところで、次は KMeans 法を自分で書いてみよう。

まず、KMeans 法ではセントロイド (重心) という概念が重要になる。 セントロイドというのは、文字通りクラスタの中心に置かれるものだ。 アルゴリズムの第一歩としては、このセントロイドを作ることになる。 セントロイドは、前述した通り各クラスタの中心に置かれる。 そのため、セントロイドはハイパーパラメータとして指定したクラスタ数だけ必要となる。 また、中心というのは、クラスタに属する特徴ベクトルの各次元でそれらの平均値にいる、ということ。

そして、データセットに含まれる特徴ベクトルは、必ず最寄りのセントロイドに属する。 ここでいう最寄りとは、特徴ベクトル同士のユークリッド距離が近いという意味だ。 属したセントロイドが、それぞれのクラスタを表している。

上記の基本的な概念を元に KMeans 法のアルゴリズムは次のようなステップに分けることができる

  • 最初のセントロイドを何らかの方法で決める
  • 特徴ベクトルを最寄りのセントロイドに所属させる
  • 各クラスタごとにセントロイドを再計算する
  • 上記 2, 3 番目の操作を繰り返す
  • もしイテレーション前後で所属するセントロイドが変化しなければ、そこで処理を終了する

ここで、一番最初のステップである「最初のセントロイドを何らかの方法で決める」のがポイントだった。 例えば、次のページではあらかじめ各特徴ベクトルをランダムにクラスタリングしてから、セントロイドを決めるやり方を取っている。

tech.nitoyon.com

上記のやり方であれば、次のようなサンプルコードになる。 ここでは KMeans というクラスで KMeans 法を実装している。 インターフェースは、先ほどお手本として提示した scikit-learn と揃えた。 そのため main() 関数は基本的に変わっておらず、見るべきは KMeans クラスだけとなっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
import numpy as np


class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
        # 初期データは各要素に対してランダムにラベルをつける
        pred = np.random.randint(0, self.n_clusters, len(features))

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):

            # 各クラスタごとにセントロイド (重心) を計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

            # 各特徴ベクトルから最短距離となるセントロイドを基準に新しいラベルをつける
            new_pred = np.array([
                np.array([
                    self._euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容を比較して、もし同じなら終了
                break

            pred = new_pred

        return pred

    def _euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python kmeans_random.py

実行すると、次のようなグラフが表示される。 f:id:momijiame:20170319153805p:plain 一見すると、上手くいっているようだ。

しかし、何度か実行すると、次のようなエラーが出ることがある。

$ python kmeans_random.py 
kmeans_random.py:41: RuntimeWarning: Mean of empty slice.
  for i in range(self.n_clusters)])
/Users/amedama/.virtualenvs/kmeans/lib/python3.5/site-packages/numpy/core/_methods.py:73: RuntimeWarning: invalid value encountered in true_divide
  ret, rcount, out=ret, casting='unsafe', subok=False)

表示されるグラフは、次のようなものになる。 f:id:momijiame:20170319153914p:plain 何やら、ぜんぜん上手いことクラスタリングできていない。

上記が起こる原因について調べたところ、セントロイドの初期化に問題があることが分かった。 初期化時に特徴ベクトルに対してランダムにラベルをつけるやり方では、一つの特徴ベクトルも属さないセントロイドが生じる恐れがあるためだ。

ランダムにラベルをつけるやり方では、セントロイドはおのずとデータセット全体の平均あたりに生じやすくなる。 場合によっては、セントロイドが別のセントロイドに囲まれるなどして、一つの特徴ベクトルも属さないセントロイドがでてくる。 サンプルコードでは、そのような事態を想定していなかったのでセントロイドが計算できなくなって上手く動作しなかった。

先ほどの KMeans 法を図示しているページでも、特徴ベクトルの数 N に対してクラスタの数 k を増やして実行してみよう。 一つの特徴ベクトルも属さないクラスタが生じることが分かる。

初期化を改良してみる

一つの特徴ベクトルも属さないクラスタが生じることについて、問題がないと捉えることもできるかもしれない。 その場合には、一つも属さないクラスタのセントロイドを、前回の位置から動かなさないようにすることでケアできそうだ。

とはいえ、ここでは別のアプローチで問題を解決してみることにする。 それというのは、初期のセントロイドをデータセットに含まれるランダムな特徴ベクトルにする、というものだ。 つまり、最初に作るセントロイドの持つ特徴ベクトルがランダムに選んだ要素の特徴ベクトルと一致する。 こうすれば各セントロイドは、少なくとも一つ以上の特徴ベクトルが属することは確定できる。 ユークリッド距離がゼロになる特徴ベクトルが、必ず一つは存在するためだ。

次のサンプルコードでは、初期のセントロイドの作り方だけを先ほどと変更している。 ここでポイントとなるのは、最初にセントロイドとして採用する特徴ベクトルは重複しないように決める必要があるということだ。 そこで、初期のセントロイドは特徴ベクトルをシャッフルした上で先頭から k 個を取り出して決めている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
import numpy as np


class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
        # 要素の中からセントロイド (重心) の初期値となる候補をクラスタ数だけ選び出す
        feature_indexes = np.arange(len(features))
        np.random.shuffle(feature_indexes)
        initial_centroid_indexes = feature_indexes[:self.n_clusters]
        self.cluster_centers_ = features[initial_centroid_indexes]

        # ラベル付けした結果となる配列はゼロで初期化しておく
        pred = np.zeros(features.shape)

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):
            # 各要素から最短距離のセントロイドを基準にラベルを更新する
            new_pred = np.array([
                np.array([
                    self._euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容が同じなら終了
                break

            pred = new_pred

            # 各クラスタごとにセントロイド (重心) を再計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

        return pred

    def _euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

それでは、上記のサンプルコードを実行してみよう。

$ python kmeans_select.py

これまでと変わらないけど、次のようなグラフが表示されるはず。 f:id:momijiame:20170319155439p:plain

今度は N_CLUSTERS を増やしたり、何度やっても先ほどのようなエラーにはならない。

まとめ

今回は教師なし学習のクラスタリングという問題を解くアルゴリズムの KMeans 法を実装してみた。 KMeans 法はシンプルなアルゴリズムだけど、セントロイドをどう初期化するかは流派があるみたい。

はじめてのパターン認識

はじめてのパターン認識

Python: k 近傍法を実装してみる

k 近傍法 (k-Nearest Neighbor algorithm) というのは、機械学習において教師あり学習分類問題を解くためのアルゴリズム。 教師あり学習における分類問題というのは、あらかじめ教師信号として特徴ベクトルと正解ラベルが与えられるものをいう。 その教師信号を元に、未知の特徴ベクトルが与えられたときに正解ラベルを予想しましょう、というもの。

k 近傍法は機械学習アルゴリズムの中でも特にシンプルな実装になっている。 じゃあ、シンプルな分だけ性能が悪いかというと、そんなことはない。 分類精度であれば、他のアルゴリズムに比べても引けを取らないと言われている。 ただし、計算量が多いという重大な欠点がある。 そのため、それを軽減するための改良アルゴリズムも数多く提案されている。

k 近傍法では、与えられた未知の特徴ベクトルを、近い場所にある教師信号の正解ラベルを使って分類する。 特徴ベクトルで近くにあるものは似たような性質を持っているはず、という考え方になっている。 今回は、そんな k 近傍法の基本的な実装を Python で書いてみることにした。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.3
BuildVersion:   16D32
$ python --version
Python 3.5.3

依存パッケージをインストールする

あらかじめ、今回のソースコードで使う依存パッケージをインストールしておく。

$ pip install numpy scipy scikit-learn

最近傍法を実装してみる

k 近傍法では、未知の特徴ベクトルの近くにある k 点の教師信号を用いる。 この k 点を 1 にしたときのことを特に最近傍法 (Nearest Neighbor algorithm) と呼ぶ。 一番近い場所にある教師信号の正解ラベルを返すだけなので、さらに実装しやすい。 そこで、まずは最近傍法から書いてみることにしよう。

次のサンプルコードでは最近傍法を NearestNeighbors というクラスで実装している。 インターフェースは scikit-learn っぽくしてみた。 分類するデータセットは Iris (あやめ) を使っている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from sklearn import datasets
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score


class NearestNeighbors(object):

    def __init__(self):
        self._train_data = None
        self._target_data = None

    def fit(self, train_data, target_data):
        """訓練データを学習する"""
        # あらかじめ計算しておけるものが特にないので保存だけする
        self._train_data = train_data
        self._target_data = target_data

    def predict(self, x):
        """訓練データから予測する"""
        # 判別する点と教師データとのユークリッド距離を計算する
        distances = np.array([self._distance(p, x) for p in self._train_data])
        # 最もユークリッド距離の近い要素のインデックスを得る
        nearest_index = distances.argmin()
        # 最も近い要素のラベルを返す
        return self._target_data[nearest_index]

    def _distance(self, p0, p1):
        """二点間のユークリッド距離を計算する"""
        return np.sum((p0 - p1) ** 2)


def main():
    # Iris データセットをロードする
    iris_dataset = datasets.load_iris()

    # 特徴データとラベルデータを取り出す
    features = iris_dataset.data
    targets = iris_dataset.target

    # LOO 法で汎化性能を調べる
    predicted_labels = []

    loo = LeaveOneOut()
    for train, test in loo.split(features):
        train_data = features[train]
        target_data = targets[train]

        # モデルを学習させる
        model = NearestNeighbors()
        model.fit(train_data, target_data)
        
        # 一つ抜いたテストデータを識別させる
        predicted_label = model.predict(features[test])
        predicted_labels.append(predicted_label)

    # 正解率を出力する
    score = accuracy_score(targets, predicted_labels)
    print(score)


if __name__ == '__main__':
    main()

上記のサンプルコードでは Leave-One-Out 法というやり方で交差検証をしている。

交差検証というのは、学習に使わなかったデータを使って正解を導くことができたか調べる方法を指す。 モデルの性能は、未知のデータに対する対処能力で比べる必要がある。 この、未知のデータに対する対処能力のことを汎化性能と呼ぶ。 交差検証をすることで、この汎化性能を測ることができる。

Leave-One-Out 法では、教師信号の中から検証用のデータをあらかじめ一つだけ抜き出しておく。 そして、それをモデルが正解できるのか調べるやり方だ。 抜き出す対象を一つずつずらしながら、データセットに含まれる要素の数だけ繰り返す。 他の交差検証に比べると計算量は増えるものの、厳密で分かりやすい。

上記のサンプルコードの実行結果は次の通り。

$ python nn.py 
0.96

汎化性能で 96% の正解率が得られた。

scikit-learn を使う場合

ちなみに、自分で書く代わりに scikit-learn にある実装を使う場合も紹介しておく。

次のサンプルコードは k 近傍法の実装を scikit-learn の KNeighborsClassifier に代えたもの。 インターフェースを揃えてあったので、使うクラスが違う以外は先ほどと同じソースコードになっている。 scikit-learn で最近傍法をしたいときは KNeighborsClassifier の k に 1 を指定するだけで良い。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn import datasets
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier


def main():
    iris_dataset = datasets.load_iris()

    features = iris_dataset.data
    targets = iris_dataset.target

    predicted_labels = []

    loo = LeaveOneOut()
    for train, test in loo.split(features):
        train_data = features[train]
        target_data = targets[train]

        model = KNeighborsClassifier(n_neighbors=1)
        model.fit(train_data, target_data)

        predicted_label = model.predict(features[test])
        predicted_labels.append(predicted_label)

    score = accuracy_score(targets, predicted_labels)
    print(score)


if __name__ == '__main__':
    main()

上記のサンプルコードの実行結果は次の通り。

$ python knn_scikit.py 
0.96

当然だけど同じ班化性能になっている。

k 近傍法を実装してみる

先ほど示した最近傍法の実装では、最寄りの教師信号だけを使うものとなっていた。 今度は、より汎用的に近くにある k 点の教師信号を使う実装にしてみる。

次のサンプルコードでは KNearestNeighbors クラスのコンストラクタに k を渡せるようになっている。 実装としては、分類するときに教師信号をユークリッド距離でソートした上で k 個を取り出している。 ひとまず k については 3 を指定した。 もしこれを 1 にすれば最近傍法になる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import Counter

import numpy as np

from sklearn import datasets
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score


class KNearestNeighbors(object):

    def __init__(self, k=1):
        self._train_data = None
        self._target_data = None
        self._k = k

    def fit(self, train_data, target_data):
        """訓練データを学習する"""
        # あらかじめ計算しておけるものが特にないので保存だけする
        self._train_data = train_data
        self._target_data = target_data

    def predict(self, x):
        """訓練データから予測する"""
        # 判別する点と教師データとのユークリッド距離を計算する
        distances = np.array([self._distance(p, x) for p in self._train_data])
        # ユークリッド距離の近い順でソートしたインデックスを得る
        nearest_indexes = distances.argsort()[:self._k]
        # 最も近い要素のラベルを返す
        nearest_labels = self._target_data[nearest_indexes]
        # 近傍のラベルで一番多いものを予測結果として返す
        c = Counter(nearest_labels)
        return c.most_common(1)[0][0]

    def _distance(self, p0, p1):
        """二点間のユークリッド距離を計算する"""
        return np.sum((p0 - p1) ** 2)


def main():
    iris_dataset = datasets.load_iris()

    features = iris_dataset.data
    targets = iris_dataset.target

    predicted_labels = []

    loo = LeaveOneOut()
    for train, test in loo.split(features):
        train_data = features[train]
        target_data = targets[train]

        model = KNearestNeighbors(k=3)
        model.fit(train_data, target_data)

        predicted_label = model.predict(features[test])
        predicted_labels.append(predicted_label)

    score = accuracy_score(targets, predicted_labels)
    print(score)


if __name__ == '__main__':
    main()

上記の実行結果は次の通り。

$ python knn.py 
0.96

汎化性能は k=1 のときと変わらないようだ。

最適な k を探す

k 近傍法では、計算に近傍何点を使うか (ようするに k) がハイパーパラメータとなっている。 ハイパーパラメータというのは、機械学習において人間が調整する必要のあるパラメータのことをいう。

次は、最適な k を探してみることにする。 といっても、やることは単に総当りで探すだけ。

せっかくならパラメータによる汎化性能の違いを可視化したい。 そこで matplotlib も入れておこう。

$ pip install matplotlib
$ mkdir -p ~/.matplotlib
$ cat << 'EOF' > ~/.matplotlib/matplotlibrc
backend: TkAgg
EOF

次のサンプルコードでは k を 1 ~ 20 の間で調整しながら総当りで汎化性能を計算している。 データセットごとに最適な k が異なるところを見ておきたいので Iris (あやめ) と Digits (数字) で調べることにした。 自分で実行するときは、データセットのロード部分にあるコメントアウトを切り替えてほしい。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt

from sklearn import datasets
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier


def main():
    # データセットをロードする
    dataset = datasets.load_digits()
    # dataset = datasets.load_iris()

    # 特徴データとラベルデータを取り出す
    features = dataset.data
    targets = dataset.target

    # 検証する近傍数の上限
    K = 20
    ks = range(1, K + 1)

    # 使う近傍数ごとに正解率を計算する
    accuracy_scores = []
    for k in ks:
        # Leave-One-Out 法で汎化性能を測る
        predicted_labels = []
        loo = LeaveOneOut()
        for train, test in loo.split(features):
            train_data = features[train]
            target_data = targets[train]

            # モデルを学習させる    
            model = KNeighborsClassifier(n_neighbors=k)
            model.fit(train_data, target_data)
    
            # 一つだけ取り除いたテストデータを識別させる
            predicted_label = model.predict(features[test])
            predicted_labels.append(predicted_label)
    
        # 正解率を計算する
        score = accuracy_score(targets, predicted_labels)
        print('k={0}: {1}'.format(k, score))

        accuracy_scores.append(score)

    # 使う近傍数ごとの正解率を折れ線グラフで可視化する
    X = list(ks)
    plt.plot(X, accuracy_scores)

    plt.xlabel('k')
    plt.ylabel('accuracy rate')
    plt.show()
    

if __name__ == '__main__':
    main()

まずはデータセットとして Digits を使ったときから。 実行結果は次のようになる。

$ python knn_tuning.py 
k=1: 0.988313856427379
k=2: 0.986644407345576
k=3: 0.988313856427379
k=4: 0.9877573734001113
k=5: 0.9877573734001113
k=6: 0.9855314412910406
k=7: 0.9855314412910406
k=8: 0.9844184752365053
k=9: 0.9833055091819699
k=10: 0.9821925431274346
k=11: 0.9844184752365053
k=12: 0.9827490261547023
k=13: 0.9844184752365053
k=14: 0.9816360601001669
k=15: 0.9816360601001669
k=16: 0.9805230940456316
k=17: 0.9805230940456316
k=18: 0.9794101279910963
k=19: 0.9766277128547579
k=20: 0.9782971619365609

どうやら Digits のときは k を 1 か 3 にするのが良さそうだ。 f:id:momijiame:20170318133735p:plain

続いて Iris を使ったとき。

$ python knn_tuning.py 
k=1: 0.96
k=2: 0.9466666666666667
k=3: 0.96
k=4: 0.96
k=5: 0.9666666666666667
k=6: 0.96
k=7: 0.9666666666666667
k=8: 0.9666666666666667
k=9: 0.9666666666666667
k=10: 0.9733333333333334
k=11: 0.9733333333333334
k=12: 0.96
k=13: 0.9666666666666667
k=14: 0.9733333333333334
k=15: 0.9733333333333334
k=16: 0.9666666666666667
k=17: 0.9733333333333334
k=18: 0.9733333333333334
k=19: 0.98
k=20: 0.98

今度は全然違うグラフになった。 どうやら Iris なら k は 20 にするのが良いらしい。 もしかすると、さらに増やすと良い可能性もある? f:id:momijiame:20170318133823p:plain

まとめ

今回は Python を使って教師あり学習の分類問題を解くためのアルゴリズムの一つ、k 近傍法を実装してみた。 k 近傍法は単純な割に分類精度は決して低くないものの、計算量が多いという欠点がある。 k 近傍法では、計算に近傍何点を使うのが適しているかはデータセットによって異なる。 そのため、異なる k を使って汎化性能を測定して決定しよう。

ちなみに、計算量の多さを軽減するための手法としては、圧縮型 k 近傍法、分岐限定法、疑似最近傍探索などがあるようだ。 それらについては、機会があれば改めて実装してみたい。

はじめてのパターン認識

はじめてのパターン認識

Python: データセットを標準化する効果を最近傍法で確かめる

データセットの標準化については、このブログでも何回か扱っている。 しかし、実際にデータセットを標準化したときの例については試していなかった。

blog.amedama.jp

blog.amedama.jp

そこで、今回は UCI の提供する小麦 (seeds) データセットを最近傍法で分類したときに、スコアが上がる様を見てみたいと思う。 あらかじめ、いくつかの説明変数が教師信号として与えられるので、そこから小麦の品種を分類 (Classification) する。

UCI Machine Learning Repository: seeds Data Set

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.3
BuildVersion:   16D32
$ python --version
Python 3.5.3

下準備

まずは今回使う Python のパッケージを pip でインストールする。 どれも科学計算系で定番のやつ。

$ pip install numpy pandas scipy scikit-learn

データセットを標準化しない場合

今回使う最近傍法というアルゴリズムでは、分類したい点から最も近くにあるデータの種別を使って分類する。 近さにはユークリッド距離を使うため、データセットの説明変数の大きさや単位に影響を受けやすい。 例えば説明変数の中に 1,000m ~ 10,000m を取る次元と 0.1cm ~ 1cm を取る次元があるとしよう。 当然ながら、ユークリッド距離を計算するとそのままでは前者の次元の影響が大きくなってしまう。 今回使うデータセットでも原理的には同じ問題が発生する。

次のサンプルコードでは、データセットを標準化しない状態で最近傍法を使った分類を実施している。 そして Leave-One-Out 法を使って、計測した汎化性能を出力する。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pandas as pd
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier


def main():
    # 小麦データセットをダウンロードする
    dataset_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt'  # noqa
    df = pd.read_csv(dataset_url, delim_whitespace=True, header=None)

    # データフレームから説明変数と目的変数を取り出す
    features = df.loc[:, :6].get_values()
    targets = df.loc[:, 7].get_values()

    # 予測した結果の入れ物
    predicted_labels = []

    # LOO で交差検証する
    loo = LeaveOneOut()
    for train, test in loo.split(features):
        train_data = features[train]
        target_data = targets[train]

        # k-NN 法を使う
        model = KNeighborsClassifier(n_neighbors=1)
        # 訓練データを学習させる
        model.fit(train_data, target_data)
        # テストデータを予測させる
        predicted_label = model.predict(features[test])
        # 予測した結果を追加する
        predicted_labels.append(predicted_label)

    # 正解率を出力する
    score = accuracy_score(targets, predicted_labels)
    print(score)


if __name__ == '__main__':
    main()

上記のサンプルコードの実行結果は次の通り。 データセットを標準化しない状態では約 90.5% の汎化性能が得られた。

$ python withoutnorm.py
0.904761904762

データセットを標準化する場合

それでは、次はデータセットを標準化する場合を試してみよう。 標準化されたデータセットでは、説明変数の各次元の値が平均 0 で標準偏差が 1 になる。 つまり、元々の単位や大きさは無くなってそれぞれの値の間隔の比率だけが残されることになる。

次のサンプルコードでは、先ほどとデータセットを標準化するところだけ変更している。 データセットの標準化には scikit-learn に用意されている StandardScaler を用いた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pandas as pd
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler


def main():
    dataset_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt'  # noqa
    df = pd.read_csv(dataset_url, delim_whitespace=True, header=None)

    features = df.loc[:, :6].get_values()
    targets = df.loc[:, 7].get_values()

    # データセットを Z-Score に標準化する
    sc = StandardScaler()
    sc.fit(features)
    normalized_features = sc.transform(features)

    predicted_labels = []

    loo = LeaveOneOut()
    for train, test in loo.split(normalized_features):
        train_data = normalized_features[train]
        target_data = targets[train]

        model = KNeighborsClassifier(n_neighbors=1)
        model.fit(train_data, target_data)

        predicted_label = model.predict(normalized_features[test])
        predicted_labels.append(predicted_label)

    score = accuracy_score(targets, predicted_labels)
    print(score)


if __name__ == '__main__':
    main()

上記の実行結果は次の通り。 今度は汎化性能が約 93.8% に上昇している。 データセットを標準化するだけで分類精度が 3.3% 上がった。

$ python withnorm.py
0.938095238095

まとめ

データセットを標準化することで、使う機械学習アルゴリズムによっては汎化性能を上げることができる。