python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok#160580
python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok#160580bdhirsh wants to merge 7 commits intogh/bdhirsh/666/basefrom
Conversation
…asing DTensorSpec is ok [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160580
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 00c5705 with merge base 25c170b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rm that aliasing DTensorSpec is ok"
**not for land**
My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:
(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.
(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)
I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated
This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):
**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
4.80 us
1 measurement, 100000 runs , 1 thread
```
**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
336.40 us
1 measurement, 1000 runs , 1 thread
```
**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
Median: 133.45 us
2 measurements, 1000 runs per measurement, 1 thread
```
benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark
fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)
mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
t0 = benchmark.Timer(
stmt='_ = x_dt.detach()',
globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())
dist.destroy_process_group()
```
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta
[ghstack-poisoned]
|
How are you thinking about landing something like this? It seems like this change is self-contained enough and useful enough to merit landing as-is, unless there is an even faster version of it in c++ coming imminently. Otoh, it also seems like there could be lots more cases we'd want to optimize, and that would probably lead to having a more extensible framework for opting more ops into this path. Also, it'd be good to have some way of measuring overhead in CI but that's a side point. Pei worked on some actual benchmarking. But nobody has tried instruction counting yet, and that might actually be viable here (haven't checked how expensive it is to do that) |
|
I would be ok with landing it as-is (maybe even going back to just the detach() case if that is the main source of slowness we have to worry about r.e. AccumulateGrad slowness). I'm hoping that this is "enough" and we won't have to reach into making C++ changes, but in either case, it would be easy enough to land this and remove it later.
Fair, if we want to add a lot of fastpaths in the future this may get unwieldy pretty quickly. I am sort of hoping that we can draw the line as "AccumulateGrad is tough to compile so lets try to make DTensor overheads for it fast, and compile everything else". But I'm not sure how likely it is that we can do that. |
…rm that aliasing DTensorSpec is ok"
**not for land**
My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:
(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.
(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)
I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated
This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):
**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
4.80 us
1 measurement, 100000 runs , 1 thread
```
**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
336.40 us
1 measurement, 1000 runs , 1 thread
```
**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
Median: 133.45 us
2 measurements, 1000 runs per measurement, 1 thread
```
benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark
fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)
mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
t0 = benchmark.Timer(
stmt='_ = x_dt.detach()',
globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())
dist.destroy_process_group()
```
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta
[ghstack-poisoned]
…h__` and `__eq__` (#161234) The performance cost of `dict` lookups keyed by `OpSchema` is a significant minority of DTensor overhead. With this change we shave a net ~1% off the total running time of the benchmark from #160580, as measured by using cProfile and comparing cumulative time spent in propagate + OpSchema's `__post_init__`. (`__post_init__` grew from 2.5% to 6.4% (+3.9%) and propagate shrank from 12.5% to 7.8% (-4.7%)). Pull Request resolved: #161234 Approved by: https://github.com/wconstab ghstack dependencies: #161231
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from #160580 and Linux perf. [ghstack-poisoned]
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from #160580 and Linux perf. ghstack-source-id: 67895a5 Pull Request resolved: #161586
This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms) [ghstack-poisoned]
…sor._make_dtensor to accelerate DTensor.__new__ further" This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms) [ghstack-poisoned]
…te DTensor.__new__ further" This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms) cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…__ further" This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms) cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from #160580 and Linux perf. [ghstack-poisoned]
|
@pytorchbot revert -m "this broke shampoo, yanking" |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot revert -m "this broke shampoo, yanking" -c nosignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@bdhirsh your PR has been successfully reverted. |
…TensorSpec is ok (#160580)" This reverts commit 4b2d297. Reverted #160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](#160580 (comment)))
…h__` and `__eq__` (pytorch#161234) The performance cost of `dict` lookups keyed by `OpSchema` is a significant minority of DTensor overhead. With this change we shave a net ~1% off the total running time of the benchmark from pytorch#160580, as measured by using cProfile and comparing cumulative time spent in propagate + OpSchema's `__post_init__`. (`__post_init__` grew from 2.5% to 6.4% (+3.9%) and propagate shrank from 12.5% to 7.8% (-4.7%)). Pull Request resolved: pytorch#161234 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#161231
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from pytorch#160580 and Linux perf. Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107) Pull Request resolved: pytorch#161586 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#161466
…ytorch#161590) This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec) Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105) Pull Request resolved: pytorch#161590 Approved by: https://github.com/albanD ghstack dependencies: pytorch#161466, pytorch#161586
…ec is ok (pytorch#160580) My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps: (1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop. (2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++) I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us): **aten.detach(plain_tensor)** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790> _ = x.detach() 4.80 us 1 measurement, 100000 runs , 1 thread ``` **aten.detach(DTensor) [before this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750> _ = x_dt.detach() 336.40 us 1 measurement, 1000 runs , 1 thread ``` **aten.detach(DTensor) [after this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520> _ = x_dt.detach() Median: 133.45 us 2 measurements, 1000 runs per measurement, 1 thread ``` benchmark script: ``` import torch import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore import torch.utils.benchmark as benchmark fake_store = FakeStore() dist.init_process_group("fake", store=fake_store, rank=0, world_size=2) mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,)) x = torch.randn(4, 4, requires_grad=True) x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) t0 = benchmark.Timer( stmt='_ = x_dt.detach()', globals={'x_dt': x_dt}, ) print(t0.blocked_autorange()) dist.destroy_process_group() ``` Pull Request resolved: pytorch#160580 Approved by: https://github.com/ezyang
…TensorSpec is ok (pytorch#160580)" This reverts commit 4b2d297. Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
|
are we expecting this to re-land? |
|
Not until Brian is back from PTO (2 weeks) at least |
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from pytorch#160580 and Linux perf. Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107) Pull Request resolved: pytorch#161586 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#161466
…ytorch#161590) This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec) Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105) Pull Request resolved: pytorch#161590 Approved by: https://github.com/albanD ghstack dependencies: pytorch#161466, pytorch#161586
…ec is ok (pytorch#160580) My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps: (1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop. (2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++) I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us): **aten.detach(plain_tensor)** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790> _ = x.detach() 4.80 us 1 measurement, 100000 runs , 1 thread ``` **aten.detach(DTensor) [before this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750> _ = x_dt.detach() 336.40 us 1 measurement, 1000 runs , 1 thread ``` **aten.detach(DTensor) [after this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520> _ = x_dt.detach() Median: 133.45 us 2 measurements, 1000 runs per measurement, 1 thread ``` benchmark script: ``` import torch import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore import torch.utils.benchmark as benchmark fake_store = FakeStore() dist.init_process_group("fake", store=fake_store, rank=0, world_size=2) mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,)) x = torch.randn(4, 4, requires_grad=True) x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) t0 = benchmark.Timer( stmt='_ = x_dt.detach()', globals={'x_dt': x_dt}, ) print(t0.blocked_autorange()) dist.destroy_process_group() ``` Pull Request resolved: pytorch#160580 Approved by: https://github.com/ezyang
…TensorSpec is ok (pytorch#160580)" This reverts commit 4b2d297. Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
…ytorch#161590) This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec) Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105) Pull Request resolved: pytorch#161590 Approved by: https://github.com/albanD ghstack dependencies: pytorch#161466, pytorch#161586
…ec is ok (pytorch#160580) My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps: (1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop. (2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++) I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us): **aten.detach(plain_tensor)** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790> _ = x.detach() 4.80 us 1 measurement, 100000 runs , 1 thread ``` **aten.detach(DTensor) [before this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750> _ = x_dt.detach() 336.40 us 1 measurement, 1000 runs , 1 thread ``` **aten.detach(DTensor) [after this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520> _ = x_dt.detach() Median: 133.45 us 2 measurements, 1000 runs per measurement, 1 thread ``` benchmark script: ``` import torch import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore import torch.utils.benchmark as benchmark fake_store = FakeStore() dist.init_process_group("fake", store=fake_store, rank=0, world_size=2) mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,)) x = torch.randn(4, 4, requires_grad=True) x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) t0 = benchmark.Timer( stmt='_ = x_dt.detach()', globals={'x_dt': x_dt}, ) print(t0.blocked_autorange()) dist.destroy_process_group() ``` Pull Request resolved: pytorch#160580 Approved by: https://github.com/ezyang
…TensorSpec is ok (pytorch#160580)" This reverts commit 4b2d297. Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from pytorch#160580 and Linux perf. Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107) Pull Request resolved: pytorch#161586 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#161466
…ytorch#161590) This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec) Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105) Pull Request resolved: pytorch#161590 Approved by: https://github.com/albanD ghstack dependencies: pytorch#161466, pytorch#161586
…ec is ok (pytorch#160580) My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps: (1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop. (2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++) I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us): **aten.detach(plain_tensor)** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790> _ = x.detach() 4.80 us 1 measurement, 100000 runs , 1 thread ``` **aten.detach(DTensor) [before this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750> _ = x_dt.detach() 336.40 us 1 measurement, 1000 runs , 1 thread ``` **aten.detach(DTensor) [after this PR]** ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520> _ = x_dt.detach() Median: 133.45 us 2 measurements, 1000 runs per measurement, 1 thread ``` benchmark script: ``` import torch import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore import torch.utils.benchmark as benchmark fake_store = FakeStore() dist.init_process_group("fake", store=fake_store, rank=0, world_size=2) mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,)) x = torch.randn(4, 4, requires_grad=True) x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) t0 = benchmark.Timer( stmt='_ = x_dt.detach()', globals={'x_dt': x_dt}, ) print(t0.blocked_autorange()) dist.destroy_process_group() ``` Pull Request resolved: pytorch#160580 Approved by: https://github.com/ezyang
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:
(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.
(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)
I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated
This by itself does seem to speed up
aliasby quite a bit (roughly 2.5x speedup, from ~336us -> 133us):aten.detach(plain_tensor)
aten.detach(DTensor) [before this PR]
aten.detach(DTensor) [after this PR]
benchmark script:
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci