[coor-slicing] Factor out DeviceMesh._compute_coordinates_from_mesh#169549
[coor-slicing] Factor out DeviceMesh._compute_coordinates_from_mesh#169549aorenste wants to merge 10 commits intogh/aorenste/156/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169549
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 59 PendingAs of commit 514595a with merge base dc48fef ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 18b0b0c Pull Request resolved: pytorch/pytorch#169549
[ghstack-poisoned]
…from_mesh" For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) [ghstack-poisoned]
I think that's only if you refer to a DeviceMesh in the compiled region where Dynamo can see it (like in a (Although I fully concede that I could be missing some detail about Angela's PRs stack that makes it work) |
…from_mesh" For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) [ghstack-poisoned]
|
Starting merge as part of PR stack under #169551 |
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method. Pull Request resolved: #169550 Approved by: https://github.com/ezyang ghstack dependencies: #169549
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`. Pull Request resolved: #169551 Approved by: https://github.com/ezyang ghstack dependencies: #169549, #169550
…atible (#172176) Fix for #169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible. Pull Request resolved: #172176 Approved by: https://github.com/bobrenjc93
…ytorch#169549) For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) Pull Request resolved: pytorch#169549 Approved by: https://github.com/ezyang
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method. Pull Request resolved: pytorch#169550 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`. Pull Request resolved: pytorch#169551 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549, pytorch#169550
…atible (pytorch#172176) Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible. Pull Request resolved: pytorch#172176 Approved by: https://github.com/bobrenjc93
…ard-compatible (pytorch#172176)" This reverts commit 084f69f. Reverted pytorch#172176 on behalf of https://github.com/jeanschmidt due to sorry, need to revert in order to revert pytorch#169549, please feel free to re-merge this change once rebased ([comment](pytorch#172176 (comment)))
…m_mesh (pytorch#169549)" This reverts commit 75ce1e9. Reverted pytorch#169549 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal signals, see D90448078 ([comment](pytorch#169549 (comment)))
…ytorch#169549) For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) Pull Request resolved: pytorch#169549 Approved by: https://github.com/ezyang
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method. Pull Request resolved: pytorch#169550 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`. Pull Request resolved: pytorch#169551 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549, pytorch#169550
…atible (pytorch#172176) Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible. Pull Request resolved: pytorch#172176 Approved by: https://github.com/bobrenjc93
…ytorch#169549) For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) Pull Request resolved: pytorch#169549 Approved by: https://github.com/ezyang
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method. Pull Request resolved: pytorch#169550 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`. Pull Request resolved: pytorch#169551 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549, pytorch#169550
…atible (pytorch#172176) Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible. Pull Request resolved: pytorch#172176 Approved by: https://github.com/bobrenjc93
…ard-compatible (pytorch#172176)" This reverts commit 084f69f. Reverted pytorch#172176 on behalf of https://github.com/jeanschmidt due to sorry, need to revert in order to revert pytorch#169549, please feel free to re-merge this change once rebased ([comment](pytorch#172176 (comment)))
…m_mesh (pytorch#169549)" This reverts commit 75ce1e9. Reverted pytorch#169549 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal signals, see D90448078 ([comment](pytorch#169549 (comment)))
…ytorch#169549) For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function. Also defined a `RankType` alias to represent rank-like types (union of `int` and `SymInt`). Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR) Pull Request resolved: pytorch#169549 Approved by: https://github.com/ezyang
For compile on one rank we need to be able to compute the DeviceMesh rank Tensor based on the raw Tensor and current rank. So this PR factors out `DeviceMesh._get_mesh_tensor_from_full_mesh()` into a static method. Pull Request resolved: pytorch#169550 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549
`Placement._split_tensor()` computes and returns too much information - in general most callers call it and then throw away most of the results. Added `Placement._select_split_tensor()` which allows the caller to say which parts they want so we can compute only those bits - in essence it is the combination of `Placement._split_tensor()` and `Shard._select_shard()`. Pull Request resolved: pytorch#169551 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#169549, pytorch#169550
…atible (pytorch#172176) Fix for pytorch#169549: An internal user was calling `_compute_local_shape_and_global_offset()` directly so I figured it was safer to make the API backward compatible. Pull Request resolved: pytorch#172176 Approved by: https://github.com/bobrenjc93
ghstack-source-id: b518abf Pull Request resolved: pytorch/pytorch#169549
ghstack-source-id: 5b9971e Pull Request resolved: pytorch/pytorch#169549
For compile on one rank we need to be able to compute the _coordinate_on_dim based on the Tensor and rank (because we don't have the DeviceMesh in the graph). So this PR factors out the _coordinate_on_dim logic into its own function.
Also defined a
RankTypealias to represent rank-like types (union ofintandSymInt).Another issue is that _compute_local_shape_and_global_offset() is passed an entire coordinate array but really only needs a single value - so change it to take a lambda to return the desired coordinate. Often this is just DeviceMesh.sym_get_coordinate (from the previous PR)
Stack from ghstack (oldest at bottom):