Minimal repro:
import torch
m = torch.nn.ReflectionPad2d((1, 2, 3, 4))
m_compiled = torch.compile(m)
inp = torch.randn(2, 3, 8, 8, device='cuda', requires_grad=True)
out = m_compiled(inp)
# set use_deterministic_algorithms to True
torch.use_deterministic_algorithms(True)
# the backward of ReflectionPad2d should error, because it doesn't have a deterministic implementation!
out.sum().backward()
The core problem comes from the fact that:
(1) torch.compile will eagerly trace out a backward graph when we compile the forward (torch.nn.ReflectionPad2d)
(2) This requires us to run our meta implementations for the backward graph, including for aten.replication_pad2d_backward, ahead-of-time, during the forward
(3) We don't know during trace-time of the forward whether or not torch.use_deterministic_algorithms will be set at runtime during the backward
cc @mruberry @kurtamohler @ezyang @msaroufim @wconstab @anijain2305 @zou3519
Minimal repro:
The core problem comes from the fact that:
(1) torch.compile will eagerly trace out a backward graph when we compile the forward (
torch.nn.ReflectionPad2d)(2) This requires us to run our meta implementations for the backward graph, including for
aten.replication_pad2d_backward, ahead-of-time, during the forward(3) We don't know during trace-time of the forward whether or not
torch.use_deterministic_algorithmswill be set at runtime during the backwardcc @mruberry @kurtamohler @ezyang @msaroufim @wconstab @anijain2305 @zou3519