Skip to content

torch.compile doesnt respect use_determistic_algorithms during the backward() #113707

@bdhirsh

Description

@bdhirsh

Minimal repro:

import torch

m = torch.nn.ReflectionPad2d((1, 2, 3, 4))
m_compiled = torch.compile(m)

inp = torch.randn(2, 3, 8, 8, device='cuda', requires_grad=True)
out = m_compiled(inp)

# set use_deterministic_algorithms to True
torch.use_deterministic_algorithms(True)

# the backward of ReflectionPad2d should error, because it doesn't have a deterministic implementation!
out.sum().backward()

The core problem comes from the fact that:

(1) torch.compile will eagerly trace out a backward graph when we compile the forward (torch.nn.ReflectionPad2d)

(2) This requires us to run our meta implementations for the backward graph, including for aten.replication_pad2d_backward, ahead-of-time, during the forward

(3) We don't know during trace-time of the forward whether or not torch.use_deterministic_algorithms will be set at runtime during the backward

cc @mruberry @kurtamohler @ezyang @msaroufim @wconstab @anijain2305 @zou3519

Metadata

Metadata

Assignees

Labels

module: aotdispatchumbrella label for AOTAutograd issuesmodule: determinismmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions