Skip to content

[coor-slicing] Add SymInt support for DTensor mesh coordinate computation in PT2#169552

Closed
aorenste wants to merge 15 commits intogh/aorenste/159/basefrom
gh/aorenste/159/head
Closed

[coor-slicing] Add SymInt support for DTensor mesh coordinate computation in PT2#169552
aorenste wants to merge 15 commits intogh/aorenste/159/basefrom
gh/aorenste/159/head

Conversation

@aorenste
Copy link
Copy Markdown
Contributor

@aorenste aorenste commented Dec 4, 2025

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

  1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
  2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
  3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
  4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
  5. Adds a config to enable compile_on_one_rank

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

@aorenste aorenste mentioned this pull request Dec 4, 2025
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169552

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (9 Unrelated Failures)

As of commit f8a63b8 with merge base ccc09a8 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

aorenste added a commit that referenced this pull request Dec 4, 2025
ghstack-source-id: 62fda46
Pull Request resolved: #169552
aorenste added a commit that referenced this pull request Dec 5, 2025
ghstack-source-id: f9383e1
Pull Request resolved: #169552
aorenste added a commit that referenced this pull request Dec 8, 2025
ghstack-source-id: 9fcce86
Pull Request resolved: #169552
aorenste added a commit that referenced this pull request Dec 8, 2025
ghstack-source-id: e2829f8
Pull Request resolved: #169552
aorenste added a commit that referenced this pull request Dec 10, 2025
ghstack-source-id: 173e6e5
Pull Request resolved: #169552
@aorenste aorenste changed the title WIP: [compile-one-rank] slicing [coor-slicing] Add dynamic slicing Jan 6, 2026
@aorenste aorenste changed the title [coor-slicing] Add dynamic slicing [coor-slicing] Add SymInt support for DTensor mesh coordinate computation in PT2 Jan 6, 2026
@aorenste aorenste added topic: not user facing topic category release notes: distributed (dtensor) release notes category and removed topic: not user facing topic category labels Jan 6, 2026
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 7, 2026
ghstack-source-id: f8bef2f
Pull Request resolved: #169552
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jan 8, 2026

Is it difficult to have tests at this stage for the PR stack? I feel you now have enough kit for tests

Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't consider the unbacked symint stuff blocking, but this is more a question for @laithsakka

@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Jan 8, 2026

Is it difficult to have tests at this stage for the PR stack? I feel you now have enough kit for tests

It's reasonable to ask for tests at this stage. It's a little tricky because we still have the redistribute targets as non-CooR so I have to make sure the tests only rely on slicing - but I'll put something together.

…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 8, 2026
ghstack-source-id: b110b06
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 9, 2026
ghstack-source-id: 14647f4
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 9, 2026
ghstack-source-id: cc37d97
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 16, 2026
ghstack-source-id: 0c5a59d
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types.
5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed).





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 16, 2026
ghstack-source-id: 6529012
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types.
5. Adds a config to enable compile_on_one_rank





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 16, 2026
ghstack-source-id: e3574bb
Pull Request resolved: #169552
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types.
5. Adds a config to enable compile_on_one_rank





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 16, 2026
ghstack-source-id: 7fc0002
Pull Request resolved: #169552
@aorenste aorenste added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 16, 2026
…ate computation in PT2"

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types.
5. Adds a config to enable compile_on_one_rank





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jan 17, 2026
ghstack-source-id: e09df08
Pull Request resolved: #169552
@aorenste
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

apakbin pushed a commit to apakbin/pytorch that referenced this pull request Jan 19, 2026
…tion in PT2 (pytorch#169552)

This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.

1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.
2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.
3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.
4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types.
5. Adds a config to enable compile_on_one_rank

Pull Request resolved: pytorch#169552
Approved by: https://github.com/ezyang
@github-actions github-actions Bot deleted the gh/aorenste/159/head branch February 17, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants