[DTensor] Add sharding strategy for aten.squeeze.dims#173563
[DTensor] Add sharding strategy for aten.squeeze.dims#173563stmcgovern wants to merge 8 commits intopytorch:mainfrom
Conversation
This PR needs a
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173563
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit eeac88f with merge base d428a3f ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
looks correct to me. though, there are a couple trivial rules missing for P(max) -> R or P(min) -> R when we squeeze away a size-1 partial, can you also fill that in? |
8b8811a to
56c897e
Compare
wconstab
left a comment
There was a problem hiding this comment.
I think i can be convinced that it's a good plan to remap all the squeeze variants to .dims when we need to normalize dims to avoid nonzero dim cases. but it'd help me if you updated the PR desc explaining this a bit more clearly. Also, i didn't understand if you're fixing the FIXME before landing or if you're not able to for some reason.
988b42f to
9663d09
Compare
9663d09 to
0fab03a
Compare
1787361 to
6801a99
Compare
I looked into adding this, but I don't really understand why its needed by squeeze. The (recent) strategy validator does not expect/ask for it. was it showing as missing in earlier strategy validator implementations? @wconstab |
|
I think those P->R rules should be invalid, probably the validator is better now. @claude can you evaluate this PR overall and specifically say whether it has the complete and correct set of sharding rules for squeeze? |
|
Claude finished @wconstab's task in 4m 41s —— View job Evaluation: Squeeze Sharding Rules
OverviewThis PR introduces a two-level approach:
Sharding Rules — Complete and CorrectThe rules produced by this pipeline cover all placement types:
P(max/min) → R is NOT neededI agree with the author's assessment. The validator is correct to not expect Op variant coverageAll 6 ATen squeeze variants are covered:
Symbolic shape handlingThe implementation is conservative with symbolic shapes: One issue worth notingThe inplace squeeze_ special case in if op_call == aten.squeeze_.dim:
args[0]._spec = output_spec
return return_and_correct_aliasing(op_call, args, kwargs, args[0])This only matches In practice this is a narrow edge case (you'd need to call if op_call in (aten.squeeze_.dim, aten.squeeze_.default, aten.squeeze_.dims):VerdictThe sharding rules are complete and correct for all squeeze variants. The DimMap + global-singleton rewrite approach is sound and properly handles the previously-broken FIXME case. The test coverage ( |
f504b95 to
0d4fcec
Compare
Added this from claude review and fixed CI failures (unexpected batch norm successes). |
|
I think it makes sense to bug fix an existing issue with #175798 before the strategy changes in this PR here. |
wconstab
left a comment
There was a problem hiding this comment.
lgtm, thanks! i'll take a look at the other PR you referenced now
5304570 to
e203dc1
Compare
… FIXME Fix the long-standing FIXME in dim_squeeze where squeeze with dim=None could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). The fix canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level, using global shape to determine which dimensions are truly singleton. The core change is _adjust_squeeze_to_global_singletons in sharding_prop, which rewrites the op before dim_squeeze ever sees it. The dispatch layer uses an effective_op pattern to execute the rewritten op and appends any extra args (like the dims tuple) from the rewritten schema. Also registers the missing aten.squeeze_.default strategy.
e203dc1 to
4393ffb
Compare
batch_norm ops (_native_batch_norm_legit, native_batch_norm, nn.functional.batch_norm) now pass in eager, local, multi-threaded, and compiled modes thanks to squeeze.dims unblocking their decomposition path. They only fail in unbacked tests (torch.compile + mark_unbacked) due to symbolic shape limitations in DecompShardingStrategy. Move xfails from dtensor_fails_no_strategy to ops_unbacked_dtensor_dde. squeeze_copy errors in local crossref (shape mismatch) but passes in multi-threaded — xfail it inline for TestLocalDTensorOps only.
4393ffb to
cfb8978
Compare
The squeeze.dims strategy registration enables DecompShardingStrategy to propagate sharding through op decompositions that internally use squeeze.dims. Previously these ops failed because squeeze.dims had no sharding strategy, blocking the entire decomposition path. Remove 10 ops from dtensor_fails_no_strategy: nansum, adaptive_avg_pool1d/2d/3d, adaptive_max_pool1d/2d/3d, avg_pool1d/2d/3d.
Restore squeeze_copy xfail in dtensor_fails_no_strategy (no strategy exists for squeeze_copy). Gate effective_op behind needs_redistribute and use self._squeeze_inplace_ops set to avoid hot-path overhead. Add CommDebugMode assertions verifying squeeze ops are communication-free.
Since that PR needs more work, I think landing this current one is the first step. The CI now passes after multiple rounds of unexpected successes and getting the correct tests to skip/run for ops. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#173521 Fixes pytorch#166124 Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton. Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below. - Add test_squeeze_variants to test all squeeze variants with DTensor ~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler Pull Request resolved: pytorch#173563 Approved by: https://github.com/wconstab
Fixes pytorch#173521 Fixes pytorch#166124 Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton. Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below. - Add test_squeeze_variants to test all squeeze variants with DTensor ~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler Pull Request resolved: pytorch#173563 Approved by: https://github.com/wconstab
Fixes #173521
Fixes #166124
Extend
dim_squeezeto handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.
Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.
Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR #166862).That PR is/will be closed in favor of this approach that avoids a custom handler