[aotautograd] Fix inplace checks in autograd backward functions during functionalization#177213
[aotautograd] Fix inplace checks in autograd backward functions during functionalization#177213azahed98 wants to merge 1 commit intopytorch:mainfrom
Conversation
…g functionalization
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177213
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 6522e39 with merge base 08b6f48 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…g functionalization (pytorch#177213) During `_functionalized_f_helper`, we call `before.copy_(after)`. If the compiled function is a custom autograd function with an inplace op during the backward, this can trigger [inplace correctness check](https://docs.pytorch.org/docs/stable/autograd.html#in-place-correctness-checks) unintentionally. This causes an exception ``` torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` This PR instead wraps the `before.copy_(after)` with a `torch.no_grad()` context to avoid the inplace check. Example script: ``` import torch class MutateBufferInBackward(torch.autograd.Function): @staticmethod def forward(ctx, x, buf): ctx.save_for_backward(buf) return x * buf.mean() @staticmethod def backward(ctx, grad_output): (buf,) = ctx.saved_tensors buf.mul_(2.0) return grad_output * buf.mean(), None @torch.compile def f(x, buf): return MutateBufferInBackward.apply(x, buf) def main(): device = "cuda" x = torch.randn(16, device=device, requires_grad=True) buf = torch.randn(16, device=device, requires_grad=True) out = f(x, buf) out.sum().backward() print("Success!") if __name__ == "__main__": main() ``` Pull Request resolved: pytorch#177213 Approved by: https://github.com/aorenste, https://github.com/frgossen
During
_functionalized_f_helper, we callbefore.copy_(after). If the compiled function is a custom autograd function with an inplace op during the backward, this can trigger inplace correctness check unintentionally. This causes an exceptionThis PR instead wraps the
before.copy_(after)with atorch.no_grad()context to avoid the inplace check.Example script: