CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PyTorch で RMSProp を実装してみる

今回は、以下の記事の続きとして PyTorch で RMSProp のオプティマイザを実装してみる。

blog.amedama.jp

上記では PyTorch で Adagrad のオプティマイザを実装した。 Adagrad は学習率の調整に過去の勾配の平方和の累積を使っている。 このやり方には、イテレーションが進むと徐々に学習が進みにくくなってしまう問題がある。 そこで、RMSProp では学習率の調整に過去の勾配の平方和の指数移動平均を使っている。 これによって、徐々に学習が進みにくくなる問題を解決した。

使った環境は次のとおり。

$ sw_vers     
ProductName:        macOS
ProductVersion:     15.1
BuildVersion:       24B83
$ python -V                               
Python 3.12.7
$ pip list | egrep -i "(torch|matplotlib)"
matplotlib        3.9.2
torch             2.5.0

もくじ

下準備

あらかじめ PyTorch と Matplotlib をインストールしておく。

$ pip install torch matplotlib 

PyTorch 組み込みの RMSProp を試す

まずは PyTorch に組み込みで用意されている RMSProp の振る舞いを確認しておこう。

以下にサンプルコードを示す。 扱う問題は先に示した記事と同じもの。 この問題設定や初期値などは「ゼロから作るDeep Learning 1」に記載されている内容を参考にしている。 このサンプルコードでは、関数の出力をゼロに近づけるようにパラメータを更新する。

import torch
from matplotlib import pyplot as plt
from torch import nn
from torch import optim


class ExampleFunction(nn.Module):

    def __init__(self, a, x, b, y):
        super(ExampleFunction, self).__init__()
        self.a = a
        self.x = nn.Parameter(torch.tensor([x]))
        self.b = b
        self.y = nn.Parameter(torch.tensor([y]))

    def forward(self):
        return self.a * self.x**2 + self.b * self.y**2


def main():
    model = ExampleFunction(a=1 / 20, x=-7.0, b=1.0, y=2.0)

    optimizer = optim.RMSprop(model.parameters(), lr=0.1)

    trajectory_x = [model.x.detach().numpy()[0]]
    trajectory_y = [model.y.detach().numpy()[0]]

    num_epochs = 30
    for epoch in range(1, num_epochs + 1):
        optimizer.zero_grad()

        outputs = model()

        outputs.backward()

        optimizer.step()

        x = model.x.detach().numpy()[0]
        trajectory_x.append(x)
        y = model.y.detach().numpy()[0]
        trajectory_y.append(y)

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.plot(trajectory_x, trajectory_y, marker="o", markersize=5, label="Trajectory")
    ax.legend()
    ax.grid(True)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.show()


if __name__ == "__main__":
    main()

上記に名前をつけて実行する。

$ python torchrmsprop.py

すると、次のようなグラフが得られる。 グラフは、パラメータが更新されていく様子を表している。

RMSProp で最適化したパラメータの軌跡

RMSProp のアルゴリズムを実装する

次に RMSProp のオプティマイザを自作する。 サンプルコードを以下に示す。 サンプルコードでは CustomRMSProp という名前でオプティマイザを実装した。

from collections.abc import Iterable
from typing import Any

import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Optimizer


class ExampleFunction(nn.Module):

    def __init__(self, a, x, b, y):
        super(ExampleFunction, self).__init__()
        self.a = a
        self.x = nn.Parameter(torch.tensor([x]))
        self.b = b
        self.y = nn.Parameter(torch.tensor([y]))

    def forward(self):
        return self.a * self.x**2 + self.b * self.y**2


class CustomRMSProp(Optimizer):
    """自作した RMSProp のオプティマイザ"""

    def __init__(
        self,
        params: Iterable,
        lr: float = 1e-3,
        alpha=0.99,
        eps=1e-10,
        momentum=0.0,
    ):
        defaults: dict[str, Any] = dict(
            lr=lr,
            eps=eps,
            alpha=alpha,
            momentum=momentum,
        )
        super(CustomRMSProp, self).__init__(params, defaults)

    def step(self, closure=None):
        """RMSProp の更新式を実装した step() メソッド

        (更新式)
        v_0 = 0  ※ 論文では 1 で初期化している
        m_0 = 0
        v_{t+1} = rho * v_t + (1 - rho) * grad(L(theta_t))^2
        m_{t+1} = gamma * m_t + eta / (sqrt(v_{t+1}) + eps) * grad(L(theta_t))
        theta_{t+1} = theta_t - m_{t+1}

        theta: パラメータ (重み)
        eta: 学習率
        grad(L(theta)): 損失関数の勾配
        v: 過去の勾配の2乗の指数移動平均
        m: 過去の勾配を加味したモーメント
        rho: v を計算するときの指数移動平均の係数
        gamma: m を計算するときの指数移動平均の係数
        eps: ゼロ除算を防ぐための小さな値
        """
        for group in self.param_groups:
            for param in group["params"]:
                if param.grad is None:
                    continue
                # v_0 = 0 に対応する
                if "v" not in self.state[param]:
                    self.state[param]["v"] = torch.zeros_like(param.data)
                # m_0 = 0 に対応する
                if "m" not in self.state[param]:
                    self.state[param]["m"] = torch.zeros_like(param.data)
                # v_{t+1} = rho * v_t + (1 - rho) * grad(L(theta_t))^2 に対応する
                self.state[param]["v"] = group["alpha"] * self.state[param]["v"] + (1 - group["alpha"]) * param.grad ** 2
                # m_{t+1} = gamma * m_t + eta / (sqrt(v_{t+1}) + eps) * grad(L(theta_t)) に対応する
                self.state[param]["m"] = group["momentum"] * self.state[param]["m"] + group["lr"] / (torch.sqrt(self.state[param]["v"]) + group["eps"]) * param.grad
                # theta_{t+1} = theta_t - m_{t+1} に対応する
                param.data -= self.state[param]["m"]


def main():
    model = ExampleFunction(a=1 / 20, x=-7.0, b=1.0, y=2.0)

    optimizer = CustomRMSProp(model.parameters(), lr=0.1)

    trajectory_x = [model.x.detach().numpy()[0]]
    trajectory_y = [model.y.detach().numpy()[0]]

    num_epochs = 30
    for epoch in range(1, num_epochs + 1):
        optimizer.zero_grad()

        outputs = model()

        outputs.backward()

        optimizer.step()

        x = model.x.detach().numpy()[0]
        trajectory_x.append(x)
        y = model.y.detach().numpy()[0]
        trajectory_y.append(y)

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.plot(trajectory_x, trajectory_y, marker="o", markersize=5, label="Trajectory")
    ax.legend()
    ax.grid(True)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.show()


if __name__ == "__main__":
    main()

上記に適当な名前をつけて実行する。

$ python customrmsprop.py

すると、次のようなグラフが得られる。

自作した RMSProp で最適化したパラメータの軌跡

先ほど確認した PyTorch 組み込みの RMSProp の結果と一致していることが分かる。

更新式とコードの対応関係について

ここからは更新式とコードについて見ていく。 RMSProp の更新式は以下のようになっている。

 \displaystyle
v_0 = 0 \\
m_0 = 0 \\
v_{t+1} = \rho v_t + (1 - \rho) \nabla_{\theta} L(\theta_t)^2 \\
m_{t+1} = \gamma m_t + \eta \frac{1}{\sqrt{v_{t+1} + \epsilon}} \nabla_{\theta} L(\theta_t) \\
\theta_{t+1} = \theta_t - m_{t+1}

数式と、プログラムの変数の対応関係は次のとおり。

  •  \theta
    • param.data
  •  \eta
    • group["lr"]
  •  \nabla_{\theta} L(\theta)
    • param.grad
  •  \rho
    • group["alpha"]
  •  \gamma
    • group["momentum"]
  •  v
    • self.state[param]["v"]
  •  m
    • self.state[param]["m"]
  •  \epsilon
    • group["eps"]

式から、 v には過去の勾配の平方和の指数移動平均が入ることが分かる。 過去の勾配がどれくらい値に影響するかは  \rho の係数を使って制御する。 この係数が大きいほど過去の値が重視され、小さいほど直近の値が重視される。 また、 m の計算式から分かるようにパラメータの更新量を決めるためにモーメントが用いられている。 ここに関しても  \gamma を係数にした指数移動平均になっている。

コードとの対応関係を見ていこう。 まずは、以下の更新式に対応するコードから。

 \displaystyle
v_0 = 0 \\
m_0 = 0

ここでは  v m がまだない状態、つまり初期状態のとき変数をゼロで初期化している。

                # v_0 = 0 に対応する
                if "v" not in self.state[param]:
                    self.state[param]["v"] = torch.zeros_like(param.data)
                # m_0 = 0 に対応する
                if "m" not in self.state[param]:
                    self.state[param]["m"] = torch.zeros_like(param.data)

続いては以下の更新式に対応するコード。

 \displaystyle
v_{t+1} = \rho v_t + (1 - \rho) \nabla_{\theta} L(\theta_t)^2

ここでは勾配の平方和について指数移動平均を求めている。

                # v_{t+1} = rho * v_t + (1 - rho) * grad(L(theta_t))^2 に対応する
                self.state[param]["v"] = group["alpha"] * self.state[param]["v"] + (1 - group["alpha"]) * param.grad ** 2

続いて以下の更新式に対応するコード。

 \displaystyle
m_{t+1} = \gamma m_t + \eta \frac{1}{\sqrt{v_{t+1} + \epsilon}} \nabla_{\theta} L(\theta_t)

ここではパラメータの更新量について指数移動平均を求めている。

                # m_{t+1} = gamma * m_t + eta / (sqrt(v_{t+1}) + eps) * grad(L(theta_t)) に対応する
                self.state[param]["m"] = group["momentum"] * self.state[param]["m"] + group["lr"] / (torch.sqrt(self.state[param]["v"]) + group["eps"]) * param.grad

そして、最後の以下の更新式に対応するコード。

 \displaystyle
\theta_{t+1} = \theta_t - m_{t+1}

ここでは実際にパラメータを更新している。

                # theta_{t+1} = theta_t - m_{t+1} に対応する
                param.data -= self.state[param]["m"]

いじょう。

参考

arxiv.org