Support calling torch.compile inside non-strict export#164171
Support calling torch.compile inside non-strict export#164171tugsbayasgalan wants to merge 6 commits intogh/tugsbayasgalan/42/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164171
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 85323e8 with merge base bac0f28 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. With above fixes, we are able to export flex attention in export. [ghstack-poisoned]
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. With above fixes, we are able to export flex attention in export. [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch/__init__.py
Outdated
|
|
||
| # Create wrapper that always uses eager backend during export | ||
| def export_wrapped_fn(*args, **kwargs): | ||
| with setup_compilation_env(remove_pre_dispatch_tf_mode=False) as backend: |
There was a problem hiding this comment.
@ydwu4 I need this on to capture vmap ops at pre-dispatch level inside torch.compile region. I feel we also want this for cond as well? But didn't make the change here to avoid behavior difference. Let me know what you think.
There was a problem hiding this comment.
I think the correct way of doing it is following how we handle _temp_remove_metadata_torch_function_mode, what it does is that it first pop the mode before dynamo tracing and when dynamo finished tracing and starts to execute the customized backend (i.e. the backend that make_eager_backend_with_torch_function_mode produces), it restores the mode such that the non-strict sees the torch function mode again.
If we do this approach, my mental model would be that 1. vmap will be preserved in dynamo traced graph 2. when dynamo finished compilation, non-strict start to trace the graph, it will see the vmap operations and since we restore the pre_dispath_torch_function mode, we'll be able to trace the vmap operations in non-strict export graph.
So idealy, we 1. remove the remove_pre_dispatch_tf_mode flag, 2. could merge the _temp_remove_metadata_torch_function_mode and _temp_remove_pre_dispatch_tf_mode and create a unified backend that can recovers all modes that have been popped out before dynamo tracing.
The reason why we need to remove pre_dispatch torch function mode for cond is because there are side effects created during tracing the ops (E.g. enter_autocast_nodes was mutated).
There was a problem hiding this comment.
Done! I did have to change how the backend is implemented tho. Basically we need to return another callback to actually persist the modes when the gm is executed. Previously it was never actually running the modes.
torch/__init__.py
Outdated
|
|
||
| # Create wrapper that always uses eager backend during export | ||
| def export_wrapped_fn(*args, **kwargs): | ||
| with setup_compilation_env(remove_pre_dispatch_tf_mode=False) as backend: |
There was a problem hiding this comment.
I think the correct way of doing it is following how we handle _temp_remove_metadata_torch_function_mode, what it does is that it first pop the mode before dynamo tracing and when dynamo finished tracing and starts to execute the customized backend (i.e. the backend that make_eager_backend_with_torch_function_mode produces), it restores the mode such that the non-strict sees the torch function mode again.
If we do this approach, my mental model would be that 1. vmap will be preserved in dynamo traced graph 2. when dynamo finished compilation, non-strict start to trace the graph, it will see the vmap operations and since we restore the pre_dispath_torch_function mode, we'll be able to trace the vmap operations in non-strict export graph.
So idealy, we 1. remove the remove_pre_dispatch_tf_mode flag, 2. could merge the _temp_remove_metadata_torch_function_mode and _temp_remove_pre_dispatch_tf_mode and create a unified backend that can recovers all modes that have been popped out before dynamo tracing.
The reason why we need to remove pre_dispatch torch function mode for cond is because there are side effects created during tracing the ops (E.g. enter_autocast_nodes was mutated).
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. With above fixes, we are able to export flex attention in export. Differential Revision: [D83569143](https://our.internmc.facebook.com/intern/diff/D83569143) [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. THis is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. With above fixes, we are able to export flex attention in export. Differential Revision: [D83569143](https://our.internmc.facebook.com/intern/diff/D83569143) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| ): | ||
| cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None | ||
|
|
||
| _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None |
There was a problem hiding this comment.
This is actually pretty tricky. The reason we saw this op in the first place is because the PreDispatchTorchFunction mode is still active while running through dynamo code. As a result, we end up proxy-ing dynamo global state restoration logic. In this new world, we disable tf modes when running through dynamo, so we don't see this anymore. This is fine because export also has its' own global state restoration logic and it just seems wrong to have these in the graph. cc: @ydwu4
|
OK... so the PR description says what bugs you are fixing... but what exactly are you doing in the PR? |
|
|
||
| with ( | ||
| _set_compilation_env(), | ||
| torch._dynamo.utils.disable_cache_limit(), |
There was a problem hiding this comment.
This is just codemod change.
There was a problem hiding this comment.
Hi Ed, is the question that why we're doing a bunch of environment patching here?
The temp remove function mode change is to unblock lazos from the "dynamo inlining torch function mode" work, where hops saw state mutations inside the inlined torch function mode.
What we did is to 1. pop out the mode before dynamo tracing so dynamo captures a graph without torch function mode then 2. create a "patched eager" backend that restores the poped out function modes and execute the dynamo captured graph. In this case, export/aot can still trigger the torch function modes when dispatching operators in the "patched eager" backend. Do you see any problems with this workaround? Any suggestions on how we can improve the situation?
There was a problem hiding this comment.
if it's a preexisting issue, I don't have any smart ideas here lol
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
updated! |
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. So the things i did for fixing above were: 1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared. 2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked. 3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode. With above fixes, we are able to export flex attention in export. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge -f "Landed internally" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. So the things i did for fixing above were: 1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared. 2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked. 3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode. With above fixes, we are able to export flex attention in export. Pull Request resolved: pytorch#164171 Approved by: https://github.com/ydwu4
Stack from ghstack (oldest at bottom):
So this fixes at least two issues:
So the things i did for fixing above were:
With above fixes, we are able to export flex attention in export.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela