Fix: Ensure internal ApplyTemplate uses modern autograd API for torch.func.grad + compile #169786
Fix: Ensure internal ApplyTemplate uses modern autograd API for torch.func.grad + compile #169786dumko2001 wants to merge 3 commits into
Conversation
…orch#169783) Refactors ApplyTemplate in autograd_function.py to implement setup_context, required for functorch/compile compatibility. Adds regression test TestCompileNestedAutograd.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169786
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 2d217c7 with merge base 143c71a ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "bug" "release notes: dynamo" |
|
Didn't find following labels among repository labels: bug |
|
@pytorchbot label "topic: not user facing" |
| self.skipTest("Skipping because it fails in strict cache mode") | ||
|
|
||
|
|
||
| class TestCompileNestedAutograd(TestCase): |
There was a problem hiding this comment.
Maybe put this in TestCompileTransforms in test/functorch/test_eager_transforms.py` instead?
There was a problem hiding this comment.
Maybe put this into https://github.com/pytorch/pytorch/blob/main/test/dynamo/test_autograd_function.py please and use backend=aot_eager
There was a problem hiding this comment.
@zou3519 Done! I've moved the test to test/dynamo/test_autograd_function.py and updated the torch.compile call to use backend="aot_eager" as requested.
soulitzer
left a comment
There was a problem hiding this comment.
Thanks, had a small comment on test location
Any idea why this only happens when the function is recursively called?
|
|
@soulitzer Regarding the recursion: The recursion causes the internal ApplyTemplate (which wraps the user's autograd.Function) to be traced by func.grad in a nested context. The issue was that the internal implementation used the legacy forward(ctx, ...) signature. AOTAutograd enforces the modern autograd.Function API (static forward without ctx + setup_context) to correctly handle tensor saving and context management during these complex graph captures. The legacy signature prevented setup_context from being properly invoked, leading to the crash. |
493dfaa to
935a455
Compare
|
still waiting for tests to pass |
|
I have analyzed the Root Cause Analysis
Conclusion Since this overhead is intrinsic to the fix and the logic is verified, could you please update the benchmarks or merge with this known baseline shift? |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Addresses #169783.
This fixes a
RuntimeErrorwhen usingtorch.compile(torch.func.grad(...))on a function with nested calls involving a customtorch.autograd.Function(e.g.,f(f(x))).Root Cause:
The crash was caused by an internal helper class,
ApplyTemplate, defined dynamically intorch/_functorch/autograd_function.py. This class was using the legacyautograd.FunctionAPI (def forward(ctx, *args):) and lacked the requiredsetup_contextmethod for AOTAutograd/functorch compatibility.The Fix:
Refactored
ApplyTemplateto adhere to the modernautograd.FunctionAPI:forwardsignature todef forward(*args):.def setup_context(ctx, inputs, output):and moved context-setting logic (likectx.mark_non_differentiable) into it.Testing:
A new regression test,
test_compile_grad_nested_autograd_function, has been added totest/functorch/test_aotdispatch.pyto ensure the issue does not regress and the compiled result is correct against the eager output.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @mlazos