Skip to content

Attach ProxyTorchDispatchMode to ProxyTensor and use the mode in __torch_dispatch__#82549

Closed
IvanYashchuk wants to merge 7 commits intopytorch:masterfrom
IvanYashchuk:proxy-tensor-mode
Closed

Attach ProxyTorchDispatchMode to ProxyTensor and use the mode in __torch_dispatch__#82549
IvanYashchuk wants to merge 7 commits intopytorch:masterfrom
IvanYashchuk:proxy-tensor-mode

Conversation

@IvanYashchuk
Copy link
Collaborator

Migrates ProxyTensors to always invoke ProxyTorchDispatchMode instead of calling into proxy_call.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 31, 2022

🔗 Helpful links

✅ No Failures (33 Pending)

As of commit 5f02e1a (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

raise AssertionError(f"Unexpected tracing type: {tracing_mode}")

proxy_mode = ProxyTorchDispatchMode(fx_tracer) if trace_factory_functions else nullcontext()
proxy_mode = ProxyTorchDispatchMode(fx_tracer, trace_factory_functions=trace_factory_functions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb: from a perf perspective this is worse than before, as you're forced to go into the torch dispatch hook even though you don't do anything with it. It doesn't matter too much though since trace factory functions is usually true.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, now we always need to create a ProxyTorchDispatchMode to be attached to ProxyTensor. When trace_factory_functions is false we still need the dispatch mode and not a nullcontext to pass it to the wrap_key function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What breaks if we don't enable the mode (keep this line as it was previously), and don't have the trace_factory_functions argument? In my mind, if we don't have the mode on shouldn't it go:

Factory Function -> No Python Key (since no arguments have the python key and there's no mode) -> Not Traced

Non-factory function -> Python Key (from arguments) -> Mode -> redispatch on python key -> mode's torch dispatch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError: 'nullcontext' object has no attribute 'restore' when called with proxy_mode.restore(). We can try to intercept this case and call proxy_call with nullcontext passed as mode. Then at least one test fails TestProxyTensorCPU.test_mode_tracing_factory_function_no_factory_function_cpu with
TypeError: unsupported operand type(s) for +: 'ProxyTensor' and 'Tensor'.

else:
assert proxy_mode is arg.proxy_mode, "All arguments must be in the same proxy mode"

with enable_torch_dispatch_mode(proxy_mode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is latent bug from the fake tensor code; does proxy_mode.restore() work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using proxy_mode.restore() TestProxyTensorCPU.test_decomposition_interpreter_cpu fails with

RuntimeError: <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7f084183ddb0> does not have any ancestors. Use the standard version instead of restore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fixes this error

diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index a7e38074eb6..d5bafefc0f2 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -476,6 +476,9 @@ class DecompositionInterpreter(torch.fx.Interpreter):
         return out
 
     def run(self, *args, **kwargs):
+        # snapshot the current mode stack
+        with self.mode:
+            pass
         with decompose(self.decomposition_table):
             return super().run(*args, **kwargs)
 

I need to think more about what went wrong on the API design side though.

cc @samdow: what's going on here is people are allocating mode objects, but then never actually entering them (and manually calling API functions that make use of the mode object.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yeah that's a hard to understand error message. I'll change it so it's clearer that that's what happening

@ezyang
Copy link
Contributor

ezyang commented Aug 1, 2022

I see you hacking it up. I'd like to properly diagnose why the underlying API isn't working in this case. As a stopgap, instead of relying on storing the mode on the proxy tensor for dispatch, you could solely use it to determine if tracing is enabled or not.

@samdow
Copy link
Contributor

samdow commented Aug 1, 2022

Okay full disclosure that I did have #81708 that I just haven't had the bandwidth to clean up and look at the failures. I'll leave some comments but happy to use either and close the other

Comment on lines +289 to +296
# Verify that the proxy mode for all arguments is the same
proxy_mode = None
for arg in pytree.tree_flatten((args, kwargs))[0]:
if isinstance(arg, ProxyTensor):
if proxy_mode is None:
proxy_mode = arg.proxy_mode
else:
assert proxy_mode is arg.proxy_mode, "All arguments must be in the same proxy mode"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit that Proxy also checks that all the tracers are the same (equivalent of this). I'm not attached to either one but if we want to have the same error message as before, we can just break after we find the first proxy mode and assume that they're all the same and let Proxy deal with it if that's not the case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can change that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @samdow, didn't you write some utilities for doing this? Or am I misremembering

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually yeah--if we need to do something like this in the future, we just need to map to get all the proxy_modes and then call all_modes_same_scope(modes) (https://github.com/pytorch/pytorch/blob/master/torch/utils/_mode_utils.py#L124)

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤞 no reference cycles lol (cc @eellison, the only reference cycle from mode on tensor was the converter cache right?)

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed due to Refusing to merge as mandatory check(s) Lint failed for rule superuser
Raised by https://github.com/pytorch/pytorch/actions/runs/2781445565

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed due to New commits were pushed while merging. Please rerun the merge command.
Raised by https://github.com/pytorch/pytorch/actions/runs/2781632681

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Aug 2, 2022

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 3, 2022
…rch_dispatch__ (#82549) (#82549)

Summary:
Migrates ProxyTensors to always invoke ProxyTorchDispatchMode instead of calling into `proxy_call`.

Pull Request resolved: #82549
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a3316cb3c7d7295de0ff686bf7b7888f2c4665ef

Reviewed By: kit1980

Differential Revision: D38359546

fbshipit-source-id: c7747e20f578e1ea2ec15c5e65d2a9ff4dbbf9a7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants