Skip to content

[dtensor] Add single_dim_strategy infrastructure for foreach/fused ops#177186

Closed
anshul-si wants to merge 21 commits intogh/anshul-si/106/basefrom
gh/anshul-si/106/head
Closed

[dtensor] Add single_dim_strategy infrastructure for foreach/fused ops#177186
anshul-si wants to merge 21 commits intogh/anshul-si/106/basefrom
gh/anshul-si/106/head

Conversation

@anshul-si
Copy link
Copy Markdown
Contributor

@anshul-si anshul-si commented Mar 11, 2026

Stack from ghstack (oldest at bottom):

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

  • single_dim_strategy.py: Add cross_mesh_indices field to
    _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
    _translate_list_op_schema with support for fused/amp_foreach prefixes,
    extend detection to match all list op prefixes
  • utils.py: Add cross_mesh_indices parameter to
    expand_to_full_mesh_op_strategy to preserve original mesh for
    cross-mesh inputs (e.g. state_steps in fused adam)
  • _sharding_prop.py: Return -1 for List[Tensor] return types in
    _get_expected_num_tensor_outputs (dynamic count), skip validation
    when expected is -1, handle output_specs=None in TupleStrategy
    processing for void-returning inplace list ops

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177186

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit af4621e with merge base 549e5c7 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
@anshul-si anshul-si added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 11, 2026
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
@anshul-si anshul-si requested review from pianpwk and wconstab and removed request for wconstab March 12, 2026 01:06
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

ghstack-source-id: 0ecfa88
Pull Request resolved: pytorch/pytorch#177186
Comment thread torch/distributed/tensor/_ops/single_dim_strategy.py Outdated
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
Comment thread torch/distributed/tensor/_ops/single_dim_strategy.py Outdated
Comment thread torch/distributed/tensor/_ops/utils.py Outdated
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
@wconstab
Copy link
Copy Markdown
Contributor

@claude review this pr

@claude
Copy link
Copy Markdown

claude Bot commented Mar 20, 2026

Claude finished @wconstab's task in 2m 13s —— View job


Review of PR #177186: [dtensor] Add single_dim_strategy infrastructure for foreach/fused ops

  • Read PR diff and changed files
  • Analyze single_dim_strategy.py changes
  • Analyze utils.py changes
  • Analyze _sharding_prop.py changes
  • Post review feedback

Overall this is well-structured work that extends the single_dim_strategy path to handle foreach, amp_foreach, and fused ops. The "Note [Multi-mesh args]" comment in utils.py is thorough and addresses wconstab's earlier feedback well. A few observations:


single_dim_strategy.py

1. Renaming from cross_mesh_indices to different_mesh_args is good. The current name is clearer about what it represents. The comment referencing "See Note [Multi-mesh args]" connects the dots nicely.

2. Mesh validation in _PreparedSingleDimStrategy.__init__ — silent skip for non-OpStrategy different_mesh_args (single_dim_strategy.py:331-335)

The remapping logic silently drops different_mesh_args indices that don't correspond to OpStrategy args (if i in schema_to_strategy). This is presumably intentional for cases where a cross-mesh arg might be an empty list at runtime, but it could mask bugs where someone passes an incorrect index. Consider adding a debug-level log or comment explaining when an index would legitimately be missing from schema_to_strategy.

3. _translate_list_op_schema — fused op path returns op_schema.kwargs_schema unmodified (single_dim_strategy.py:576-583)

For the fused op early-return path, the new OpSchema is constructed with kwargs_schema=op_schema.kwargs_schema (the original kwargs). In contrast, foreach ops go through the full translation that resolves a per-element op. This means fused ops carry the full kwargs dict through to the per-element strategy. This seems correct for fused ops (their kwargs are shared across all elements), but it's worth a brief comment explaining why kwargs aren't per-element-translated for fused ops.

4. Prefix stripping uses in rather than startswith (single_dim_strategy.py:571-574)

if "_foreach_" in op_name:
    base_op_name = op_name.replace("_foreach_", "")
elif "_amp_foreach_" in op_name:
    base_op_name = op_name.replace("_amp_foreach_", "")

Using in could match false positives if an op name ever contained _foreach_ in a non-prefix position. The outer detection at line 653 already uses startswith, so these inner checks are unlikely to hit false positives in practice, but using startswith here too would be more defensive and consistent.

5. removesuffix("_") for inplace detection (single_dim_strategy.py:586)

This correctly strips trailing _ for inplace foreach ops. Minor note: this relies on the convention that the base name without the prefix always ends with exactly one _ for inplace variants. That holds for all current foreach ops, so this is fine.


utils.py

6. The "Note [Multi-mesh args]" is well-written (utils.py:481-509)

This addresses the reviewer feedback about documenting the cross-mesh concept. The explanation of why Replicate is required and how it differs from element_mesh is clear.

7. Cross-mesh spec construction copies original_spec with its own mesh (utils.py:521-525)

input_specs[idx] = DTensorSpec(
    mesh=original_spec.mesh,
    placements=original_spec.placements,
    tensor_meta=original_spec.tensor_meta,
)

This reconstructs a DTensorSpec from original_spec but with explicit field assignment. Since original_spec already has all these fields set, this is effectively a copy. This is fine for clarity, but note that if DTensorSpec ever gains additional fields that matter for redistribution, this would need updating. A shallow_copy_with_tensor_meta-style method might be more robust long-term (though not necessary now).

8. Placement inside the strategy loop (utils.py:510-525)

The different_mesh_args processing happens inside the per-strategy-combination loop (for strategy_comb in strategy_combs). Since the original mesh doesn't change between combinations, this does redundant work on every iteration. However, the cost is negligible (iterating a small list), so this is not a real concern — just noting it for awareness.


_sharding_prop.py

9. Returning -1 from _get_expected_num_tensor_outputs for List[Tensor] (_sharding_prop.py:74-75)

The previous behavior returned 1 for List[Tensor], which was a known approximation. Returning -1 and skipping validation is a cleaner approach. One consideration: any downstream code that uses the return value of _get_expected_num_tensor_outputs arithmetically (e.g., as an array size) would break. I only see it used in _validate_tensor_meta_count, so this is safe.

10. output_specs is None guard in TupleStrategy processing (_sharding_prop.py:818-819)

if selected_strategy.output_specs is not None:
    out_spec_list.append(selected_strategy.output_spec)

This handles void-returning inplace list ops. The consequence is that out_spec_list could be empty, which flows to tuple(out_spec_list) if out_tensor_meta is not None else None at line 874. For inplace ops, out_tensor_meta should be None, so the output sharding will be None. This seems correct — but it's worth verifying that the caller handles output_specs=None correctly for TupleStrategy. If out_tensor_meta is None but some children do have output specs (a partially-void tuple), out_spec_list would be shorter than expected. Is that possible with the current set of ops?


Summary

The PR is well-organized, addresses reviewer feedback (renaming, documentation), and the code changes are consistent with each other. The main suggestions are:

  1. Consider logging or commenting on silently dropped different_mesh_args indices during remapping
  2. Use startswith instead of in for prefix checks to be consistent with the outer detection
  3. Add a brief comment in the fused-op early return path of _translate_list_op_schema explaining why kwargs pass through unmodified
  4. Verify the partially-void TupleStrategy scenario in _sharding_prop.py is handled correctly

None of these are blocking.


Comment thread torch/distributed/tensor/_sharding_prop.py
Comment thread torch/distributed/tensor/_sharding_prop.py Outdated

# Note [Multi-mesh args]
#
# Some ops accept args whose DTensor lives on a different DeviceMesh
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would rephrase this a bit more to emphasize that we noticed a special case exists in practice, not that it is a standard / encouraged design pattern.

and below, i would not say they must be replicate, i would say 'in these cases, the args are replicas, and we have no way to validate/enforce this so we just make the assumption' or something?

# args so the propagator does not try to redistribute them onto
# the compute mesh (which would fail or produce wrong results).
#
# This is distinct from the *element_mesh* handling in
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just call this 'foreach mesh handling' where tensors are free to have different meshes across the foreach list as long as they have matching meshes with their peers or something

@wconstab
Copy link
Copy Markdown
Contributor

lgtm overall, i think its worth fixing #4 from claude and doing something about the -1 return that i commented about, but you could do that later.

…ch/fused ops"

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops

[ghstack-poisoned]
@anshul-si
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 24, 2026
…7187)

Enable foreach and fused ops to use the single_dim_strategy path:

- Add foreach ops to existing category lists (binary_additive_ops,
  binary_mul_ops, etc.) and the pointwise_ops list
- Add fused ops to pointwise_ops list for unified registration
- Add _is_list_op() helper to detect foreach/fused/amp_foreach ops
- Modify _register_single_dim_pointwise to use needs_pytree for list
  ops and cross_mesh_indices for fused ops
- Remove separate register_single_dim_strategy loop for fused ops

Tests cover multi-tensor foreach lists, mixed placements across list
elements, same-mesh fused adam, and cross-mesh fused adam exercising
the cross_mesh_indices code path.
Pull Request resolved: #177187
Approved by: https://github.com/wconstab
ghstack dependencies: #177186
pytorchmergebot pushed a commit that referenced this pull request Mar 24, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: #177208
Approved by: https://github.com/Skylion007
ghstack dependencies: #177186, #177187
anshul-si added a commit that referenced this pull request Mar 31, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: #177208
Approved by: https://github.com/Skylion007
ghstack dependencies: #177186, #177187
ghstack-source-id: f005cbc
anshul-si added a commit that referenced this pull request Mar 31, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: #177208
Approved by: https://github.com/Skylion007
ghstack dependencies: #177186, #177187
ghstack-source-id: f4e0b82
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
pytorch#177186)

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops
Pull Request resolved: pytorch#177186
Approved by: https://github.com/wconstab
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…orch#177187)

Enable foreach and fused ops to use the single_dim_strategy path:

- Add foreach ops to existing category lists (binary_additive_ops,
  binary_mul_ops, etc.) and the pointwise_ops list
- Add fused ops to pointwise_ops list for unified registration
- Add _is_list_op() helper to detect foreach/fused/amp_foreach ops
- Modify _register_single_dim_pointwise to use needs_pytree for list
  ops and cross_mesh_indices for fused ops
- Remove separate register_single_dim_strategy loop for fused ops

Tests cover multi-tensor foreach lists, mixed placements across list
elements, same-mesh fused adam, and cross-mesh fused adam exercising
the cross_mesh_indices code path.
Pull Request resolved: pytorch#177187
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#177186
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: pytorch#177208
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#177186, pytorch#177187
anshul-si added a commit that referenced this pull request Mar 31, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: #177208
Approved by: https://github.com/Skylion007
ghstack dependencies: #177186, #177187
ghstack-source-id: 5a0cf25
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
pytorch#177186)

Add infrastructure support in the single_dim_strategy path for
list-based ops (foreach, fused, amp_foreach):

- single_dim_strategy.py: Add cross_mesh_indices field to
  _SingleDimStrategyInfo, rename _translate_foreach_op_schema to
  _translate_list_op_schema with support for fused/amp_foreach prefixes,
  extend detection to match all list op prefixes
- utils.py: Add cross_mesh_indices parameter to
  expand_to_full_mesh_op_strategy to preserve original mesh for
  cross-mesh inputs (e.g. state_steps in fused adam)
- _sharding_prop.py: Return -1 for List[Tensor] return types in
  _get_expected_num_tensor_outputs (dynamic count), skip validation
  when expected is -1, handle output_specs=None in TupleStrategy
  processing for void-returning inplace list ops
Pull Request resolved: pytorch#177186
Approved by: https://github.com/wconstab
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
…orch#177187)

Enable foreach and fused ops to use the single_dim_strategy path:

- Add foreach ops to existing category lists (binary_additive_ops,
  binary_mul_ops, etc.) and the pointwise_ops list
- Add fused ops to pointwise_ops list for unified registration
- Add _is_list_op() helper to detect foreach/fused/amp_foreach ops
- Modify _register_single_dim_pointwise to use needs_pytree for list
  ops and cross_mesh_indices for fused ops
- Remove separate register_single_dim_strategy loop for fused ops

Tests cover multi-tensor foreach lists, mixed placements across list
elements, same-mesh fused adam, and cross-mesh fused adam exercising
the cross_mesh_indices code path.
Pull Request resolved: pytorch#177187
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#177186
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
The old strategy registration path (register_op_strategy-based functions
and data structures) was superseded by the register_single_dim_strategy
infrastructure but never cleaned up. This removes ~550 lines of dead
code including: pointwise_strategy, linear_pointwise_strategy,
copy_strategy, common_pointwise_strategy, single_mesh_dim_* strategy
functions, list_pointwise_strategy, list_linear_pointwise_strategy,
for_each_ops/for_each_linearity_ops/fused_ops lists, and their
associated helper sets and unused imports.

Authored with Claude.
Pull Request resolved: pytorch#177208
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#177186, pytorch#177187
@github-actions github-actions Bot deleted the gh/anshul-si/106/head branch April 23, 2026 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/dtensor Run DTensor specific tests ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants