Skip to content

[DTensor] fix redistribute cost crashing on non-participating ranks#172478

Closed
wconstab wants to merge 8 commits intogh/wconstab/500/basefrom
gh/wconstab/500/head
Closed

[DTensor] fix redistribute cost crashing on non-participating ranks#172478
wconstab wants to merge 8 commits intogh/wconstab/500/basefrom
gh/wconstab/500/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jan 14, 2026

Stack from ghstack (oldest at bottom):

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert. Also cc pianpwk for discussion

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by @aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc @pianpwk for discussion

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172478

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 29c7796 with merge base b731ffe (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

ghstack-source-id: 32a07ce
Pull Request resolved: pytorch/pytorch#172478
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

ghstack-source-id: 7913d7e
Pull Request resolved: pytorch/pytorch#172478
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
@wconstab wconstab requested review from pianpwk and zpcore January 26, 2026 04:03
…ing ranks"

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

[ghstack-poisoned]
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a possibility that shard prop for non-participating ranks might now cache in suboptimal redistribute decisions?

@wconstab
Copy link
Copy Markdown
Contributor Author

is there a possibility that shard prop for non-participating ranks might now cache in suboptimal redistribute decisions?

I think the answer should be no: we include the mesh in the cache-key, so if later we tried to run the same op including more ranks, that would imply a different mesh and we would not cache-hit.

@wconstab
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 26, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

riccardofelluga pushed a commit to riccardofelluga/pytorch that referenced this pull request Jan 27, 2026
…ytorch#172478)

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in pytorch#169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion
Pull Request resolved: pytorch#172478
Approved by: https://github.com/pianpwk
@github-actions github-actions Bot deleted the gh/wconstab/500/head branch February 26, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants