fix torch.tensor for functionalization#76319
fix torch.tensor for functionalization#76319bdhirsh wants to merge 9 commits intogh/bdhirsh/218/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 3 New FailuresAs of commit 724cc37 (more details on the Dr. CI page): Expand to see more
🕵️ 3 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
torch/csrc/utils/tensor_new.cpp
Outdated
| @@ -292,7 +294,29 @@ Tensor internal_new_from_data( | |||
| "Expected a Storage of type ", scalar_type, | |||
| " or an _UntypedStorage, but got ", storage_scalar_type); | |||
| tensor = at::empty(sizes, at::initialTensorOptions().dtype(is_typed_storage ? storage_scalar_type : inferred_scalar_type).pinned_memory(pin_memory).device(storage.device())); | |||
There was a problem hiding this comment.
It occurs to me that this actually is very inefficient, right?! We allocate a sizes storage, and then throw it out immediately after! If we had some sort of new_as_strided (which @albanD was mumbling about at #75994 ) we could do this all in one go, it would be faster, and you could directly implement functionalization there.
There was a problem hiding this comment.
We do have the low level function in c++:
The new version could be added!
torch/csrc/utils/tensor_new.cpp
Outdated
| // LazyNativeFunctions::empty is explicitly responsible for wrapping its output into a FunctionalTensorWrapper. | ||
| // - That leaves us with the problem described here though: at::empty() is going to return a wrapper. | ||
| // One way to generalize this would be to make "wrapper tensor" a first class concept, | ||
| // e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls). |
There was a problem hiding this comment.
This is a long comment saying why the thing doesn't work, but I think what I'd actually read about is how, morally it should work
torch/csrc/utils/tensor_new.cpp
Outdated
| if (at::functionalization::impl::isFunctionalTensor(tensor)) { | ||
| at::functionalization::impl::from_functional_tensor(tensor).set_(storage); | ||
| } else { | ||
| data_tensor = tensor; |
There was a problem hiding this comment.
how come data_tensor gets set in one branch but not the other?
torch/csrc/utils/tensor_new.cpp
Outdated
| if (c10::multiply_integers(tensor.sizes()) != 0) { | ||
|
|
||
| // See Note [Functionalization <> torch.Tensor Constructor] | ||
| at::Tensor data_tensor; |
There was a problem hiding this comment.
...shadowing the data_tensor above?
torch/csrc/utils/tensor_new.cpp
Outdated
| // One way to generalize this would be to make "wrapper tensor" a first class concept, | ||
| // e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls). | ||
| if (at::functionalization::impl::isFunctionalTensor(tensor)) { | ||
| at::functionalization::impl::from_functional_tensor(tensor).set_(storage); |
There was a problem hiding this comment.
This is the "storage was passed to tensor constructor" code path, doesn't anyone actually need this?
If we want to do this soundly we need to identify if the passed in storage is a functionalization storage or a regular storage, because it seems to me like unwrapping the functional tensor if its a functionalization storage would be the wrong thing to do.
torch/csrc/utils/tensor_new.cpp
Outdated
| data_tensor = tensor; | ||
| } | ||
|
|
||
| if (c10::multiply_integers(data_tensor.sizes()) != 0) { |
There was a problem hiding this comment.
It should be sound to compute this on tensor.sizes() too right?
|
The code at stake here is not big, so I don't think it's too risky to land this as is (esp as an unblocker). However, I think I have an alternative suggestion for how to do this properly. The general concept of what is going on here is that |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
Hmm ok, I originally wasn't sure if this was the end-state behavior that we wanted (partially because of the weird interaction that it would cause with LTC/XLA), but the description makes sense to me. I'll switch it over to do it this way in the PR. We can also get things to work out on the LTC/XLA side by having their |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
|
I didn't think carefully about the LTC/XLA side of things. But it seems similar? You need to make an honest to goodness tensor with the data you want, and then lower it into the XLA graph as a constant. That's what lift should be doing, I think? |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
|
@zhxchen17 I've been working on a big stack of functionalization changes locally including the feedback from this PR. Actually just pushed out the changes a minute ago. @ezyang I saw that you just gave the PR an approve, feel free to take a look at the new changes - I added a new I'm going to need to add a companion PR for functorch before I can land this though - the existing wrapper subclasses in functorch need to know about "lifting" |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
| c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_guard(c10::DispatchKey::Python); | ||
| c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_snapshot_guard(c10::DispatchKey::PythonTLSSnapshot); | ||
| // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all | ||
| // tensors returned from operators in special TensorWrapper tensor extension |
There was a problem hiding this comment.
It feels like there should be an easy way to just "exclude everything", but we can probably work that out later.
ezyang
left a comment
There was a problem hiding this comment.
I reviewed the new code with lift and I like it a lot! Thanks!
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
|
I'm actually pretty sure that this doesn't require functorch changes: I was trying to change functorch locally, and my understanding is that now the code in the Getting functorch to support It also looks like functorch CI is failing on master, but I was able to run |
|
@pytorchbot merge this please |
|
Hey @bdhirsh. |
|
@pytorchbot revert this as it breaks ONNX tests (which also show up on the PR) https://hud.pytorch.org/minihud?name_filter=pull%20/%20linux-xenial-py3.7-clang7-onnx%20/%20test%20(default,%202,%202,%20linux.2xlarge) |
This reverts commit 9edee09. Reverted #76319 on behalf of https://github.com/janeyx99
Right now, using the
torch.Tensorconstructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)update: I took the approach described in the comments (letting
at::empty()run directly with the data, and using.to()to "lift" it into a wrapper), which required a minor change described below.Stack from ghstack: