Make __new__ on tensor subclasses match _make_subclass#73727
Make __new__ on tensor subclasses match _make_subclass#73727ezyang wants to merge 1 commit intogh/ezyang/1090/basefrom
Conversation
Previously, calling SubclassTensor(tensor) would give you a SubclassTensor where the underlying at::Tensor was computed by an alias() call. In particular, a grad_fn would be created in this situation. This is usually not what people want, as the alias grad_fn is oblivious to the subclass's semantics (and just as likely to be wrong) and it means that you cannot use the constructor to directly create a leaf SubclassTensor that requires_grad=True. This PR changes the meaning of this call so that SubclassTensor(tensor) is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor); that is to say, the underlying at::Tensor is created by a detach() call (deleting grad_fn), and furthermore the requires_grad defaults to False (but you can set it explicitly afterwards). I keep exactly the old behavior if you call a normal Tensor, which could be somewhat confusing as it doesn't match exactly. I'm not sure if this is completely correct. Here are some other ways we could skin the cat: - detach(), but propagate requires_grad-ness. This lets an idiom like TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive thing. - detach(), ignore input requires_grad and also accept a requires_grad kwarg for setting requires_grad directly. This means you would write TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf node. - Same as above, but assert that requires_grad=False or that we are in a no_grad mode. This would remind users that if they want a non-leaf tensor subclass, they are obligated to think about the autograd semantics for this boundary. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit df853a1 (more details on the Dr. CI page):
🕵️ 10 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
Previously, calling SubclassTensor(tensor) would give you a SubclassTensor where the underlying at::Tensor was computed by an alias() call. In particular, a grad_fn would be created in this situation. This is usually not what people want, as the alias grad_fn is oblivious to the subclass's semantics (and just as likely to be wrong) and it means that you cannot use the constructor to directly create a leaf SubclassTensor that requires_grad=True. This PR changes the meaning of this call so that SubclassTensor(tensor) is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor); that is to say, the underlying at::Tensor is created by a detach() call (deleting grad_fn), and furthermore the requires_grad defaults to False (but you can set it explicitly afterwards). I keep exactly the old behavior if you call a normal Tensor, which could be somewhat confusing as it doesn't match exactly. I'm not sure if this is completely correct. Here are some other ways we could skin the cat: - detach(), but propagate requires_grad-ness. This lets an idiom like TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive thing. - detach(), ignore input requires_grad and also accept a requires_grad kwarg for setting requires_grad directly. This means you would write TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf node. - Same as above, but assert that requires_grad=False or that we are in a no_grad mode. This would remind users that if they want a non-leaf tensor subclass, they are obligated to think about the autograd semantics for this boundary. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: df8e576 Pull Request resolved: #73727
| TORCH_CHECK(type != &THPVariableType, "Cannot directly construct _TensorBase; subclass it and then construct that"); | ||
| jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); | ||
| auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs); | ||
| auto tensor = torch::utils::legacy_tensor_ctor(type, torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs); |
There was a problem hiding this comment.
Do I read correctly that this constructor does not accept the "requires_grad" flag today at all? And thus can't be used to do Foo(torch.empty(2), requires_grad=True)?
In that case, I think I like the idea of raising an error if the input requires_grad and not in no_grad mode. We can even recommend to use a custom Function if they want to make this construction differentiable.
There was a problem hiding this comment.
Based in your chat comments I was planning to readd the requires grad field, seemed like a good idea.
There was a problem hiding this comment.
So the final thing would be:
- You have a requires_grad flag you can use to create a leaf that requires grad
- The given Tensor can never require gradients.
There was a problem hiding this comment.
One thing that gives me pause about forcing the given Tensor to not require gradients is that it is a bit tiresome for AOTAutograd when you get a requires_grad input; we actually do want to turn these into leaves so that we can compute gradients only to them. But I guess it is not too bad; it looks like `Tracer(x.detach(), requires_grad=x.requires_grad)
| @staticmethod | ||
| def __new__(cls, elem): | ||
| return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
There was a problem hiding this comment.
How does the new constructor work with wrapper Tensors? I guess I want to see an example of how this changes construction for the different type of tensor subclasses
Someone might want to create a wrapper Tensor subclass (DiagTensor?) that holds a vector and would want DiagTensor(blah) to work. Does this mean DiagTensor would override new ?
There was a problem hiding this comment.
WrapperTensor still has to define new; in general any tensor which doesn't match metadata doesn't work here.
|
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Based on the failing test case, I need to revise my plan.
So here is what I suggest we do.
|
|
So we will have these modes of use: |
Why is the view False here? Or is the problem with the fact that views that are leafs are ~broken today? |
Under the hood, setting |
|
OK so there is a problem with the strategy I suggested, which can be seen with this sample program using subclass zoo: failing with It shouldn't. The problem is that inside |
|
My current thinking is that when we are "past" the autograd layer (as is the case with Something that doesn't work is to forbid torch dispatch from returning views; we can see from the default constructor for Tensor that the most obvious way of creating a tensor subclass involves creating a view from a temporary tensor that immediately ceases to exist (so you shouldn't think of it as a view). One thing that is annoying is if someone incorrectly returns a real view (as opposed to a fake view derived from a temporary tensor) we will lose the view information. Not sure if there is an easy way to detect this has occurred. |
It is today. It only get restored alongside the Autograd key by PythonTLSSnapshot. Which is what you want as the autograd is re-enabled there. nit:for "base_tensor.py": the Changing the "wrap" of the |
Hmm, this doesn't seem like a complete explanation to me. The problem here is that there are two "levels" of autograd (one for the outer object which I got as inputs, and one for the inner objects I may be wrapping over), and I only want autograd to be reenabled for the inner objects. With the current behavior, weird stuff like this can happen: This test fails, which seems suboptimal. |
|
So my current thinking is that a user should be explicitly responsible for restoring TLS. A few justifications for this:
|
|
From offline discussion, we agreed on updating the TLS system: #75130 I think we want to revisit this issue in that context afterwards. |
|
is this pr still necessary ? |
|
No it's wrong so we won't do it |
Stack from ghstack:
Previously, calling SubclassTensor(tensor) would give you a
SubclassTensor where the underlying at::Tensor was computed by an
alias() call. In particular, a grad_fn would be created in this
situation. This is usually not what people want, as the alias grad_fn
is oblivious to the subclass's semantics (and just as likely to be wrong)
and it means that you cannot use the constructor to directly create a
leaf SubclassTensor that requires_grad=True.
This PR changes the meaning of this call so that SubclassTensor(tensor)
is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor);
that is to say, the underlying at::Tensor is created by a detach() call
(deleting grad_fn), and furthermore the requires_grad defaults to False
(but you can set it explicitly afterwards). I keep exactly the old
behavior if you call a normal Tensor, which could be somewhat confusing
as it doesn't match exactly.
I'm not sure if this is completely correct. Here are some other ways
we could skin the cat:
TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive
thing.
kwarg for setting requires_grad directly. This means you would write
TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf
node.
in a no_grad mode. This would remind users that if they want a
non-leaf tensor subclass, they are obligated to think about the
autograd semantics for this boundary.
Signed-off-by: Edward Z. Yang ezyang@fb.com
Differential Revision: D34615319