[DTensor] fix redistribute cost crashing on non-participating ranks#172478
[DTensor] fix redistribute cost crashing on non-participating ranks#172478wconstab wants to merge 8 commits intogh/wconstab/500/basefrom
Conversation
Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by @aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc @pianpwk for discussion [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172478
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 29c7796 with merge base b731ffe ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion ghstack-source-id: 32a07ce Pull Request resolved: pytorch/pytorch#172478
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion ghstack-source-id: 7913d7e Pull Request resolved: pytorch/pytorch#172478
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion This PR is to fix an error happening on this test: ``` with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) self.assertEqual(dtensor.size(), torch.Size([6, 4])) self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4), torch.tensor([]), dtensor.to_local(), ) # test dtensor created in submesh, the operation should only # be applied to the local shard inside the mesh, not the whole # world, so only 0/2 really run the computation dtensor = dtensor + 2 self.sub_mesh_assert_equal( mesh.mesh, torch.ones(3, 4) + 2, torch.tensor([]), dtensor.to_local(), ) ``` After looking at the test, I am very confused about why we support this behavior in the first place. aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh. I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. Claude summarized the behavior: * we run shard prop on and shape prop on every (including non-participating) rank: ``` | Question | Answer | |----------------------------------------|----------------------------------------------------------| | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([]) | | Has global shape? | Yes - dtensor.size() returns (6, 4) | | Has placements? | Yes - same as participating ranks | | Runs shape propagation? | Yes - output spec is computed, just no local computation | The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation. ``` [ghstack-poisoned]
…ing ranks" Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion [ghstack-poisoned]
pianpwk
left a comment
There was a problem hiding this comment.
is there a possibility that shard prop for non-participating ranks might now cache in suboptimal redistribute decisions?
I think the answer should be no: we include the mesh in the cache-key, so if later we tried to run the same op including more ranks, that would imply a different mesh and we would not cache-hit. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#172478) Previously, ranks not participating in redistribution would hit an assert in redistribution planner that the rank was participating. The assert in question was added recently in pytorch#169548 by aorenste, and i'm not sure if patching an early exit in this PR is the best fix or rethinking the original assert. Also cc pianpwk for discussion Pull Request resolved: pytorch#172478 Approved by: https://github.com/pianpwk
Stack from ghstack (oldest at bottom):
Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.
The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert. Also cc pianpwk for discussion