Skip to content

[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond#134732

Closed
mlazos wants to merge 16 commits intogh/mlazos/78/basefrom
gh/mlazos/78/head
Closed

[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond#134732
mlazos wants to merge 16 commits intogh/mlazos/78/basefrom
gh/mlazos/78/head

Conversation

@mlazos
Copy link
Contributor

@mlazos mlazos commented Aug 28, 2024

For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134732

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6d61357 with merge base 23dec79 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mlazos mlazos added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 28, 2024
@mlazos mlazos requested a review from avikchaudhuri August 28, 2024 23:26
@mlazos mlazos requested a review from anijain2305 August 29, 2024 09:09
@mlazos mlazos changed the title [Dynamo] Disable metadata tf mode when tracing cond [Dynamo] Disable metadata tf mode when tracing while/cond Aug 29, 2024
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
mlazos added 3 commits August 29, 2024 03:49
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
@zou3519 zou3519 requested a review from ydwu4 September 4, 2024 13:53
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
Previously, before dynamo handled torch function mode tracing, these modes were ignored. This keeps this behavior by popping this mode before tracing with dynamo in the handling of cond.




[ghstack-poisoned]
@mlazos mlazos changed the title [Dynamo] Disable metadata tf mode when tracing while/cond [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond Sep 5, 2024
…en tracing while/cond"

For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to ydwu4 for the help with implementing this




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
pytorchmergebot added a commit that referenced this pull request Sep 13, 2024
…tracing while/cond (#134732)"

This reverts commit e504fb7.

Reverted #134732 on behalf of https://github.com/albanD due to Broke tests on main ([comment](#134732 (comment)))
@pytorchmergebot
Copy link
Collaborator

@mlazos your PR has been successfully reverted.

…en tracing while/cond"

For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to ydwu4 for the help with implementing this




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
…133137)

This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
#133135
#133136
#133134
#133133
#133132
#133131
#133729
#133130

Pull Request resolved: #133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
In preparation for tracing through DeviceContext (https://github.com/pytorch/pytorch/blob/defb515306fc53ec62e92937a5a76fa5cbc05b84/torch/utils/_device.py#L66)
This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error.

Pull Request resolved: #135443
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422.  The ignored modes handling will be removed in a future PR after #135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness.

Pull Request resolved: #135444
Approved by: https://github.com/anijain2305, https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.

Pull Request resolved: #135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
@mlazos
Copy link
Contributor Author

mlazos commented Sep 14, 2024

@pytorchbot revert -m "broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
…guard (#135503)"

This reverts commit e77bd0e.

Reverted #135503 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
This reverts commit 5c67cf1.

Reverted #135502 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
This reverts commit 7743149.

Reverted #135422 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
This reverts commit ce3c74f.

Reverted #135444 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
This reverts commit 149d0b7.

Reverted #135443 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
…compile (#133137)"

This reverts commit 4528777.

Reverted #133137 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 14, 2024
…tracing while/cond (#134732)"

This reverts commit 731b178.

Reverted #134732 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](#134732 (comment)))
@pytorchmergebot
Copy link
Collaborator

@mlazos your PR has been successfully reverted.

…en tracing while/cond"

For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to ydwu4 for the help with implementing this




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants