Add context manager for conditional rewrites of torch.* to torch._refs.* calls#81764
Add context manager for conditional rewrites of torch.* to torch._refs.* calls#81764IvanYashchuk wants to merge 32 commits intopytorch:masterfrom
Conversation
🔗 Helpful links
✅ No Failures (1 Pending)As of commit 6c3e9e7 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| # make_fx doesn't support kwargs, so we need to do this flattening | ||
| # and then unflatten the args before calling func | ||
| nargs = len(args) | ||
| flat_kwargs = list(kwargs.values()) |
There was a problem hiding this comment.
Consider using pytree flatten/unflatten instead?
| @@ -87,6 +91,9 @@ def __torch_function__( | |||
| mapping = torch_to_refs_map() | |||
There was a problem hiding this comment.
This function seems to be only targeting torch.xxx.
In our dynamo workload, I think we are expecting an input GraphModule with aten.ops. Might want to expand this.
There was a problem hiding this comment.
Right, it's only torch.xxx for now. One option is to use dynamo under this context manager and another is to expand the context manager with "aten_to_refs_map", let's handle the extension in a separate PR.
| return next(proxy_tensors, None) | ||
|
|
||
|
|
||
| def get_isolated_graphmodule(func, args, kwargs): |
There was a problem hiding this comment.
Would we want some special decomposition here to short-cut things like nvfuser.var_mean? With a proper impl_nvfuser this short-cut would work out well 🎉
| try: | ||
| for arg in all_args: | ||
| if isinstance(arg, ProxyTensor): | ||
| arg.proxy.tracer = new_tracer |
There was a problem hiding this comment.
Instead of mutating the argument, couldn't you pull out the inner elem and create a fresh ProxyTensor from that? Then you wouldn't need to worry about resetting the proxy afterwards.
There was a problem hiding this comment.
Yes I can and I posted this example of creating a fresh ProxyTensor on Slack (https://pytorch.slack.com/archives/C03DP57R27M/p1658256451021309?thread_ts=1658256439.210389&cid=C03DP57R27M)
It requires creating a fresh torch.fx.proxy.Proxy which in turn requires creating a fresh Node (because nodes do not work with arbitrary graphs, it has to match). Just resetting the tracer seemed cleaner. Would you like me to change that and create a fresh ProxyTensor instead?
There was a problem hiding this comment.
I changed it now to create a fresh ProxyTensor instead of mutating.
| if isinstance(arg, ProxyTensor): | ||
| arg.proxy.tracer = new_tracer | ||
|
|
||
| gm = make_fx(wrapped)(all_args) |
There was a problem hiding this comment.
This doesn't feel like it is enough. If there is an ambient proxy tensor mode active, it was never disabled and will still be interposing on invocations. But your test is passing. Do you have an explanation for why this is working as is?
There was a problem hiding this comment.
It looks like the answer is, you're in a context where the proxy tensor mode is disabled by the time you called this no this isn't right
There was a problem hiding this comment.
It works because all the calls are recorded on the different tracer which is created in this function.
There was a problem hiding this comment.
I see... this feels so bad haha
There was a problem hiding this comment.
Do you have a different understanding of the expected behavior? We have ProxyTensors as inputs to the function they all are expected to have the same tracer object attached (there's an assert for this somewhere), this tracer object is the place where the information about the calls gets stored and when we swap the tracer information about the calls is getting stored elsewhere.
There was a problem hiding this comment.
One reason it feels bad is because the ProxyTensorMode itself contains a tracer, and that tracer is now inconsistent with the tracers on the proxy objects. In fact, this means that if you have factory functions inside the conditional rewrite, they will go to the wrong graph. So this is definitely wrong!
There was a problem hiding this comment.
Yes, this is definitely wrong.
I added a test case with factory functions and two nested get_isolated_graphmodule calls. The test passes with:
mode = torch._C._get_torch_dispatch_mode()
with enable_torch_dispatch_mode(mode.inner, replace=mode):
make_fx(...)
Then I added a test case with two nested make_fx and one get_isolated_graphmodule with factory functions. It also now passes with ExistStack of enable_torch_dispatch_mode(mode.inner, replace=mode) contexts.
Done with a combination of
Done in #82549.
Done: with contextlib.ExitStack() as stack:
while torch._C._get_torch_dispatch_mode() is not None:
stack.enter_context(maybe_disable_proxy_tensor_mode())
In the current state Could you please provide a concrete test case you think the current state of the pull request is not covering? Because as I mentioned I couldn't make up a test case for
|
The reason unwrapping doesn't unconditionally work is because the proxy tensor may itself be embedded within another data structure that isn't tree mappable. The most common situation is a tensor subclass. For example, consider this patch: The LoggingTensor prevents the unwrapping from happening on the inside, and so sigmoid shows up in the inner graph. If you really don't want to handle this case, I suppose that is reasonable, because inside AOTAutograd the expectation is that all tensor subclasses have already been erased, so we should only be passing plain tensors through and this should never happen. But in that case, there ought to be asserts about the assumed preconditions; and I also don't think it is that complicated to make it work for this general case as well. |
|
You also have another problem which is that maybe disable proxy tensor mode cannot "see" if there is a proxy mode behind another, more recently pushed mode. When I patch with: the script seems to infinite loop. |
|
Thanks a lot for the help here! I've added an assert that unwrapped Tensor arguments should not wrap other Tensors for now. I'd like to make it work later in a separate PR.
Is this erasion done by TorchDynamo? |
It's done by AOTAutograd / proxy tensor tracing |
| getattr(a, "elem", None) is None | ||
| for a in unwrapped_all_args | ||
| if isinstance(a, torch.Tensor) | ||
| ), "ProxyTensor is wrapped with another Tensor subclass" |
There was a problem hiding this comment.
I don't feel like this assert actually works haha
There was a problem hiding this comment.
Maybe... 🤔 But the test passes!
The assumption is that a tensor subclass is actually a subclass of torch.Tensor and not a generic object created with torch.Tensor._make_wrapper_subclass. Is there a better way to test for subclasses?
| for mode in reversed([m for m in modes if not isinstance(m, ProxyTorchDispatchMode)]): | ||
| # mode.restore() doesn't work because mode.inner might be ProxyTorchDispatchMode | ||
| # mode.push() is restricted to modes that don't take any arguments | ||
| stack.enter_context(mode.push()) |
There was a problem hiding this comment.
This still seems super error prone. As you say here, this only works for modes that don't take any arguments. It would be much better if there was just a boolean toggle on the proxy dispatch mode you can use to turn it off without actually removing it from the stack.
There was a problem hiding this comment.
I just pushed a change to use mode.restore() with modified mode.inner and mode.ancestors, is this too hacky?
There was a problem hiding this comment.
I'm against anything that modifies mode.inner, we don't really know how to understand dynamically changing mode stack structure, it's really weird and hard to think about.
There was a problem hiding this comment.
It's done on a copy of the mode locally inside the exit stack context and I checked that it doesn't mutate the modes outside the exit stack context.
There was a problem hiding this comment.
Instead of the current approach, you would like to see the same as done with in_kernel_invocation_manager and FakeTensorMode.in_kernel_invocation = True/False? But it also mutates the attribute locally, yes not the .inner but a different attribute, still similar.
There was a problem hiding this comment.
Yes, that is my preference. The mutation in this case can be treated like a dynamically scoped variable with limited impact (as opposed to changing of inner, which totally changes the semantics of subsequent calls in the stack.)
There was a problem hiding this comment.
Okay, do you think changing of inner should be disallowed programmatically?
There was a problem hiding this comment.
In this situation, the semantics of modes is quite clear and asserted that outside the "with" block the modes are not changed. Isn't it a good thing to exercise various usage of .inner attribute?
Is there an issue for
pytorch/torch/utils/_python_dispatch.py
Lines 24 to 25 in afafd16
There was a problem hiding this comment.
Okay, do you think changing of inner should be disallowed programmatically?
Maybe. Python makes it hard to prevent people from doing this sort of thing though lol.
|
Now, this PR is safer: asserts are added to prevent nested tensor subclasses, and other asserts are added to verify that all tensor modes before and after the context disabling proxy tensor modes are the same and not modified. The last thing to do for this PR to be accepted is to change the approach of disabling proxy tensor modes: #81764 (comment) (partially blocked on #82549 being merged first). |
| assert torch._C._get_torch_dispatch_mode() is None | ||
|
|
||
| # Enable all torch dispatch modes except ProxyTorchDispatchMode | ||
| for mode in reversed([m for m in modes if not isinstance(m, ProxyTorchDispatchMode)]): |
There was a problem hiding this comment.
@samdow, what do you think of the approach here modifying .inner and .ancestors to rebuild torch dispatch context skipping ProxyTorchDispatchMode instances?
There was a problem hiding this comment.
I think I mostly agree with @ezyang that I don't love this as an idea but I get that copying has some limitations that make this difficult (ideally we would want to copy while passing a new argument to the constructor...). So what about this idea:
(1) delete the inner and ancestor attributes for the copy of every mode in the stack (note: they cannot just be set to None but must be deleted in order to get the mode code to work)
(2) as we're doing this, one by one push the modes back onto the stack
This is still imperfect since we are altering the inner and ancestor elements of the stack but at least the mode mechanism is still the one in charge of the mode ordering instead of having this repeated code
There was a problem hiding this comment.
Also open to thoughts from either of you on this idea
There was a problem hiding this comment.
I am still advocating for just setting dynamically scoped variables on the context objects themselves. This saves you from (1) having to copy or (2) mutating the inner pointers resulting in weird behavior
There was a problem hiding this comment.
I'm happy with that. And to Ed's other point, I would be happy to have us restrict users updating inner but not sure if there's a clean way to
ezyang
left a comment
There was a problem hiding this comment.
I'll probably be editing this in the near future
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @IvanYashchuk. |
…s.* calls (#81764) (#81764) Summary: Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not. A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used. Pull Request resolved: #81764 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/900e93d351bf9b0eae89efddabc7ba0c9339396a Reviewed By: kit1980 Differential Revision: D38359506 fbshipit-source-id: c66ba2c8ee54bf27ae5ab689a8d6237139c56930
…ng torch dispatch mode stack inner attributes (#82643) ### Description This PR removes fiddling with the mode stack using copies and ExitStack in favor of a simpler and more straightforward approach. ### Issue #81764 (comment) ### Testing No new tests are needed. Pull Request resolved: #82643 Approved by: https://github.com/ezyang
…ng torch dispatch mode stack inner attributes (#82643) (#82643) Summary: ### Description This PR removes fiddling with the mode stack using copies and ExitStack in favor of a simpler and more straightforward approach. ### Issue #81764 (comment) ### Testing No new tests are needed. Pull Request resolved: #82643 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8092cf60c6d5985f88ab2c4ceceac75b83c428ff Reviewed By: kit1980 Differential Revision: D38395117 fbshipit-source-id: 0d0dcc9fb7c663181b82fed6bcb048bbd0ffc88c
Adds a new context manager
TorchRefsNvfuserCapabilityModefor conditional rewrite oftorch.*calls totorch._refs.*based on whether the decomposition consisting of prims supports nvFuser execution or not.A new optional argument for
TorchRefsModeis added -should_fallback_fn, a callable that returns whether the originaltorch.fooor the replacementtorch._refs.fooshould be used.