Skip to content

[coor-slicing] DeviceMesh.is_current_rank_part_of_mesh#169548

Closed
aorenste wants to merge 11 commits intogh/aorenste/155/basefrom
gh/aorenste/155/head
Closed

[coor-slicing] DeviceMesh.is_current_rank_part_of_mesh#169548
aorenste wants to merge 11 commits intogh/aorenste/155/basefrom
gh/aorenste/155/head

Conversation

@aorenste
Copy link
Copy Markdown
Contributor

@aorenste aorenste commented Dec 4, 2025

Adds two methods to DeviceMesh:

  • is_current_rank_part_of_mesh
    There are a number of places where we only care if our current rank is part of the DeviceMesh but don't actually care about other information. So instead of getting all the mesh coordinates and checking for None we can just have a predicate that says whether we are part of the rank or not.
  • sym_get_coordinate
    Morally equivalent to get_coordinate()[i] - Instead of getting all the mesh coordinates as a list and extracting the one we want this allows specifying the rank and just getting the specific coordinate for that rank. Right now it only returns int but in the future can also return a SymInt.

Today both of these are a simple lookup in the _coordinate_on_dim array but a later PR will specialize them to properly limit their scope to make compile-on-one-rank happier.

Also the LocalTensorMode ad-hoc method patching was becoming unwieldy (and I had to add more items to it) so I made it automated based off a list instead of easy to miss one-offs.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169548

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e2055ee with merge base d26db89 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@aorenste aorenste added the topic: not user facing topic category label Dec 7, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
ghstack-source-id: c24ef48
Pull Request resolved: pytorch/pytorch#169548
@aorenste aorenste changed the title WIP: DeviceMesh.is_part_of_mesh DeviceMesh.is_current_rank_part_of_mesh Dec 10, 2025
@aorenste aorenste marked this pull request as ready for review December 10, 2025 15:35
Adds two methods to DeviceMesh:
- `is_current_rank_part_of_mesh`
    There are a number of places where we only care if our current rank is part of the DeviceMesh but don't actually care about other information. So instead of getting all the mesh coordinates and checking for `None` we can just have a predicate that says whether we are part of the rank or not.
- `sym_get_coordinate`
    Morally equivalent to `get_coordinate()[i]` - Instead of getting all the mesh coordinates as a list and extracting the one we want this allows specifying the rank and just getting the specific coordinate for that rank. Right now it only returns `int` but in the future can also return a `SymInt`.

Today both of these are a simple lookup in the `_coordinate_on_dim` array but a later PR will specialize them to properly limit their scope to make compile-on-one-rank happier.




[ghstack-poisoned]
aorenste added a commit that referenced this pull request Dec 19, 2025
ghstack-source-id: cd33ca4
Pull Request resolved: #169548
Adds two methods to DeviceMesh:
- `is_current_rank_part_of_mesh`
    There are a number of places where we only care if our current rank is part of the DeviceMesh but don't actually care about other information. So instead of getting all the mesh coordinates and checking for `None` we can just have a predicate that says whether we are part of the rank or not.
- `sym_get_coordinate`
    Morally equivalent to `get_coordinate()[i]` - Instead of getting all the mesh coordinates as a list and extracting the one we want this allows specifying the rank and just getting the specific coordinate for that rank. Right now it only returns `int` but in the future can also return a `SymInt`.

Today both of these are a simple lookup in the `_coordinate_on_dim` array but a later PR will specialize them to properly limit their scope to make compile-on-one-rank happier.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Dec 19, 2025
ghstack-source-id: d85ab1a
Pull Request resolved: #169548
Adds two methods to DeviceMesh:
- `is_current_rank_part_of_mesh`
    There are a number of places where we only care if our current rank is part of the DeviceMesh but don't actually care about other information. So instead of getting all the mesh coordinates and checking for `None` we can just have a predicate that says whether we are part of the rank or not.
- `sym_get_coordinate`
    Morally equivalent to `get_coordinate()[i]` - Instead of getting all the mesh coordinates as a list and extracting the one we want this allows specifying the rank and just getting the specific coordinate for that rank. Right now it only returns `int` but in the future can also return a `SymInt`.

Today both of these are a simple lookup in the `_coordinate_on_dim` array but a later PR will specialize them to properly limit their scope to make compile-on-one-rank happier.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Jan 9, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 9, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
Adds two methods to DeviceMesh:
- `is_current_rank_part_of_mesh`
    There are a number of places where we only care if our current rank is part of the DeviceMesh but don't actually care about other information. So instead of getting all the mesh coordinates and checking for `None` we can just have a predicate that says whether we are part of the rank or not.
- `sym_get_coordinate`
    Morally equivalent to `get_coordinate()[i]` - Instead of getting all the mesh coordinates as a list and extracting the one we want this allows specifying the rank and just getting the specific coordinate for that rank. Right now it only returns `int` but in the future can also return a `SymInt`.

Today both of these are a simple lookup in the `_coordinate_on_dim` array but a later PR will specialize them to properly limit their scope to make compile-on-one-rank happier.

Also the LocalTensorMode ad-hoc method patching was becoming unwieldy (and I had to add more items to it) so I made it automated based off a list instead of easy to miss one-offs.

Pull Request resolved: pytorch#169548
Approved by: https://github.com/zpcore, https://github.com/bobrenjc93
wconstab added a commit that referenced this pull request Jan 14, 2026
Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by @aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc @pianpwk for discussion

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 15, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 15, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 15, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 15, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 16, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 16, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
ghstack-source-id: 6b55006
Pull Request resolved: pytorch/pytorch#169548
wconstab added a commit that referenced this pull request Jan 20, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
…n non-participating ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
…ing ranks"



Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

This PR is to fix an error happening on this test:
```
    with_comms
    def test_from_local_sub_mesh(self):
        mesh = DeviceMesh(self.device_type, [0, 2])
        local_tensor = torch.ones(3, 4)

        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
        self.assertEqual(dtensor.size(), torch.Size([6, 4]))

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4),
            torch.tensor([]),
            dtensor.to_local(),
        )

        # test dtensor created in submesh, the operation should only
        # be applied to the local shard inside the mesh, not the whole
        # world, so only 0/2 really run the computation
        dtensor = dtensor + 2

        self.sub_mesh_assert_equal(
            mesh.mesh,
            torch.ones(3, 4) + 2,
            torch.tensor([]),
            dtensor.to_local(),
        )
```
After looking at the test, I am very confused about why we support this behavior in the first place.  aorenste suggested maybe we should just make DTensor.from_local error out on ranks that aren't included in the mesh.  I am not sure why we want to allow python code to run on non-participating ranks, and go through dtensor dispatch, and return a dtensor object that is defunct. 

Claude summarized the behavior:
* we run shard prop on and shape prop on every (including non-participating) rank:
```
  | Question                               | Answer                                                   |
  |----------------------------------------|----------------------------------------------------------|
  | Value of dtensor + 2 on excluded ranks | Empty tensor torch.tensor([])                            |
  | Has global shape?                      | Yes - dtensor.size() returns (6, 4)                      |
  | Has placements?                        | Yes - same as participating ranks                        |
  | Runs shape propagation?                | Yes - output spec is computed, just no local computation |

  The design ensures all ranks can query DTensor properties consistently while only participating ranks do actual computation.

``` 


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
…n non-participating ranks"

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
…ing ranks"

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jan 27, 2026
…172478)

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in #169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion
Pull Request resolved: #172478
Approved by: https://github.com/pianpwk
riccardofelluga pushed a commit to riccardofelluga/pytorch that referenced this pull request Jan 27, 2026
…ytorch#172478)

Previously, ranks not participating in redistribution would hit an
assert in redistribution planner that the rank was participating.

The assert in question was added recently in pytorch#169548 by aorenste, and i'm not sure if patching
an early exit in this PR is the best fix or rethinking the original
assert.  Also cc pianpwk for discussion
Pull Request resolved: pytorch#172478
Approved by: https://github.com/pianpwk
@github-actions github-actions Bot deleted the gh/aorenste/155/head branch February 9, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants