Skip to content

Fix: Ensure internal ApplyTemplate uses modern autograd API for torch.func.grad + compile #169786

Closed
dumko2001 wants to merge 3 commits into
pytorch:mainfrom
dumko2001:fix/issue-169783
Closed

Fix: Ensure internal ApplyTemplate uses modern autograd API for torch.func.grad + compile #169786
dumko2001 wants to merge 3 commits into
pytorch:mainfrom
dumko2001:fix/issue-169783

Conversation

@dumko2001

@dumko2001 dumko2001 commented Dec 7, 2025

Copy link
Copy Markdown
Contributor

Addresses #169783.

This fixes a RuntimeError when using torch.compile(torch.func.grad(...)) on a function with nested calls involving a custom torch.autograd.Function (e.g., f(f(x))).

Root Cause:
The crash was caused by an internal helper class, ApplyTemplate, defined dynamically in torch/_functorch/autograd_function.py. This class was using the legacy autograd.Function API (def forward(ctx, *args):) and lacked the required setup_context method for AOTAutograd/functorch compatibility.

The Fix:
Refactored ApplyTemplate to adhere to the modern autograd.Function API:

  1. Changed forward signature to def forward(*args):.
  2. Introduced def setup_context(ctx, inputs, output): and moved context-setting logic (like ctx.mark_non_differentiable) into it.

Testing:
A new regression test, test_compile_grad_nested_autograd_function, has been added to test/functorch/test_aotdispatch.py to ensure the issue does not regress and the compiled result is correct against the eager output.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @mlazos

…orch#169783)

Refactors ApplyTemplate in autograd_function.py to implement setup_context, required for functorch/compile compatibility. Adds regression test TestCompileNestedAutograd.
@pytorch-bot

pytorch-bot Bot commented Dec 7, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169786

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 2d217c7 with merge base 143c71a (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@dumko2001

Copy link
Copy Markdown
Contributor Author

@pytorchbot label "bug" "release notes: dynamo"

@pytorch-bot

pytorch-bot Bot commented Dec 7, 2025

Copy link
Copy Markdown

Didn't find following labels among repository labels: bug

@dumko2001

Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

Comment thread test/functorch/test_aotdispatch.py Outdated
self.skipTest("Skipping because it fails in strict cache mode")


class TestCompileNestedAutograd(TestCase):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put this in TestCompileTransforms in test/functorch/test_eager_transforms.py` instead?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put this into https://github.com/pytorch/pytorch/blob/main/test/dynamo/test_autograd_function.py please and use backend=aot_eager

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Done! I've moved the test to test/dynamo/test_autograd_function.py and updated the torch.compile call to use backend="aot_eager" as requested.

@soulitzer soulitzer left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, had a small comment on test location
Any idea why this only happens when the function is recursively called?

@soulitzer soulitzer requested a review from zou3519 December 12, 2025 22:07
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 12, 2025
@linux-foundation-easycla

linux-foundation-easycla Bot commented Dec 18, 2025

Copy link
Copy Markdown

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: dumko2001 / name: dumko2001 (2d217c7)

@dumko2001

Copy link
Copy Markdown
Contributor Author

Thanks, had a small comment on test location Any idea why this only happens when the function is recursively called?

@soulitzer Regarding the recursion: The recursion causes the internal ApplyTemplate (which wraps the user's autograd.Function) to be traced by func.grad in a nested context. The issue was that the internal implementation used the legacy forward(ctx, ...) signature. AOTAutograd enforces the modern autograd.Function API (static forward without ctx + setup_context) to correctly handle tensor saving and context management during these complex graph captures. The legacy signature prevented setup_context from being properly invoked, leading to the crash.

@zou3519 zou3519 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm if tests pass

@dumko2001 dumko2001 requested a review from zou3519 December 19, 2025 01:24
@zou3519

zou3519 commented Dec 20, 2025

Copy link
Copy Markdown
Contributor

still waiting for tests to pass

@dumko2001

Copy link
Copy Markdown
Contributor Author

@zou3519 @soulitzer

I have analyzed the pr_time_benchmarks failures. The logic tests (including the new regression test) are passing, but compile_time_instruction_count has regressed on basic_modules_ListOfLinears_inductor.

Root Cause Analysis
This regression is a deterministic side-effect of the fix, not an optimization failure.

  • Legacy Implementation: ApplyTemplate used the legacy forward(ctx, *args) signature. This resulted in a single Python function call during compilation.
  • New Implementation: To fix the RuntimeError and satisfy AOTAutograd/functorch requirements, I refactored ApplyTemplate to the modern torch.autograd.Function API.
    • This splits execution into two distinct calls: forward(*args) and setup_context(ctx, inputs, output).
    • The Autograd engine now incurs additional overhead to dispatch setup_context, unpack inputs, and manage the context object explicitly.

Conclusion
The instruction count increase corresponds to the mandatory overhead of the modern Autograd API (specifically the setup_context mechanism). This structural change is the only way to support nested torch.func.grad transformations correctly.

Since this overhead is intrinsic to the fix and the logic is verified, could you please update the benchmarks or merge with this known baseline shift?

@github-actions

Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label Feb 19, 2026
@github-actions github-actions Bot closed this Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: dynamo open source release notes: dynamo Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants