読者です 読者をやめる 読者になる 読者になる

CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: SQLAlchemy の生成する SQL をテストするパッケージを作ってみた

Mac OS X Python SQLAlchemy SQLite

SQLAlchemy は Python でよく使われている O/R マッパーの一つ。 今回は、そんな SQLAlchemy が生成する SQL 文を確認するためのパッケージを作ってみたよ、という話。

具体的には、以下の sqlalchemy-profile というパッケージを作ってみた。 このエントリでは、なんでこんなものを作ったのかみたいな話をしてみる。

github.com

使った環境は次の通り。 ただし sqlalchemy-profile 自体はプラットフォームに依存せず Python 2.7, 3.3 ~ 3.6 に対応している。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

O/R マッパーについて

O/R マッパーというのは、プログラミング言語からリレーショナルデータベース (RDB) を良い感じに使うための機能ないしライブラリの総称。 プログラミング言語から RDB を操作するための SQL 文を直に扱ってしまうと、両者のパラダイムの違いから色々な問題が起こる。 この問題は、一般にインピーダンスミスマッチと呼ばれている。 そこで登場するのが O/R マッパーで、これを使うとプログラミング言語のオブジェクトを操作する形で RDB を操作できるようになる。

論よりソースということで、まずは SQLAlchemy の基本的な使い方から見てみよう。 その前に SQLAlchemy 自体をインストールしておく。

$ pip install sqlalchemy

そして次に示すのがサンプルコード。 ユーザ情報を模したモデルクラスを用意して、それを SQLite のオンメモリデータベースで永続化している。 この中には SQL 文が全く登場していないところがポイントとなる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

Base = declarative.declarative_base()


class User(Base):
    """SQLAlchemy のモデルクラス

    このクラスが RDB のテーブルと対応し、インスタンスはテーブルの一レコードに対応する
    ここではユーザの情報を格納するテーブルを模している"""
    __tablename__ = 'users'

    # テーブルの主キー
    id = Column(Integer, primary_key=True)
    # 名前を入れるカラム
    name = Column(Text, nullable=False)


def main():
    # データベースとの接続に使う情報
    # ここでは SQLite のオンメモリデータベースを使う
    # echo=True とすることで生成される SQL 文を確認できる
    engine = create_engine('sqlite:///', echo=True)
    # モデルの情報を元にテーブルを生成する
    Base.metadata.create_all(engine)
    # データベースとのセッションを確立する
    session_maker = sessionmaker(bind=engine)
    session = session_maker()

    # データベースのトランザクションを作る
    with session.begin(subtransactions=True):
        # レコードに対応するモデルのインスタンスを作る
        user = User(name='Alice')
        # そのインスタンスをセッションに追加する
        session.add(user)

    # トランザクションがコミットされてオブジェクトが RDB で永続化される

if __name__ == '__main__':
    main()

上記のサンプルコードでは生成される SQL 文を標準出力に表示するようにしている。 なので、実行するとこんな感じの出力が得られる。

2017-04-20 04:48:30,976 INFO sqlalchemy.engine.base.Engine SELECT CAST('test plain returns' AS VARCHAR(60)) AS anon_1
2017-04-20 04:48:30,976 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,978 INFO sqlalchemy.engine.base.Engine SELECT CAST('test unicode returns' AS VARCHAR(60)) AS anon_1
2017-04-20 04:48:30,978 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,980 INFO sqlalchemy.engine.base.Engine PRAGMA table_info("users")
2017-04-20 04:48:30,980 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,982 INFO sqlalchemy.engine.base.Engine 
CREATE TABLE users (
    id INTEGER NOT NULL, 
    name TEXT NOT NULL, 
    PRIMARY KEY (id)
)


2017-04-20 04:48:30,983 INFO sqlalchemy.engine.base.Engine ()
2017-04-20 04:48:30,984 INFO sqlalchemy.engine.base.Engine COMMIT
2017-04-20 04:48:30,987 INFO sqlalchemy.engine.base.Engine BEGIN (implicit)
2017-04-20 04:48:30,989 INFO sqlalchemy.engine.base.Engine INSERT INTO users (name) VALUES (?)
2017-04-20 04:48:30,989 INFO sqlalchemy.engine.base.Engine ('Alice',)

たしかに Python のオブジェクトを使うだけで RDB を操作できた。便利。 ただし、上記で使った生成した SQL 文を出力する機能はデバッグ用途なので普段は無効にされる場合が多い。

SQL が隠蔽されることのメリットとデメリット

先ほど見た通り O/R マッパーを使うと Python のオブジェクトを通して RDB を操作できるようになる。 これにはインピーダンスミスマッチの解消という多大なメリットがある反面、生成される SQL が隠蔽されるというデメリットもある。

例えば、直接 SQL を書くならそんな非効率なクエリは組まないよね・・・というような内容も、気をつけていないと生成されうる。 これは、典型的には N + 1 問題とか。 それを防ぐには、これまでだとコードから生成される SQL 文を推測したり、あるいは先ほどのようにして実際に目で見て確かめていた。 慣れてくるとどんな SQL 文が発行されるか分かってくるのと、実際に目で見て確かめるのは手間なので大体は前者になっている。

ただ、パフォーマンスチューニングの世界では、推測する前に測定せよという格言もある。 実際に生成される SQL 文を、ユニットテストで確認できるようになっているべきなのでは、という考えに至った。 それが、今回作ったパッケージ sqlalchemy-profile のモチベーションになっている。

ただし、どんな SQL 文が生成されるかは SQLAlchemy のアルゴリズム次第なので、気をつけないとテストのメンテナンス性が低下する恐れはあると思う。 これは、SQLAlchemy のバージョン変更とか、些細なモデルの構造変更でテストを修正する手間がかかるかも、ということ。 とはいえ、それはそれで生成される SQL が変更されたことにちゃんと気づけるのは大事じゃないかという感じでいる。

sqlalchemy-profile について

やっと本題に入るんだけど、前述した問題を解消すべく sqlalchemy-profile という Python のパッケージを作ってみた。 これを使うことで、SQLAlchemy が生成する SQL 文を確かめることができる。

Python のパッケージリポジトリである PyPI にも登録しておいた。 pypi.python.org

インストールは Python のパッケージマネージャの PIP からできる。

$ pip install sqlalchemy-profile

使い方

ここからは sqlalchemy-profile の具体的な使い方について見ていく。 シンプルなのでサンプルコードをいくつか見れば、すぐに分かってもらえると思う。 ちなみに、トラッキングしている SQL 文は今のところ INSERT, UPDATE, SELECT, DELETE の四つ。

以下のサンプルコードでは、最も基本的な使い方を示している。 まず、プロファイラとなる StatementProfiler には SQLAlchemy のデータベースとの接続情報を渡す。 そして、プロファイルしている期間中に実行された SQL 文を記録する、というもの。 ユニットテストで利用することを意図しているので、サンプルコードも Python の unittest モジュールを使うものにしてみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import StatementProfiler


class Test_RawExecute(unittest.TestCase):

    def test(self):
        # データベースとの接続を確立する
        engine = create_engine('sqlite:///')
        connection = engine.connect()

        # データベースとの接続情報を渡してプロファイラをインスタンス化する
        profiler = StatementProfiler(engine)
        # プロファイルを開始する
        profiler.start()

        # SQLAlchemy を使って RDB を操作する
        # ここでは、サンプルコードをシンプルにする目的で低レイヤーな API を使っている
        connection.execute('SELECT 1')
        connection.execute('SELECT 2')

        # プロファイルを停止する
        profiler.stop()

        # 実行された SQL 文の内容を確認する
        assert profiler.count == 2
        assert profiler.select == 2


if __name__ == '__main__':
    unittest.main()

上記では、分かりやすくするためにあえて SQLAlchemy の直接 SQL 文を扱う低レイヤーな API を使っている。

上記を実行するとテストがパスする。

$ python profile101.py 
.
----------------------------------------------------------------------
Ran 1 test in 0.020s

OK

このとき assert しているところの数値を変更すると、当然だけどテストは失敗するようになる。 想定していた SQL 文の数と、実際に発行された数が一致しないため。

$ python profile101.py
F
======================================================================
FAIL: test (__main__.Test_RawExecute)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "profile101.py", line 32, in test
    assert profiler.count == 1
AssertionError

----------------------------------------------------------------------
Ran 1 test in 0.017s

FAILED (failures=1)

O/R マッピングと共に使う

先ほどの例では、分かりやすさのためにあえて SQLAlchemy の直接 SQL 文を扱う低レイヤーな API を使っていた。 もちろん sqlalchemy-profile は O/R マッピングをしたコードでも動作するし、使い方については何も変わらない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    """ユーザ情報を模したモデルクラス"""
    __tablename__ = 'users'

    # 主キー
    id = Column(Integer, primary_key=True)
    # 名前を格納するカラム
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        """テストが実行される前の下準備"""
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        """テストが実行された後の後始末"""
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        # 以下のユーザを模したインスタンスを一通り CRUD していく
        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        # UPDATE
        with session.begin(subtransactions=True):
            user.name = 'Bob'

        # SELECT
        session.query(_User).all()

        # DELETE
        with session.begin(subtransactions=True):
            session.delete(user)

        profiler.stop()

        # SQL 文は各一回ずつ実行されているはず
        assert profiler.count == 4
        assert profiler.insert == 1
        assert profiler.update == 1
        assert profiler.select == 1
        assert profiler.delete == 1


if __name__ == '__main__':
    unittest.main()

with ステートメントと共に使う

これまでの例ではプロファイリング期間を start() メソッドと stop() メソッドで制御したけど、これは with でも代用できる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import StatementProfiler


class Test_WithStatement(unittest.TestCase):

    def test(self):
        engine = create_engine('sqlite:///')
        connection = engine.connect()

        # with ステートメントのスコープで実行された SQL 文を記録する
        with StatementProfiler(engine) as profiler:
            connection.execute('SELECT 1')
            connection.execute('SELECT 2')

        assert profiler.count == 2
        assert profiler.select == 2


if __name__ == '__main__':
    unittest.main()

デコレータと共に使う

with を使うのもめんどくさいなー、というときはテストメソッド自体をデコレータで修飾しちゃうような使い方もできる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import sqlprofile

ENGINE = create_engine('sqlite:///')


class Test_Decorator(unittest.TestCase):

    # ユニットテストのメソッドをデコレータで修飾する
    # メソッド内で実行されることが想定される SQL 文の数を指定する
    @sqlprofile(ENGINE, count=2, select=2)
    def test(self):
        connection = ENGINE.connect()

        connection.execute('SELECT 1')
        connection.execute('SELECT 2')


if __name__ == '__main__':
    unittest.main()

SQL 文の種類と順序まで確認したい

いやいや回数だけのアサーションとかアバウトすぎるでしょ、っていうときは StatementProfiler#sequence を使う。 これで INSERT, UPDATE, SELECT, DELETE が、どんな順番で実行されたかを確認できる。 中身は文字列で、それぞれの操作の頭文字が入っている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True)
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        # UPDATE
        with session.begin(subtransactions=True):
            user.name = 'Bob'

        # SELECT
        session.query(_User).all()

        # DELETE
        with session.begin(subtransactions=True):
            session.delete(user)

        profiler.stop()

        # [I]NSERT -> [U]PDATE -> [S]ELECT -> [D]ELETE
        assert profiler.sequence == 'IUSD'


if __name__ == '__main__':
    unittest.main()

もちろんデコレータの API でも使える。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy import create_engine

from sqlalchemy_profile import sqlprofile

ENGINE = create_engine('sqlite:///')


class Test_Decorator(unittest.TestCase):

    # SELECT -> SELECT = SS
    @sqlprofile(ENGINE, seq='SS')
    def test(self):
        connection = ENGINE.connect()

        connection.execute('SELECT 1')
        connection.execute('SELECT 2')


if __name__ == '__main__':
    unittest.main()

もっと厳密にアサーションしたい

いやいや SQL 文の構造までもっと調べたいよ、というときは StatementProfiler#statementsStatementProfiler#statements_with_parameters を使う。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from sqlalchemy.ext import declarative
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import create_engine
from sqlalchemy.orm.session import sessionmaker

from sqlalchemy_profile import StatementProfiler

Base = declarative.declarative_base()


class _User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True)
    name = Column(Text, nullable=False)


class Test_ORMapping(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine('sqlite:///')
        Base.metadata.create_all(self.engine)
        self.session_maker = sessionmaker(bind=self.engine)

    def tearDown(self):
        Base.metadata.drop_all(self.engine)

    def test(self):
        session = self.session_maker()

        profiler = StatementProfiler(self.engine)
        profiler.start()

        user = _User(name='Alice')

        # INSERT
        with session.begin(subtransactions=True):
            session.add(user)

        profiler.stop()

        assert profiler.count == 1
        assert profiler.insert == 1

        # 生の SQL 文を取得する
        print(profiler.statements)
        print(profiler.statements_with_parameters)


if __name__ == '__main__':
    unittest.main()

こんな感じ。

['INSERT INTO users (name) VALUES (?)']
[('INSERT INTO users (name) VALUES (?)', ('Alice',))]
.
----------------------------------------------------------------------
Ran 1 test in 0.019s

OK

こちらは、今のところデコレータの API では使えない。

まとめ

  • SQLAlchemy の生成する SQL 文を確認するための sqlalchemy-profile というパッケージを作ってみた
  • O/R マッピングをすると、生成される SQL 文をプログラマが把握しにくくなる
  • 非効率なクエリをコードや実行結果から目で見て確認するのは手間がかかる
  • sqlalchemy-profile を使うことで実行される SQL 文をユニットテストで確認できるようになる

もしかすると似たようなことができるパッケージが既にあるかも。

Python: 相関行列を計算してヒートマップを描いてみる

Mac OS X Python matplotlib scikit-learn seaborn NumPy

以前、このブログで相関係数について解説した記事を書いたことがある。 相関係数というのは、データセットのある次元とある次元の関連性を示すものだった。

blog.amedama.jp

この相関係数を、データセットの各次元ごとに計算したものを相関行列と呼ぶ。 データ分析の世界では、それぞれの次元の関連性を見るときに、この相関行列を計算することがある。 また、それを見やすくするためにヒートマップというグラフを用いて図示することが多い。

今回は Python を使って相関行列を計算すると共にヒートマップを描いてみることにした。

使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.5.3

下準備

今回は、相関行列の計算には NumPy を、グラフの描画には seaborn を使うのでインストールしておく。 最後の scikit-learn については、相関行列の計算に使うデータセットを読み込むためだけに使っている。

$ pip install seaborn numpy scikit-learn

相関行列を計算してみる

まずはヒートマップの描く以前に相関行列の計算から。 データセットにはみんな大好きアイリス (あやめ) データセットを用いる。 これは 150 行 4 次元の構造になっている。

>>> from sklearn import datasets
>>> dataset = datasets.load_iris()
>>> features = dataset.data
>>> features.shape
(150, 4)

相関行列の計算には NumPycorrcoef() 関数が使える。 この関数には、相関行列を計算したい次元をリストの形で渡す。 すごくベタに書くとしたら、こんな感じ。

>>> import numpy as np
>>> np.corrcoef([features[:, 0], features[:, 1], features[:, 2], features[:, 3]])
array([[ 1.        , -0.10936925,  0.87175416,  0.81795363],
       [-0.10936925,  1.        , -0.4205161 , -0.35654409],
       [ 0.87175416, -0.4205161 ,  1.        ,  0.9627571 ],
       [ 0.81795363, -0.35654409,  0.9627571 ,  1.        ]])

上記には、それぞれの次元ごとの相関係数が格納されている。 対角要素が全て 1 になっているのは、全く同じデータ同士の相関係数は 1 になるため。

ただ、実際には上記のようなベタ書きをする必要はない。 次元ごとにリストで、というのはようするに行と列を入れ替えれば良いということ。 つまり M 行 N 列のデータなら N 行 M 列に直して渡すことになる。

これは NumPy 行列なら transpose() メソッドで実現できる。

>>> features.transpose().shape
(4, 150)

ようするに、こうなる。

>>> np.corrcoef(features.transpose())
array([[ 1.        , -0.10936925,  0.87175416,  0.81795363],
       [-0.10936925,  1.        , -0.4205161 , -0.35654409],
       [ 0.87175416, -0.4205161 ,  1.        ,  0.9627571 ],
       [ 0.81795363, -0.35654409,  0.9627571 ,  1.        ]])

相関行列はこれで計算できた。

ヒートマップを描いてみる

続いては、先ほどの相関行列をヒートマップにしてみよう。 Python のグラフ描画ライブラリの seaborn には、あらかじめヒートマップを描くための API が用意されている。

次のサンプルコードではアイリスデータセットの相関行列をヒートマップで図示している。 それぞれの行の説明についてはコメントで補足している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn import datasets


def main():
    # アイリスデータセットを読み込む
    dataset = datasets.load_iris()

    features = dataset.data
    feature_names = dataset.feature_names

    # N 行 M 列を M 行 N 列に変換して相関行列を計算する
    correlation_matrix = np.corrcoef(features.transpose())

    # 相関行列のヒートマップを描く
    sns.heatmap(correlation_matrix, annot=True,
                xticklabels=feature_names,
                yticklabels=feature_names)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を実行すると、次のようなグラフが得られる。

f:id:momijiame:20170418225256p:plain

このヒートマップでは、正の相関が強いものほど赤く、負の相関が強いものほど青く図示されている。 相関がないものについては色が薄いことから白に近づくことになる。 上記を見ると、アイリスデータセットには相関の強い次元が多いことが分かる。

主成分分析した結果の相関行列でヒートマップを描いてみる

ここで一つ気になったことを試してみることにした。 主成分分析した結果を相関行列にしてヒートマップで描いてみる、というものだ。 主成分分析とは何ぞや、ということは以下のブログエントリで書いた。

blog.amedama.jp

理屈の上では、主成分分析した結果は互いに直交した次元になるため相関しないはず。 これを相関行列とヒートマップで確かめてみよう、ということ。

次のサンプルコードでは、アイリスデータセットを主成分分析している。 そして、分析した内容に対して相関行列を計算してヒートマップを描いた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.decomposition import PCA


def main():
    # アイリスデータセットを読み込む
    dataset = datasets.load_iris()

    features = dataset.data

    # 特徴量を主成分分析する
    pca = PCA()
    pca.fit(features)

    # 分析にもとづいて特徴量を主成分に変換する
    transformed_features = pca.fit_transform(features)

    # 主成分の相関行列を計算する
    correlation_matrix = np.corrcoef(transformed_features.transpose())

    # 主成分の相関行列をヒートマップで描く
    feature_names = ['PCA{0}'.format(i)
                     for i in range(features.shape[1])]
    sns.heatmap(correlation_matrix, annot=True,
                xticklabels=feature_names,
                yticklabels=feature_names)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

上記を実行すると、次のようなグラフが得られる。

f:id:momijiame:20170418225930p:plain

見事に真っ白で、互いに相関が全然ないことが分かる。 理屈通りの結果になった。

まとめ

  • 二つの次元の関連性を調べるには相関係数を用いる
  • データセットに含まれる全ての次元で相関係数を計算したものを相関行列と呼ぶ
  • 相関行列はヒートマップというグラフで図示することが多い
  • 主成分分析した結果の相関行列は対角要素を覗いてゼロになる

めでたしめでたし。

フレッツ回線が遅すぎる問題を IPv6/IPoE と DS-Lite で解決した

IPv6 Linux

最近というほど最近でもないんだけど、近頃はとにかくフレッツ回線のスループットが出ない。 下手をすると、モバイルネットワークの方が速いので時間帯によってはテザリングをし始めるような始末だった。 今回は、そんなスループットの出ないフレッツ回線を何とか使い物になるようにするまでの流れを書いてみる。

先に断っておくと、今回はいつものような特定の技術に関する解説という側面は強くない。 思考の過程なども含んでいるので、いつもより読み物的な感じになっていると思う。 調べ物をして、それらについて理解した内容のまとめになっている。

結論から書いてしまうと、今回のケースでは IPv6/IPoE 接続と DS-Lite を使って何とかなった。 DS-Lite というのはゲーム端末ではなくて IPv4/IPv6 共存技術の一つである RFC6333 (Dual-Stack Lite Broadband Deployments Following IPv4 Exhaustion) の通称を指す。

そして、背景とか技術的な内容が不要なら下の方にある「利用までの流れ」から読み始めてもらうと手っ取り早いと思う。

フレッツ回線が遅い原因

まず、一般的に「フレッツが遅い」と言われる主な原因は、フレッツ網内にある PPPoE の終端装置がボトルネックになっているため。 技術に明るい個人ユーザレベルでは周知の事実みたいだけど、ISP なんかが公式でこの問題について言及しているものにも以下があった。 この記事を読むと、どのような設備を通ってパケットがインターネットまで届くのかが理解できる。

techlog.iij.ad.jp

ようするに、フレッツ網と ISP の接続境界 (POI: Point of Interface) が詰まっている。 この場合、フレッツ網内も ISP のバックボーン内も転送能力には余裕があるからタチが悪い。 フレッツ網内だけでスループットを測定しても、バックボーン内だけで測定しても、どちらも何ら問題は出ない。 また、終端装置の設備は都道府県ごと・ISP ごとに用意されているので、状況はまちまちらしい。 とはいえ、都内では特に多くの ISP で上記の輻輳が起きているようだ。

そんな背景もあって、これまで遅いなと感じたときには PPPoE のセッションを張り直してみることもあった。 たまに、比較的空いている終端装置につながることもあって、そうするとスループットが改善したりする。 PPPoE ガチャ。

じゃあ、ボトルネックになっている終端装置を増強すれば良いじゃないか、という話になるんだけど、それがどうやら難しいらしい。 何故なら、ISP が終端装置を増強したくても、その権限は NTT 側にあるため。 NTT 的には PPPoE セッション数に応じて設備増強をするポリシーでいるようで、ISP から要望があっても拒んでいる。 これは、総務省が収集している NGN (フレッツ光ネクストのこと) に対するパブリックコメントを見ると分かる。

www.soumu.go.jp

ここには ISP の NTT に対する恨みつらみが書かれていて、収集された意見に対して「終端装置」で検索をかけると、それが分かる。

インターネット上のコンテンツはリッチ化を続けており、一人あたりのトラフィック量は急速に増え続けている。

www.soumu.go.jp

そうした状況の中で、フレッツ網の設備増強基準が時代にそぐわないものになっているんだろう。

解決策について

話がだいぶ大きくなったけど、ここからは個人の話になる。 この問題の解決策としては、当初いくつかの選択肢が考えられた。

まず一つ目は、そもそも別の回線キャリアに乗り換えてしまうこと。 これは、例えば NURO のようなサービスを指している。 ただ、調べていくとこのやり方だとかなりお金がかかることが分かった。 具体的には、初期工事費用が万単位でかかるのと、月額料金も現状よりだいぶ上がりそうな感じ。 初期工事費用については乗り換えキャンペーンやキャッシュバックで補填するとしても、月額料金は如何ともしがたい。

二つ目は、そもそもフレッツ回線を解約してしまうこと。 モバイルネットワークはそれなりにスループットが出ているので、これに一本化してしまう。 必要な経費は、固定回線っぽく使うために必要な SIM フリーのモバイルルータを購入する資金のみ。 具体的には、以下のようなやつ。

指す SIM カードについては MVNO で複数枚出せるプランに加入しているので問題なし。 ただ、これだと転送量上限がある点に不安を覚えた。 自宅で技術的な検証をしていると数 GB のファイルとかダウンロードする機会もあるし、これはまずい。

三つ目は、今あるフレッツ回線をなんとかする方法。 上記の記事でも挙げられているけど、ようするにフレッツ網内にある PPPoE の終端装置さえ迂回できればボトルネックは解消できる。 この迂回する方法としては IPoE 接続を使うやり方が挙げられる。

IPoE について

従来、フレッツ網で IPv4 の通信をするときは PPPoE というトンネリングプロトコルが使われている。 これは、主にユーザの認証のために用いられる。 このプロトコルの終端装置がフレッツ網内でボトルネックになっていることは前述した通り。 この通信方式を、以後は IPv4/PPPoE と記述する。

それに対して IPv6 の通信では、上記の PPPoE を使うやり方と IPoE という二つの方法がある。 これらの通信方式を、以後はそれぞれ IPv6/PPPoE と IPv6/IPoE と記述する。 その中でも IPv6/IPoE については現時点で輻輳しておらず快適らしい。

IPoE というのは Internet Protocol over Ethernet の略で、トンネリングプロトコルを使わず直にイーサネットを扱うやり方を指す。 IPoE については以下のサイトが分かりやすい。

IPv6 IPoEの仕組み:Geekなぺーじ

また、これは副次的なものだけど IPv6/IPoE を使うと下りの速度が IPv4/PPPoE よりもサービススペックとして上がる場合もある。 例えば、フレッツ光ネクストのマンション・ハイスピードタイプを使う場合を考えてみよう。 以下の公式サイトを見ると、但し書きで IPv6/IPoE は受信速度が 1Gbps と書かれている。

flets.com

ちなみに、上りと下りの両方が 1Gbps のプランを契約すると、だいたい月額料金が ¥500 は上がる。 それに対し、月額料金はそのままで下りに関しては 1Gbps で使えるのはでかい。

ただし、全ての ISP が IPv6/IPoE に対応しているわけではない。 例えば、東京で IPv6/IPoE が使える ISP は次の通り。 それぞれの回線プランごとにも使える・使えないなどが細かく分かれている。

flets.com

IPv6/IPoE を使えば全て解決?

また、IPv6/IPoE さえ使うことができれば全て解決かというと、そうもいかない。 なぜかというと、ちまたのサービスによって IPv6 への対応状況がまちまちだから。 IPv6 に対応していないホストへの通信は IPv4/PPPoE を経由するので今まで通りのスループットしか出ない。

それに、IPv4 と IPv6 のデュアルスタックな環境で、どちらが使われるかは設定や OS の実装にも影響を受ける。 低速な IPv4/PPPoE と高速な IPv6/IPoE のどちらも使えるようになっているときに IPv4/PPPoE が優先して使われてしまうと、遅いままになる。

IPv4 のトラフィックも IPoE に流したい

そこで考えられるのが IPv4 の通信も IPv6/IPoE 経由で流してしまう、という発想。 これには、元々は IPv4 の縮小期に使うために検討されていた技術を応用する。 色々な種類があるんだけど、日本では主に MAP-EDS-Lite が使われている。 これは、一般に IPv4 over IPv6 トンネルと呼ばれるもので IPv4 のパケットを IPv6 のパケットで包んで送ってしまう手法になっている。

MAP-E と DS-Lite それぞれの手法の違いは、次のサイトが詳しい。 ようするに NAPT (Network Address Port Translation) が CPE (Customer Private Edge) にあるか CGN (Carrier Grade NAT) になるかの違い。

techlog.iij.ad.jp

IPv6/IPoE を提供しているのは誰?

ちょっと話が巻き戻るんだけど、そもそも IPv6/IPoE を使ったインターネットへの疎通性を提供しているのは ISP ではない。 これは、いわゆるネイティブ接続事業者 (VNE: Virtual Network Enabler) が担っている。 VNE は最近まで三社だけに限定されていて、それぞれの固定回線キャリアが出資している IX (Internet eXchange) 事業者が同時に VNE も担っていた。 ISP は、そのいずれかと契約することで IPv6/IPoE をユーザが使えるようにしている。

そんな VNE なんだけど、事業者によって IPv4 を IPv6 を通して使うために採用している技術はそれぞれ異なる。 例えば JPNE (KDDI 系) は MAP-E を、マルチフィード (NTT 系) は DS-Lite をそれぞれ採用している。 BBIX (ソフトバンク系) については公開されているドキュメントを調べても、具体的に何を使っているかは分からなかった。

JPNE は MAP-E を「v6 プラス」というサービス名で提供している。

v6プラス | サービス紹介 | JPNE | 日本ネットワークイネイブラー株式会社

マルチフィードは DS-Lite を「IPv4インターネット接続オプションサービス」というサービス名で提供している。

サービスのご案内 - transixサービス|インターネットマルチフィード株式会社

BBIX は「IPv6 IPoE + IPv4 ハイブリッドサービス」というサービスを提供している。 IPv4 over IPv6 トンネルを使っている、ということは分かるんだけど具体的な方式は分からない。

IPv6 IPoE + IPv4 ハイブリッドサービス サービス概要|ソリューション|BBIX株式会社

つまり、まとめるとこう。

VNE サービス名 方式
JPNE v6 プラス MAP-E
マルチフィード IPv4インターネット接続オプションサービス DS-Lite
BBIX IPv6 IPoE + IPv4 ハイブリッドサービス IPv4 over IPv6 トンネル (具体的な方式は不明)

なので、続いては IPv6/IPoE を使うにしても ISP が契約している VNE が何処かを確認したい。 これは、それぞれの方式によって用意すべきルータなどが異なるため。

契約している VNE については、公開している ISP もあれば非公開にしている ISP もある。 もし、非公開な ISP だと上記のサービスが使えたとしても公式なサポートではなくなる。 つまり、いつ通信できなくなっても文句は言えない自己責任ということ。

ちなみに、今回自分が使った ISP も契約している VNE は非公開だった。 ただし、ウェブサイトでの下調べや、使ってみた結果としてはマルチフィードであることは明らかだった。 つまり、DS-Lite に対応したルータを用意すれば良い、ということになる。

とはいえ、公式には非公開とされているので、ここでは具体的にどの ISP を使っているか書くことは控えることにする。 参考までにということであれば、個別に聞いてもらえたらお伝えします。 IPoE を申し込んでも初期費用や追加の月額料金がかからない点は良かったと思う。

ちなみに、元々は 3 社に限られていた VNE 事業者の枠は、最近になって 16 社にまで拡大された。 そのため、ISP 自身が直接 VNE 事業に乗り出す例も出てきている。 具体的には、以下。

「IPv6接続機能」提供開始のご案内|プロバイダ ASAHIネット |料金、接続品質、満足度で比較して選ばれるISP

ただし、そのような新規参入した事業者が今後どういった IPv4/IPv6 共存技術を採用するのかは分からない。 同種の技術としては MAP-E や DS-Lite 以外にも 464XLATMAP-T4rd/SAM といった選択肢もある。 また、いつ頃から提供し始めるのかも分からない。 そもそも、そのようなサービスを VNE が提供しなければいけない、といったような決まりは存在しないので。

ここらへんは、どうも以前に「VNE はみんなで 4rd/SAM を使っていこう」みたいな流れがあったっぽいんだけど。 今はどうしてこうなったのかな。

NTT NGNネイティブ方式上でIPv4 over IPv6提供へ:Geekなぺーじ

利用までの流れ

続いては IPv6/IPoE と DS-Lite を使えるようにするまでに実施した流れなどについて書く。 ちなみに IPoE についての知見は、どうやら何処のサポートセンターにもあまりたまっていないらしい。 そのため、問い合わせて返ってきた答えが間違っていたということも平気で起こる、というか起こった。

また、不安に思う人もいるかもしれないのであらかじめ書いておくと、IPv6/IPoE + DS-Lite を一旦使い始めたからといって従来の IPv4/PPPoE が使えなくなるわけではない。 両者は二者択一というわけではなく ISP によって IPv4/PPPoE の機能は提供され続ける。 これは、あくまで VNE が IPv6/IPoE を通して IPv4 の通信も通せるようにする機能を「オプションとして」提供しているに過ぎないため。 仮に何らかの問題があったとしても (遅いとはいえ) IPv4/PPPoE に戻すことはできる。 さらに言えば、例えば複数のルータを使ったり業務用ルータを使うなどの工夫をすれば、両者を同時に併用することも可能になっている。

フレッツの契約内容を確認する

まずは、自分が使っているフレッツの契約内容を確認する。 これは、前述した通りフレッツの契約内容によって ISP の対応状況が異なるため。

ちなみに、サービス名などが途中で変わっていたというケースもあるみたい。 今回のケースは正にそれで、いつの間にか B フレッツマンションタイプからフレッツ光ネクストマンションタイプになっていた。 念を入れるならサポートセンターに問い合わせると今の契約内容を書いた紙が送ってもらえる。

契約内容が分かったら、フレッツの公式サイトで自分の使っている ISP がそれに対応しているかを確認しよう。

flets.com

また、上記の表に載っていなくても実は対応していた、ということもあるみたい。 なので、あきらめずに ISP のサポートセンターに問い合わせてみるのも手だと思う。

ISP に IPv6/IPoE の利用を申し込む

続いて、対応していることが分かったら ISP に IPoE を使いたい旨を伝えて申し込む。 これはおそらくサポートセンターに電話することになると思う。

同時に、ISP が契約している VNE についても確認しておこう。 前述した通り VNE によって用意するルータが異なってくる。 JPNE なら MAP-E 対応のものになるし、マルチフィードなら DS-Lite 対応のものになる。 また、前述した通り契約している VNE を公式には非公開にしている ISP もある。 なので、ウェブサイトであらかじめ下調べするのが手っ取り早いだろう。

ちなみに、自分の場合はここで一度「使っている回線がフレッツ光ネクストではないので無理」と断られた。 ただ、結論としては ISP に登録されている内容がおかしくなっていただけだった。 前述したフレッツの契約内容を調べた上で再度問い合わせたところ受け付けてもらえた。

フレッツに v6 オプションを申し込む

続いて、フレッツ側でも v6 オプションというサービスを有効にする。 これにはサポートセンターに電話する以外にも、サービス情報サイトという Web サービス経由を使って自分で設定することもできる。 また、ISP によってはユーザの同意を得た上で代理で申し込んでくれる場合もあるようだ。

flets.com

ちなみに、NTT のサポートセンターからは一度「VDSL 配線方式なので IPoE は使えない」と言われた。 しかし、結論としては VDSL 配線方式だろうと光配線方式だろうと関係なく IPoE は使える。 ここは、調べても情報が出てこないので、あやうく工事をするところだった。

IPv6/IPoE + DS-Lite 対応ルータを用意する

ここまでで回線のレベルでは IPv6/IPoE が使えるようになる。 続いては、それを実際に使ってみるフェーズ。 今回のケースでは ISP が契約している VNE がマルチフィードだった。

そのため、DS-Lite に対応したルータを用意する必要がある。 これについては、マルチフィードで接続性を検証した機種を公式サイトで公開している。

transix DS-Lite 接続確認機種情報

ここから選ぶなら、入手性や設定の敷居という面では Buffalo のルータを選ぶのが良さそう。 もし、腕に覚えがあるようなら YAMAHA も業務用ルータとしては意外と安く買えるから良いのかも。 上記の検証で使われたそのものズバリな機種は古くて既に入手がしづらくなっているみたいだけど、後継機種はこれかな。

また、Buffalo のウェブサイトを調べたところ、次のようなページを見つけた。 どうやら、ここに載っている機種であれば MAP-E と DS-Lite の両方に対応しているらしい。

buffalo.jp

とはいえ DS-Lite の機構は極めてシンプルで、CPE 側に関しては IPv6 の SLAAC と IPv4 over IPv6 トンネルの機能さえあれば動作する。 なので、ちょっとがんばれば Linux なんかを使って自作することもできる。 例えば次のサイトでは Raspberry Pi を使って DS-Lite のルータを作る方法を紹介してる。

techlog.iij.ad.jp

つまり、自分で作ればタダみたいなもんやぞ! いや、筐体の費用がかかるので、タダではないです。

それと、買うにしても作るにしてもファイアーウォールはしっかりかけよう。 IPv4/PPPoE を使っていた頃とは、使うべきルールが異なっているはずなので。 IPv6 をブリッジしてデュアルスタックにするなら、特に気をつけた方が良い。

使う

あとは DS-Lite を設定したルータでインターネットへの疎通があることを確認して使うだけ。

例えば、以下は Google に traceroute してみたところ。

$ traceroute www.google.co.jp
traceroute to www.google.co.jp (216.58.197.195), 64 hops max, 52 byte packets
 1  172.16.0.1 (172.16.0.1)  1.437 ms  1.103 ms  1.015 ms
 2  ike-gw00.transix.jp (14.0.9.66)  6.579 ms  6.584 ms  6.410 ms
 3  ike-bbrt10.transix.jp (14.0.9.65)  6.456 ms  6.336 ms  6.506 ms
 4  bek-bbrt10.transix.jp (14.0.8.66)  7.005 ms
    bek-bbrt10.transix.jp (14.0.8.94)  6.952 ms
    bek-bbrt10.transix.jp (14.0.8.78)  6.999 ms
 5  210.173.176.243 (210.173.176.243)  7.474 ms  7.551 ms  8.019 ms
 6  108.170.242.193 (108.170.242.193)  7.615 ms  9.450 ms  8.119 ms
 7  72.14.233.221 (72.14.233.221)  7.687 ms  7.825 ms  7.665 ms
 8  nrt13s48-in-f195.1e100.net (216.58.197.195)  7.958 ms  7.940 ms  7.865 ms

ちゃんと Transix を経由していることが経路の逆引き結果から分かる。

DS-Lite を使うことで、以前は混雑しているときの RTT が 35ms 以上あったところが 15ms 未満まで短縮された。 それに伴って TCP のスループットも、ひどいときは 1Mbps 前後だったのが常時 50Mbps 以上出るようになった。

(2017/4/9 追記)

一つ注意点としては DS-Lite は NAPT が CGN になっているのでポート開放がユーザレベルでは制御できない。 家でサーバを立てている場合なんかには、それらを IPv4/PPPoE の方にルーティングしてやる必要がある。 これが v6 プラスの MAP-E なら NAPT が CPE にあるので原理上は開けられるはずなんだけど。 とはいえ、その場合も一つの IPv4 グローバルアドレスを複数の CPE で共有する関係で使えるポートレンジは限られるはず。

もう一つ懸念しているのは CGN の NAPT セッション数が足りなくなったりしないかという点。 同じく DS-Lite を試している他のブログでそれらしき事象が起こったらしい言及があった。 まだ遭遇したことはないんだけど、使うアプリケーションや時間帯によっては影響を受けるかもしれない。

まとめ

  • フレッツが遅い原因は PPPoE の終端装置がボトルネックになっているせい
  • IPv6/IPoE を使うとボトルネックになっている PPPoE の終端装置を迂回できる
  • ただし、使っているフレッツの契約と ISP の組み合わせで使えるかは異なる
  • また、そのままデュアルスタックで使うと IPv4 の通信が遅いままで残ってしまう
  • その解決策として IPv6/IPoE を使って IPv4 をトランスポートする技術がある
  • 採用している技術は IPv6/IPoE を提供している VNE ごとに異なる
  • JPNE なら MAP-E を、MF なら DS-Lite を使っている
  • 使うには ISP への申込みとフレッツで v6 オプションの有効化がいる
  • ISP によって契約している VNE が異なる
  • ISP が契約している VNE を非公開としてるなら、上記も非公式サポートとなる
  • DS-Lite 対応ルータは MF が公開している
  • とはいえ、ぶっちゃけ IP-IP トンネルさえあれば何とかなる
  • IPoE 関連でサポートセンターの返答は間違っていることが多い

Ω<つまり、これは PPPoE の終端装置をあえて輻輳させることで IPv6 への移行を促す陰謀だったんだよ!

ΩΩΩ<な、なんだってー?!

おわり。

Ubuntu 16.04 LTS で NVIDIA Docker を使ってみる

Ubuntu16.04LTS 機械学習 TensorFlow Keras Docker

以前、このブログで Keras/TensorFlow の学習を GPU (CUDA) で高速化する記事を書いた。 このときは、それぞれの環境の分離には Python の virtualenv を使っていた。

blog.amedama.jp

今回は、別の選択肢として NVIDIA Docker を使う方法を試してみる。 NVIDIA Docker というのは NVIDIA が公式で出している Docker から CUDA を使えるようにするユーティリティ群と Docker イメージ。 このやり方だと Docker ホストには NVIDIA Driver さえ入っていれば動作する。 そして、CUDA Toolkit と cuDNN は Docker コンテナの中に配備される。 ここら辺の概念図は GitHub の公式ページが分かりやすい。

github.com

この方法なら、それぞれのコンテナ毎に異なるバージョンの CUDA Toolkit と cuDNN を組み合わせて使うことができる。 例えば、あるコンテナでは CUDA Toolkit 7.5 + cuDNN 4.0 を、別のコンテナでは CUDA Toolkit 8.0 + cuDNN 5.1 を、といった具合に。

今回使った環境は次の通り。

$ cat /etc/lsb-release
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=16.04
DISTRIB_CODENAME=xenial
DISTRIB_DESCRIPTION="Ubuntu 16.04.2 LTS"
$ uname -r
4.4.0-71-generic

GPU については GTX 1050 Ti を使っている。

$ lspci | grep -i nvidia
01:00.0 VGA compatible controller: NVIDIA Corporation Device 1c82 (rev a1)
01:00.1 Audio device: NVIDIA Corporation Device 0fb9 (rev a1)

NVIDIA Driver をインストールする

まずは NVIDIA Driver をインストールする。 ドライバを単体で入れても良いんだけど、ついでだから CUDA Toolkit ごと入れてしまう。 次の公式ページからインストール用のファイルを取得する。

developer.nvidia.com

今回はネットワークインストール用の deb ファイルを使った。 ボタンを Linux > x86_64 > Ubuntu > 16.04 > deb (network) のように操作する。

得られたリンクから deb ファイルを取得する。

$ wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb

インストールしてからリポジトリの内容を更新する。

$ sudo dpkg -i cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
$ sudo apt-get update

これで CUDA Toolkit がインストールできる。 依存パッケージとして、ついでに NVIDIA Driver も入る。

$ sudo apt-get install cuda

インストールが終わったら、一旦再起動しておこう。

$ sudo shutdown -r now

Docker エンジンをインストールする

続いて NVIDIA Docker の動作に必要な Docker エンジンをインストールする。

まずは、必要な依存パッケージをインストールする。

$ sudo apt-get install apt-transport-https ca-certificates curl software-properties-common

続いて APT の認証鍵をインストールする。

$ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -

Docker エンジンをインストールするための公式リポジトリを登録する。

$ sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable"

リポジトリの内容を更新する。

$ sudo apt-get update

Docker エンジン (Community Edition) をインストールする。

$ sudo apt-get -y install docker-ce

これで、自動的に Docker のサービスが起動する。 docker version コマンドを実行してエラーにならなければ上手くいっている。

$ sudo docker version
Client:
 Version:      17.03.1-ce
 API version:  1.27
 Go version:   go1.7.5
 Git commit:   c6d412e
 Built:        Mon Mar 27 17:14:09 2017
 OS/Arch:      linux/amd64

Server:
 Version:      17.03.1-ce
 API version:  1.27 (minimum version 1.12)
 Go version:   go1.7.5
 Git commit:   c6d412e
 Built:        Mon Mar 27 17:14:09 2017
 OS/Arch:      linux/amd64
 Experimental: false

NVIDIA Docker をインストールする

続いて、肝心の NVIDIA Docker をインストールする。 今のところ、配布パッケージはローカルインストールの deb ファイルしかないようだ。 もしパッケージが更新されても自動更新はされないはずなので、また入れ直す必要があるね。

github.com

ローカルインストール用の deb ファイルを取得する。

$ wget https://github.com/NVIDIA/nvidia-docker/releases/download/v1.0.1/nvidia-docker_1.0.1-1_amd64.deb

取得したファイルをインストールする。

$ sudo dpkg -i nvidia-docker_1.0.1-1_amd64.deb

これで、自動的に nvidia-docker サービスが起動する。

$ systemctl list-units --type=service | grep -i nvidia-docker
nvidia-docker.service                                 loaded active running NVIDIA Docker plugin

まずは、CUDA Toolkit 8.0 + cuDNN 5 がインストールされたコンテナをダウンロードしてこよう。

$ sudo docker pull nvidia/cuda:8.0-cudnn5-runtime

終わったら、動作確認として先ほどのコンテナで nvidia-smi コマンドを実行してみよう。 次のように、エラーにならず結果が出力されれば上手くいっている。

$ sudo nvidia-docker run --rm nvidia/cuda nvidia-smi
Mon Apr  3 12:49:17 2017
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.39                 Driver Version: 375.39                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 105...  Off  | 0000:01:00.0      On |                  N/A |
| 35%   27C    P8    35W /  75W |     61MiB /  4038MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID  Type  Process name                               Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

これで Docker コンテナから GPU を使えるようになった。

カスタムイメージをビルドして使ってみる

先ほどは公式が配布しているバニラなイメージを使って動作確認をしてみた。 続いてはカスタムイメージをビルドして実行してみることにしよう。 サンプルとしては Keras と GPU 版 TensorFlow をインストールしてみることにした。

まずは適当に Dockerfile を書く。

$ cat << 'EOF' > Dockerfile
FROM nvidia/cuda:8.0-cudnn5-runtime

LABEL maintainer "example@example.jp"

RUN apt-get update
RUN apt-get -y install python3-pip curl
RUN pip3 install keras tensorflow-gpu
EOF

適当な名前でイメージをビルドする。

$ sudo docker build -t test/myimage .

ビルドしたイメージをインタラクティブモードで起動する。

$ sudo nvidia-docker run --rm -i -t test/myimage /bin/bash

Keras の CNN サンプルをダウンロードしてくる。

# curl -O https://raw.githubusercontent.com/fchollet/keras/master/examples/mnist_cnn.py
# echo 'K.clear_session()' >> mnist_cnn.py

そして実行する。 MNIST データセットのダウンロードとか GPU デバイスのスピンアップに時間がかかるので、これもイメージのビルドに含めた方が良かったかも。

# python3 mnist_cnn.py
Using TensorFlow backend.
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcublas.so.8.0 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcudnn.so.5 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcufft.so.8.0 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcurand.so.8.0 locally
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:910] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
I tensorflow/core/common_runtime/gpu/gpu_device.cc:885] Found device 0 with properties:
name: GeForce GTX 1050 Ti
major: 6 minor: 1 memoryClockRate (GHz) 1.455
pciBusID 0000:01:00.0
Total memory: 3.94GiB
Free memory: 3.84GiB
I tensorflow/core/common_runtime/gpu/gpu_device.cc:906] DMA: 0
I tensorflow/core/common_runtime/gpu/gpu_device.cc:916] 0:   Y
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1050 Ti, pci bus id: 0000:01:00.0)
60000/60000 [==============================] - 106s - loss: 0.3181 - acc: 0.9041 - val_loss: 0.0792 - val_acc: 0.9739
Epoch 2/12
60000/60000 [==============================] - 9s - loss: 0.1105 - acc: 0.9674 - val_loss: 0.0572 - val_acc: 0.9811
Epoch 3/12
60000/60000 [==============================] - 9s - loss: 0.0851 - acc: 0.9743 - val_loss: 0.0440 - val_acc: 0.9858
Epoch 4/12
60000/60000 [==============================] - 9s - loss: 0.0700 - acc: 0.9793 - val_loss: 0.0387 - val_acc: 0.9869
Epoch 5/12
60000/60000 [==============================] - 9s - loss: 0.0610 - acc: 0.9818 - val_loss: 0.0380 - val_acc: 0.9873
Epoch 6/12
60000/60000 [==============================] - 9s - loss: 0.0554 - acc: 0.9833 - val_loss: 0.0324 - val_acc: 0.9901
Epoch 7/12
60000/60000 [==============================] - 9s - loss: 0.0512 - acc: 0.9849 - val_loss: 0.0311 - val_acc: 0.9892
Epoch 8/12
60000/60000 [==============================] - 9s - loss: 0.0475 - acc: 0.9861 - val_loss: 0.0331 - val_acc: 0.9894
Epoch 9/12
60000/60000 [==============================] - 9s - loss: 0.0432 - acc: 0.9871 - val_loss: 0.0304 - val_acc: 0.9902
Epoch 10/12
60000/60000 [==============================] - 9s - loss: 0.0421 - acc: 0.9874 - val_loss: 0.0313 - val_acc: 0.9892
Epoch 11/12
60000/60000 [==============================] - 9s - loss: 0.0400 - acc: 0.9880 - val_loss: 0.0305 - val_acc: 0.9902
Epoch 12/12
60000/60000 [==============================] - 9s - loss: 0.0377 - acc: 0.9887 - val_loss: 0.0290 - val_acc: 0.9910
Test loss: 0.0289570764096
Test accuracy: 0.991

良い感じ。

Python: scikit-learn で主成分分析 (PCA) してみる

Mac OS X Python 統計 機械学習 scikit-learn

主成分分析 (PCA) は、主にデータ分析や統計の世界で使われる道具の一つ。 データセットに含まれる次元が多いと、データ分析をするにせよ機械学習をするにせよ分かりにくさが増える。 そんなとき、主成分分析を使えば取り扱う必要のある次元を圧縮 (削減) できる。 ただし、ここでいう圧縮というのは非可逆なもので、いくらか失われる情報は出てくる。 今回は、そんな主成分分析を Python の scikit-learn というライブラリを使って試してみることにした。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

下準備

あらかじめ、今回使う Python のパッケージを pip でインストールしておく。

$ pip install matplotlib scipy scikit-learn

主成分分析の考え方

前述した通り、主成分分析はデータセットの次元を圧縮 (削減) するのに用いる。 ただし、実は元のデータセットと分析結果で次元数を変えないようにすることもできる。 それじゃあ圧縮できていないじゃないかという話になるんだけど、実は分析結果では次元ごとの性質が異なっている。 これは、例えるなら「すごく重要な次元・それなりに重要な次元・あんまり重要じゃない次元」と分かれているような感じ。 そして、その中から重要な次元をいくつかピックアップして使えば、次元の数が減るというわけ。 もちろん、そのとき選ばれなかった「あんまり重要じゃない次元」に含まれていた情報は失われてしまう。

では、主成分分析ではどのような基準で次元の重要さを決めるのだろうか。 これは、データの分散が大きな次元ほど、より多くの情報を含んでいると考える。 分散というのは、データのバラつきの大きさを表す統計量なので、ようするに値がバラけている方が価値が大きいと捉える。 分散が小さいというのは、ようするにどの値も似たり寄ったりで差異を見出すのが難しいということ。 それに対し、分散が大きければ値ごとの違いも見つけやすくなる。

例えば、次のような x 次元と y 次元から成る、二次元のデータを考えてみよう。 この中には (1, 2), (2, 4), (3, 6) という三点の要素が含まれる。

f:id:momijiame:20170402110001p:plain

ここで x 次元の標本分散は  \frac{2}{3} で、y 次元の標本分散は  \frac{8}{3} になる。 主成分分析の考え方でいくと y 次元の方が分散が大きいので、より重要といえる。 ただ、上記のデータは二つの次元が相関しているようだ。 相関しているということは、似たような情報を含む次元が二つある、とも捉えることができる。

では、上記で相関に沿って新しい次元を作ってみたら、どうなるだろうか?

f:id:momijiame:20170402121107p:plain

値の間隔はピタゴラスの定理から  \sqrt{5} となることが分かる。 これは x 次元の間隔である 1 や y 次元の間隔である 2 よりも大きい。

f:id:momijiame:20170402121423p:plain

間隔が大きいということは分散も大きくなることが分かる。

続いては、先ほどの相関に沿って作った次元とは直交する軸でさらに新しい次元を作ってみよう。

f:id:momijiame:20170402121737p:plain

今度は、新しい次元からそれぞれの要素を見てみよう。 このとき、全ての要素は同じ場所にいるので間隔は 0 になっている。 つまり、分散も 0 なので、この次元には全然情報が含まれていないことになる。

上記の作業によって、情報がたくさん含まれる次元と、全く含まれない次元に分けることができた。 あとは、最初に作った情報がたくさん含まれる次元だけを使えば、二次元を一次元に圧縮できたことになる。 実は、これこそ正に主成分分析でしている作業を表している。

実際に試してみる

やっていることの概要は分かったので、次は実際にその通りになるのか試してみよう。 データセットとしては、まずは先ほどの三点をそのまま使ってみる。

次のサンプルコードでは、相関した三点のデータを主成分分析している。 そして、元のデータと分析結果を散布図にした。 また、分析結果の各次元の寄与率というものも出力している。 scikit-learn では

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA


def main():
    # y = 2x
    features = np.array([[1, 2], [2, 4], [3, 6]])

    # グラフ描画サイズを設定する
    plt.figure(figsize=(12, 4))

    # 元データをプロットする
    plt.subplot(1, 2, 1)
    plt.scatter(features[:, 0], features[:, 1])
    plt.title('origin')
    plt.xlabel('x')
    plt.ylabel('y')

    # 主成分分析する
    pca = PCA()
    pca.fit(features)

    # 分析結果を元にデータセットを主成分に変換する
    transformed = pca.fit_transform(features)

    # 主成分をプロットする
    plt.subplot(1, 2, 2)
    plt.scatter(transformed[:, 0], transformed[:, 1])
    plt.title('principal component')
    plt.xlabel('pc1')
    plt.ylabel('pc2')

    # 主成分の次元ごとの寄与率を出力する
    print(pca.explained_variance_ratio_)

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

それでは、上記のサンプルコードを実行してみよう。 出力されたリストは、分析結果の各次元の寄与率を表している。

$ python pca.py 
[ 1.  0.]

寄与率というのは、前述した「各次元の重要度」を表したもの。 その次元に元のデータからどれだけの割合で情報が含まれているかで、全てを足すと 1 になるように作られている。 つまり、主成分分析をした結果から全ての次元を使えば、元のデータセットから情報の損失は起こらない。 ただし、それだと次元も圧縮できないことになる。

先ほどの出力結果を見ると、最初の次元に寄与率が全て集中している。 つまり、最初の次元だけに全ての情報が含まれていることになる。 これは、先ほど主成分分析の概要を図示したときに得られた結論と一致している。

では、上記をグラフでも確認してみよう。 次のグラフは、主成分分析の前後を散布図で比べたもの。 左が元データで、右が分析結果となっている。

f:id:momijiame:20170402123229p:plain

見て分かる通り、先ほど図示した内容と一致している。 ちなみに、主成分分析では分析結果として得られた次元のことを第 n 主成分と呼ぶ。 例えば、最初に作った次元なら第一主成分、次に作った次元なら第二主成分という風になる。 今回の例では第一主成分に必要な情報が全て集中した。

アイリスデータセットを主成分分析してみる

次はもうちょっとだけそれっぽいデータを使ってみることにする。 みんな大好きアイリスデータセットは、あやめという花の特徴量と品種を含んでいる。 この特徴量は四次元なので、別々のグラフに分けたりしないと本来は可視化できない。 今回は、主成分分析を使って二次元に圧縮して可視化してみることにしよう。

次のサンプルコードではアイリスデータセットの次元を主成分分析している。 そして、分析結果から第二主成分までを取り出して散布図に可視化した。 また、同時に寄与率と累積寄与率を出力するようにした。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn import datasets


def main():
    dataset = datasets.load_iris()

    features = dataset.data
    targets = dataset.target

    # 主成分分析する
    pca = PCA(n_components=2)
    pca.fit(features)

    # 分析結果を元にデータセットを主成分に変換する
    transformed = pca.fit_transform(features)

    # 主成分をプロットする
    for label in np.unique(targets):
        plt.scatter(transformed[targets == label, 0],
                    transformed[targets == label, 1])
    plt.title('principal component')
    plt.xlabel('pc1')
    plt.ylabel('pc2')

    # 主成分の寄与率を出力する
    print('各次元の寄与率: {0}'.format(pca.explained_variance_ratio_))
    print('累積寄与率: {0}'.format(sum(pca.explained_variance_ratio_)))

    # グラフを表示する
    plt.show()


if __name__ == '__main__':
    main()

それでは、上記を実行してみよう。 コンソールには寄与率と累積寄与率が表示される。

$ python pcairis.py 
各次元の寄与率: [ 0.92461621  0.05301557]
累積寄与率: 0.9776317750248034

寄与率は先ほど説明した通りで、ここでの累積寄与率は使うことにした次元の寄与率を足したもの。 ようするに、今回の場合なら第一主成分と第二主成分の寄与率を足したものになっている。 累積寄与率は約 0.97 で、ようするに第二主成分までで元のデータの 97% が表現できていることが分かる。

同時に、次のような散布図が表示される。 これは、第一主成分と第二主成分を x 軸と y 軸に取って散布図にしたもの。 点の色の違いは品種を表している。

f:id:momijiame:20170402125421p:plain

本来なら四次元の特徴量で複数の散布図になるところを、主成分分析を使うことで一つの散布図にできた。

まとめ

今回は Python の scikit-learn を使って主成分分析について学んだ。 データセットに含まれる次元が多いと、データ分析なら分かりにくいし、機械学習なら計算量が増えてしまう。 そんなとき主成分分析を使えば、重要さが異なる新たな次元を含んだデータが分析結果として得られる。 その中から、重要なものをいくつかピックアップして使えば、データの損失を最小限に抑えて次元を減らすことができる。

参考文献

実践 機械学習システム

実践 機械学習システム

Python: ソケットプログラミングのアーキテクチャパターン

Mac OS X Python

今回はソケットプログラミングについて。 ソケットというのは Unix 系のシステムでネットワークを扱うとしたら、ほぼ必ずといっていいほど使われているもの。 ホスト間の通信やホスト内での IPC など、ネットワークを抽象化したインターフェースになっている。

そんな幅広く使われているソケットだけど、取り扱うときには色々なアーキテクチャパターンが考えられる。 また、比較的低レイヤーな部分なので、効率的に扱うためにはシステムコールなどの、割りと OS レベルに近い知識も必要になってくる。 ここらへんの話は、体系的に語られているドキュメントが少ないし、あっても鈍器のような本だったりする。 そこで、今回はそれらについてざっくりと見ていくことにした。

尚、今回はプログラミング言語として Python を使うけど、何もこれは特定の言語に限った話ではない。 どんな言語を使うにしても、あるいは表面上は抽象化されたインターフェースで隠蔽されていても、内部的にはソケットが使われている。 例えば Java サーブレットや Ruby on Rails で Web アプリケーションを書くにしても、それが動くサーバの通信部分はソケットで書かれていることだろう。

動作確認に使った環境は次の通り。

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.12.4
BuildVersion:   16E195
$ python --version
Python 3.6.1

ブロッキングとノンブロッキングについて

まず、ソケットを扱うには大きく分けて「ブロッキングで使うか・ノンブロッキングで使うか」を選ぶことになる。 その中でも基本となる使い方はブロッキングで、こちらの方が逐次的なプログラミングモデルとなりやすいので理解も早い。 ではノンブロッキングにはどんなメリットがあるかというと、こちらは通信相手が増えたときのパフォーマンス面 (スケーラビリティ) で優れている。

このエントリでは、ソケットの扱い方をブロッキング・ノンブロッキングと分けた上で、それぞれにどんなアーキテクチャパターンが考えられるか見ていく。 しかし、その前にまずは事前知識としてソケットにおけるブロッキング・ノンブロッキングという概念自体の説明から入ろう。

まず、ソケットというオブジェクトに対してはデータの読み込みや書き込みを指示できる。 読み込まれるデータは通信相手から送られてきたもので、書き込まれたデータは通信相手に送り届けられる。 しかし、それらのデータを読み書きする指示は即座に完了するわけではない。 具体的には、ソケットには読み書きができる状態とできない状態があるためだ。 読み書きができないというのは、わんこそばで例えると口の中でもぐもぐしている最中で、次のそばを口に入れられない状態を指す。

では、もしも読み込みや書き込みができない状態にあるソケットに対して、その指示を出したらどう振る舞うのだろうか。 ブロッキング・ノンブロッキングの違いというのは、正にこの「どう振る舞うか」の違いを指す。 ブロッキングというのは、読み書きができる状態になるまで、じっとそのまま待つことを意味している。 それに対して、ノンブロッキングは読み書きができない状態にあるときエラーを出してすぐに処理を終了する。

これで、ソケットのブロッキング・ノンブロッキングの違いについて説明できた。

ソケットをブロッキングで扱う場合

さて、前フリが長くなったけど、ここからは具体的なアーキテクチャパターンを見ていくことにしよう。 初めは、基本的な使い方であるソケットをブロッキングで扱う場合から。

今回、サンプルコードとして題材にするのはエコーサーバにした。 エコーサーバというのは、クライアントから送られてきたデータを、そのままオウム返しでクライアントに送り返すサーバのことをいう。

実装については IPv4 のループバックアドレスを使って TCP:37564 ポートでクライアントからの接続を待ち受けるようにした。 ループバックアドレスとは何か、みたいな TCP/IP 的な概念についての説明は省く。 これは、今回の主題として扱うアーキテクチャパターンという範疇からは、ちょっと外れるため。

あと、クライアントサイドについても自分で書いても良いんだけど、今回はありものを使うことにした。 ここでは netcat というツールを使うことにしよう。 netcatHomebrew を使ってインストールする。

$ brew install netcat

Homebrew が入っていないときは入れる感じで。

シングルスレッド

まずは、ソケットがブロッキングで、それをシングルスレッドで扱う場合から考えてみよう。 これが、最もシンプルなパターンといえるはず。

早速だけどサンプルコードを以下に示す。 それぞれの処理の内容はコメントで補足している。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    # IPv4/TCP のソケットを用意する
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # 'Address already in use' の回避策 (必須ではない)
    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    # 待ち受けるアドレスとポートを指定する
    # もし任意のアドレスで Listen したいときは '' を使う
    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    # クライアントをいくつまでキューイングするか
    serversocket.listen(128)

    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            # クライアントソケットから指定したバッファバイト数だけデータを受け取る
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except OSError:
                break

            # 受信したデータの長さが 0 ならクライアントからの切断を表す
            if len(message) == 0:
                break

            # 受信したデータをそのまま送り返す (エコー)
            sent_message = message
            while True:
                # 送信できたバイト数が返ってくる
                sent_len = clientsocket.send(sent_message)
                # 全て送れたら完了
                if sent_len == len(sent_message):
                    break
                # 送れなかった分をもう一度送る
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        # 後始末
        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

サーバにおけるソケットプログラミングの基本的な流れは次の通り。

  • ソケットを作る (socket)
  • 待ち受けるアドレスとポートを指定する (bind)
  • 接続キューの長さを指定して接続を待ち受ける (listen)
  • 接続してきたクライアントからソケットを取得する (accept)
  • 取得したクライアントのソケットに対して読み書きする (send/recv)

このパターンでは、上記の一連の処理を一つのスレッドでこなしていく。

それではサンプルコードを実行してみよう。 これで、エコーサーバが起動する。 とはいえ、クライアントが接続しない限り特に何も表示されることはない。

$ python singlethread.py

続いて、別のターミナルを開いて netcat を実行しよう。 次のようにすると、先ほど起動したエコーサーバに接続できる。

$ nc localhost 37564

すると、エコーサーバを起動したターミナルに、クライアントからの接続を表す表示が出るはず。

$ python singlethread.py
New client: 127.0.0.1:63917

さらに netcat のターミナルで文字列を入力して Enter すると、同じ内容がまた表示される。 これは、送信した内容がエコーサーバからオウム返しで返ってきたことを意味する。

$ nc localhost 37564
hogehoge
hogehoge

エコーサーバのターミナルを見ると、送受信した内容が表示されている。

$ python singlethread.py
New client: 127.0.0.1:63917
Recv: b'hogehoge\n'
Send: b'hogehoge\n'

netcat は Ctrl キーと C キーを一緒に押すことで終了できる。 これでサーバとの接続も切断される。

$ nc localhost 37564
hogehoge
hogehoge
^C

サーバの方にもクライアントとの接続が切れた旨が表示された。

$ python singlethread.py
New client: 127.0.0.1:63917
Recv: b'hogehoge\n'
Send: b'hogehoge\n'
Recv: b''
Bye-Bye: 127.0.0.1:63917

ここまで見た限り、このパターンで何の問題も無いように見える。 しかし、クライアントを二つにすると問題点が分かってくる。

サーバを一旦終了して、もう一度起動し直そう。 ちなみにサーバについても Ctrl-C で終了できる。

$ python singlethread.py

そして、改めて別のターミナルから netcat で接続する。

$ nc localhost 37564

クライアントが一つなら、サーバは接続を正常に受け付ける。

$ python singlethread.py
New client: 127.0.0.1:49746

では、さらにもう一つターミナルを開いて netcat で接続してみると、どうだろうか?

$ nc localhost 37564

今度は、サーバ側に接続を受け付けたメッセージが表示されない。

$ python singlethread.py
New client: 127.0.0.1:49746

そう、ソケットをブロッキングかつシングルスレッドで扱う場合、二つ以上のクライアントを同時に上手くさばくことができない。 なぜなら、唯一のスレッドは最初のクライアントからデータを読み書きする仕事に従事しているからだ。

先ほどのサンプルコードでいえば以下、クライアントからの新たなデータの到来を待ち続けて (ブロックして) いることだろう。

message = clientsocket.recv(1024)

唯一のスレッドが一つのクライアントにかかりきりなので、以下の別のクライアントからの接続を受け付ける処理は実行されない。 クライアントからの接続は、ソケットの接続キューに積まれたまま放置プレイを食らう。

clientsocket, (client_address, client_port) = serversocket.accept()

今かかりきりになっている相手との通信が終わるまで、別のクライアントは受け付けることができないというわけ。

マルチスレッド

ソケットをブロッキングで扱うとき、シングルスレッドでは二つ以上のクライアントを上手くさばけないことが分かった。 そこで、次はクライアントを処理するスレッドを複数用意してマルチスレッドにしてみよう。

先ほどの内容に手を加えて、マルチスレッドにしたサンプルコードは次の通り。 先ほどと同じ処理についてはコメントを省いて、新たに追加したり変更したところにコメントを書いている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import threading


def client_handler(clientsocket, client_address, client_port):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        try:
            message = clientsocket.recv(1024)
            print('Recv: {0} from {1}:{2}'.format(message,
                                                  client_address,
                                                  client_port))
        except OSError:
            break

        if len(message) == 0:
            break

        sent_message = message
        while True:
            sent_len = clientsocket.send(sent_message)
            if sent_len == len(sent_message):
                break
            sent_message = sent_message[sent_len:]
        print('Send: {0} to {1}:{2}'.format(message,
                                            client_address,
                                            client_port))

    clientsocket.close()
    print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        # 接続してきたクライアントを処理するスレッドを用意する
        client_thread = threading.Thread(target=client_handler,
                                         args=(clientsocket,
                                               client_address,
                                               client_port))
        # 親 (メイン) スレッドが死んだら子も道連れにする
        client_thread.daemon = True
        # スレッドを起動する
        client_thread.start()


if __name__ == '__main__':
    main()

先ほどとの違いは、クライアントとの接続に対してスレッドが一対一で生成されるところだ。 プログラムが起動された直後に生成されるメインスレッドは、クライアントからの接続を受け付ける仕事だけに専念している。 実際に受け付けたクライアントとの接続の処理は、新たに生成した子スレッドに任せるわけだ。

では、上記サンプルコードの動作を確認してみよう。 まずはエコーサーバを起動する。

$ python multithread.py

そして、二つのターミナルからエコーサーバに接続してみよう。

$ nc localhost 37564

すると、今度は二つのクライアントから接続を受け付けた旨が表示された。

$ python multithread.py
New client: 127.0.0.1:51027
New client: 127.0.0.1:51028

それぞれのクライアントのターミナルで文字列を入力すると、ちゃんとエコーバックされるし上手く動いている。

$ python multithread.py
New client: 127.0.0.1:51027
New client: 127.0.0.1:51028
Recv: b'hogehoge\n' from 127.0.0.1:51027
Send: b'hogehoge\n' to 127.0.0.1:51027
Recv: b'hogehoge\n' from 127.0.0.1:51028
Send: b'hogehoge\n' to 127.0.0.1:51028

マルチスレッド (スレッドプール)

先ほどの例では、クライアントを処理する部分をマルチスレッド化することで、二つ以上のクライアントを同時にさばけるようになった。 しかし、実は先ほどのやり方ではクライアントの接続数がどんどん増えていくと問題になってくることがある。 それは、メモリの使用量とコンテキストスイッチにかかるコストの増加だ。

スレッドというのは、新たに作ろうとするとそれ用のコンテキストを必要とする。 この、コンテキストというのは各スレッドの状態を保持しておくために必要なメモリに他ならない。 スレッドあたりのコンテキストのサイズは状態や実装に依存するので、これくらいとはなかなか言いづらいものがある。 とはいえ、一つ一つが小さくてもクライアントの接続数が増えれば決してばかにできないサイズになってくる。

また、コンテキストスイッチというのは、CPU が処理しているスレッドを OS が途中で切り替える作業のことをいう。 まず、CPU というのは同時に処理できるスレッドの数が、あらかじめ製品ごとに決まっている。 例えば、今売られている Intel や AMD の x86-64 アーキテクチャの CPU を例に挙げてみよう。 この場合は、物理コアあたり 1 または 2 スレッドである場合が多い。 つまり、同時に処理できるスレッドには機械的な上限がある。 ちなみに、物理コアあたり同時 2 スレッドの製品については、OS からは論理コアが 2 つあるように扱われる。

にも関わらず、実のところ私たちは普段からそれよりも多くのスレッドを同時に起動して扱っている。 なぜそんなことができるかというと、CPU が実行するスレッドを、OS が途中で別のスレッドに入れ替えているためだ。 この入れ替えは、ごく短時間で行われているので、見かけ上はたくさんのスレッドが同時に実行できているかのように見える。

しかしながら、この入れ替え作業には短時間ながらもちろん時間がかかる。 そして、CPU で同時に処理できるスレッドの数に対して、OS が扱うスレッドの数が増えてくると、その頻度も上がる。 これによって、切り替え作業に要する時間が増えて、だんだんと CPU が非効率な使われ方をしてしまうことがある。

先ほどのサンプルコードでは、まさに上記の二つが問題となる。 なぜなら、生成するスレッドの数に上限を設けていないからだ。 上限がないと、クライアントの数に応じてどんどんスレッドが増え続ける。 結果として、メモリを消費すると共に CPU が非効率な使われ方をしてしまう。

スレッドが多すぎるとまずいという問題点が分かったところで、次はスレッドを生成する数に上限を設けてみよう。 具体的には、あらかじめスレッドを既定数だけ生成して、それらに仕事を割り振る形にする。 この手法は、一般にスレッドプールと呼ばれている。 スレッドプールの中にいる各スレッドは、ワーカースレッドと呼ばれる。

次のサンプルコードはスレッドプールを使った実装になっている。 生成されたワーカーがサーバソケットからの接続を奪い合う形になる。 これなら、あらかじめ決まった数を超えるスレッドは生成されないので、前述したような問題は発生しない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import threading


def worker_thread(serversocket):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカスレッド同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # サーバソケットを渡してワーカースレッドを起動する
    NUMBER_OF_THREADS = 10
    for _ in range(NUMBER_OF_THREADS):
        thread = threading.Thread(target=worker_thread, args=(serversocket, ))
        thread.daemon = True
        thread.start()

    while True:
        # メインスレッドは遊ばせておく (ハンドラを処理させても構わない)
        time.sleep(1)


if __name__ == '__main__':
    main()

ただし、上記にも注意点がある。 それは、あらかじめプールしたスレッド数を超えてクライアントをさばくことができない、という点だ。 プール数を超えた接続があったときは、他のクライアントとの接続が切れるまで、ソケットは処理されないままキューに積まれてしまう。

実行結果については、先ほどと変わらないので省略する。

ちなみに、蛇足だけど Mac OS X に関してはプロセスごとに生成できるスレッド数があらかじめ制限されているようだ。 例えば、次のようなサンプルコードを用意して、たくさんのスレッドを起動してみる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import threading
import time


def loop():
    """各スレッドは特に何もしない"""
    while True:
        time.sleep(1)


def main():
    # ネイティブスレッドをたくさん起動してみる
    for _ in range(10000):
        t = threading.Thread(target=loop)
        t.daemon = True
        t.start()
        # 動作中のスレッド数を出力する
        print(threading.active_count())


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると 2049 個目のスレッドを起動するところで例外になった。

$ python toomanythreads.py
...(省略)...
2046
2047
2048
Traceback (most recent call last):
  File "toomanythreads.py", line 25, in <module>
    main()
  File "toomanythreads.py", line 19, in main
    t.start()
  File "/Users/amedama/.pyenv/versions/3.6.1/lib/python3.6/threading.py", line 846, in start
    _start_new_thread(self._bootstrap, ())
RuntimeError: can't start new thread

このリミットは、どうやら次のカーネルパラメータでかかっているらしい。

$ sysctl -n kern.num_taskthreads
2048

Mac OS X においては、スレッドの生成数に上限を設けないと、メモリの枯渇などを待つことなくサーバが突然死することになる。

マルチプロセス (プロセスプール)

先ほどの例では、スレッドプールを使うことで同時に処理できるクライアントの数を増やしつつ、リソースの消費を抑えることができた。 しかしながら、実はここまでの例では、パフォーマンスを求める上で、まだ使い切れていないリソースが残っている。 それは、複数の CPU コアだ。

実は Python の標準的な処理系である CPython には、とある制限が存在している。 それは、一つのプロセスで同時に実行できるスレッドの数が一つだけ、というもの。 一般的に、これはグローバルインタプリタロック (Global Interpreter Lock, GIL) と呼ばれている。 この制限は、Python/C API で書かれた拡張モジュールを Python から扱いやすくするために存在する。

この GIL がある処理系では、CPU に複数の論理コアがあったとしても、同時に使われるのが一つだけに制限されてしまう。 つまり、先ほどの例では、マルチスレッドにしても実際に使われている CPU 論理コアは同時に一つだけだった。 ようするに、複数のスレッドを OS が一つの CPU 論理コアの上で切り替え (コンテキストスイッチ) ながら動作する。

ちなみに、コンピュータの処理には、大きく分けて入出力 (I/O) が主体になるものと計算 (CPU) が主体になるものがある。 CPU が主体となるのは、例えば科学計算のようなもの。 それに対して、今回の例であるエコーサーバのようなプログラムは、CPU の処理がほとんどない。 処理時間のほとんどを I/O の待ちに使っていることから、入出力が主体のプログラムといえる。

つまり、今回取り扱うエコーサーバは CPU の処理能力がボトルネックになりにくい。 ようするに、あえて CPU の能力を最大限引き出すようなコードにする必然性は薄い。 しかしながら、アーキテクチャパターンの紹介という意味では重要だと思う。 なので、その方法についても記述しておこう。

その方法というのは、具体的にはプログラムを複数のプロセスで動かす。 前述した通り GIL はプロセスあたりの同時実行スレッド数を一つに制限するというものだった。 なので、プロセスを複数立ち上げてしまえば、同時実行スレッド数をプログラム全体で見たときに増やすことができる。

次のサンプルコードでは、スレッドの代わりにプロセスを複数起動 (マルチプロセス) している。 Python でマルチプロセスを扱う方法としては、例えば標準ライブラリの multiprocessing モジュールがある。 起動するプロセスの数は CPU の論理コア数と同じにした。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import multiprocessing


def worker_process(serversocket):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカープロセス同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # プロセス数は CPU のコア数前後に合わせると良い
    NUMBER_OF_PROCESS = multiprocessing.cpu_count()
    # サーバソケットを渡してワーカープロセスを起動する
    for _ in range(NUMBER_OF_PROCESS):
        process = multiprocessing.Process(target=worker_process,
                                          args=(serversocket, ))
        # デーモンプロセスにする (親プロセスが死んだら子も道連れに死ぬ)
        process.daemon = True
        # プロセスを起動する
        process.start()

    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()

マルチプロセスを使うときの注意点についても見ていこう。 これは、マルチスレッドの場合とほとんど変わらない。 つまり、プロセスを作るにもコンテキストが必要であり、コンテキストスイッチが起こるということだ。 そのため、同時に起動するプロセス数は制限してやる必要がある。 しかも、必要なリソースの量はスレッドに比べるとずっと多い。 そのため、一般的には起動するプロセス数は CPU の論理コアの数前後が良いとされている。

また、マルチプロセス固有の問題としては、プロセス間での値の共有が挙げられる。 マルチスレッドであれば、同一プロセス内でメモリ空間を共有していた。 なので、例えばグローバル変数の値をスレッド間で情報を共有する手段にもできた。 それに対し、マルチプロセスではプロセス同士でメモリ空間は共有していない。 そのため、別の何らかの IPC を使って情報をやり取りしなければいけない。

尚、繰り返しになるけどマルチプロセスにする必要があるのは、あくまで GIL があることに由来している。 もし、これがない処理系やプログラミング言語を使うなら、単にマルチスレッドにするだけで大丈夫。 ちゃんと CPU のコアを使い切ってくれるはず。

マルチプロセス・マルチスレッド

先ほどの例では、プロセスを複数立ち上げることで CPU の能力を使い切れるようにした。 ただし、マルチプロセスではあるものの、それぞれのプロセスでは一つのスレッドしか動かしていなかった。 そこで、次は各プロセスの中をマルチスレッドにしてみよう。 これなら、マルチプロセスかつマルチスレッドになって CPU と I/O の両方を上手く使い切れるはず。

次のサンプルコードでは、各ワーカープロセスの中でさらにスレッドプールを動かしている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import multiprocessing
import threading


def worker_thread(serversocket):
    """クライアントとの接続を処理するハンドラ (スレッド)"""
    while True:
        # クライアントからの接続を待ち受ける (接続されるまでブロックする)
        # ワーカースレッド同士でクライアントからの接続を奪い合う
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {0} from {1}:{2}'.format(message,
                                                      client_address,
                                                      client_port))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {0} to {1}:{2}'.format(message,
                                                client_address,
                                                client_port))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def worker_process(serversocket):
    """クライアントとの接続を受け付けるハンドラ (プロセス)"""

    # サーバソケットを渡してワーカースレッドを起動する
    NUMBER_OF_THREADS = 10
    for _ in range(NUMBER_OF_THREADS):
        thread = threading.Thread(target=worker_thread, args=(serversocket, ))
        thread.start()
        # ここではワーカーをデーモンスレッドにする必要はない (死ぬときはプロセスごと逝くので)

    while True:
        # ワーカープロセスのメインスレッドは遊ばせておく
        time.sleep(1)


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    NUMBER_OF_PROCESSES = multiprocessing.cpu_count()
    for _ in range(NUMBER_OF_PROCESSES):
        process = multiprocessing.Process(target=worker_process,
                                          args=(serversocket, ))
        process.daemon = True
        process.start()

    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()

実行結果については、これまで変わらないので省略する。

ひとまず、ソケットをブロッキングで扱う場合のアーキテクチャパターンについては、これでおわり。

ソケットをノンブロッキングで扱う場合

続いては、ソケットをノンブロッキングで扱う場合について見ていこう。 前述した通り、ソケットをノンブロッキングで扱うと、読み書きなどを指示してもブロックが起きない。 その代わり、もし読み書きの準備ができていないときはその旨がエラーで返ってくる。

とりあえずノンブロッキングにしてみよう

最初に、ノンブロッキングなソケットをブロッキングっぽく扱ったときの挙動を確認しておこう。 具体的に、どんなことが起こるのだろうか?

次のサンプルコードは、最初に示したシングルスレッドのサーバに一行だけ手を加えている。 それは、サーバソケットを setblocking() メソッドでノンブロッキングモードにしているところだ。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # ソケットをノンブロッキングモードにする
    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                sent_len = clientsocket.send(sent_message)
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、すぐに例外が出て終了してしまう。

$ python nonblocking.py
Traceback (most recent call last):
  File "nonblocking.py", line 48, in <module>
    main()
  File "nonblocking.py", line 22, in main
    clientsocket, (client_address, client_port) = serversocket.accept()
  File "/Users/amedama/.pyenv/versions/3.6.1/lib/python3.6/socket.py", line 205, in accept
    fd, addr = self._accept()
BlockingIOError: [Errno 35] Resource temporarily unavailable

上記の BlockingIOError という例外は、まだ準備が整っていないにも関わらず指示が出されたときに上がる。 今回の場合だと、クライアントからの接続が到着していないのに accept() メソッドを呼び出している。 ブロッキングモードのソケットなら、そのまま到着するまで待ってくれていた。 それに対し、ノンブロッキングモードでは呼び出した時点で到着していないなら即座に例外となってしまう。 正に、これがブロッキングとノンブロッキングの挙動の違い。

準備が整うまで待つ (ビジーループ)

先ほどの例で分かるように、ソケットをノンブロッキングで使うとブロッキングとは使い勝手が異なっている。 具体的には、ソケットの準備が整うのを勝手に待ってくれるわけではないので、自分で意図的に待たなければいけない。

では、どのようにすれば待つことができるだろうか。 一つのやり方としては、エラーが出なくなるまで定期的に実行してみる方法が考えられる。 この、何度も自分から試しに行くやり方はポーリングと呼ばれる。 その中でも、それぞれの試行間隔を全く空けないものはビジーループという。

次のサンプルコードではノンブロッキングなソケットをビジーループで待ちながら処理している。 ただし、あらかじめ言っておくと、このやり方は間違っている。 ソケットをノンブロッキングで扱うとき、こんなソースコードは書いちゃいけない。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        try:
            clientsocket, (client_address, client_port) = serversocket.accept()
        except (BlockingIOError, socket.error):
            # まだソケットの準備が整っていない
            continue

        print('New client: {0}:{1}'.format(client_address, client_port))

        while True:
            try:
                message = clientsocket.recv(1024)
                print('Recv: {}'.format(message))
            except (BlockingIOError,  socket.error):
                # まだソケットの準備が整っていない
                continue
            except OSError:
                break

            if len(message) == 0:
                break

            sent_message = message
            while True:
                try:
                    sent_len = clientsocket.send(sent_message)
                except (BlockingIOError,  socket.error):
                    # まだソケットの準備が整っていない
                    continue
                if sent_len == len(sent_message):
                    break
                sent_message = sent_message[sent_len:]
            print('Send: {}'.format(message))

        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


if __name__ == '__main__':
    main()

上記のサンプルコードは一応動作するものの、複数のクライアントを処理することができない。 それに、ビジーループを使っているとプロセスの CPU 使用率が 100% になってしまう。 繰り返しになるけど、ソケットをノンブロッキングで扱うとき、こんなソースコードは書いちゃだめ。

準備が整うまで待つ (イベントループ)

ビジーループでは色々と難しいことが分かったところで、次は実用的に待つ方法を見ていこう。 これには、イベントループや I/O 多重化と呼ばれる手法というかシステムコールを用いる。 システムコールというのは OS のカーネルに備わっている API のことだ。 ユーザランドのプログラムは、このシステムコールを呼び出すことで OS の機能が利用できる。

システムコールの中には、ソケットの状態を監視して、変更されたときにそれを通知してくれるものがある。 より正確には、監視できるものはファイルやソケットに汎用的に割り当てられるファイルディスクリプタだ。

イベントループにはいくつかの種類があるものの、ここでは古典的な select(2) を使うやり方を見ていく。 次のサンプルコードは、エコーサーバを select(2) システムコールで実装したもの。 ただし、先に断っておくと、これは実装している機能の割にコード量が多いし、逐次的でないから読みにくいと思う。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import select


# 読み取りが可能になるまで待っているソケットと、可能になったときに呼び出されるハンドラ・引数の対応を持つ
read_waiters = {}
# 書き込みが可能になるまで待っているソケットと、可能になったときに呼び出されるハンドラ・引数の対応を持つ
write_waiters = {}
# 接続してきたクライアントとの接続情報を格納する
connections = {}


def accept_handler(serversocket):
    """サーバソケットが読み取り可能になったとき呼ばれるハンドラ"""
    # 準備ができているので、すぐに accept() できる
    clientsocket, (client_address, client_port) = serversocket.accept()

    # クライアントソケットもノンブロックモードにする
    clientsocket.setblocking(False)

    # 接続してきたクライアントの情報を出力する
    # ただし、厳密に言えば print() もブロッキング I/O なので避けるべき
    print('New client: {0}:{1}'.format(client_address, client_port))

    # ひとまずクライアントの一覧に入れておく
    connections[clientsocket.fileno()] = (clientsocket,
                                          client_address,
                                          client_port)

    # 次はクライアントのソケットが読み取り可能になるまで待つ
    read_waiters[clientsocket.fileno()] = (recv_handler,
                                           (clientsocket.fileno(), ))

    # 次のクライアントからの接続を待ち受ける
    read_waiters[serversocket.fileno()] = (accept_handler, (serversocket, ))


def recv_handler(fileno):
    """クライアントソケットが読み取り可能になったとき呼ばれるハンドラ"""
    def terminate():
        """クライアントとの接続が切れたときの後始末"""
        # クライアント一覧から取り除く
        del connections[clientsocket.fileno()]
        # ソケットを閉じる
        clientsocket.close()
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))

    # クライアントとの接続情報を取り出す
    clientsocket, client_address, client_port = connections[fileno]

    try:
        # 準備ができているので、すぐに recv() できる
        message = clientsocket.recv(1024)
    except OSError:
        terminate()
        return

    if len(message) == 0:
        terminate()
        return

    print('Recv: {0} to {1}:{2}'.format(message,
                                        client_address,
                                        client_port))

    # 次はクライアントのソケットが書き込み可能になるまで待つ
    write_waiters[fileno] = (send_handler, (fileno, message))


def send_handler(fileno, message):
    """クライアントソケットが書き込み可能になったとき呼ばれるハンドラ"""
    # クライアントとの接続情報を取り出す
    clientsocket, client_address, client_port = connections[fileno]

    # 準備ができているので、すぐに send() できる
    sent_len = clientsocket.send(message)
    print('Send: {0} to {1}:{2}'.format(message[:sent_len],
                                        client_address,
                                        client_port))

    if sent_len == len(message):
        # 全て送ることができたら、次はまたソケットが読み取れるようになるのを待つ
        read_waiters[clientsocket.fileno()] = (recv_handler,
                                               (clientsocket.fileno(), ))
    else:
        # 送り残している内容があったら、再度ソケットが書き込み可能になるまで待つ
        write_waiters[fileno] = (send_handler,
                                 (fileno, message[sent_len:]))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    # ソケットをノンブロックモードにする
    serversocket.setblocking(False)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    # クライアントからの接続がくるまで待つ
    read_waiters[serversocket.fileno()] = (accept_handler, (serversocket, ))

    while True:
        # ソケットが読み取り・書き込み可能になるまで待つ
        rlist, wlist, _ = select.select(read_waiters.keys(),
                                        write_waiters.keys(),
                                        [],
                                        60)

        # 読み取り可能になったソケット (のファイル番号) の一覧
        for r_fileno in rlist:
            # 読み取り可能になったときに呼んでほしいハンドラを取り出す
            handler, args = read_waiters.pop(r_fileno)
            # ハンドラを実行する
            handler(*args)

        # 書き込み可能になったソケット (のファイル番号の一覧)
        for w_fileno in wlist:
            # 書き込み可能になったときに呼んでほしいハンドラを取り出す
            handler, args = write_waiters.pop(w_fileno)
            # ハンドラを実行する
            handler(*args)


if __name__ == '__main__':
    main()

Python では select(2) システムコールの薄いラッパとして select モジュールが使える。 このモジュールが提供する select() 関数には、ファイルディスクリプタの入ったリストを渡す。

ファイルディスクリプタというのは、名前だけ聞くと難しそうだけどただの整数に過ぎない。 これは、各ソケットやファイルなどを使うときに OS が割り当てた一意な整数を指している。 ようするに 10 とか 20 とかいう数字が、何らかのソケットやファイルなどを表す。 ソケットに割り当てられたファイルディスクリプタは fileno() メソッドで得られる。

select() 関数には、読み込みや書き込みの準備ができたら通知してほしいファイルディスクリプタを渡す。 そして select() 関数を呼び出すと、そこでブロックした後に、準備ができたファイルディスクリプタが返される。 返ってきたファイルディスクリプタは、既に読み書きができるようになっているので指示を出しても例外にはならない。

先ほどのサンプルコードでは、そのようにして準備ができたものに対して読み書きをしている。 ビジーループと比べると CPU を使い切ることもなく、複数のクライアントを処理できる。 また、大きなポイントとしてはシングルスレッドにも関わらず、複数のクライアントを処理できているところだ。 これはソケットをブロッキングで使っていたときとの大きな違いだろう。

ちなみに、今回使った select(2) システムコールにはパフォーマンス上の問題が知られている。 そのため、実用的な用途で使われることはそこまで多くない。 代わりに、BSD 系なら kqueue(2)、Linux であれば epoll(2) が用いられる。 ただし、select(2) なら大抵のプラットフォームで使えるので、それらに比べると移植性が高いというメリットはある。

また「ソケットやファイルなど」と前述した通り、実はブロッキング・ノンブロッキングという概念はソケットに限った話ではない。 ファイルやデバイスについてもノンブロッキングで扱うことはできる。 そして、これはノンブロッキングなソケットプログラミングをする上で重要な意味を持ってくる。 詳細は後述するものの、これはノンブロッキングとブロッキングを同じスレッドで混ぜて使うと問題が発生する、というもの。

尚、前述した通り先ほどのサンプルコードはシングルプロセス・シングルスレッドで動作している。 そのため、複数の CPU コアを使い切ることはできない。 使い切れるようにするときは、マルチプロセスにする必要がある。 もちろん、これは GIL の制約のためにプロセスを複数立ち上げる必要があるに過ぎない。 別の処理系やプログラミング言語であれば、単にマルチスレッドにするだけで良い。 いずれの場合でも、それぞれのスレッドごとにイベントループを用意する。

ノンブロッキング I/O をラップした API やライブラリを使う

先ほどの例ではイベントループのシステムコールを使ってノンブロッキングなソケットを処理してみた。 とはいえ、実際にシステムコールを直接使ってソケットプログラミングする機会は、あまりないと思う。 なぜなら、先ほどのサンプルコードを見て分かる通り、それらの API はそのままでは扱いにくい上にコード量も増えてしまうため。

実際には、イベントループをラップしたライブラリを使ってプログラミングすることになると思う。 どんなライブラリがあるかはプログラミング言語ごとに異なる。 例えば C 言語なら libev が有名だと思うし Python なら Twisted などがある。 また、Python に関しては 3.4 から標準ライブラリに asyncio というモジュールが追加された。 次は、この asyncio を使ってみることにしよう。

Python の asyncio には色んなレイヤーの API が用意されている。 それこそ、先ほどのシステムコールを直接使うのと大差ないようなコードも書ける。 しかし、それだとライブラリを使う意味がないので、もうちょっと抽象度の高いものを使ってみた。 次のサンプルコードでは asyncio を使ってエコーサーバを実装している。 コードを見て分かる通り、先ほどと比べるとだいぶコード量が減って読みやすくなっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio


class EchoServer(asyncio.Protocol):

    def connection_made(self, transport):
        """クライアントからの接続があったときに呼ばれるイベントハンドラ"""
        # 接続をインスタンス変数として保存する
        self.transport = transport

        # 接続元の情報を出力する
        client_address, client_port = self.transport.get_extra_info('peername')
        print('New client: {0}:{1}'.format(client_address, client_port))

    def data_received(self, data):
        """クライアントからデータを受信したときに呼ばれるイベントハンドラ"""
        # 受信した内容を出力する
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Recv: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

        # 受信したのと同じ内容を返信する
        self.transport.write(data)
        print('Send: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

    def connection_lost(self, exc):
        """クライアントとの接続が切れたときに呼ばれるイベントハンドラ"""
        # 接続が切れたら後始末をする
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))
        self.transport.close()


def main():
    host = 'localhost'
    port = 37564

    # イベントループを用意する
    ev_loop = asyncio.get_event_loop()

    # 指定したアドレスとポートでサーバを作る
    factory = ev_loop.create_server(EchoServer, host, port)
    # サーバを起動する
    server = ev_loop.run_until_complete(factory)

    try:
        # イベントループを起動する
        ev_loop.run_forever()
    finally:
        # 後始末
        server.close()
        ev_loop.run_until_complete(server.wait_closed())
        ev_loop.close()


if __name__ == '__main__':
    main()

注目すべきは、もはやソースコードの中に socket モジュールが登場していないところ。 それらは Protocol や Transport といった抽象的なオブジェクトに取って代わられている。

では、本当に内部でイベントループのシステムコールが使われているのかを調べてみよう。 まずは上記のサンプルコードを実行して、エコーサーバを起動する。

$ python asyncioserver.py

続いて、別のターミナルを開いたら上記エコーサーバが動いているプロセスの ID を調べる。

$ ps auxww | grep [a]syncioserver
amedama        31018   0.0  0.2  2430616  17344 s000  S+    7:58PM   0:00.16 python asyncioserver.py

そして、プロセスで発行されるシステムコールをトレースする dtruss コマンドを仕掛ける。

$ sudo dtruss -p 31018

準備ができたらクライアントを接続する。

$ nc localhost 37564

すると、次のように kevent(2) システムコールが発行されていることが分かる。 kevent(2) システムコールは kqueue(2) と共に用いるイベントループのためのシステムコール。

$ sudo dtruss -p 31018
SYSCALL(args)            = return
...
kevent(0x3, 0x0, 0x0)            = 0 0
getsockname(0xA, 0x7FFF50F55B00, 0x7FFF50F55AFC)                 = 0 0
setsockopt(0xA, 0x6, 0x1)                = 0 0
kevent(0x3, 0x0, 0x0)            = 0 0
write(0x1, "New client: 127.0.0.1:51822\n\0", 0x1C)              = 28 0
kevent(0x3, 0x10F8FB6F0, 0x1)            = 0 0
kevent(0x3, 0x0, 0x0)            = 0 0

どうやら、ちゃんと内部がノンブロッキングな世界になっていることが確認できた。 しかも、プラットフォームに応じたパフォーマンスに優れるイベントループをちゃんと使ってくれている。

ノンブロッキングとブロッキングは混ぜるな危険

ちなみに、ノンブロッキングなソケットプログラミングをする上では重要なポイントが一つある。 それは、ノンブロッキングなソケットを扱うスレッドで、ブロッキングな操作をしてはいけない、という点。 もちろん、前述した通りブロッキング・ノンブロッキングという概念はソケットに限った話ではない。 つまり、言い換えるとノンブロッキングな I/O とブロッキングな I/O は同じスレッドで混ぜてはいけない。

二つ前のセクションで登場した select システムコールを使ったサンプルコードを思い出してほしい。 あのサンプルコードでは、シングルスレッドで複数のクライアントをさばいていた。 では、もしその一つしかないスレッドが何処かでブロックしたら、何が起こるだろうか? これは、そのスレッドでさばいている全ての処理が、そこで停止してしまうことを意味する。 これは、ノンブロッキングな I/O を扱う上で登場する代表的な問題の一つ。

どのようなことが起こるかを実際に確かめてみよう。 次のサンプルコードでは、データを受信した際に time.sleep() 関数を使っている。 これには、実行したスレッドを指定した時間だけブロックさせる効果がある。 正に、ノンブロッキングなスレッドへのブロッキングな操作の混入といえる。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio
import time


class EchoServer(asyncio.Protocol):

    def connection_made(self, transport):
        self.transport = transport

        client_address, client_port = self.transport.get_extra_info('peername')
        print('New client: {0}:{1}'.format(client_address, client_port))

    def data_received(self, data):
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Recv: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

        # 何らかの処理で、イベントループのスレッドがブロックしてしまった!
        print('Go to sleep...')
        time.sleep(20)

        self.transport.write(data)
        print('Send: {0} to {1}:{2}'.format(data,
                                            client_address,
                                            client_port))

    def connection_lost(self, exc):
        client_address, client_port = self.transport.get_extra_info('peername')
        print('Bye-Bye: {0}:{1}'.format(client_address, client_port))
        self.transport.close()


def main():
    host = 'localhost'
    port = 37564

    ev_loop = asyncio.get_event_loop()

    factory = ev_loop.create_server(EchoServer, host, port)
    server = ev_loop.run_until_complete(factory)

    try:
        ev_loop.run_forever()
    finally:
        server.close()
        ev_loop.run_until_complete(server.wait_closed())
        ev_loop.close()


if __name__ == '__main__':
    main()

上記のサンプルコードを実行してエコーサーバを起動しよう。

$ python asyncblock.py

続いて別のターミナルから nc コマンドでサーバに接続したら適当な文字列を入力する。

$ nc localhost 37564
hogehoge

これでイベントループを回しているスレッドはブロックを起こした。

$ python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...

すかさず別のターミナルから nc でクライアントを追加してみよう。

$ nc localhost 37564

すると、今度はサーバに新しいクライアントが追加された旨は表示されない。 エコーサーバ全体の処理が、一箇所で停止してしまっているからだ。

$ python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...

もうしばらく待つと、スレッドのブロックが解除されて新しいクライアントの接続が受理される。

python asyncblock.py
New client: 127.0.0.1:51883
Recv: b'hogehoge\n' to 127.0.0.1:51883
Go to sleep...
Send: b'hogehoge\n' to 127.0.0.1:51883
New client: 127.0.0.1:51884

このように、ノンブロッキング I/O を扱うスレッドにブロッキング I/O のコードが混入すると、全てがそこで停止してしまう。

そして、真にこの問題が恐ろしいのは、混入に気づきにくい点かもしれない。 先ほどのサンプルコードは極端な例なので、使ってみるだけでも明確に変化を知覚できた。 しかしながら、実際にはブロッキング I/O の処理は人間にとって一瞬なので気づくことは難しいかもしれない。 にも関わらず、そのタイミングで一連の処理が全て停止していることに間違いはない。 結果として、パフォーマンスの低下をもたらす。

また、世の中のほとんどのライブラリはブロッキング I/O を使って実装されている。 例えば、外部の WebAPI を叩こうとそのまま requests でも使おうものなら、それだけでアウトだ。 それに、HTTP のような分かりやすい I/O 以外にもキューのような基本的な部品であっても操作をブロックしたりする。

つまり、新たに何かを使おうとしたら、それにブロックする操作が混入していないかをあらかじめ調べる必要がある。 さらに、ブロックする操作が含まれると分かったら、それをブロックしないようにする方法を模索しなきゃいけない。 以上のように、イベントループを中心に据えた非同期なフレームワークというのは、一般的な認識よりもずっと扱いが難しいと思う。

ブロッキング I/O が混入する問題へのアプローチについて

ノンブロッキング I/O を扱うスレッドにブロッキング I/O が混ざり込む問題に対するアプローチはいくつかある。 もちろん、混入しないように人間が頑張ってコードを見張る、というのは最も基本的なやり方の一つ。

それ以外には、プログラミング言語のレベルでブロッキング I/O を排除してしまうという選択肢もある。 これは例えば JavaScript (Node.js) が採用している。 Golang も、ネットワーク部分に関してはノンブロッキング I/O しか用意していないらしい。 初めからブロッキング I/O の操作が存在していないなら、そもそも混入することはない。

それ以外には、モンキーパッチを当てるというアプローチもある。 つまり、ブロッキング I/O を使うコードを、全てノンブロッキング I/O を使うように書き換えてしまう。 Python であれば、例えば EventletGevent といったサードパーティー製ライブラリがこれにあたる。

試しに Eventlet を使った例を見てみよう。 まずは Python のパッケージマネージャである pip を使って Eventlet をインストールしておく。

$ pip install eventlet

それでは Eventlet の魔法をお見せしよう。 次のサンプルコードは、最初に示したマルチスレッドの例に、たった二行だけコードを追加している。 その冒頭に追加した二行こそ、まさにモンキーパッチを当てるためのコードになっている。 たったこれだけで、ブロッキングだった世界がノンブロッキングな世界に書き換わってしまう。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# 標準ライブラリにモンキーパッチを当てる
# ブロッキング I/O を使った操作が裏側で全てノンブロッキング I/O を使うように書き換えられる
import eventlet
eventlet.monkey_patch()

import socket
import threading


def client_handler(clientsocket, client_address, client_port):
    """クライアントとの接続を処理するハンドラ"""
    while True:
        try:
            message = clientsocket.recv(1024)
            print('Recv: {0} from {1}:{2}'.format(message,
                                                  client_address,
                                                  client_port))
        except OSError:
            break

        if len(message) == 0:
            break

        sent_message = message
        while True:
            sent_len = clientsocket.send(sent_message)
            if sent_len == len(sent_message):
                break
            sent_message = sent_message[sent_len:]
        print('Send: {0} to {1}:{2}'.format(message,
                                            client_address,
                                            client_port))

    clientsocket.close()
    print('Bye-Bye: {0}:{1}'.format(client_address, client_port))


def main():
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    host = 'localhost'
    port = 37564
    serversocket.bind((host, port))

    serversocket.listen(128)

    while True:
        clientsocket, (client_address, client_port) = serversocket.accept()
        print('New client: {0}:{1}'.format(client_address, client_port))

        client_thread = threading.Thread(target=client_handler,
                                         args=(clientsocket,
                                               client_address,
                                               client_port))
        client_thread.daemon = True
        client_thread.start()


if __name__ == '__main__':
    main()

本当にイベントループが使われているのか確かめてみることにしよう。 まずは、上記のサンプルコードを実行する。

$ python eventletserver.py

続いて別のターミナルを開いて、上記で実行しているプロセス ID を調べる。

$ ps auxww | grep [e]ventletserver
amedama         7796   0.0  0.1  2426888  19488 s000  S+    8:44PM   0:00.17 python eventletserver.p

dtruss コマンドでプロセス内で発行されるシステムコールをトレースする。

$ sudo dtruss -p 7796

クライアントからサーバに接続してみよう。

$ nc localhost 37564

すると、次のように dtruss の実行結果に kevent システムコールが登場している。 本当に、モンキーパッチを当てるだけでノンブロッキング I/O を使うようになった。

$ sudo dtruss -p 7796
SYSCALL(args)            = return
kevent(0x4, 0x101256710, 0x1)            = 0 0
accept(0x3, 0x7FFF5F8A7750, 0x7FFF5F8A774C)              = 7 0
ioctl(0x7, 0x20006601, 0x0)              = 0 0
ioctl(0x7, 0x8004667E, 0x7FFF5F8A7A04)           = 0 0
ioctl(0x7, 0x8004667E, 0x7FFF5F8A74E4)           = 0 0
write(0x1, "New client: 127.0.0.1:54132\n\0", 0x1C)              = 28 0
recvfrom(0x7, 0x7FE86782BC20, 0x400)             = -1 Err#35
kevent(0x4, 0x101256710, 0x1)            = 0 0
kevent(0x4, 0x7FE866D11020, 0x0)                 = 0 0
accept(0x3, 0x7FFF5F8A7750, 0x7FFF5F8A774C)              = -1 Err#35
kevent(0x4, 0x101256710, 0x1)            = 0 0

注目すべきは、逐次的なプログラミングモデルを保ったまま、それが実現できているところだろう。 asyncio の例でも、データの読み書きなどは逐次的に書くことができたものの、基本はイベントドリブンだった。 しかし Eventlet のコードでは、完全にブロッキング I/O を使っているときと同じように書くことができている。

