Skip to content

[DTensor] Enable Dijkstra search in sharding propagation#175999

Open
wconstab wants to merge 9 commits intogh/wconstab/551/basefrom
gh/wconstab/551/head
Open

[DTensor] Enable Dijkstra search in sharding propagation#175999
wconstab wants to merge 9 commits intogh/wconstab/551/basefrom
gh/wconstab/551/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Feb 27, 2026

Stack from ghstack (oldest at bottom):

Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175999

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures, 2 Unrelated Failures

As of commit 8700de5 with merge base 1011d3f (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Feb 27, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: da66f22
Pull Request resolved: #175999
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Feb 27, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: b105032
Pull Request resolved: #175999
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Feb 27, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: 949e0c9
Pull Request resolved: #175999
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
@wconstab wconstab added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-auto-revert
Copy link

@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert

This PR is attributed to have caused regression in:

Please investigate and fix the issues.
@claude Can you please read this revert comment, follow the links and read the errors, to then give a brief diagnostics on the cause of the error? If you judge the error to be legitimate reason for a revert, please provide brief guidance on how the author could fix it.

@claude
Copy link

claude bot commented Mar 10, 2026

Claude finished @pytorch-auto-revert[bot]'s task in 8m 26s —— View job


Revert Diagnostics for PR #175999

  • Read the PR changes and source code
  • Read the Dijkstra expansion implementation
  • Analyze the failing test patterns
  • Diagnose the root cause
  • Provide fix guidance

Failing Tests

  • test_dtensor_op_db_baddbmm_cpu_float32 (in TestLocalDTensorOps)
  • test_compiled_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (in TestCompiledDTensorOps)

Is the Revert Legitimate?

Yes. The Dijkstra-based search path introduced by this PR has behavioral differences from the full O(S^N) expansion that cause real test failures.

Root Cause Analysis

The PR wires _dijkstra_expand_single_dim_strategy_to_mesh into the sharding propagation path at _sharding_prop.py:697-703, attempting it first and falling back to full expansion only when the Dijkstra function returns None. There are two classes of issues:

1. Hard failure: AssertionError when no valid strategy is found

The Dijkstra function at single_dim_strategy.py:936-940 raises AssertionError("No valid strategy found...") when the priority queue is exhausted without finding a matching state. By contrast, the full expansion path in expand_to_full_mesh_op_strategy (utils.py:559) returns OpStrategy(all_strategies) which may be an empty list — the downstream _select_min_cost_strategy would then crash on min() of an empty sequence, but the error handling is different. More importantly, the full expansion path filters strategies more gracefully through multiple checks (shardability, inplace, out-variant, mixed-partial, etc.) and the strategies that pass form a valid set. The Dijkstra search space exploration may miss valid strategies due to:

  • try_propagate shardability check is too strict (single_dim_strategy.py:389-393): It calls is_tensor_shardable(spec.tensor_meta.shape, spec) without passing allow_unbacked_sharding or checking allow_uneven_sharding. The full expansion uses is_tensor_shardable(..., allow_unbacked_sharding=allow_unbacked_sharding) and additionally has an allow_uneven_sharding fallback (utils.py:518-528). For baddbmm, which is registered with allow_unbacked_sharding=True, this mismatch could cause valid strategies to be incorrectly rejected.

  • Neighbor generation limitations (_get_neighbor_placements): The transition rules at single_dim_strategy.py:654-706 model one-shot placement changes. If the optimal strategy requires multi-step transitions that pass through states rejected by is_tensor_shardable, the search may fail to find it.

2. Indirect failures via decomposition

For poisson_nll_loss: this op is not registered with register_single_dim_strategy at all — it's not in the DTensor ops registry. It decomposes at the ATen level into constituent ops (mul, log, exp, etc.), some of which DO use single-dim strategy and would now go through the Dijkstra path. The failure would propagate up from a decomposed sub-op hitting the issues described above.

Suggested Fixes

  1. Pass allow_unbacked_sharding and handle allow_uneven_sharding in try_propagate: The _PreparedSingleDimStrategy already stores allow_unbacked_sharding and allow_uneven_sharding. The try_propagate method at line 389-393 should use them:
# In try_propagate, replace lines 389-394:
if not all(
    is_tensor_shardable(
        spec.tensor_meta.shape, spec,
        allow_unbacked_sharding=self.allow_unbacked_sharding
    )
    or (
        self.allow_uneven_sharding
        and input_spec.placements == spec.placements
    )
    for spec, input_spec in zip(arg_specs, input_specs)
    if spec.tensor_meta is not None
):
    return None
  1. Convert the hard AssertionError to return None: Instead of raising at line 936-940, return None to let the caller fall back to full expansion. This ensures the Dijkstra path is a pure optimization that gracefully degrades:
# Replace lines 936-940:
# No strategy found via Dijkstra; fall back to full expansion
return None
  1. Add the failing ops to test coverage: After fixing, verify that both baddbmm and poisson_nll_loss pass with the Dijkstra path enabled, and also that FORCE_FULLY_EXPAND_SINGLE_DIM = True produces the same results (to validate equivalence).

Fix #2 is the most conservative and least risky change — it turns Dijkstra misses into graceful fallbacks rather than hard failures. Fix #1 addresses the root cause of why valid strategies are being rejected.


@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
…5999)"

This reverts commit f13ba06.

Reverted #175999 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#175999 (comment)))
@pytorchmergebot
Copy link
Collaborator

@wconstab your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 10, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 11, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: 9867cd0
Pull Request resolved: #175999
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: e264f02
Pull Request resolved: pytorch/pytorch#175999
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: dbbb2c1
Pull Request resolved: pytorch/pytorch#175999
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 12, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: 31ac13d
Pull Request resolved: #175999
@wconstab
Copy link
Contributor Author

@pytorchbot rebase

@wconstab
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/23052976057

Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 13, 2026
Wire _dijkstra_expand_single_dim_strategy_to_mesh into the sharding
propagation path. For ops with single-dim strategies, try the PQ search
first; fall back to full O(S^N) expansion when it returns None
(StridedShard, symbolic shapes, or TupleStrategy inputs).

Authored with Claude.

ghstack-source-id: 1e047be
Pull Request resolved: #175999
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants