[DTensor] constant_pad_nd non-replicate strategy#175656
[DTensor] constant_pad_nd non-replicate strategy#175656pianpwk wants to merge 6 commits intogh/pianpwk/101/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175656
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 30cbb04 with merge base ea9fce2 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
| ] | ||
| ) | ||
| @register_single_dim_strategy(aten.constant_pad_nd.default) | ||
| def constant_pad_nd_single_dim_strategy( |
There was a problem hiding this comment.
from claude review:
● Found the issues. Here's my review:
---
1. BUG: Boolean logic error in padded_dims detection (line 675)
if not guard_or_false(pad[i * 2] == 0) and guard_or_false(pad[i * 2 + 1] == 0):
Due to Python operator precedence (not binds tighter than and), this evaluates as:
(not guard_or_false(pad_left == 0)) and (guard_or_false(pad_right == 0))
Which means: "left is non-zero AND right IS zero" — i.e. only detects left-only padding. It misses:
┌──────────┬───────────┬─────────────────────┬────────────────┐
│ pad_left │ pad_right │ Current result │ Correct result │ ├──────────┼───────────┼─────────────────────┼────────────────┤
│ 0 │ 0 │ not added (correct) │ not added │
├──────────┼───────────┼─────────────────────┼────────────────┤
│ 1 │ 0 │ added │ added │
├──────────┼───────────┼─────────────────────┼────────────────┤
│ 0 │ 1 │ not added (wrong) │ added │
├──────────┼───────────┼─────────────────────┼────────────────┤
│ 1 │ 1 │ not added (wrong) │ added │
└──────────┴───────────┴─────────────────────┴────────────────┘
The most common case ([1, 1] — pad both sides) is not detected! This means sharding is incorrectly allowed on padded
dimensions.
The fix should be:
if not (guard_or_false(pad[i * 2] == 0) and guard_or_false(pad[i * 2 + 1] == 0)):
Note: the unit test test_constant_pad_nd_bans_shard_on_padded_dim uses pad=[1,1], which would fail with this bug —
it asserts assertNotIn(1, shard_dims) but the buggy code would include dim 1.
2. Missing Partial rules — P(avg), P(max), P(min) work for ANY pad value
The code only generates [Partial(), Partial()] (i.e. P(sum)→P(sum)) when value==0. But the analysis is incomplete:
For any pad value v, consider rank-local computation pad(A_i, pad, v): - At non-padded positions: reduce(output_i[j]) = reduce(A_i[j]) — correct by definition of P(x)
- At padded positions: every rank writes the same constant v, so:
- avg(v, v, ..., v) = v — P(avg) works for any v
- max(v, v, ..., v) = v — P(max) works for any v
- min(v, v, ..., v) = v — P(min) works for any v
- sum(v, v, ..., v) = N*v ≠ v — P(sum) only works when v=0
So the complete rules are:
- Always (any value): P(avg)→P(avg), P(max)→P(max), P(min)→P(min)
- Only when value=0: P(sum)→P(sum)
The current code misses the first group entirely. This won't produce incorrect results (inputs with those placements
would fall back to Replicate), but it's a missed optimization — unnecessary all-gathers.
3. Minor: the value==0 comment is imprecise
Line 685 says "Linearity: pad(a+b, 0) == pad(a, 0) + pad(b, 0)". This is correct for P(sum), but the linearity
framing obscures the avg/max/min cases which hold for different reasons (idempotency of the reduce on identical
values, not linearity).
---
Summary: One correctness bug (#1 — wrong boolean logic lets sharding through on padded dims), one completeness gap
(#2 — missing P(avg/max/min) rules that hold unconditionally).
There was a problem hiding this comment.
also i dunno if sharding validator runs. i tried to run but my pytorch needs a rebuild. if you can run it and also ask claude to analyze any 'missing rule's thats a good thing to paste on these PRs by default
There was a problem hiding this comment.
Just updated the PR, I guess the min/max/avg rules with padding are interesting.
There's a lot of false positives when the output tensor is 0-sized (allclose reports True), I'll put that fix up in a followup:
(pytorch-3048) [pianpwk@devvm3048.dkl0 /data/users/pianpwk/pytorch (95594238)]$ python -m torch.distributed.tensor._ops.strategy_validation --op constant_pad_nd
Testing ops: aten.constant_pad_nd
Device: cuda, Dtype: torch.float32, World size: 2
[1/1] aten.constant_pad_nd — Samples: 35 (16 skipped), Combinations: 2280
----------------------------------------------------------------------
Possibly missing (valid in ground truth but no DTensor rule)
[aten.constant_pad_nd.default]
P(avg) -> P(max)
P(avg) -> P(min)
P(avg) -> P(sum)
P(avg) -> R
P(avg) -> S(0)
P(avg) -> S(1)
P(max) -> P(avg)
P(max) -> P(min)
P(max) -> P(sum)
P(max) -> R
P(max) -> S(0)
P(max) -> S(1)
P(min) -> P(avg)
P(min) -> P(max)
P(min) -> P(sum)
P(min) -> R
P(min) -> S(0)
P(min) -> S(1)
P(sum) -> P(avg)
P(sum) -> P(max)
P(sum) -> P(min)
P(sum) -> P(sum)
P(sum) -> R
P(sum) -> S(0)
P(sum) -> S(1)
S(0) -> S(0)
======================================================================
Summary
======================================================================
Op Correct Incorrect Missing Time
---------------------------------------------------------
aten.constant_pad_nd 149 0 26 38.2s
---------------------------------------------------------
Total 149 0 26 38.2s
Once that's added the rules are correct:
(pytorch-3048) [pianpwk@devvm3048.dkl0 /data/users/pianpwk/pytorch (95594238)]$ python -m torch.distributed.tensor._ops.strategy_validation --op constant_pad_nd
Testing ops: aten.constant_pad_nd
Device: cuda, Dtype: torch.float32, World size: 2
[1/1] aten.constant_pad_nd — Samples: 33 (18 skipped), Combinations: 2196
----------------------------------------------------------------------
======================================================================
Summary
======================================================================
Op Correct Incorrect Missing Time
---------------------------------------------------------
aten.constant_pad_nd 143 0 0 35.3s
---------------------------------------------------------
Total 143 0 0 35.3s
| tensor_metas = tuple( | ||
| TensorMeta(shape=t.shape, stride=t.stride(), dtype=t.dtype) for _, t in tensors | ||
| ) | ||
| args_meta = tensor_metas + non_tensor_args |
There was a problem hiding this comment.
seems we need these for padding amounts, value
There was a problem hiding this comment.
yea, take a look at #175821 too- not sure if it helps to land mine first and remove this, or land yours first and rebase mine.
There was a problem hiding this comment.
ah, I think yours is closer to landing
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (distributed, 1, 3, linux.rocm.gpu.gfx950.4) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: inductor / inductor-cpu-test / test (cpu_inductor_torchbench, 1, 2, linux.2xlarge.amx, unstable), inductor / unit-test / inductor-test / test (inductor, 1, 2, linux.g5.4xlarge.nvidia.gpu), trunk / linux-jammy-rocm-py3.10 / test (distributed, 1, 3, linux.rocm.gpu.gfx950.4) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630 The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy Pull Request resolved: pytorch#175656 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#175776
…_layer_norm, and native_layer_norm_backward These three rules were carried as local overrides in autoparallel while upstream PyTorch lacked proper handling: - constant_pad_nd: non-replicate strategy filtering on padded dims (upstreamed in pytorch/pytorch#175656) - native_layer_norm forward: correct per-output shapes and contiguous strides (upstreamed in pytorch/pytorch#175652) - native_layer_norm backward: contiguous stride handling for grad_input (upstreamed in a companion PR to pytorch/pytorch) With all three fixes now in upstream PyTorch, the overrides can be removed and autoparallel defers to the upstream register_op_strategy implementations. Authored with Claude.
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630 The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy Pull Request resolved: pytorch#175656 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#175776
Stack from ghstack (oldest at bottom):
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630
The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy