Skip to content

[DTensor] Fix double-shard validation in propagate_shape_and_sharding #177973

Closed
stmcgovern wants to merge 1 commit intopytorch:mainfrom
stmcgovern:fix/dtensor-view-ops-preexisting-bugs
Closed

[DTensor] Fix double-shard validation in propagate_shape_and_sharding #177973
stmcgovern wants to merge 1 commit intopytorch:mainfrom
stmcgovern:fix/dtensor-view-ops-preexisting-bugs

Conversation

@stmcgovern
Copy link
Copy Markdown
Collaborator

Fixes #177972.

shard.dim == in_dim compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.

Also remove trailing commas so that the error messages are treated as strings and not tuples.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177973

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit fcc9a47 with merge base d428a3f (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • Lint OSDC (unstable) / lintrunner-noclang-all / lint (gh)
    Error computing the main repository mapping: Encountered error while reading extension file 'requirements.bzl': no such package '@pip_deps//': no such package '@python3_10_x86_64-unknown-linux-gnu//': The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added the release notes: distributed (dtensor) release notes category label Mar 20, 2026
@stmcgovern stmcgovern requested review from tianyu-l and zpcore and removed request for zpcore March 20, 2026 16:28
1. `shard.dim == in_dim` compares int to InputDim dataclass (always
   False), making the [Shard(0), Shard(0)] submesh validation dead code.
   Incompatible double-sharding configs were silently accepted.
   Fixes pytorch#177972.

2. Trailing commas in two RuntimeError calls made str(e) print as a
   tuple instead of a concatenated string.
@stmcgovern stmcgovern force-pushed the fix/dtensor-view-ops-preexisting-bugs branch from 74f3862 to fcc9a47 Compare March 20, 2026 16:29
@stmcgovern stmcgovern requested a review from wconstab March 20, 2026 16:30
@wconstab
Copy link
Copy Markdown
Contributor

@claude review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Mar 20, 2026

Claude finished @wconstab's task in 2m 1s —— View job


PR Review: [DTensor] Fix double-shard validation in propagate_shape_and_sharding

  • Read the PR diff and changed files
  • Analyze the bug fix in _view_ops.py
  • Review the new test in test_view_ops.py
  • Post review feedback

Summary

This PR fixes two distinct bugs in propagate_shape_and_sharding:

  1. Type mismatch comparison (_view_ops.py:675-679): shard.dim == in_dim compared an int (Shard.dim) to an InputDim dataclass, which always evaluates to False. This made submesh_size always stay 1, turning the double-sharding validation (e.g. [Shard(0), Shard(0)]) into dead code. The fix correctly compares shard.dim == in_dim.input_dim.

  2. Trailing commas in raise statements (_view_ops.py:620-623, 668-671): RuntimeError("msg1", "msg2",) passes two separate args to the exception constructor rather than raising a single concatenated message. While the error would still be raised, its str() representation would look like ("msg1", "msg2") — a tuple representation. Removing the commas enables Python's implicit string literal concatenation, producing a clean single-string message.

Analysis

Bug fix (comparison): Correct and well-targeted. The InputDim dataclass at line 61-64 has an input_dim: int field, and Shard.dim is an int. The old shard.dim == in_dim was always False since int.__eq__(InputDim(...)) returns NotImplementedFalse. The fix to shard.dim == in_dim.input_dim compares int to int as intended.

Bug fix (trailing commas): Correct. This is a subtle Python gotcha — the trailing comma inside raise ExceptionType(...) creates a tuple of args rather than a single concatenated string.

Test (test_view_ops.py:251-274): The test covers both the positive case (compatible double-sharding) and the negative case (incompatible double-sharding that should now correctly raise). The test calls propagate_shape_and_sharding directly without @with_comms, which is appropriate since this is a pure function. The test arithmetic in the comments is correct:

  • Compatible: (24,)→(6,4) with mesh (2,3)submesh_size=6, 6%6==0
  • Incompatible: (12,)→(3,4) with mesh (2,3)submesh_size=6, 3%6!=0 → error ✓

Verdict

Clean, minimal, well-scoped fix. The changes are correct and the test coverage is adequate. LGTM.


@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 23, 2026
raise RuntimeError(
f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. ",
"It cannot be performed without redistribution, which is disallowed by the current operator.",
f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

if isinstance(shard, Shard | _StridedShard) and shard.dim == in_dim:
if (
isinstance(shard, Shard | _StridedShard)
and shard.dim == in_dim.input_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch.
you know, i think i've seen this kind of error before.
i wonder if its worth overriding InputDim.operator== to raise typeerrors. though i guess if int is on the LHS we'd enter int's operator== :/ not much we can do about that i guess.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a way to prevent this from re-occurring. I'll follow up.

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 25, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build

Details for Dev Infra team Raised by workflow job

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build

Details for Dev Infra team Raised by workflow job

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build

Details for Dev Infra team Raised by workflow job

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

fmassa added a commit to meta-pytorch/autoparallel that referenced this pull request Mar 26, 2026
PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding
dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32).

This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the
AssertionError and skip the iteration, since the replicated variant is already covered by another strategy.

Authored with Claude.
Copilot AI pushed a commit that referenced this pull request Mar 27, 2026
…#177973)

Fixes #177972.

 `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.

 Also remove trailing commas so that the error messages are treated as strings and not tuples.

Pull Request resolved: #177973
Approved by: https://github.com/wconstab

Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
fmassa added a commit to meta-pytorch/autoparallel that referenced this pull request Mar 27, 2026
* Work around upstream PyTorch view op assertion in propagation rules

PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding
dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32).

This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the
AssertionError and skip the iteration, since the replicated variant is already covered by another strategy.

Authored with Claude.

* Fix compatibility with latest PyTorch nightly by replacing removed pointwise_strategy

PyTorch removed pointwise_strategy from torch.distributed.tensor._ops._pointwise_ops in pytorch/pytorch#177208 as part of a dead code cleanup. Our custom native_layer_norm and native_layer_norm_backward rules imported
this function, causing an ImportError on PyTorch nightlies after 2025-03-23.

This replaces the import with a local _pointwise_strategy that recomposes the same behavior from PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy, _fill_single_dim_strategy_placeholders,
expand_to_full_mesh_op_strategy).

Authored with Claude.

* Bugfix

The issue was that _get_unique_placements from PyTorch's single-dim strategy path asserts each OpStrategy has exactly 1 strategy — but in autoparallel, input OpStrategy objects have multiple strategies (all possible placements). Replaced it
with an inline loop that collects unique placements from all strategies across all inputs.

* One more fix

The fix: _common_pointwise_single_dim_strategy produces num_outputs output entries per strategy row (e.g., 3 for native_layer_norm's out/mean/rstd). Since all pointwise outputs share the same placement, we collapse to a single output entry
before passing to expand_to_full_mesh_op_strategy with input_index=1. This way strategy.output_specs is a single DTensorSpec, matching the old pointwise_strategy behavior — the callers then construct the per-output specs (out, mean, rstd)
themselves.

* Fix stale tensor_meta when reusing sharding placements across mesh configurations in example_dcp

When reconstructing sharding placements from the serialized map for a different mesh, we were only updating .mesh on each DTensorSpec but leaving tensor_meta unchanged. Since the batch size scales with the mesh (bs = 8 * mesh.shape[0]), the
tensor shapes differ between the two phases (256 vs 16 on dim 0), so the specs carried stale shapes from the original configuration.

This was previously harmless — PyTorch's _maybe_unpad_tensor only used logical_dim_size (derived from tensor_meta.shape) to compute the unpad amount, and with even sharding that was a no-op. PyTorch PR
pytorch/pytorch#178210 added a torch._check(orig_size >= logical_dim_size) assertion that now catches the mismatch.

Fix: pull correct tensor_meta from the new AutoParallel instance's computed strategies, which reflect the actual shapes for the current mesh and batch size.

Authored with Claude.
stmcgovern added a commit to stmcgovern/pytorch that referenced this pull request Mar 27, 2026
InputDim is a dataclass wrapping an int (input_dim), which makes it
easy to accidentally compare shard.dim (an int) with an InputDim
instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim)
in propagate_shape_and_sharding was fixed in pytorch#177973. This
change adds a structural guard: InputDim.__eq__ raises TypeError when
compared with non-DimSpec types (like int), directing the developer to
use .input_dim instead. A matching __hash__ preserves the hash
contract. This prevents the bug class from recurring.
stmcgovern added a commit to stmcgovern/pytorch that referenced this pull request Mar 27, 2026
InputDim is a dataclass wrapping an int (input_dim), which makes it
easy to accidentally compare shard.dim (an int) with an InputDim
instance instead of InputDim.input_dim. This comparison silently
returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) survived over
3 years in propagate_shape_and_sharding before being fixed(pytorch#177973). This
change adds a structural guard: InputDim.__eq__ raises TypeError when
compared with non-DimSpec types (like int), directing the developer to
use .input_dim instead. A matching __hash__ (salted with the class to
avoid collisions with raw ints in dicts/sets) preserves the hash
contract. This prevents the bug class from recurring.
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…pytorch#177973)

Fixes pytorch#177972.

 `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.

 Also remove trailing commas so that the error messages are treated as strings and not tuples.

Pull Request resolved: pytorch#177973
Approved by: https://github.com/wconstab
pytorchmergebot pushed a commit that referenced this pull request Apr 1, 2026
…gs (#178599)

Follow on for #177972 .

InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in #177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab
Pull Request resolved: #178599
Approved by: https://github.com/wconstab
pytorch-bot Bot pushed a commit that referenced this pull request Apr 2, 2026
…#177973)

Fixes #177972.

 `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.

 Also remove trailing commas so that the error messages are treated as strings and not tuples.

Pull Request resolved: #177973
Approved by: https://github.com/wconstab
IvanKobzarev pushed a commit to IvanKobzarev/pytorch that referenced this pull request Apr 3, 2026
…gs (pytorch#178599)

Follow on for pytorch#177972 .

InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab
Pull Request resolved: pytorch#178599
Approved by: https://github.com/wconstab
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
…pytorch#177973)

Fixes pytorch#177972.

 `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.

 Also remove trailing commas so that the error messages are treated as strings and not tuples.

Pull Request resolved: pytorch#177973
Approved by: https://github.com/wconstab
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
…gs (pytorch#178599)

Follow on for pytorch#177972 .

InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab
Pull Request resolved: pytorch#178599
Approved by: https://github.com/wconstab
bobrenjc93 pushed a commit to bobrenjc93/pytorch that referenced this pull request Apr 9, 2026
…gs (pytorch#178599)

Follow on for pytorch#177972 .

InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab
Pull Request resolved: pytorch#178599
Approved by: https://github.com/wconstab
bobrenjc93 pushed a commit to bobrenjc93/pytorch that referenced this pull request Apr 10, 2026
…gs (pytorch#178599)

Follow on for pytorch#177972 .

InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always
 returns False, making downstream logic dead code.

The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab
Pull Request resolved: pytorch#178599
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (dtensor) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DTensor] Dead double-shard validation in propagate_shape_and_sharding

5 participants