PyTorch を使っていると、はるか遠く離れた場所で計算した結果に nan や inf が含まれることで、思いもよらない場所から非直感的なエラーを生じることがある。
あるいは、自動微分したときにゼロ除算が生じるようなパターンでは、順伝搬の結果だけ見ていても原因にたどり着くことが難しい。
こういった問題は、デバッガなどを使って地道に原因を探ろうとすると多くの手間と時間がかかる。
そんな折、PyTorch にはそうした問題に対処する上で有益な機能があることを知った。
具体的には、以下の関数を使うと自動でバックプロパゲーションが上手くいかない箇所を見つけることができる。
今回は、この機能について書いてみる。
torch.autograd.set_detect_anomaly()
torch.autograd.detect_anomaly()
使った環境は次のとおり。
$ sw_vers
ProductName: macOS
ProductVersion: 12.6.2
BuildVersion: 21G320
$ python -V
Python 3.10.9
$ pip list | grep -i torch
torch 1.13.1
もくじ
下準備
あらかじめ PyTorch と NumPy をインストールしておく。
$ pip install torch numpy
入力によって逆伝搬が上手くいかないコード (1)
例として RMSLE (Root Mean Squared Logarithmic Error) を計算する場合について考える。
以下のサンプルコードでは、RMSLE の計算を RMSLELoss
というクラスで実装している。
このコードは入力によってはバックプロパゲーションが上手くいかない。
具体的には、入力されるモデルの予測 (y_pred
) に -1
以下の値が含まれたとき torch.log1p()
の返り値に inf
が含まれる。
import torch
from torch import nn
class RMSLELoss(nn.Module):
"""Root Mean Squared Logarithmic Error"""
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, y_pred, y_true):
log_y_pred = torch.log1p(y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle)
return rmsle_loss
def main():
y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
out.backward()
print(y_pred.grad)
print(y_true.grad)
if __name__ == '__main__':
main()
上記を実行してみよう。
入力に -1
以下の値が入ると、最終的な結果が inf
になっている。
そして y_pred
と y_true
の勾配に nan
が確認できる。
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([nan, -0., -0.], dtype=torch.float64)
tensor([nan, 0., 0.], dtype=torch.float64)
上手くいかない箇所を自動で見つける
では、今回の主題となる torch.autograd.set_detect_anomaly()
を使ってみよう。
この関数には、第一引数に真偽値のフラグを渡して機能の有効・無効を切り替える。
もちろんデフォルトでは機能は無効となっており、デバッグをするときだけ有効にすることが推奨されている。
これは、機能を有効にするとバックプロパゲーションにおいて値のチェックが逐一入ることによるオーバーヘッドが生じるため。
import torch
from torch import nn
class RMSLELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, y_pred, y_true):
log_y_pred = torch.log1p(y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle)
return rmsle_loss
def main():
torch.autograd.set_detect_anomaly(True)
y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
out.backward()
print(y_pred.grad)
print(y_true.grad)
if __name__ == '__main__':
main()
上記を実行してみよう。
すると、MseLossBackward0
の結果において値に nan が含まれることが示されている。
$ python anodet.py
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error:
File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module>
main()
File "/Users/amedama/Documents/temporary/anodet.py", line 27, in main
out = loss_fn(y_pred, y_true)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/amedama/Documents/temporary/anodet.py", line 14, in forward
msle = self.mse_loss(log_y_pred, log_y_true)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward
return F.mse_loss(input, target, reduction=self.reduction)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss
return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack
return traceback.format_stack()
(Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module>
main()
File "/Users/amedama/Documents/temporary/anodet.py", line 31, in main
out.backward()
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
このように、バックプロパゲーションが上手くいかない場所を自動で検知できた。
問題を修正する (1)
では、問題を修正するため試しに y_pred
の値の下限が 0
となるように torch.clamp()
の処理を挟んでみよう。
import torch
from torch import nn
class RMSLELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, y_pred, y_true):
clamped_y_pred = torch.clamp(y_pred, min=0.)
log_y_pred = torch.log1p(clamped_y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle)
return rmsle_loss
def main():
torch.autograd.set_detect_anomaly(True)
y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
out.backward()
if __name__ == '__main__':
main()
実行すると、今度は例外にならずに済んでいる。
y_pred
と y_true
の勾配にも nan
は登場しない。
$ python anodet.py
tensor(1.1501, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([0.0000, -0.4018, -0.1328], dtype=torch.float64)
tensor([0.1061, 0.1004, 0.0531], dtype=torch.float64)
入力によって逆伝搬が上手くいかないコード (2)
さて、これで万事解決かと思いきや、実はまだ問題が残っている。
損失がゼロになるときを考えると torch.sqrt()
のバックプロパゲーションにおいてゼロ除算が生じるため。
これは順伝搬では値に nan や inf が登場しないことから問題に気づきにくそう。
import torch
from torch import nn
class RMSLELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, y_pred, y_true):
clamped_y_pred = torch.clamp(y_pred, min=0.)
log_y_pred = torch.log1p(clamped_y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle)
return rmsle_loss
def main():
torch.autograd.set_detect_anomaly(True)
y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
out.backward()
print(y_pred.grad)
print(y_true.grad)
if __name__ == '__main__':
main()
実行すると、今度も MseLossBackward0
において結果に nan が含まれると指摘されている。
$ python anodet.py
tensor(0., dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error:
File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module>
main()
File "/Users/amedama/Documents/temporary/anodet.py", line 28, in main
out = loss_fn(y_pred, y_true)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/amedama/Documents/temporary/anodet.py", line 15, in forward
msle = self.mse_loss(log_y_pred, log_y_true)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward
return F.mse_loss(input, target, reduction=self.reduction)
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss
return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack
return traceback.format_stack()
(Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module>
main()
File "/Users/amedama/Documents/temporary/anodet.py", line 32, in main
out.backward()
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
問題を修正する (2)
では、先ほどの問題を修正するために torch.sqrt()
の計算に小さな値を足してみよう。
import torch
from torch import nn
class RMSLELoss(nn.Module):
def __init__(self, epsilon=1e-5):
super().__init__()
self.mse_loss = nn.MSELoss()
self.epsilon = epsilon
def forward(self, y_pred, y_true):
clamped_y_pred = torch.clamp(y_pred, min=0.)
log_y_pred = torch.log1p(clamped_y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle + self.epsilon)
return rmsle_loss
def main():
torch.autograd.set_detect_anomaly(True)
y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
out.backward()
print(y_pred.grad)
print(y_true.grad)
if __name__ == '__main__':
main()
実行すると、今度は例外にならない。
$ python anodet.py
tensor(0.0032, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0.], dtype=torch.float64)
特定のスコープでチェックする
ちなみに、特定のスコープでだけ backward()
の結果をチェックしたいときは torch.autograd.detect_anomaly()
が使える。
これはコンテキストマネージャになっているため、チェックしたい部分にだけ入れて使うことができる。
import torch
from torch import nn
class RMSLELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, y_pred, y_true):
log_y_pred = torch.log1p(y_pred)
log_y_true = torch.log1p(y_true)
msle = self.mse_loss(log_y_pred, log_y_true)
rmsle_loss = torch.sqrt(msle)
return rmsle_loss
def main():
y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)
loss_fn = RMSLELoss()
out = loss_fn(y_pred, y_true)
print(out)
with torch.autograd.detect_anomaly():
out.backward()
if __name__ == '__main__':
main()
とはいえ、そんなに出番は無さそうかな。
また、トレースバックに含まれる情報も torch.autograd.set_detect_anomaly()
より少なくなっているようだ。
$ python anodet.py
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/Documents/temporary/anodet.py:29: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. No forward pass information available. Enable detect anomaly during forward pass for more information. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:97.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "/Users/amedama/Documents/temporary/anodet.py", line 34, in <module>
main()
File "/Users/amedama/Documents/temporary/anodet.py", line 30, in main
out.backward()
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
まとめ
今回は PyTorch でバックプロパゲーションが上手くいかない場所を自動で見つけることのできる機能を試してみた。
参考
pytorch.org
pytorch.org