Skip to content

[coor-slicing] Factor out DeviceMesh._compute_coordinates_from_mesh#169549

Closed
aorenste wants to merge 10 commits intogh/aorenste/156/basefrom
gh/aorenste/156/head
Closed

[coor-slicing] Factor out DeviceMesh._compute_coordinates_from_mesh#169549
aorenste wants to merge 10 commits intogh/aorenste/156/basefrom
gh/aorenste/156/head

Conversation

@aorenste
Copy link
Copy Markdown
Contributor

@aorenste aorenste commented Dec 4, 2025

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a RankType alias to represent rank-like types (union of int and SymInt).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169549

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 59 Pending

As of commit 514595a with merge base dc48fef (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
@aorenste aorenste added the topic: not user facing topic category label Dec 10, 2025
@aorenste aorenste changed the title WIP: factor out DeviceMesh._compute_coordinates_from_mesh [compile-on-one-rank] Factor out DeviceMesh._compute_coordinates_from_mesh Dec 10, 2025
@aorenste aorenste changed the title [compile-on-one-rank] Factor out DeviceMesh._compute_coordinates_from_mesh Factor out DeviceMesh._compute_coordinates_from_mesh Dec 10, 2025
@aorenste aorenste changed the title Factor out DeviceMesh._compute_coordinates_from_mesh [coor-slicing] Factor out DeviceMesh._compute_coordinates_from_mesh Jan 6, 2026
…from_mesh"

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)




[ghstack-poisoned]
@aorenste aorenste marked this pull request as ready for review January 7, 2026 14:35
@aorenste aorenste requested a review from ezyang January 7, 2026 14:35
@ezyang ezyang requested a review from dzmitry-huba January 8, 2026 15:02
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 8, 2026

Because we don't have the DeviceMesh in the graph

Confusion here. Won't we have it in the graph after @angelayi's stuff? #169867

@ezyang ezyang requested a review from fduwjj January 8, 2026 15:13
@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Jan 8, 2026

Because we don't have the DeviceMesh in the graph

Confusion here. Won't we have it in the graph after @angelayi's stuff? #169867

I think that's only if you refer to a DeviceMesh in the compiled region where Dynamo can see it (like in a DTensor.from_local() call). But if you're doing DTensor ops that don't refer to DeviceMesh as parameters they don't decompose until after Dynamo is done and Dynamo won't hoist the DeviceMesh as a parameter because it never sees it.

(Although I fully concede that I could be missing some detail about Angela's PRs stack that makes it work)

…from_mesh"

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)




[ghstack-poisoned]
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #169551

pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2026
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method.

Pull Request resolved: #169550
Approved by: https://github.com/ezyang
ghstack dependencies: #169549
pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2026
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`.

Pull Request resolved: #169551
Approved by: https://github.com/ezyang
ghstack dependencies: #169549, #169550
pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2026
…atible (#172176)

Fix for #169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible.

Pull Request resolved: #172176
Approved by: https://github.com/bobrenjc93
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…ytorch#169549)

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)

Pull Request resolved: pytorch#169549
Approved by: https://github.com/ezyang
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method.

Pull Request resolved: pytorch#169550
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`.

Pull Request resolved: pytorch#169551
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549, pytorch#169550
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…atible (pytorch#172176)

Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible.

Pull Request resolved: pytorch#172176
Approved by: https://github.com/bobrenjc93
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…ard-compatible (pytorch#172176)"

This reverts commit 084f69f.

Reverted pytorch#172176 on behalf of https://github.com/jeanschmidt due to sorry, need to revert in order to revert pytorch#169549, please feel free to re-merge this change once rebased ([comment](pytorch#172176 (comment)))
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…m_mesh (pytorch#169549)"

This reverts commit 75ce1e9.

Reverted pytorch#169549 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal signals, see D90448078 ([comment](pytorch#169549 (comment)))
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…ytorch#169549)

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)

Pull Request resolved: pytorch#169549
Approved by: https://github.com/ezyang
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method.

Pull Request resolved: pytorch#169550
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`.

Pull Request resolved: pytorch#169551
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549, pytorch#169550
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…atible (pytorch#172176)

Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible.

Pull Request resolved: pytorch#172176
Approved by: https://github.com/bobrenjc93
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…ytorch#169549)

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)

Pull Request resolved: pytorch#169549
Approved by: https://github.com/ezyang
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method.

Pull Request resolved: pytorch#169550
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`.

Pull Request resolved: pytorch#169551
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549, pytorch#169550
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…atible (pytorch#172176)

Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible.

Pull Request resolved: pytorch#172176
Approved by: https://github.com/bobrenjc93
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…ard-compatible (pytorch#172176)"

This reverts commit 084f69f.

Reverted pytorch#172176 on behalf of https://github.com/jeanschmidt due to sorry, need to revert in order to revert pytorch#169549, please feel free to re-merge this change once rebased ([comment](pytorch#172176 (comment)))
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…m_mesh (pytorch#169549)"

This reverts commit 75ce1e9.

Reverted pytorch#169549 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal signals, see D90448078 ([comment](pytorch#169549 (comment)))
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…ytorch#169549)

For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.

Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`).

Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)

Pull Request resolved: pytorch#169549
Approved by: https://github.com/ezyang
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method.

Pull Request resolved: pytorch#169550
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`.

Pull Request resolved: pytorch#169551
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#169549, pytorch#169550
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
…atible (pytorch#172176)

Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible.

Pull Request resolved: pytorch#172176
Approved by: https://github.com/bobrenjc93
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
@github-actions github-actions Bot deleted the gh/aorenste/156/head branch February 12, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor Merged Reverted suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants