CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PyTorch で Apple Silicon GPU を使ってみる

PyTorch v1.12 以降では、macOS において Apple Silicon あるいは AMD の GPU を使ったアクセラレーションが可能になっているらしい。 バックエンドの名称は Metal Performance Shaders (MPS) という。 意外と簡単に使えるようなので、今回は手元の Mac で試してみた。

使った環境は次のとおり。 GPU が 19 コアの Apple M2 Pro を積んだ Mac mini を使用している。

$ sw_vers
ProductName:        macOS
ProductVersion:     14.4.1
BuildVersion:       23E224
$ sysctl machdep.cpu.brand_string     
machdep.cpu.brand_string: Apple M2 Pro
$ pip list | grep -i torch
torch                     2.2.1
$ python -V               
Python 3.10.14

もくじ

下準備

あらかじめ、必要なパッケージをインストールする。 特に意識しなくても MPS バックエンドが有効なバイナリが入る。

$ pip install torch tqdm numpy

インストールできたら Python のインタプリタを起動する。

$ python

そして、PyTorch のパッケージをインポートしておく。

>>> import torch

MPS バックエンドを使ってみる

MPS バックエンドが有効かどうかは以下のコードで確認できる。 True が返ってくれば利用できる状態にある。

>>> torch.backends.mps.is_available()
True

使い方は CUDA バックエンドと変わらない。 テンソルやモデルを .to() メソッドで転送するだけ。 このとき、引数に "mps" を指定すれば良い。

>>> x = torch.randn(2, 3, 4).to("mps")
>>> x.shape
torch.Size([2, 3, 4])
>>> x.device
device(type='mps', index=0)

ちゃんと転送できた。

簡単にベンチマークしてみる

続いては、どれくらいパフォーマンスが出るのか気になるので簡単にベンチマークしてみる。 PyTorch のベンチマークのページ 1 を参考に、以下のようなコードを用意した。 いくつかのサイズやスレッド数の組み合わせで、行列の積や和を計算している。

from itertools import product

from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark


def device():
    """環境毎に利用できるアクセラレータを返す"""
    if torch.backends.mps.is_available():
        # macOS w/ Apple Silicon or AMD GPU
        return "mps"
    if torch.cuda.is_available():
        # NVIDIA GPU
        return "cuda"
    return "cpu"


def batched_dot_mul_sum(a, b):
    """mul -> sum"""
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    """bmm -> flatten"""
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


DEVICE = device()
print(f"device: {DEVICE}")


results = []

# 行列サイズ x スレッド数の組み合わせでベンチマークする
sizes = [1, 64, 1024, 10000]
for b, n in tqdm(list(product(sizes, sizes))):
    label = "Batched dot"
    sub_label = f"[{b}, {n}]"
    x = torch.ones((b, n)).to(DEVICE)
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt="batched_dot_mul_sum(x, x)",
            setup="from __main__ import batched_dot_mul_sum",
            globals={"x": x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description="mul/sum",
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt="batched_dot_bmm(x, x)",
            setup="from __main__ import batched_dot_bmm",
            globals={"x": x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description="bmm",
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

Apple M2 Pro (GPU 19C)

実際に、上記を実行してみよう。 まずは Apple M2 Pro の環境から。

$ python bench.py 
device: mps
100%|███████████████████████████████████████████| 16/16 [02:12<00:00,  8.27s/it]
[-------------- Batched dot --------------]
                      |  mul/sum  |   bmm  
1 threads: --------------------------------
      [1, 1]          |     49.9  |    30.8
      [1, 64]         |     48.7  |    30.1
      [1, 1024]       |     48.8  |    30.0
      [1, 10000]      |     51.5  |    30.1
      [64, 1]         |     50.1  |    30.3
      [64, 64]        |     49.1  |    30.2
      [64, 1024]      |     54.9  |    30.1
      [64, 10000]     |     58.0  |    30.0
      [1024, 1]       |     49.9  |    30.0
      [1024, 64]      |     55.5  |    30.4
      [1024, 1024]    |     54.9  |    30.0
      [1024, 10000]   |    400.2  |    90.0
      [10000, 1]      |     53.7  |    30.5
      [10000, 64]     |     56.0  |    31.0
      [10000, 1024]   |    271.2  |   107.0
      [10000, 10000]  |   6594.7  |    31.2
4 threads: --------------------------------
      [1, 1]          |     52.1  |    31.5
      [1, 64]         |     50.6  |    31.3
      [1, 1024]       |     50.5  |    30.5
      [1, 10000]      |     53.2  |    31.3
      [64, 1]         |     52.7  |    31.3
      [64, 64]        |     51.2  |    30.3
      [64, 1024]      |     56.7  |    30.5
      [64, 10000]     |     59.6  |    30.7
      [1024, 1]       |     51.5  |    30.6
      [1024, 64]      |     56.6  |    30.7
      [1024, 1024]    |     57.1  |    30.7
      [1024, 10000]   |     64.5  |   204.3
      [10000, 1]      |     55.3  |    35.1
      [10000, 64]     |     58.0  |    34.4
      [10000, 1024]   |    590.8  |   223.3
      [10000, 10000]  |  32409.0  |  1498.3
16 threads: -------------------------------
      [1, 1]          |     51.6  |    30.8
      [1, 64]         |     51.1  |    30.4
      [1, 1024]       |     50.6  |    30.4
      [1, 10000]      |     53.7  |    30.7
      [64, 1]         |     51.7  |    30.6
      [64, 64]        |     50.4  |    30.4
      [64, 1024]      |     57.1  |    30.7
      [64, 10000]     |     59.5  |    30.5
      [1024, 1]       |     51.2  |    30.3
      [1024, 64]      |     56.3  |    30.8
      [1024, 1024]    |     57.3  |    31.0
      [1024, 10000]   |     60.3  |   106.8
      [10000, 1]      |     54.9  |    34.9
      [10000, 64]     |     57.2  |    34.5
      [10000, 1024]   |    400.3  |   220.7
      [10000, 10000]  |  32418.2  |  1503.2
32 threads: -------------------------------
      [1, 1]          |     51.1  |    30.6
      [1, 64]         |     50.4  |    30.6
      [1, 1024]       |     50.7  |    30.5
      [1, 10000]      |     53.0  |    30.5
      [64, 1]         |     51.8  |    30.7
      [64, 64]        |     50.4  |    30.2
      [64, 1024]      |     56.7  |    30.6
      [64, 10000]     |     59.3  |    30.5
      [1024, 1]       |     51.3  |    30.6
      [1024, 64]      |     56.6  |    34.5
      [1024, 1024]    |     57.8  |    33.5
      [1024, 10000]   |    447.3  |   202.6
      [10000, 1]      |     54.3  |    35.3
      [10000, 64]     |     57.0  |    34.5
      [10000, 1024]   |    591.2  |   219.7
      [10000, 10000]  |  32443.3  |  1493.3

Times are in microseconds (us).

NVIDIA GeForce RTX 3060

さきほどの結果は、もちろん CPU よりは全然速い。 とはいえ、他の GPU などに比べてどれくらい速いのかイメージしにくい。 そこで、厳密な比較にはならないものの RTX 3060 を積んだ Linux の環境でも実行してみる。

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:    20.04
Codename:   focal
$ pip list | grep -i torch
torch                    2.2.2
$ python -V
Python 3.10.14

先ほどのコードを実行する。

$ python bench.py 
device: cuda
100%|███████████████████████████████████████████| 16/16 [03:08<00:00, 11.81s/it]
[-------------- Batched dot --------------]
                      |  mul/sum  |   bmm  
1 threads: --------------------------------
      [1, 1]          |      6.5  |     6.8
      [1, 64]         |      6.5  |     6.8
      [1, 1024]       |      6.4  |     7.8
      [1, 10000]      |      6.3  |     7.8
      [64, 1]         |      6.3  |     6.6
      [64, 64]        |      6.4  |     6.8
      [64, 1024]      |      6.6  |     6.8
      [64, 10000]     |     25.1  |    10.2
      [1024, 1]       |      6.4  |     6.7
      [1024, 64]      |      6.3  |     6.7
      [1024, 1024]    |     40.2  |    15.6
      [1024, 10000]   |    375.3  |   179.1
      [10000, 1]      |      6.3  |    32.6
      [10000, 64]     |     29.2  |    34.9
      [10000, 1024]   |    374.7  |   123.5
      [10000, 10000]  |   3603.7  |  1672.6
4 threads: --------------------------------
      [1, 1]          |      6.5  |     6.9
      [1, 64]         |      6.5  |     6.9
      [1, 1024]       |      6.4  |     7.8
      [1, 10000]      |      6.4  |     7.8
      [64, 1]         |      6.4  |     6.6
      [64, 64]        |      6.5  |     6.8
      [64, 1024]      |      6.6  |     6.9
      [64, 10000]     |     25.1  |    10.2
      [1024, 1]       |      6.3  |     6.7
      [1024, 64]      |      6.4  |     6.7
      [1024, 1024]    |     40.2  |    15.6
      [1024, 10000]   |    375.3  |   179.1
      [10000, 1]      |      6.3  |    32.6
      [10000, 64]     |     29.2  |    34.9
      [10000, 1024]   |    374.9  |   123.5
      [10000, 10000]  |   3602.4  |  1672.5
16 threads: -------------------------------
      [1, 1]          |      6.5  |     6.9
      [1, 64]         |      6.5  |     6.7
      [1, 1024]       |      6.5  |     7.9
      [1, 10000]      |      6.3  |     7.8
      [64, 1]         |      6.3  |     6.6
      [64, 64]        |      6.4  |     6.8
      [64, 1024]      |      6.5  |     6.9
      [64, 10000]     |     25.1  |    10.2
      [1024, 1]       |      6.3  |     6.7
      [1024, 64]      |      6.4  |     6.7
      [1024, 1024]    |     40.3  |    15.6
      [1024, 10000]   |    375.3  |   179.1
      [10000, 1]      |      6.4  |    32.6
      [10000, 64]     |     29.2  |    34.9
      [10000, 1024]   |    374.9  |   123.5
      [10000, 10000]  |   3604.9  |  1672.4
32 threads: -------------------------------
      [1, 1]          |      6.6  |     6.9
      [1, 64]         |      6.4  |     6.8
      [1, 1024]       |      6.5  |     7.9
      [1, 10000]      |      6.4  |     7.8
      [64, 1]         |      6.3  |     6.7
      [64, 64]        |      6.6  |     6.8
      [64, 1024]      |      6.6  |     6.9
      [64, 10000]     |     25.1  |    10.3
      [1024, 1]       |      6.4  |     6.8
      [1024, 64]      |      6.3  |     6.8
      [1024, 1024]    |     40.2  |    15.6
      [1024, 10000]   |    375.1  |   179.2
      [10000, 1]      |      6.4  |    32.6
      [10000, 64]     |     29.2  |    34.9
      [10000, 1024]   |    374.9  |   123.5
      [10000, 10000]  |   3604.6  |  1672.4

Times are in microseconds (us).

こちらの環境の方が多くの場合に 2 ~ 10 倍程度速いことがわかる。 ただし、一部サイズの大きな bmm を使った演算に関しては、むしろ Apple Silicon の方が速いようだ。 また、消費電力は RTX 3060 の方が 20 倍近く大きい 2

まとめ

Apple Silicon の GPU は、そこまで速くないにしてもワットパフォーマンスには優れている。 また、CPU に比べればずっと速いので PyTorch で気軽に使えるのはありがたい。

参考

developer.apple.com

pytorch.org

pytorch.org



  1. https://pytorch.org/tutorials/recipes/recipes/benchmark.html
  2. Apple M2 Pro の GPU は実測で最大 10W 程度、RTX 3060 はカタログで 170W