Skip to content

[Dynamo] Trace torch function modes entered outside of torch.compile#133137

Closed
mlazos wants to merge 61 commits intogh/mlazos/70/basefrom
gh/mlazos/70/head
Closed

[Dynamo] Trace torch function modes entered outside of torch.compile#133137
mlazos wants to merge 61 commits intogh/mlazos/70/basefrom
gh/mlazos/70/head

Conversation

@mlazos
Copy link
Contributor

@mlazos mlazos commented Aug 9, 2024

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Stack from ghstack (oldest at bottom):

Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133137

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 75e0854 with merge base 23dec79 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Chillee
Copy link
Collaborator

Chillee commented Aug 10, 2024

Would be good to test compiling create_block_mask in FlexAttention.

block_mask_a = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=True)

So, make sure that torch.compile(create_block_mask, fullgraph=True) works.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Aug 10, 2024
ghstack-source-id: 1e86fa6
Pull Request resolved: #133137
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auto_functionalized expecttest changes still look fishy, but otherwise this PR LGTM

@mlazos
Copy link
Contributor Author

mlazos commented Sep 9, 2024

The auto_functionalized expecttest changes still look fishy, but otherwise this PR LGTM

Can you approve? You're the one with the most context at this point lol

@zou3519
Copy link
Contributor

zou3519 commented Sep 9, 2024

I want to poke at the auto_functionalized expecttests a bit more. The PR in its current state looks like it is regressing Dynamo torch_function support for torch.ops.*

…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@mlazos
Copy link
Contributor Author

mlazos commented Sep 9, 2024

I want to poke at the auto_functionalized expecttests a bit more. The PR in its current state looks like it is regressing Dynamo torch_function support for torch.ops.*

Per our discussion, I've fixed this by adding torch.ops.* to the check for torch function overriding in torch variable.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some last comments, please read

@mlazos
Copy link
Contributor Author

mlazos commented Sep 9, 2024

Some last comments, please read

Yup, I'll push those before merging. Thanks for the review! The ops thing was a good catch.

…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@clee2000
Copy link
Contributor

@pytorchbot revert -m "something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph GH job link HUD commit link, newly added test yesterday" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@mlazos your PR has been successfully reverted.

…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

@mlazos your PR has been successfully reverted.

…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

@mlazos your PR has been successfully reverted.

…ch.compile"

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. 
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.





Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants