Add a default implementation of __torch_dispatch__#73684
Add a default implementation of __torch_dispatch__#73684ezyang wants to merge 3 commits intogh/ezyang/1088/basefrom
Conversation
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 161563d (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
|
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| def __new__(cls, elem): | ||
| return torch.Tensor._make_subclass(cls, elem) | ||
|
|
||
| __torch_function__ = torch._C._disabled_torch_function_impl |
There was a problem hiding this comment.
question unrelated to this PR - right now if you want to "just use torch_dispatch", you need to manually disable torch_function like the above.
It seems like using both of them is the exception rather than the common case. Do you think we'd wanna consider flipping the default so that torch_function is default disabled for your subclass if you write your own torch_dispatch? (Open question around how BC-breaking that is)
There was a problem hiding this comment.
We should do this, and I think literally no one will be BC broken by it.
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Differential Revision: [D34590357](https://our.internmc.facebook.com/intern/diff/D34590357) [ghstack-poisoned]
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 4ee2b1e Pull Request resolved: #73684
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Differential Revision: [D34590357](https://our.internmc.facebook.com/intern/diff/D34590357) [ghstack-poisoned]
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 8e02aae Pull Request resolved: #73684
|
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| // dispatch and then go again, triggering segfault. TBH I'm thinking I want | ||
| // to delete this function entirely |
There was a problem hiding this comment.
You want to delete _make_wrapper_subclass?
There was a problem hiding this comment.
Well, I can't, because it does legitimately solve a problem that cannot be solved any other way. But it feels like there should be a better way of doing this.
|
@pytorchbot merge this |
Summary: Pull Request resolved: #73684 I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D34590357 Pulled By: ezyang fbshipit-source-id: 9471709addc826c7316874b629a6ed6b512ec4ae
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch/pytorch#73684 Approved by: albanD
Stack from ghstack:
I was working on an explanation of how to call into the "super"
implementation of some given ATen operation inside of torch_dispatch
(https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py)
and I kept thinking to myself "Why doesn't just calling super() on
torch_dispatch work"? Well, after this patch, it does! The idea
is if you don't actually unwrap the input tensors, you can call
super().torch_dispatch to get at the original behavior.
Internally, this is implemented by disabling PythonKey and then
redispatching. This implementation of disabled_torch_dispatch is
not /quite/ right, and some reasons why are commented in the code.
There is then some extra work I have to do to make sure we recognize
disabled_torch_dispatch as the "default" implementation (so we don't
start slapping PythonKey on all tensors, including base Tensors),
which is modeled the same way as how disabled_torch_function is done.
Signed-off-by: Edward Z. Yang ezyang@fb.com
Differential Revision: D34590357