Attach ProxyTorchDispatchMode to ProxyTensor and use the mode in __torch_dispatch__#82549
Attach ProxyTorchDispatchMode to ProxyTensor and use the mode in __torch_dispatch__#82549IvanYashchuk wants to merge 7 commits intopytorch:masterfrom
Conversation
🔗 Helpful links
✅ No Failures (33 Pending)As of commit 5f02e1a (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| raise AssertionError(f"Unexpected tracing type: {tracing_mode}") | ||
|
|
||
| proxy_mode = ProxyTorchDispatchMode(fx_tracer) if trace_factory_functions else nullcontext() | ||
| proxy_mode = ProxyTorchDispatchMode(fx_tracer, trace_factory_functions=trace_factory_functions) |
There was a problem hiding this comment.
nb: from a perf perspective this is worse than before, as you're forced to go into the torch dispatch hook even though you don't do anything with it. It doesn't matter too much though since trace factory functions is usually true.
There was a problem hiding this comment.
Well, now we always need to create a ProxyTorchDispatchMode to be attached to ProxyTensor. When trace_factory_functions is false we still need the dispatch mode and not a nullcontext to pass it to the wrap_key function.
There was a problem hiding this comment.
What breaks if we don't enable the mode (keep this line as it was previously), and don't have the trace_factory_functions argument? In my mind, if we don't have the mode on shouldn't it go:
Factory Function -> No Python Key (since no arguments have the python key and there's no mode) -> Not Traced
Non-factory function -> Python Key (from arguments) -> Mode -> redispatch on python key -> mode's torch dispatch
There was a problem hiding this comment.
AttributeError: 'nullcontext' object has no attribute 'restore' when called with proxy_mode.restore(). We can try to intercept this case and call proxy_call with nullcontext passed as mode. Then at least one test fails TestProxyTensorCPU.test_mode_tracing_factory_function_no_factory_function_cpu with
TypeError: unsupported operand type(s) for +: 'ProxyTensor' and 'Tensor'.
| else: | ||
| assert proxy_mode is arg.proxy_mode, "All arguments must be in the same proxy mode" | ||
|
|
||
| with enable_torch_dispatch_mode(proxy_mode): |
There was a problem hiding this comment.
Looks like this is latent bug from the fake tensor code; does proxy_mode.restore() work?
There was a problem hiding this comment.
When using proxy_mode.restore() TestProxyTensorCPU.test_decomposition_interpreter_cpu fails with
RuntimeError: <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7f084183ddb0> does not have any ancestors. Use the standard version instead of restoreThere was a problem hiding this comment.
this fixes this error
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index a7e38074eb6..d5bafefc0f2 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -476,6 +476,9 @@ class DecompositionInterpreter(torch.fx.Interpreter):
return out
def run(self, *args, **kwargs):
+ # snapshot the current mode stack
+ with self.mode:
+ pass
with decompose(self.decomposition_table):
return super().run(*args, **kwargs)
I need to think more about what went wrong on the API design side though.
cc @samdow: what's going on here is people are allocating mode objects, but then never actually entering them (and manually calling API functions that make use of the mode object.)
There was a problem hiding this comment.
Ahh yeah that's a hard to understand error message. I'll change it so it's clearer that that's what happening
|
I see you hacking it up. I'd like to properly diagnose why the underlying API isn't working in this case. As a stopgap, instead of relying on storing the mode on the proxy tensor for dispatch, you could solely use it to determine if tracing is enabled or not. |
|
Okay full disclosure that I did have #81708 that I just haven't had the bandwidth to clean up and look at the failures. I'll leave some comments but happy to use either and close the other |
| # Verify that the proxy mode for all arguments is the same | ||
| proxy_mode = None | ||
| for arg in pytree.tree_flatten((args, kwargs))[0]: | ||
| if isinstance(arg, ProxyTensor): | ||
| if proxy_mode is None: | ||
| proxy_mode = arg.proxy_mode | ||
| else: | ||
| assert proxy_mode is arg.proxy_mode, "All arguments must be in the same proxy mode" |
There was a problem hiding this comment.
Small nit that Proxy also checks that all the tracers are the same (equivalent of this). I'm not attached to either one but if we want to have the same error message as before, we can just break after we find the first proxy mode and assume that they're all the same and let Proxy deal with it if that's not the case
There was a problem hiding this comment.
Sure, I can change that.
There was a problem hiding this comment.
Actually, @samdow, didn't you write some utilities for doing this? Or am I misremembering
There was a problem hiding this comment.
Oh actually yeah--if we need to do something like this in the future, we just need to map to get all the proxy_modes and then call all_modes_same_scope(modes) (https://github.com/pytorch/pytorch/blob/master/torch/utils/_mode_utils.py#L124)
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to Refusing to merge as mandatory check(s) Lint failed for rule superuser |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to New commits were pushed while merging. Please rerun the merge command. |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @IvanYashchuk. |
…rch_dispatch__ (#82549) (#82549) Summary: Migrates ProxyTensors to always invoke ProxyTorchDispatchMode instead of calling into `proxy_call`. Pull Request resolved: #82549 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a3316cb3c7d7295de0ff686bf7b7888f2c4665ef Reviewed By: kit1980 Differential Revision: D38359546 fbshipit-source-id: c7747e20f578e1ea2ec15c5e65d2a9ff4dbbf9a7
Migrates ProxyTensors to always invoke ProxyTorchDispatchMode instead of calling into
proxy_call.