Skip to content

Add context manager for conditional rewrites of torch.* to torch._refs.* calls#81764

Closed
IvanYashchuk wants to merge 32 commits intopytorch:masterfrom
IvanYashchuk:nvfuser-context
Closed

Add context manager for conditional rewrites of torch.* to torch._refs.* calls#81764
IvanYashchuk wants to merge 32 commits intopytorch:masterfrom
IvanYashchuk:nvfuser-context

Conversation

@IvanYashchuk
Copy link
Collaborator

Adds a new context manager TorchRefsNvfuserCapabilityMode for conditional rewrite of torch.* calls to torch._refs.* based on whether the decomposition consisting of prims supports nvFuser execution or not.

A new optional argument for TorchRefsMode is added - should_fallback_fn, a callable that returns whether the original torch.foo or the replacement torch._refs.foo should be used.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 20, 2022

🔗 Helpful links

✅ No Failures (1 Pending)

As of commit 6c3e9e7 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

# make_fx doesn't support kwargs, so we need to do this flattening
# and then unflatten the args before calling func
nargs = len(args)
flat_kwargs = list(kwargs.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using pytree flatten/unflatten instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -87,6 +91,9 @@ def __torch_function__(
mapping = torch_to_refs_map()
Copy link
Collaborator

@jjsjann123 jjsjann123 Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to be only targeting torch.xxx.

In our dynamo workload, I think we are expecting an input GraphModule with aten.ops. Might want to expand this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it's only torch.xxx for now. One option is to use dynamo under this context manager and another is to expand the context manager with "aten_to_refs_map", let's handle the extension in a separate PR.

return next(proxy_tensors, None)


def get_isolated_graphmodule(func, args, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want some special decomposition here to short-cut things like nvfuser.var_mean? With a proper impl_nvfuser this short-cut would work out well 🎉

try:
for arg in all_args:
if isinstance(arg, ProxyTensor):
arg.proxy.tracer = new_tracer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of mutating the argument, couldn't you pull out the inner elem and create a fresh ProxyTensor from that? Then you wouldn't need to worry about resetting the proxy afterwards.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can and I posted this example of creating a fresh ProxyTensor on Slack (https://pytorch.slack.com/archives/C03DP57R27M/p1658256451021309?thread_ts=1658256439.210389&cid=C03DP57R27M)

It requires creating a fresh torch.fx.proxy.Proxy which in turn requires creating a fresh Node (because nodes do not work with arbitrary graphs, it has to match). Just resetting the tracer seemed cleaner. Would you like me to change that and create a fresh ProxyTensor instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it now to create a fresh ProxyTensor instead of mutating.

@ezyang ezyang requested a review from zou3519 July 21, 2022 15:15
if isinstance(arg, ProxyTensor):
arg.proxy.tracer = new_tracer

gm = make_fx(wrapped)(all_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't feel like it is enough. If there is an ambient proxy tensor mode active, it was never disabled and will still be interposing on invocations. But your test is passing. Do you have an explanation for why this is working as is?

Copy link
Contributor

@ezyang ezyang Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the answer is, you're in a context where the proxy tensor mode is disabled by the time you called this no this isn't right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works because all the calls are recorded on the different tracer which is created in this function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see... this feels so bad haha

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a different understanding of the expected behavior? We have ProxyTensors as inputs to the function they all are expected to have the same tracer object attached (there's an assert for this somewhere), this tracer object is the place where the information about the calls gets stored and when we swap the tracer information about the calls is getting stored elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason it feels bad is because the ProxyTensorMode itself contains a tracer, and that tracer is now inconsistent with the tracers on the proxy objects. In fact, this means that if you have factory functions inside the conditional rewrite, they will go to the wrong graph. So this is definitely wrong!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is definitely wrong.
I added a test case with factory functions and two nested get_isolated_graphmodule calls. The test passes with:

mode = torch._C._get_torch_dispatch_mode()
with enable_torch_dispatch_mode(mode.inner, replace=mode):
    make_fx(...)

Then I added a test case with two nested make_fx and one get_isolated_graphmodule with factory functions. It also now passes with ExistStack of enable_torch_dispatch_mode(mode.inner, replace=mode) contexts.

@IvanYashchuk
Copy link
Collaborator Author

  1. Write a function that walks up the current mode stack, and looks for ProxyTensorMode

Done with a combination of maybe_disable_proxy_tensor_mode() and while torch._C._get_torch_dispatch_mode() is not None.

  1. Make every ProxyTensor hold a reference to ProxyTensorMode (you can look at FakeTensor to see an example of how this is done)

Done in #82549.

6.Make isolated graph disables all proxy tensor modes on the mode stack (using (1) to find the modes)

Done:

with contextlib.ExitStack() as stack:
    while torch._C._get_torch_dispatch_mode() is not None:
        stack.enter_context(maybe_disable_proxy_tensor_mode())

then runs make_fx with the arguments as is

make_fx is now run with unwrapped elem of proxy tensors. I think steps (3), (4), (5) are not needed in this case?

In the current state get_isolated_graphmodule is a very simple function that unwraps given proxy tensors and runs make_fx with all outer ProxyTensorModes disabled.

Could you please provide a concrete test case you think the current state of the pull request is not covering? Because as I mentioned I couldn't make up a test case for

What if a inner tensor from the isolated graph mode escapes (e.g. by mutation)

@ezyang
Copy link
Contributor

ezyang commented Jul 31, 2022

Could you please provide a concrete test case you think the current state of the pull request is not covering? Because as I mentioned I couldn't make up a test case for

The reason unwrapping doesn't unconditionally work is because the proxy tensor may itself be embedded within another data structure that isn't tree mappable. The most common situation is a tensor subclass. For example, consider this patch:

diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 9ac4e0a470..9d1635ecfa 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -169,7 +169,10 @@ class TestProxyTensor(TestCase):
             self.assertTrue(is_any_sum(gm))
             return torch.sigmoid(x)
 
+        from torch.testing._internal.logging_tensor import LoggingTensor
+
         def f2(x):
+            x = LoggingTensor(x)
             gm = get_isolated_graphmodule(f1, (x,), {})
             self.assertFalse(is_any_sum(gm))
             self.assertTrue(is_any_sigmoid(gm))

The LoggingTensor prevents the unwrapping from happening on the inside, and so sigmoid shows up in the inner graph.

If you really don't want to handle this case, I suppose that is reasonable, because inside AOTAutograd the expectation is that all tensor subclasses have already been erased, so we should only be passing plain tensors through and this should never happen. But in that case, there ought to be asserts about the assumed preconditions; and I also don't think it is that complicated to make it work for this general case as well.

@ezyang
Copy link
Contributor

ezyang commented Jul 31, 2022

You also have another problem which is that maybe disable proxy tensor mode cannot "see" if there is a proxy mode behind another, more recently pushed mode. When I patch with:

diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 9ac4e0a470..4cd100f18f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -169,8 +169,11 @@ class TestProxyTensor(TestCase):
             self.assertTrue(is_any_sum(gm))
             return torch.sigmoid(x)
 
+        from torch.testing._internal.logging_tensor import LoggingTensorMode
+
         def f2(x):
-            gm = get_isolated_graphmodule(f1, (x,), {})
+            with LoggingTensorMode():
+                gm = get_isolated_graphmodule(f1, (x,), {})
             self.assertFalse(is_any_sum(gm))
             self.assertTrue(is_any_sigmoid(gm))
             return torch.digamma(x)

the script seems to infinite loop.

@IvanYashchuk
Copy link
Collaborator Author

Thanks a lot for the help here!
I've fixed the infinite loop problem: currently, all dispatch modes are disabled first and then pushed again skipping proxy tensor modes.

I've added an assert that unwrapped Tensor arguments should not wrap other Tensors for now. I'd like to make it work later in a separate PR.

inside AOTAutograd the expectation is that all tensor subclasses have already been erased, so we should only be passing plain tensors through and this should never happen.

Is this erasion done by TorchDynamo?

@ezyang
Copy link
Contributor

ezyang commented Aug 1, 2022

Is this erasion done by TorchDynamo?

It's done by AOTAutograd / proxy tensor tracing

getattr(a, "elem", None) is None
for a in unwrapped_all_args
if isinstance(a, torch.Tensor)
), "ProxyTensor is wrapped with another Tensor subclass"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel like this assert actually works haha

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe... 🤔 But the test passes!

The assumption is that a tensor subclass is actually a subclass of torch.Tensor and not a generic object created with torch.Tensor._make_wrapper_subclass. Is there a better way to test for subclasses?

for mode in reversed([m for m in modes if not isinstance(m, ProxyTorchDispatchMode)]):
# mode.restore() doesn't work because mode.inner might be ProxyTorchDispatchMode
# mode.push() is restricted to modes that don't take any arguments
stack.enter_context(mode.push())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still seems super error prone. As you say here, this only works for modes that don't take any arguments. It would be much better if there was just a boolean toggle on the proxy dispatch mode you can use to turn it off without actually removing it from the stack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a change to use mode.restore() with modified mode.inner and mode.ancestors, is this too hacky?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm against anything that modifies mode.inner, we don't really know how to understand dynamically changing mode stack structure, it's really weird and hard to think about.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done on a copy of the mode locally inside the exit stack context and I checked that it doesn't mutate the modes outside the exit stack context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the current approach, you would like to see the same as done with in_kernel_invocation_manager and FakeTensorMode.in_kernel_invocation = True/False? But it also mutates the attribute locally, yes not the .inner but a different attribute, still similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is my preference. The mutation in this case can be treated like a dynamically scoped variable with limited impact (as opposed to changing of inner, which totally changes the semantics of subsequent calls in the stack.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, do you think changing of inner should be disallowed programmatically?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this situation, the semantics of modes is quite clear and asserted that outside the "with" block the modes are not changed. Isn't it a good thing to exercise various usage of .inner attribute?
Is there an issue for

# - We need a better user-facing api for torch._C._DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, do you think changing of inner should be disallowed programmatically?

Maybe. Python makes it hard to prevent people from doing this sort of thing though lol.

@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Aug 1, 2022

Now, this PR is safer: asserts are added to prevent nested tensor subclasses, and other asserts are added to verify that all tensor modes before and after the context disabling proxy tensor modes are the same and not modified.

The last thing to do for this PR to be accepted is to change the approach of disabling proxy tensor modes: #81764 (comment) (partially blocked on #82549 being merged first).

assert torch._C._get_torch_dispatch_mode() is None

# Enable all torch dispatch modes except ProxyTorchDispatchMode
for mode in reversed([m for m in modes if not isinstance(m, ProxyTorchDispatchMode)]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samdow, what do you think of the approach here modifying .inner and .ancestors to rebuild torch dispatch context skipping ProxyTorchDispatchMode instances?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I mostly agree with @ezyang that I don't love this as an idea but I get that copying has some limitations that make this difficult (ideally we would want to copy while passing a new argument to the constructor...). So what about this idea:
(1) delete the inner and ancestor attributes for the copy of every mode in the stack (note: they cannot just be set to None but must be deleted in order to get the mode code to work)
(2) as we're doing this, one by one push the modes back onto the stack

This is still imperfect since we are altering the inner and ancestor elements of the stack but at least the mode mechanism is still the one in charge of the mode ordering instead of having this repeated code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also open to thoughts from either of you on this idea

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still advocating for just setting dynamically scoped variables on the context objects themselves. This saves you from (1) having to copy or (2) mutating the inner pointers resulting in weird behavior

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with that. And to Ed's other point, I would be happy to have us restrict users updating inner but not sure if there's a clean way to

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll probably be editing this in the near future

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Aug 2, 2022

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 3, 2022
…s.* calls (#81764) (#81764)

Summary:
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.

A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.

Pull Request resolved: #81764
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/900e93d351bf9b0eae89efddabc7ba0c9339396a

Reviewed By: kit1980

Differential Revision: D38359506

fbshipit-source-id: c66ba2c8ee54bf27ae5ab689a8d6237139c56930
pytorchmergebot pushed a commit that referenced this pull request Aug 3, 2022
…ng torch dispatch mode stack inner attributes (#82643)

### Description
This PR removes fiddling with the mode stack using copies and ExitStack in favor of a simpler and more straightforward approach.

### Issue
#81764 (comment)

### Testing
No new tests are needed.

Pull Request resolved: #82643
Approved by: https://github.com/ezyang
facebook-github-bot pushed a commit that referenced this pull request Aug 4, 2022
…ng torch dispatch mode stack inner attributes (#82643) (#82643)

Summary:
### Description
This PR removes fiddling with the mode stack using copies and ExitStack in favor of a simpler and more straightforward approach.

### Issue
#81764 (comment)

### Testing
No new tests are needed.

Pull Request resolved: #82643
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8092cf60c6d5985f88ab2c4ceceac75b83c428ff

Reviewed By: kit1980

Differential Revision: D38395117

fbshipit-source-id: 0d0dcc9fb7c663181b82fed6bcb048bbd0ffc88c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: fx module: nvfuser module: primTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants