[coor-slicing] Add SymInt support for DTensor mesh coordinate computation in PT2#169552
[coor-slicing] Add SymInt support for DTensor mesh coordinate computation in PT2#169552aorenste wants to merge 15 commits intogh/aorenste/159/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169552
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (9 Unrelated Failures)As of commit f8a63b8 with merge base ccc09a8 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
|
Is it difficult to have tests at this stage for the PR stack? I feel you now have enough kit for tests |
ezyang
left a comment
There was a problem hiding this comment.
I don't consider the unbacked symint stuff blocking, but this is more a question for @laithsakka
It's reasonable to ask for tests at this stage. It's a little tricky because we still have the redistribute targets as non-CooR so I have to make sure the tests only rely on slicing - but I'll put something together. |
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op device_mesh::_runtime_compute_coordinate_on_dim - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. DeviceMesh.sym_get_coordinate - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. Shard._select_split_tensor - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. Shard.local_shard_size_and_offset - Updated type hints to properly reflect SymInt return types. 5. Dynamo config - Always enables allow_dynamic_output_shape_ops (with TODO to reconsider if still needed). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types. 5. Adds a config to enable compile_on_one_rank cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types. 5. Adds a config to enable compile_on_one_rank cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
…ate computation in PT2" This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types. 5. Adds a config to enable compile_on_one_rank cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…tion in PT2 (pytorch#169552) This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware. 1. New custom op `device_mesh::_runtime_compute_coordinate_on_dim` - An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants. 2. `DeviceMesh.sym_get_coordinate` - Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime. 3. `Shard._select_split_tensor` - Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning. 4. `Shard.local_shard_size_and_offset` - Updated type hints to properly reflect SymInt return types. 5. Adds a config to enable compile_on_one_rank Pull Request resolved: pytorch#169552 Approved by: https://github.com/ezyang
This change enables compile-on-one-rank for DTensor slicing by making mesh coordinate lookups symbolic-aware.
device_mesh::_runtime_compute_coordinate_on_dim- An operator that computes mesh coordinates at runtime, allowing coordinate lookups to be deferred during tracing rather than baked in as constants.DeviceMesh.sym_get_coordinate- Now uses the custom op when in fake mode (tracing), lifting the rank map as a graph constant and deferring the actual coordinate computation to runtime.Shard._select_split_tensor- Extended to handle SymInt indices by using torch.narrow with symbolic start/length instead of list indexing, enabling symbolic tensor partitioning.Shard.local_shard_size_and_offset- Updated type hints to properly reflect SymInt return types.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo