fix redistribute() handling for finding flattened device mesh dims under compile#173873
fix redistribute() handling for finding flattened device mesh dims under compile#173873bdhirsh wants to merge 2 commits intogh/bdhirsh/699/basefrom
Conversation
…der compile [ghstack-poisoned]
This PR needs a
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173873
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 7 Unrelated FailuresAs of commit ee61515 with merge base 969986a ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| expected_layout = submesh._layout.coalesce() | ||
| # Compute expected layout WITHOUT creating a submesh (avoids tracing issues) | ||
| # _get_slice_mesh_layout does pure layout math, no tensor operations | ||
| sliced_layout = root_mesh._get_slice_mesh_layout(dim_names) |
There was a problem hiding this comment.
should this be mesh instead of root_mesh? (for the same reason as #173790)
…esh dims under compile" Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like `submesh = mesh[dim_names]` will try to construct a fresh DeviceMesh, and ends up calling `.item()` (full stacktrace of the error below). I'm not 100% familiar with the `DeviceMesh` API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call `.item` Stacktrace: ``` output = redistribute_local_tensor( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor optimized_transform_infos = _optimize_transform_infos( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos flattened, failure_reason = try_create_flattened(group) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout submesh = mesh[dim_names] ~~~~^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__ submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh res_submesh = DeviceMesh( ^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__ if self._layout.numel() != self.mesh.numel(): ^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh return self._get_mesh_tensor_from_full_mesh(full_mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh return full_mesh[my_coords[0, 0]] ~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__ outs_unwrapped = func._op_dk( ^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call raise RuntimeError( torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar. It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call. ``` [ghstack-poisoned]
wconstab
left a comment
There was a problem hiding this comment.
lgtm. also i am happy to help make an internal diff of this to land asap if this is blocking, i'm not sure it is. thanks for the fix!
…der compile Summary: internal-first land of #173873 Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like submesh = mesh[dim_names] will try to construct a fresh DeviceMesh, and ends up calling .item() (full stacktrace of the error below). I'm not 100% familiar with the DeviceMesh API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call .item Stacktrace: output = redistribute_local_tensor( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor optimized_transform_infos = _optimize_transform_infos( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos flattened, failure_reason = try_create_flattened(group) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout submesh = mesh[dim_names] ~~~~^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__ submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh res_submesh = DeviceMesh( ^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__ if self._layout.numel() != self.mesh.numel(): ^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh return self._get_mesh_tensor_from_full_mesh(full_mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh return full_mesh[my_coords[0, 0]] ~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__ outs_unwrapped = func._op_dk( ^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call raise RuntimeError( torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar. It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call. Test Plan: python test/distributed/tensor/test_dtensor_compile.py -k test_compile_redistribute_flattened_mesh Differential Revision: D91852906
|
No, it just means we need to work together to figure this out! I think your front end/backend proposal may help. We'll have to think through which APIs we want traced into the graph, and then figure out how to do it. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…der compile (pytorch#173873) Co-authored with claude. I noticed after pytorch#172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like `submesh = mesh[dim_names]` will try to construct a fresh DeviceMesh, and ends up calling `.item()` (full stacktrace of the error below). I'm not 100% familiar with the `DeviceMesh` API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call `.item` Stacktrace: ``` output = redistribute_local_tensor( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor optimized_transform_infos = _optimize_transform_infos( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos flattened, failure_reason = try_create_flattened(group) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout submesh = mesh[dim_names] ~~~~^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__ submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh res_submesh = DeviceMesh( ^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__ if self._layout.numel() != self.mesh.numel(): ^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh return self._get_mesh_tensor_from_full_mesh(full_mesh) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh return full_mesh[my_coords[0, 0]] ~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__ outs_unwrapped = func._op_dk( ^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call raise RuntimeError( torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar. It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call. ``` Pull Request resolved: pytorch#173873 Approved by: https://github.com/wconstab, https://github.com/fegin
|
@pytorchbot revert -m="Diff reverted internally" -c="ghfirst" This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).) |
|
@pytorchbot successfully started a revert job. Check the current status here. |
… dims under compile (#173873)" This reverts commit 2517bc4. Reverted #173873 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#173873 (comment)))
|
@bdhirsh your PR has been successfully reverted. |
|
sorry for the churn, please feel free to rebase and reland |
|
I will take care of this |
Summary: Reland of #172610 - includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev) Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. Differential Revision: D92540256
|
squashed into #174630 |
Summary: Reland of #172610 - includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev) Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. Differential Revision: D92540256
Reland of #172610: same code as previous land except: - includes #173873 (credit @bdhirsh) - includes #173790 (credit @IvanKobzarev) - includes #173436 - adds disable contextmanager + test Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. Differential Revision: D92540256 Pull Request resolved: #174630 Approved by: https://github.com/zpcore
… dims under compile (pytorch#173873)" This reverts commit 2517bc4. Reverted pytorch#173873 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#173873 (comment)))
… dims under compile (pytorch#173873)" This reverts commit 2517bc4. Reverted pytorch#173873 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#173873 (comment)))
Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like
submesh = mesh[dim_names]will try to construct a fresh DeviceMesh, and ends up calling.item()(full stacktrace of the error below).I'm not 100% familiar with the
DeviceMeshAPI's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call.itemStacktrace:
Stack from ghstack (oldest at bottom):