[DTensor] fix copy_ strategy#158538
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158538
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 3e65cb3 with merge base 1e86fa2 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: d026b77 Pull Request resolved: #158538
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: d026b77 Pull Request resolved: #158538
| for strategy in first_input_strategy.strategies | ||
| ] | ||
| ) | ||
| ) |
There was a problem hiding this comment.
The specs looks good to me, can you also add the redistribute_cost?
There was a problem hiding this comment.
oh yea! thanks. i had it in my first version and then forgot about it.
Can we actually make it a required argument for creating an OpSpec? (Unless we do your proposal of autogenerating them always, which i prefer)
There was a problem hiding this comment.
Yea, I think we can make it a required field once we fix all currently supported ops in the list #157495.
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: d339bbc Pull Request resolved: #158538
| # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) | ||
| @with_comms | ||
| def test_copy_broadcast_redistribute(self): | ||
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
There was a problem hiding this comment.
i suggest we universally use the init_device_mesh API instead
There was a problem hiding this comment.
i am just sticking with the convention in the file. I also want to change things to MultiProcContinuousTest. (or threaded, via PR #158082, if that can be made to work - so ill leave it alone in this PR
There was a problem hiding this comment.
I closed PR #158082 since I don't have the experience with the codebase yet, to know how to get it to work
so you can do those changes now if you want to
Propably better to split up those changes into seperate commits anyway
| def test_copy_broadcast_redistribute(self): | ||
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) | ||
| # The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen | ||
| src_specs = [[Shard(1)], [Shard(1)], [Shard(1)]] |
There was a problem hiding this comment.
src_spec = [Shard(1)]. Remove src_specs from the loop variable.
There was a problem hiding this comment.
i don't understand? seems ok to me.
| src_tensor = torch.randn((64, 1)) | ||
|
|
||
| dst_tensor = torch.zeros(16, 32, 64, 128) | ||
| dst_specs = [[Replicate()], [Shard(1)], [Shard(2)]] |
There was a problem hiding this comment.
not really, its the same edge case as Shard(1), both of them are going to trigger the Shard->Redistribute path. But i can add it, just for completeness.
| dst_dtensor.copy_(src_dtensor) | ||
| dst_tensor.copy_(src_tensor) | ||
| self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) |
There was a problem hiding this comment.
since we're not tracking comm counts, use the _run_test_on_dtensor util?
There was a problem hiding this comment.
actually, i should check the comm counts, i was just lazy. I should ensure that a redistribute happend.
| self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) | ||
|
|
||
| @with_comms | ||
| def test_copy_broadcast(self): |
There was a problem hiding this comment.
Do we need 2 separate tests for copy_ + broadcast? Can they merge?
There was a problem hiding this comment.
note- i prefer separate tests (or using parameterize) so its easy to run a subtest without commenting stuff out.
but its horribly slow with the current multiproc test case, so i merged them for now
| assert isinstance(op_schema.args_schema[0], OpStrategy) | ||
| assert isinstance(op_schema.args_schema[1], OpStrategy) | ||
| self_strategy: OpStrategy = op_schema.args_schema[0] | ||
| mesh = self_strategy.mesh |
There was a problem hiding this comment.
TODO: consider supporting cross-mesh copy
There was a problem hiding this comment.
hmm, is this possible?
at least, the meshes would have to have compatible shapes.
if we already have this support in other places, then i should follow it here. Is it documented somewhere / example code to point to?
wanchaol
left a comment
There was a problem hiding this comment.
Please see inlined comments, I think we should just use register copy_as a pointwise strategy in _tensor_ops.py
| # that is invalid for dst tensor. | ||
| # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we | ||
| # may broadcast a new dim to the left or right of 0 when copying. | ||
| def copy_inplace_strategy(op_schema: OpSchema) -> StrategyType: |
There was a problem hiding this comment.
Hmmm I think you probably want to remove the copy strategy, just register it as a pointwise strategy.
There was a problem hiding this comment.
hmm. i would be ok with that, but does pointwise strategy properly enforce the extra requirement of inplace_ ops, where broadcasting 'self' to match 'src' is NOT allowed?
There was a problem hiding this comment.
Yeah, for inplace ops it just follow the first argument
There was a problem hiding this comment.
ok, yea i think we can try pointwise. Wish I had known about it!
| assert isinstance(first_input_strategy, OpStrategy) | ||
| return OpStrategy( | ||
| [ | ||
| # DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input. |
There was a problem hiding this comment.
I don't quite like the fact that we are implementing broadcasting semantic separately, it would be better if you could just reuse the pointwise strategy for copy_
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: 86c3ff2 Pull Request resolved: #158538
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: 92e0320 Pull Request resolved: #158538
This reverts commit 7b05bdd. Reverted #158538 on behalf of https://github.com/clee2000 due to broke lint? [GH job link](https://github.com/pytorch/pytorch/actions/runs/16361950974/job/46231492581) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/d8b084312b54e97bdbaf6a178fe2fc628a23243b) ([comment](#158490 (comment)))
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta [ghstack-poisoned]
The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) These are the correct sharding combinations: self | src ------------------------ Shard(0) | Replicate() Shard(1) | Replicate() Shard(2) | Shard(0) Shard(3) | Shard(1) ghstack-source-id: 8aa8a68 Pull Request resolved: #158538
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor / cuda12.8-py3.10-gcc9-sm86 / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable), inductor / cuda12.8-py3.10-gcc9-sm86 / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This addresses reviews made for: #158538 #108749 It interchanged all the specific DevideMesh constructor calls with the API provided by the test cases, to improve abstraction Pull Request resolved: #158675 Approved by: https://github.com/wconstab
Fixing issue introduced in #158538 where `aten.copy_.default` is registered as a pointwise op, but without linearity. In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy. This was discovered from silent incorrect results e.g. on `torch.einsum` backward. Pull Request resolved: #162460 Approved by: https://github.com/zpcore
Fixing issue introduced in pytorch#158538 where `aten.copy_.default` is registered as a pointwise op, but without linearity. In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy. This was discovered from silent incorrect results e.g. on `torch.einsum` backward. Pull Request resolved: pytorch#162460 Approved by: https://github.com/zpcore
Fixing issue introduced in pytorch#158538 where `aten.copy_.default` is registered as a pointwise op, but without linearity. In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy. This was discovered from silent incorrect results e.g. on `torch.einsum` backward. Pull Request resolved: pytorch#162460 Approved by: https://github.com/zpcore
Fixing issue introduced in pytorch#158538 where `aten.copy_.default` is registered as a pointwise op, but without linearity. In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy. This was discovered from silent incorrect results e.g. on `torch.einsum` backward. Pull Request resolved: pytorch#162460 Approved by: https://github.com/zpcore
Fixing issue introduced in pytorch#158538 where `aten.copy_.default` is registered as a pointwise op, but without linearity. In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy. This was discovered from silent incorrect results e.g. on `torch.einsum` backward. Pull Request resolved: pytorch#162460 Approved by: https://github.com/zpcore
Stack from ghstack (oldest at bottom):
The previous strategy directly used 'self' input strategy for 'src'
input. The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.
E.g. for this program, broadcasting will occur on dims 0,1,3 of self.
These are the correct sharding combinations:
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta