Skip to content

[DTensor] fix copy_ strategy#158538

Closed
wconstab wants to merge 8 commits intogh/wconstab/431/basefrom
gh/wconstab/431/head
Closed

[DTensor] fix copy_ strategy#158538
wconstab wants to merge 8 commits intogh/wconstab/431/basefrom
gh/wconstab/431/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jul 17, 2025

Stack from ghstack (oldest at bottom):

The previous strategy directly used 'self' input strategy for 'src'
input. The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

self src
Shard(0) Replicate()
Shard(1) Replicate()
Shard(2) Shard(0)
Shard(3) Shard(1)

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158538

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 3e65cb3 with merge base 1e86fa2 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 17, 2025
wconstab added a commit that referenced this pull request Jul 17, 2025
ghstack-source-id: 06392d0
Pull Request resolved: #158538
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
ghstack-source-id: 555e5eb
Pull Request resolved: #158538
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: d026b77
Pull Request resolved: #158538
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: d026b77
Pull Request resolved: #158538
@wconstab wconstab requested review from XilunWu, wanchaol and zpcore and removed request for wanchaol July 17, 2025 20:32
for strategy in first_input_strategy.strategies
]
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The specs looks good to me, can you also add the redistribute_cost?

Copy link
Contributor Author

@wconstab wconstab Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea! thanks. i had it in my first version and then forgot about it.

Can we actually make it a required argument for creating an OpSpec? (Unless we do your proposal of autogenerating them always, which i prefer)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think we can make it a required field once we fix all currently supported ops in the list #157495.

The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

```
self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)
```

These are the correct sharding combinations:

|   self   |     src |
|-------|------|
| Shard(0)  |   Replicate() |
| Shard(1)  |   Replicate() |
| Shard(2)  |   Shard(0) |
| Shard(3)  |   Shard(1) |

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: d339bbc
Pull Request resolved: #158538
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Jul 17, 2025
Copy link
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nits

# self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
@with_comms
def test_copy_broadcast_redistribute(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i suggest we universally use the init_device_mesh API instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am just sticking with the convention in the file. I also want to change things to MultiProcContinuousTest. (or threaded, via PR #158082, if that can be made to work - so ill leave it alone in this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I closed PR #158082 since I don't have the experience with the codebase yet, to know how to get it to work
so you can do those changes now if you want to

Propably better to split up those changes into seperate commits anyway

def test_copy_broadcast_redistribute(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen
src_specs = [[Shard(1)], [Shard(1)], [Shard(1)]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src_spec = [Shard(1)]. Remove src_specs from the loop variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't understand? seems ok to me.

src_tensor = torch.randn((64, 1))

dst_tensor = torch.zeros(16, 32, 64, 128)
dst_specs = [[Replicate()], [Shard(1)], [Shard(2)]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for [Shard(0)]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, its the same edge case as Shard(1), both of them are going to trigger the Shard->Redistribute path. But i can add it, just for completeness.

Comment on lines +83 to +85
dst_dtensor.copy_(src_dtensor)
dst_tensor.copy_(src_tensor)
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're not tracking comm counts, use the _run_test_on_dtensor util?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, i should check the comm counts, i was just lazy. I should ensure that a redistribute happend.

self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)

@with_comms
def test_copy_broadcast(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need 2 separate tests for copy_ + broadcast? Can they merge?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note- i prefer separate tests (or using parameterize) so its easy to run a subtest without commenting stuff out.

but its horribly slow with the current multiproc test case, so i merged them for now

assert isinstance(op_schema.args_schema[0], OpStrategy)
assert isinstance(op_schema.args_schema[1], OpStrategy)
self_strategy: OpStrategy = op_schema.args_schema[0]
mesh = self_strategy.mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: consider supporting cross-mesh copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this possible?
at least, the meshes would have to have compatible shapes.

if we already have this support in other places, then i should follow it here. Is it documented somewhere / example code to point to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see people are supporting cross-mesh ops to some degree such as in: #157682 and #157049

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see inlined comments, I think we should just use register copy_as a pointwise strategy in _tensor_ops.py

# that is invalid for dst tensor.
# It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we
# may broadcast a new dim to the left or right of 0 when copying.
def copy_inplace_strategy(op_schema: OpSchema) -> StrategyType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think you probably want to remove the copy strategy, just register it as a pointwise strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. i would be ok with that, but does pointwise strategy properly enforce the extra requirement of inplace_ ops, where broadcasting 'self' to match 'src' is NOT allowed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for inplace ops it just follow the first argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, yea i think we can try pointwise. Wish I had known about it!

assert isinstance(first_input_strategy, OpStrategy)
return OpStrategy(
[
# DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite like the fact that we are implementing broadcasting semantic separately, it would be better if you could just reuse the pointwise strategy for copy_

The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

```
self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)
```

These are the correct sharding combinations:

|   self   |     src |
|-------|------|
| Shard(0)  |   Replicate() |
| Shard(1)  |   Replicate() |
| Shard(2)  |   Shard(0) |
| Shard(3)  |   Shard(1) |

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: 86c3ff2
Pull Request resolved: #158538
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

```
self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)
```

These are the correct sharding combinations:

|   self   |     src |
|-------|------|
| Shard(0)  |   Replicate() |
| Shard(1)  |   Replicate() |
| Shard(2)  |   Shard(0) |
| Shard(3)  |   Shard(1) |

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 17, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: 92e0320
Pull Request resolved: #158538
@pytorchmergebot
Copy link
Collaborator

@wconstab your PR has been reverted as part of the stack under #158490.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jul 18, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

```
self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)
```

These are the correct sharding combinations:

|   self   |     src |
|-------|------|
| Shard(0)  |   Replicate() |
| Shard(1)  |   Replicate() |
| Shard(2)  |   Shard(0) |
| Shard(3)  |   Shard(1) |

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 18, 2025
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)

These are the correct sharding combinations:

   self   |     src
------------------------
Shard(0)  |   Replicate()
Shard(1)  |   Replicate()
Shard(2)  |   Shard(0)
Shard(3)  |   Shard(1)

ghstack-source-id: 8aa8a68
Pull Request resolved: #158538
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / cuda12.8-py3.10-gcc9-sm86 / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@wconstab
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable), inductor / cuda12.8-py3.10-gcc9-sm86 / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 22, 2025
This addresses reviews made for:
#158538
#108749

It interchanged all the specific DevideMesh constructor calls with the API provided by the test cases, to improve abstraction

Pull Request resolved: #158675
Approved by: https://github.com/wconstab
@github-actions github-actions bot deleted the gh/wconstab/431/head branch August 18, 2025 02:21
pytorchmergebot pushed a commit that referenced this pull request Sep 10, 2025
Fixing issue introduced in #158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: #162460
Approved by: https://github.com/zpcore
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fixing issue introduced in pytorch#158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: pytorch#162460
Approved by: https://github.com/zpcore
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Fixing issue introduced in pytorch#158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: pytorch#162460
Approved by: https://github.com/zpcore
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Fixing issue introduced in pytorch#158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: pytorch#162460
Approved by: https://github.com/zpcore
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Fixing issue introduced in pytorch#158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: pytorch#162460
Approved by: https://github.com/zpcore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants