[DTensor] Fix double-shard validation in propagate_shape_and_sharding #177973
[DTensor] Fix double-shard validation in propagate_shape_and_sharding #177973stmcgovern wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177973
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit fcc9a47 with merge base d428a3f ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1. `shard.dim == in_dim` compares int to InputDim dataclass (always False), making the [Shard(0), Shard(0)] submesh validation dead code. Incompatible double-sharding configs were silently accepted. Fixes pytorch#177972. 2. Trailing commas in two RuntimeError calls made str(e) print as a tuple instead of a concatenated string.
74f3862 to
fcc9a47
Compare
|
@claude review this PR |
|
Claude finished @wconstab's task in 2m 1s —— View job PR Review: [DTensor] Fix double-shard validation in propagate_shape_and_sharding
SummaryThis PR fixes two distinct bugs in
AnalysisBug fix (comparison): Correct and well-targeted. The Bug fix (trailing commas): Correct. This is a subtle Python gotcha — the trailing comma inside Test (
VerdictClean, minimal, well-scoped fix. The changes are correct and the test coverage is adequate. LGTM. |
| raise RuntimeError( | ||
| f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. ", | ||
| "It cannot be performed without redistribution, which is disallowed by the current operator.", | ||
| f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. " |
| if isinstance(shard, Shard | _StridedShard) and shard.dim == in_dim: | ||
| if ( | ||
| isinstance(shard, Shard | _StridedShard) | ||
| and shard.dim == in_dim.input_dim |
There was a problem hiding this comment.
nice catch.
you know, i think i've seen this kind of error before.
i wonder if its worth overriding InputDim.operator== to raise typeerrors. though i guess if int is on the LHS we'd enter int's operator== :/ not much we can do about that i guess.
There was a problem hiding this comment.
I think there is a way to prevent this from re-occurring. I'll follow up.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32). This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the AssertionError and skip the iteration, since the replicated variant is already covered by another strategy. Authored with Claude.
…#177973) Fixes #177972. `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding. Also remove trailing commas so that the error messages are treated as strings and not tuples. Pull Request resolved: #177973 Approved by: https://github.com/wconstab Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
* Work around upstream PyTorch view op assertion in propagation rules PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32). This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the AssertionError and skip the iteration, since the replicated variant is already covered by another strategy. Authored with Claude. * Fix compatibility with latest PyTorch nightly by replacing removed pointwise_strategy PyTorch removed pointwise_strategy from torch.distributed.tensor._ops._pointwise_ops in pytorch/pytorch#177208 as part of a dead code cleanup. Our custom native_layer_norm and native_layer_norm_backward rules imported this function, causing an ImportError on PyTorch nightlies after 2025-03-23. This replaces the import with a local _pointwise_strategy that recomposes the same behavior from PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy, _fill_single_dim_strategy_placeholders, expand_to_full_mesh_op_strategy). Authored with Claude. * Bugfix The issue was that _get_unique_placements from PyTorch's single-dim strategy path asserts each OpStrategy has exactly 1 strategy — but in autoparallel, input OpStrategy objects have multiple strategies (all possible placements). Replaced it with an inline loop that collects unique placements from all strategies across all inputs. * One more fix The fix: _common_pointwise_single_dim_strategy produces num_outputs output entries per strategy row (e.g., 3 for native_layer_norm's out/mean/rstd). Since all pointwise outputs share the same placement, we collapse to a single output entry before passing to expand_to_full_mesh_op_strategy with input_index=1. This way strategy.output_specs is a single DTensorSpec, matching the old pointwise_strategy behavior — the callers then construct the per-output specs (out, mean, rstd) themselves. * Fix stale tensor_meta when reusing sharding placements across mesh configurations in example_dcp When reconstructing sharding placements from the serialized map for a different mesh, we were only updating .mesh on each DTensorSpec but leaving tensor_meta unchanged. Since the batch size scales with the mesh (bs = 8 * mesh.shape[0]), the tensor shapes differ between the two phases (256 vs 16 on dim 0), so the specs carried stale shapes from the original configuration. This was previously harmless — PyTorch's _maybe_unpad_tensor only used logical_dim_size (derived from tensor_meta.shape) to compute the unpad amount, and with even sharding that was a no-op. PyTorch PR pytorch/pytorch#178210 added a torch._check(orig_size >= logical_dim_size) assertion that now catches the mismatch. Fix: pull correct tensor_meta from the new AutoParallel instance's computed strategies, which reflect the actual shapes for the current mesh and batch size. Authored with Claude.
InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring.
InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison silently returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) survived over 3 years in propagate_shape_and_sharding before being fixed(pytorch#177973). This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ (salted with the class to avoid collisions with raw ints in dicts/sets) preserves the hash contract. This prevents the bug class from recurring.
…pytorch#177973) Fixes pytorch#177972. `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding. Also remove trailing commas so that the error messages are treated as strings and not tuples. Pull Request resolved: pytorch#177973 Approved by: https://github.com/wconstab
…gs (#178599) Follow on for #177972 . InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in #177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab Pull Request resolved: #178599 Approved by: https://github.com/wconstab
…#177973) Fixes #177972. `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding. Also remove trailing commas so that the error messages are treated as strings and not tuples. Pull Request resolved: #177973 Approved by: https://github.com/wconstab
…gs (pytorch#178599) Follow on for pytorch#177972 . InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab Pull Request resolved: pytorch#178599 Approved by: https://github.com/wconstab
…pytorch#177973) Fixes pytorch#177972. `shard.dim == in_dim` compares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding. Also remove trailing commas so that the error messages are treated as strings and not tuples. Pull Request resolved: pytorch#177973 Approved by: https://github.com/wconstab
…gs (pytorch#178599) Follow on for pytorch#177972 . InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab Pull Request resolved: pytorch#178599 Approved by: https://github.com/wconstab
…gs (pytorch#178599) Follow on for pytorch#177972 . InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab Pull Request resolved: pytorch#178599 Approved by: https://github.com/wconstab
…gs (pytorch#178599) Follow on for pytorch#177972 . InputDim is a dataclass wrapping an int (input_dim), which makes it easy to accidentally compare shard.dim (an int) with an InputDim instance instead of InputDim.input_dim. This comparison always returns False, making downstream logic dead code. The original instance of this bug (shard.dim == in_dim) in propagate_shape_and_sharding was fixed in pytorch#177973. This change adds a structural guard: InputDim.__eq__ raises TypeError when compared with non-DimSpec types (like int), directing the developer to use .input_dim instead. A matching __hash__ preserves the hash contract. This prevents the bug class from recurring. @wconstab Pull Request resolved: pytorch#178599 Approved by: https://github.com/wconstab
Fixes #177972.
shard.dim == in_dimcompares an int (Shard.dim) to an InputDim dataclass, which is always False. This makes the [Shard(0), Shard(0)] double-sharding submesh_size calculation dead code in the Split handler. Fix: compare against in_dim.input_dim. This activates previously-dead validation that correctly rejects incompatible double-sharding configs (e.g. reshape (12,)→(3,4) with [Shard(0), Shard(0)] on mesh (2,3)) which were previously silently producing incorrect sharding.Also remove trailing commas so that the error messages are treated as strings and not tuples.