Skip to content

Support calling torch.compile inside non-strict export#164171

Closed
tugsbayasgalan wants to merge 6 commits intogh/tugsbayasgalan/42/basefrom
gh/tugsbayasgalan/42/head
Closed

Support calling torch.compile inside non-strict export#164171
tugsbayasgalan wants to merge 6 commits intogh/tugsbayasgalan/42/basefrom
gh/tugsbayasgalan/42/head

Conversation

@tugsbayasgalan
Copy link
Contributor

@tugsbayasgalan tugsbayasgalan commented Sep 29, 2025

Stack from ghstack (oldest at bottom):

So this fixes at least two issues:

  1. When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
  2. There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally.

So the things i did for fixing above were:

  1. Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared.
  2. The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked.
  3. torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode.

With above fixes, we are able to export flex attention in export.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164171

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 85323e8 with merge base bac0f28 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

tugsbayasgalan added a commit that referenced this pull request Sep 29, 2025
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. 

With above fixes, we are able to export flex attention in export. 



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Sep 30, 2025
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. 

With above fixes, we are able to export flex attention in export. 



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Sep 30, 2025
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 30, 2025

# Create wrapper that always uses eager backend during export
def export_wrapped_fn(*args, **kwargs):
with setup_compilation_env(remove_pre_dispatch_tf_mode=False) as backend:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydwu4 I need this on to capture vmap ops at pre-dispatch level inside torch.compile region. I feel we also want this for cond as well? But didn't make the change here to avoid behavior difference. Let me know what you think.

Copy link
Contributor

@ydwu4 ydwu4 Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the correct way of doing it is following how we handle _temp_remove_metadata_torch_function_mode, what it does is that it first pop the mode before dynamo tracing and when dynamo finished tracing and starts to execute the customized backend (i.e. the backend that make_eager_backend_with_torch_function_mode produces), it restores the mode such that the non-strict sees the torch function mode again.

If we do this approach, my mental model would be that 1. vmap will be preserved in dynamo traced graph 2. when dynamo finished compilation, non-strict start to trace the graph, it will see the vmap operations and since we restore the pre_dispath_torch_function mode, we'll be able to trace the vmap operations in non-strict export graph.

So idealy, we 1. remove the remove_pre_dispatch_tf_mode flag, 2. could merge the _temp_remove_metadata_torch_function_mode and _temp_remove_pre_dispatch_tf_mode and create a unified backend that can recovers all modes that have been popped out before dynamo tracing.

The reason why we need to remove pre_dispatch torch function mode for cond is because there are side effects created during tracing the ops (E.g. enter_autocast_nodes was mutated).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I did have to change how the backend is implemented tho. Basically we need to return another callback to actually persist the modes when the gm is executed. Previously it was never actually running the modes.


# Create wrapper that always uses eager backend during export
def export_wrapped_fn(*args, **kwargs):
with setup_compilation_env(remove_pre_dispatch_tf_mode=False) as backend:
Copy link
Contributor

@ydwu4 ydwu4 Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the correct way of doing it is following how we handle _temp_remove_metadata_torch_function_mode, what it does is that it first pop the mode before dynamo tracing and when dynamo finished tracing and starts to execute the customized backend (i.e. the backend that make_eager_backend_with_torch_function_mode produces), it restores the mode such that the non-strict sees the torch function mode again.

If we do this approach, my mental model would be that 1. vmap will be preserved in dynamo traced graph 2. when dynamo finished compilation, non-strict start to trace the graph, it will see the vmap operations and since we restore the pre_dispath_torch_function mode, we'll be able to trace the vmap operations in non-strict export graph.

So idealy, we 1. remove the remove_pre_dispatch_tf_mode flag, 2. could merge the _temp_remove_metadata_torch_function_mode and _temp_remove_pre_dispatch_tf_mode and create a unified backend that can recovers all modes that have been popped out before dynamo tracing.

The reason why we need to remove pre_dispatch torch function mode for cond is because there are side effects created during tracing the ops (E.g. enter_autocast_nodes was mutated).

So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. 

With above fixes, we are able to export flex attention in export.

Differential Revision: [D83569143](https://our.internmc.facebook.com/intern/diff/D83569143)

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Oct 1, 2025
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@tugsbayasgalan tugsbayasgalan requested a review from ydwu4 October 1, 2025 02:05
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. 

With above fixes, we are able to export flex attention in export.

Differential Revision: [D83569143](https://our.internmc.facebook.com/intern/diff/D83569143)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Oct 1, 2025
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

):
cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None

_set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually pretty tricky. The reason we saw this op in the first place is because the PreDispatchTorchFunction mode is still active while running through dynamo code. As a result, we end up proxy-ing dynamo global state restoration logic. In this new world, we disable tf modes when running through dynamo, so we don't see this anymore. This is fine because export also has its' own global state restoration logic and it just seems wrong to have these in the graph. cc: @ydwu4

Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@ezyang
Copy link
Contributor

ezyang commented Oct 1, 2025

OK... so the PR description says what bugs you are fixing... but what exactly are you doing in the PR?


with (
_set_compilation_env(),
torch._dynamo.utils.disable_cache_limit(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just codemod change.

Copy link
Contributor

@ydwu4 ydwu4 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ed, is the question that why we're doing a bunch of environment patching here?

The temp remove function mode change is to unblock lazos from the "dynamo inlining torch function mode" work, where hops saw state mutations inside the inlined torch function mode.

What we did is to 1. pop out the mode before dynamo tracing so dynamo captures a graph without torch function mode then 2. create a "patched eager" backend that restores the poped out function modes and execute the dynamo captured graph. In this case, export/aot can still trigger the torch function modes when dispatching operators in the "patched eager" backend. Do you see any problems with this workaround? Any suggestions on how we can improve the situation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's a preexisting issue, I don't have any smart ideas here lol

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 012887a6d69eeeeae41f85a8b7d0141af6651bba returned non-zero exit code 1

Auto-merging test/export/test_export.py
Auto-merging test/functorch/test_aotdispatch.py
Auto-merging torch/__init__.py
CONFLICT (content): Merge conflict in torch/__init__.py
Auto-merging torch/_higher_order_ops/cond.py
error: could not apply 012887a6d69... Support calling torch.compile inside non-strict export
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@tugsbayasgalan
Copy link
Contributor Author

OK... so the PR description says what bugs you are fixing... but what exactly are you doing in the PR?

updated!

@tugsbayasgalan tugsbayasgalan requested a review from ezyang October 2, 2025 15:21
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. 

So the things i did for fixing above were:
1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared.
2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked. 
3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode. 

With above fixes, we are able to export flex attention in export.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Oct 2, 2025
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@tugsbayasgalan
Copy link
Contributor Author

@pytorchbot merge -f "Landed internally"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally.

So the things i did for fixing above were:
1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared.
2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked.
3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode.

With above fixes, we are able to export flex attention in export.

Pull Request resolved: pytorch#164171
Approved by: https://github.com/ydwu4
@github-actions github-actions bot deleted the gh/tugsbayasgalan/42/head branch November 3, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants