Skip to content

Commit 6522e39

Browse files
committed
[aotautograd] Fix inplace checks in autograd backward functions during functionalization
1 parent 08b6f48 commit 6522e39

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

test/functorch/test_aotdispatch.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9118,6 +9118,45 @@ def _inps():
91189118
self.assertEqual(ref_inps_after_fw, inps_after_fw)
91199119
self.assertEqual(ref_inps_after_bw, inps_after_bw)
91209120

9121+
def test_mutations_in_bw_requires_grad_input(self):
9122+
# Inplace mutation of a requires_grad=True forward input during the
9123+
# backward pass should not trigger the check_inplace error
9124+
class AF(torch.autograd.Function):
9125+
@staticmethod
9126+
def forward(ctx, dummy, inplace_tensor):
9127+
ctx.save_for_backward(inplace_tensor)
9128+
return dummy.clone()
9129+
9130+
@staticmethod
9131+
def backward(ctx, grad_output):
9132+
(inplace_tensor,) = ctx.saved_tensors
9133+
inplace_tensor.mul_(2.0)
9134+
return grad_output, None
9135+
9136+
def fn(dummy, inplace_tensor):
9137+
return AF.apply(dummy, inplace_tensor)
9138+
9139+
def _inps():
9140+
dummy = torch.zeros((2,), requires_grad=True)
9141+
# requires_grad=True is what triggered the bug
9142+
inplace_tensor = torch.ones((2,), requires_grad=True)
9143+
return dummy, inplace_tensor
9144+
9145+
inps = _inps()
9146+
out = fn(*inps)
9147+
ref_inps_after_fw = [x.clone().detach() for x in inps]
9148+
out.sum().backward()
9149+
ref_inps_after_bw = [x.clone().detach() for x in inps]
9150+
9151+
inps = _inps()
9152+
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inps)
9153+
inps_after_fw = [x.clone().detach() for x in inps]
9154+
out.sum().backward()
9155+
inps_after_bw = [x.clone().detach() for x in inps]
9156+
9157+
self.assertEqual(ref_inps_after_fw, inps_after_fw)
9158+
self.assertEqual(ref_inps_after_bw, inps_after_bw)
9159+
91219160
def test_mutation_of_input_in_fw_and_bw(self):
91229161
class AF(torch.autograd.Function):
91239162
@staticmethod

torch/_functorch/_aot_autograd/graph_capture_wrappers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,11 @@ def _post_forward(primals: Any) -> None:
969969
raise AssertionError(
970970
f"expected both before and after to be Tensors, got {type(before)} and {type(after)}"
971971
)
972-
before.copy_(after)
972+
# no_grad prevents the FakeTensor's requires_grad from
973+
# triggering check_inplace during tracing. The
974+
# requires_grad case is checked at runtime instead
975+
with torch.no_grad():
976+
before.copy_(after)
973977
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
974978
idx
975979
)

0 commit comments

Comments
 (0)