これが一体どのようにして実現されているかというと、主にグリーンスレッドの寄与が大きい。 Eventlet では、カーネルで実装されたネイティブスレッドの代わりにユーザランドで実装されたグリーンスレッドを用いる。 グリーンスレッドには、実装によってコルーチン、軽量プロセス、協調スレッドなど色々な呼び方がある。

カーネルで実装されたネイティブスレッドとの大きな違いは、コンテキストスイッチのタイミングがプログラムで制御できるところにある。 ネイティブスレッドのコンテキストスイッチはカーネルのスケジューラ次第なので、基本的にプログラムからは制御できない。 それに対し、グリーンスレッドでは実行中のスレッドが自発的に処理を手放さない限りコンテキストスイッチが起こらない。 つまり、I/O などの外的な要因がない限りグリーンスレッドは決定論的に動作することを意味している。

Eventlet では、モンキーパッチを使うと既存のスレッドやソケットがインターフェースはそのままに書き換えられる。 そして、本来ならブロックするコードに処理が到達したタイミングでコンテキストスイッチが起こるように変化する。 コンテキストスイッチする先は、読み書きの準備が整った I/O を処理しているグリーンスレッドだ。 これは、先ほどシステムコールをトレースした通り、イベントループを使って判断している。 そして、コンテキストスイッチした元のグリーンスレッドは、イベントループを使って処理中の I/O が読み書きができるようになるまで待たされる。 ちなみに Golang はプログラミング言語のレベルで上記を実現していて、それは goroutine と呼ばれている。

このアーキテクチャでは、逐次的なプログラミングモデルを保ったままノンブロッキング I/O を使った恩恵が受けられる。 また、グリーンスレッドは一般的にネイティブスレッドよりもコンテキストに必要なメモリのサイズが小さい。 つまり、同時に多くのクライアントをさばきやすい。

ただし、Eventlet のようなモンキーパッチを使ったアプローチには抵抗がある人も多いかもしれない。 実際のところ Eventlet にはクセが全くないとは言えないし、よく分からずに使うのはやめた方が良いと思う。 ただし、名誉のために言っておくと Eventlet は OpenStack のような巨大なプロジェクトでも使われている実績のあるライブラリだ。

ちなみに、モンキーパッチでは一つだけブロッキング I/O の混入を防げないところがある。 それは Python/C API を使って書かれた拡張モジュールだ。 コンパイル済みの拡張モジュールに対しては、個別に対応しない限り自動でモンキーパッチが効くことはない。 これは、典型的には Python/C API で書かれたデータベースドライバで問題になることが多い。

まとめ

今回はソケットプログラミングにおいて、どういったアーキテクチャが考えられるかについて見てきた。 まず、ソケットは大きく分けてブロッキングで使うかノンブロッキングで使うかという選択肢がある。

ブロッキングは、逐次的なプログラミングモデルで扱いやすいことから理解もしやすい。 ただし、複数のクライアントをさばくにはマルチスレッドやマルチプロセスにする必要がある。 それらは必要なコンテキストのサイズやスイッチのコストも大きいことから、スケーラビリティの面で問題となりやすい。

それに対し、ノンブロッキングはイベントドリブンなプログラミングモデルとなりやすいことから理解が難しい。 しかしながら、イベントループを使うことでシングルスレッドでも複数のクライアントを効率的にさばける。

また、ブロッキング・ノンブロッキングというのはソケットに限った概念ではない。 ファイルなども同じようにノンブロッキングで扱うことができる。

ノンブロッキングで I/O を扱うときの注意点としては、同じスレッドをブロックさせてはいけない、というところ。 言い換えると、イベントループを回しているスレッドにブロッキングな I/O のコードを混入させてはいけない。 もし混入するとパフォーマンス低下をもたらす。

ブロッキング I/O が混入する問題に対するアプローチは、言語や処理系、ライブラリによっていくつかある。 例えば JavaScript (Node.js) では、プログラミング言語自体にブロッキング I/O を扱う API がない。 それ以外だと、スクリプト言語ならモンキーパッチで動的に実装を書き換えてしまうというものもある。

参考文献

詳解UNIXプログラミング 第3版

詳解UNIXプログラミング 第3版

UNIXネットワークプログラミング〈Vol.1〉ネットワークAPI:ソケットとXTI

UNIXネットワークプログラミング〈Vol.1〉ネットワークAPI:ソケットとXTI

  • 作者: W.リチャードスティーヴンス,W.Richard Stevens,篠田陽一
  • 出版社/メーカー: ピアソンエデュケーション
  • 発売日: 1999/07
  • メディア: 単行本
  • 購入: 8人 クリック: 151回
  • この商品を含むブログ (37件) を見る

Python: KMeans 法を実装してみる

機械学習 scikit-learn Python NumPy Mac OS X matplotlib

KMeans 法は、機械学習における教師なし学習のクラスタリングという問題を解くためのアルゴリズム。 教師なし学習というのは、事前に教師データというヒントが与えられないことを指している。 その上で、クラスタリングというのは未知のデータに対していくつかのまとまりを作る問題をいう。

今回取り扱う KMeans 法は、比較的単純なアルゴリズムにも関わらず広く使われているものらしい。 実際に書いてみても、基本的な実装であればたしかにとてもシンプルだった。 ただし、データの初期化をするところで一点考慮すべき内容があることにも気づいたので、それについても書く。 KMeans 法の具体的なアルゴリズムについてはサンプルコードと共に後述する。

今回使った環境は次の通り。

$ sw_vers 
ProductName:    Mac OS X
ProductVersion: 10.12.3
BuildVersion:   16D32
$ python --version
Python 3.5.3

依存パッケージをインストールする

あらかじめ、今回ソースコードで使う依存パッケージをインストールしておく。 グラフ描画ライブラリの matplotlib を Mac OS X で動かすには、pip でのインストール以外にもちょっとした設定が必要になる。

$ pip install scipy scikit-learn matplotlib
$ mkdir -p ~/.matplotlib
$ cat << 'EOF' > ~/.matplotlib/matplotlibrc
backend: TkAgg
EOF

まずは scikit-learn のお手本から

本来はお手本の例については後回しにしたいんだけど、今回は構成の都合で先に示しておく。 Python の機械学習系ライブラリである scikit-learn には KMeans 法の実装も用意されている。 自分で書いた実装をお披露目する前に、どんなものなのか見てもらいたい。

次のサンプルコードでは、ダミーのデータを生成した上でそれをクラスタリングしている。 scikit-learn にはデータセットを生成する機能も備わっている。 その中でも make_blobs() という関数は、いくつかのまとまりを持ったダミーデータを作ってくれる。 あとは、そのダミーデータを KMeans 法の実装である sklearn.cluster.KMeans で処理させている。 クラスタリングした処理結果は matplotlib を使って可視化した。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.cluster import KMeans


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    # クラスタのセントロイド (重心) を描く
    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

上記のサンプルコードでは、クラスタ数として適当に 5 を選んだ。 ここで「クラスタ数を選ぶ」というのはポイントとなる。 KMeans 法では、あらかじめクラスタ数をハイパーパラメータとして指定するためだ。 つまり、データをいくつのまとまりとして捉えるかを人間が教えてやる必要がある。 クラスタリングの手法によっては、これをアルゴリズムがやってくれる場合もある。

それでは、上記のサンプルコードを実行してみよう。

$ python kmeans_scikit.py

すると、こんな感じのグラフが出力される。 それぞれの特徴ベクトルが所属するクラスタが、別々の色で表現されている。 また、それぞれのクラスタの真ん中らへんにある黒い円は、後述するセントロイドという概念を表している。 f:id:momijiame:20170319081628p:plain ちなみに、生成されるダミーデータは毎回異なるので、出力されるグラフも異なる。

自分で実装してみる

scikit-learn がクラスタリングしたお手本を見たところで、次は KMeans 法を自分で書いてみよう。

まず、KMeans 法ではセントロイド (重心) という概念が重要になる。 セントロイドというのは、文字通りクラスタの中心に置かれるものだ。 アルゴリズムの第一歩としては、このセントロイドを作ることになる。 セントロイドは、前述した通り各クラスタの中心に置かれる。 そのため、セントロイドはハイパーパラメータとして指定したクラスタ数だけ必要となる。 また、中心というのは、クラスタに属する特徴ベクトルの各次元でそれらの平均値にいる、ということ。

そして、データセットに含まれる特徴ベクトルは、必ず最寄りのセントロイドに属する。 ここでいう最寄りとは、特徴ベクトル同士のユークリッド距離が近いという意味だ。 属したセントロイドが、それぞれのクラスタを表している。

上記の基本的な概念を元に KMeans 法のアルゴリズムは次のようなステップに分けることができる

  • 最初のセントロイドを何らかの方法で決める
  • 特徴ベクトルを最寄りのセントロイドに所属させる
  • 各クラスタごとにセントロイドを再計算する
  • 上記 2, 3 番目の操作を繰り返す
  • もしイテレーション前後で所属するセントロイドが変化しなければ、そこで処理を終了する

ここで、一番最初のステップである「最初のセントロイドを何らかの方法で決める」のがポイントだった。 例えば、次のページではあらかじめ各特徴ベクトルをランダムにクラスタリングしてから、セントロイドを決めるやり方を取っている。

tech.nitoyon.com

上記のやり方であれば、次のようなサンプルコードになる。 ここでは KMeans というクラスで KMeans 法を実装している。 インターフェースは、先ほどお手本として提示した scikit-learn と揃えた。 そのため main() 関数は基本的に変わっておらず、見るべきは KMeans クラスだけとなっている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
import numpy as np


class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
        # 初期データは各要素に対してランダムにラベルをつける
        pred = np.random.randint(0, self.n_clusters, len(features))

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):

            # 各クラスタごとにセントロイド (重心) を計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

            # 各特徴ベクトルから最短距離となるセントロイドを基準に新しいラベルをつける
            new_pred = np.array([
                np.array([
                    self._euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容を比較して、もし同じなら終了
                break

            pred = new_pred

        return pred

    def _euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

上記を実行してみよう。

$ python kmeans_random.py

実行すると、次のようなグラフが表示される。 f:id:momijiame:20170319153805p:plain 一見すると、上手くいっているようだ。

しかし、何度か実行すると、次のようなエラーが出ることがある。

$ python kmeans_random.py 
kmeans_random.py:41: RuntimeWarning: Mean of empty slice.
  for i in range(self.n_clusters)])
/Users/amedama/.virtualenvs/kmeans/lib/python3.5/site-packages/numpy/core/_methods.py:73: RuntimeWarning: invalid value encountered in true_divide
  ret, rcount, out=ret, casting='unsafe', subok=False)

表示されるグラフは、次のようなものになる。 f:id:momijiame:20170319153914p:plain 何やら、ぜんぜん上手いことクラスタリングできていない。

上記が起こる原因について調べたところ、セントロイドの初期化に問題があることが分かった。 初期化時に特徴ベクトルに対してランダムにラベルをつけるやり方では、一つの特徴ベクトルも属さないセントロイドが生じる恐れがあるためだ。

ランダムにラベルをつけるやり方では、セントロイドはおのずとデータセット全体の平均あたりに生じやすくなる。 場合によっては、セントロイドが別のセントロイドに囲まれるなどして、一つの特徴ベクトルも属さないセントロイドがでてくる。 サンプルコードでは、そのような事態を想定していなかったのでセントロイドが計算できなくなって上手く動作しなかった。

先ほどの KMeans 法を図示しているページでも、特徴ベクトルの数 N に対してクラスタの数 k を増やして実行してみよう。 一つの特徴ベクトルも属さないクラスタが生じることが分かる。

初期化を改良してみる

一つの特徴ベクトルも属さないクラスタが生じることについて、問題がないと捉えることもできるかもしれない。 その場合には、一つも属さないクラスタのセントロイドを、前回の位置から動かなさないようにすることでケアできそうだ。

とはいえ、ここでは別のアプローチで問題を解決してみることにする。 それというのは、初期のセントロイドをデータセットに含まれるランダムな特徴ベクトルにする、というものだ。 つまり、最初に作るセントロイドの持つ特徴ベクトルがランダムに選んだ要素の特徴ベクトルと一致する。 こうすれば各セントロイドは、少なくとも一つ以上の特徴ベクトルが属することは確定できる。 ユークリッド距離がゼロになる特徴ベクトルが、必ず一つは存在するためだ。

次のサンプルコードでは、初期のセントロイドの作り方だけを先ほどと変更している。 ここでポイントとなるのは、最初にセントロイドとして採用する特徴ベクトルは重複しないように決める必要があるということだ。 そこで、初期のセントロイドは特徴ベクトルをシャッフルした上で先頭から k 個を取り出して決めている。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from matplotlib import pyplot as plt
from sklearn import datasets
import numpy as np


class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
        # 要素の中からセントロイド (重心) の初期値となる候補をクラスタ数だけ選び出す
        feature_indexes = np.arange(len(features))
        np.random.shuffle(feature_indexes)
        initial_centroid_indexes = feature_indexes[:self.n_clusters]
        self.cluster_centers_ = features[initial_centroid_indexes]

        # ラベル付けした結果となる配列はゼロで初期化しておく
        pred = np.zeros(features.shape)

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):
            # 各要素から最短距離のセントロイドを基準にラベルを更新する
            new_pred = np.array([
                np.array([
                    self._euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容が同じなら終了
                break

            pred = new_pred

            # 各クラスタごとにセントロイド (重心) を再計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

        return pred

    def _euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)


def main():
    # クラスタ数
    N_CLUSTERS = 5

    # Blob データを生成する
    dataset = datasets.make_blobs(centers=N_CLUSTERS)

    # 特徴データ
    features = dataset[0]
    # 正解ラベルは使わない
    # targets = dataset[1]

    # クラスタリングする
    cls = KMeans(n_clusters=N_CLUSTERS)
    pred = cls.fit_predict(features)

    # 各要素をラベルごとに色付けして表示する
    for i in range(N_CLUSTERS):
        labels = features[pred == i]
        plt.scatter(labels[:, 0], labels[:, 1])

    centers = cls.cluster_centers_
    plt.scatter(centers[:, 0], centers[:, 1], s=100,
                facecolors='none', edgecolors='black')

    plt.show()


if __name__ == '__main__':
    main()

それでは、上記のサンプルコードを実行してみよう。

$ python kmeans_select.py

これまでと変わらないけど、次のようなグラフが表示されるはず。 f:id:momijiame:20170319155439p:plain

今度は N_CLUSTERS を増やしたり、何度やっても先ほどのようなエラーにはならない。

まとめ

今回は教師なし学習のクラスタリングという問題を解くアルゴリズムの KMeans 法を実装してみた。 KMeans 法はシンプルなアルゴリズムだけど、セントロイドをどう初期化するかは流派があるみたい。

はじめてのパターン認識

はじめてのパターン認識