[dtensor] Add single_dim_strategy infrastructure for foreach/fused ops#177186
[dtensor] Add single_dim_strategy infrastructure for foreach/fused ops#177186anshul-si wants to merge 21 commits intogh/anshul-si/106/basefrom
Conversation
Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177186
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit af4621e with merge base 549e5c7 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops ghstack-source-id: 0ecfa88 Pull Request resolved: pytorch/pytorch#177186
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
|
@claude review this pr |
|
Claude finished @wconstab's task in 2m 13s —— View job Review of PR #177186: [dtensor] Add single_dim_strategy infrastructure for foreach/fused ops
Overall this is well-structured work that extends the single_dim_strategy path to handle foreach, amp_foreach, and fused ops. The "Note [Multi-mesh args]" comment in
|
|
|
||
| # Note [Multi-mesh args] | ||
| # | ||
| # Some ops accept args whose DTensor lives on a different DeviceMesh |
There was a problem hiding this comment.
i would rephrase this a bit more to emphasize that we noticed a special case exists in practice, not that it is a standard / encouraged design pattern.
and below, i would not say they must be replicate, i would say 'in these cases, the args are replicas, and we have no way to validate/enforce this so we just make the assumption' or something?
| # args so the propagator does not try to redistribute them onto | ||
| # the compute mesh (which would fail or produce wrong results). | ||
| # | ||
| # This is distinct from the *element_mesh* handling in |
There was a problem hiding this comment.
maybe just call this 'foreach mesh handling' where tensors are free to have different meshes across the foreach list as long as they have matching meshes with their peers or something
|
lgtm overall, i think its worth fixing #4 from claude and doing something about the -1 return that i commented about, but you could do that later. |
…ch/fused ops" Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…7187) Enable foreach and fused ops to use the single_dim_strategy path: - Add foreach ops to existing category lists (binary_additive_ops, binary_mul_ops, etc.) and the pointwise_ops list - Add fused ops to pointwise_ops list for unified registration - Add _is_list_op() helper to detect foreach/fused/amp_foreach ops - Modify _register_single_dim_pointwise to use needs_pytree for list ops and cross_mesh_indices for fused ops - Remove separate register_single_dim_strategy loop for fused ops Tests cover multi-tensor foreach lists, mixed placements across list elements, same-mesh fused adam, and cross-mesh fused adam exercising the cross_mesh_indices code path. Pull Request resolved: #177187 Approved by: https://github.com/wconstab ghstack dependencies: #177186
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: #177208 Approved by: https://github.com/Skylion007 ghstack dependencies: #177186, #177187
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: #177208 Approved by: https://github.com/Skylion007 ghstack dependencies: #177186, #177187 ghstack-source-id: f005cbc
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: #177208 Approved by: https://github.com/Skylion007 ghstack dependencies: #177186, #177187 ghstack-source-id: f4e0b82
pytorch#177186) Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops Pull Request resolved: pytorch#177186 Approved by: https://github.com/wconstab
…orch#177187) Enable foreach and fused ops to use the single_dim_strategy path: - Add foreach ops to existing category lists (binary_additive_ops, binary_mul_ops, etc.) and the pointwise_ops list - Add fused ops to pointwise_ops list for unified registration - Add _is_list_op() helper to detect foreach/fused/amp_foreach ops - Modify _register_single_dim_pointwise to use needs_pytree for list ops and cross_mesh_indices for fused ops - Remove separate register_single_dim_strategy loop for fused ops Tests cover multi-tensor foreach lists, mixed placements across list elements, same-mesh fused adam, and cross-mesh fused adam exercising the cross_mesh_indices code path. Pull Request resolved: pytorch#177187 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#177186
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: pytorch#177208 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#177186, pytorch#177187
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: #177208 Approved by: https://github.com/Skylion007 ghstack dependencies: #177186, #177187 ghstack-source-id: 5a0cf25
pytorch#177186) Add infrastructure support in the single_dim_strategy path for list-based ops (foreach, fused, amp_foreach): - single_dim_strategy.py: Add cross_mesh_indices field to _SingleDimStrategyInfo, rename _translate_foreach_op_schema to _translate_list_op_schema with support for fused/amp_foreach prefixes, extend detection to match all list op prefixes - utils.py: Add cross_mesh_indices parameter to expand_to_full_mesh_op_strategy to preserve original mesh for cross-mesh inputs (e.g. state_steps in fused adam) - _sharding_prop.py: Return -1 for List[Tensor] return types in _get_expected_num_tensor_outputs (dynamic count), skip validation when expected is -1, handle output_specs=None in TupleStrategy processing for void-returning inplace list ops Pull Request resolved: pytorch#177186 Approved by: https://github.com/wconstab
…orch#177187) Enable foreach and fused ops to use the single_dim_strategy path: - Add foreach ops to existing category lists (binary_additive_ops, binary_mul_ops, etc.) and the pointwise_ops list - Add fused ops to pointwise_ops list for unified registration - Add _is_list_op() helper to detect foreach/fused/amp_foreach ops - Modify _register_single_dim_pointwise to use needs_pytree for list ops and cross_mesh_indices for fused ops - Remove separate register_single_dim_strategy loop for fused ops Tests cover multi-tensor foreach lists, mixed placements across list elements, same-mesh fused adam, and cross-mesh fused adam exercising the cross_mesh_indices code path. Pull Request resolved: pytorch#177187 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#177186
The old strategy registration path (register_op_strategy-based functions and data structures) was superseded by the register_single_dim_strategy infrastructure but never cleaned up. This removes ~550 lines of dead code including: pointwise_strategy, linear_pointwise_strategy, copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy functions, list_pointwise_strategy, list_linear_pointwise_strategy, for_each_ops/for_each_linearity_ops/fused_ops lists, and their associated helper sets and unused imports. Authored with Claude. Pull Request resolved: pytorch#177208 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#177186, pytorch#177187
Stack from ghstack (oldest at bottom):
Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):
_SingleDimStrategyInfo, rename _translate_foreach_op_schema to
_translate_list_op_schema with support for fused/amp_foreach prefixes,
extend detection to match all list op prefixes
expand_to_full_mesh_op_strategy to preserve original mesh for
cross-mesh inputs (e.g. state_steps in fused adam)
_get_expected_num_tensor_outputs (dynamic count), skip validation
when expected is -1, handle output_specs=None in TupleStrategy
processing for void-returning inplace list ops