Skip to content

[aotautograd] Fix inplace checks in autograd backward functions during functionalization#177213

Closed
azahed98 wants to merge 1 commit intopytorch:mainfrom
azahed98:fix/aotautograd_inplace
Closed

[aotautograd] Fix inplace checks in autograd backward functions during functionalization#177213
azahed98 wants to merge 1 commit intopytorch:mainfrom
azahed98:fix/aotautograd_inplace

Conversation

@azahed98
Copy link
Copy Markdown
Contributor

During _functionalized_f_helper, we call before.copy_(after). If the compiled function is a custom autograd function with an inplace op during the backward, this can trigger inplace correctness check unintentionally. This causes an exception

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

This PR instead wraps the before.copy_(after) with a torch.no_grad() context to avoid the inplace check.

Example script:

import torch


class MutateBufferInBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, buf):
        ctx.save_for_backward(buf)
        return x * buf.mean()

    @staticmethod
    def backward(ctx, grad_output):
        (buf,) = ctx.saved_tensors
        buf.mul_(2.0)
        return grad_output * buf.mean(), None


@torch.compile
def f(x, buf):
    return MutateBufferInBackward.apply(x, buf)


def main():
    device = "cuda"
    x = torch.randn(16, device=device, requires_grad=True)
    buf = torch.randn(16, device=device, requires_grad=True)

    out = f(x, buf)
    out.sum().backward()
    print("Success!")


if __name__ == "__main__":
    main()
    ```
    
    

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177213

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 6522e39 with merge base 08b6f48 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Mar 12, 2026

@azahed98 has imported this pull request. If you are a Meta employee, you can view this in D96254286.

@azahed98
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 12, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@azahed98
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…g functionalization (pytorch#177213)

During `_functionalized_f_helper`, we call `before.copy_(after)`. If the compiled function is a custom autograd function with an inplace op during the backward, this can trigger [inplace correctness check](https://docs.pytorch.org/docs/stable/autograd.html#in-place-correctness-checks) unintentionally. This causes an exception
```
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
```

This PR instead wraps the `before.copy_(after)` with a `torch.no_grad()` context to avoid the inplace check.

Example script:
```
import torch

class MutateBufferInBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, buf):
        ctx.save_for_backward(buf)
        return x * buf.mean()

    @staticmethod
    def backward(ctx, grad_output):
        (buf,) = ctx.saved_tensors
        buf.mul_(2.0)
        return grad_output * buf.mean(), None

@torch.compile
def f(x, buf):
    return MutateBufferInBackward.apply(x, buf)

def main():
    device = "cuda"
    x = torch.randn(16, device=device, requires_grad=True)
    buf = torch.randn(16, device=device, requires_grad=True)

    out = f(x, buf)
    out.sum().backward()
    print("Success!")

if __name__ == "__main__":
    main()
    ```

Pull Request resolved: pytorch#177213
Approved by: https://github.com/aorenste, https://github.com/frgossen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants