Skip to content

python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok#160580

Closed
bdhirsh wants to merge 7 commits intogh/bdhirsh/666/basefrom
gh/bdhirsh/666/head
Closed

python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok#160580
bdhirsh wants to merge 7 commits intogh/bdhirsh/666/basefrom
gh/bdhirsh/666/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Aug 13, 2025

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up alias by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

aten.detach(plain_tensor)

<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread

aten.detach(DTensor) [before this PR]

<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread

aten.detach(DTensor) [after this PR]

<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread

benchmark script:

import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())


dist.destroy_process_group()

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160580

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 00c5705 with merge base 25c170b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Aug 13, 2025
…asing DTensorSpec is ok

ghstack-source-id: 90116b9
Pull Request resolved: #160580
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 13, 2025
…rm that aliasing DTensorSpec is ok"

**not for land**

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())


dist.destroy_process_group()
```




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 20, 2025
…asing DTensorSpec is ok

ghstack-source-id: abbe0ad
Pull Request resolved: #160580
@wconstab
Copy link
Contributor

How are you thinking about landing something like this?

It seems like this change is self-contained enough and useful enough to merit landing as-is, unless there is an even faster version of it in c++ coming imminently.

Otoh, it also seems like there could be lots more cases we'd want to optimize, and that would probably lead to having a more extensible framework for opting more ops into this path.

Also, it'd be good to have some way of measuring overhead in CI but that's a side point. Pei worked on some actual benchmarking. But nobody has tried instruction counting yet, and that might actually be viable here (haven't checked how expensive it is to do that)

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Aug 20, 2025

I would be ok with landing it as-is (maybe even going back to just the detach() case if that is the main source of slowness we have to worry about r.e. AccumulateGrad slowness).

I'm hoping that this is "enough" and we won't have to reach into making C++ changes, but in either case, it would be easy enough to land this and remove it later.

Otoh, it also seems like there could be lots more cases we'd want to optimize, and that would probably lead to having a more extensible framework for opting more ops into this path.

Fair, if we want to add a lot of fastpaths in the future this may get unwieldy pretty quickly. I am sort of hoping that we can draw the line as "AccumulateGrad is tough to compile so lets try to make DTensor overheads for it fast, and compile everything else". But I'm not sure how likely it is that we can do that.

…rm that aliasing DTensorSpec is ok"

**not for land**

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())


dist.destroy_process_group()
```




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 20, 2025
…asing DTensorSpec is ok

ghstack-source-id: 1b7affa
Pull Request resolved: #160580
pytorchmergebot pushed a commit that referenced this pull request Aug 25, 2025
…h__` and `__eq__` (#161234)

The performance cost of `dict` lookups keyed by `OpSchema` is a
significant minority of DTensor overhead. With this change we shave a
net ~1% off the total running time of the benchmark from #160580, as
measured by using cProfile and comparing cumulative time spent in
propagate + OpSchema's `__post_init__`. (`__post_init__` grew from
2.5% to 6.4% (+3.9%) and propagate shrank from 12.5% to 7.8% (-4.7%)).

Pull Request resolved: #161234
Approved by: https://github.com/wconstab
ghstack dependencies: #161231
swolchok added a commit that referenced this pull request Aug 27, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from #160580
and Linux perf.

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Aug 27, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from #160580
and Linux perf.

ghstack-source-id: 67895a5
Pull Request resolved: #161586
swolchok added a commit that referenced this pull request Aug 27, 2025
This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms)

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Aug 27, 2025
…sor._make_dtensor to accelerate DTensor.__new__ further"

This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms)

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Aug 28, 2025
…te DTensor.__new__ further"

This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms)

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Aug 28, 2025
…__ further"

This seems to be a (very very roughly) ~8% improvmeent on DTensor benchmark very similar to the benchmark from #160580 (120ish ms -> 110ish ms)

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Aug 28, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from #160580
and Linux perf.

[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Sep 13, 2025

@pytorchbot revert -m "this broke shampoo, yanking"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Sep 13, 2025

@pytorchbot revert -m "this broke shampoo, yanking" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@bdhirsh your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 13, 2025
…TensorSpec is ok (#160580)"

This reverts commit 4b2d297.

Reverted #160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](#160580 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Sep 13, 2025
@albanD albanD removed their request for review September 15, 2025 13:27
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…h__` and `__eq__` (pytorch#161234)

The performance cost of `dict` lookups keyed by `OpSchema` is a
significant minority of DTensor overhead. With this change we shave a
net ~1% off the total running time of the benchmark from pytorch#160580, as
measured by using cProfile and comparing cumulative time spent in
propagate + OpSchema's `__post_init__`. (`__post_init__` grew from
2.5% to 6.4% (+3.9%) and propagate shrank from 12.5% to 7.8% (-4.7%)).

Pull Request resolved: pytorch#161234
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#161231
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from pytorch#160580
and Linux perf.

Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107)
Pull Request resolved: pytorch#161586
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#161466
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ytorch#161590)

This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec)

Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105)
Pull Request resolved: pytorch#161590
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#161466, pytorch#161586
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ec is ok (pytorch#160580)

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())

dist.destroy_process_group()
```

Pull Request resolved: pytorch#160580
Approved by: https://github.com/ezyang
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…TensorSpec is ok (pytorch#160580)"

This reverts commit 4b2d297.

Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
@swolchok
Copy link
Contributor

are we expecting this to re-land?

@ezyang
Copy link
Contributor

ezyang commented Sep 20, 2025

Not until Brian is back from PTO (2 weeks) at least

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from pytorch#160580
and Linux perf.

Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107)
Pull Request resolved: pytorch#161586
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#161466
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ytorch#161590)

This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec)

Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105)
Pull Request resolved: pytorch#161590
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#161466, pytorch#161586
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ec is ok (pytorch#160580)

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())

dist.destroy_process_group()
```

Pull Request resolved: pytorch#160580
Approved by: https://github.com/ezyang
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…TensorSpec is ok (pytorch#160580)"

This reverts commit 4b2d297.

Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ytorch#161590)

This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec)

Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105)
Pull Request resolved: pytorch#161590
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#161466, pytorch#161586
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ec is ok (pytorch#160580)

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())

dist.destroy_process_group()
```

Pull Request resolved: pytorch#160580
Approved by: https://github.com/ezyang
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…TensorSpec is ok (pytorch#160580)"

This reverts commit 4b2d297.

Reverted pytorch#160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](pytorch#160580 (comment)))
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
If SymInt::maybe_as_int() returns non-empty, then we get an inline
fast path. The philosophy here (as with the previous PR) is to
preserve performance in the "plain old ints" case.

Observed time spent in SymInt functions in computeStorageNBytes to
drop (and not cost shift elsewhere in the function) after this change,
profiling detach() using code similar to the benchmark from pytorch#160580
and Linux perf.

Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107)
Pull Request resolved: pytorch#161586
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#161466
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ytorch#161590)

This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from pytorch#160580 (120ish usec -> 110ish usec)

Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105)
Pull Request resolved: pytorch#161590
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#161466, pytorch#161586
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ec is ok (pytorch#160580)

My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())

dist.destroy_process_group()
```

Pull Request resolved: pytorch#160580
Approved by: https://github.com/ezyang
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/h100-distributed ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category Reverted Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants