Skip to content

[shard prop] single-dim rules for conv, uniform, scatter, index ops #179185

Closed
pianpwk wants to merge 6 commits intogh/pianpwk/129/basefrom
gh/pianpwk/129/head
Closed

[shard prop] single-dim rules for conv, uniform, scatter, index ops #179185
pianpwk wants to merge 6 commits intogh/pianpwk/129/basefrom
gh/pianpwk/129/head

Conversation

@pianpwk
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk commented Apr 2, 2026

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179185

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ 1 Pending, 11 Unrelated Failures

As of commit ae11809 with merge base 278936a (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added ciflow/dtensor Run DTensor specific tests ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests release notes: distributed (dtensor) release notes category labels Apr 2, 2026
@pianpwk pianpwk changed the title [shard prop] single-dim rules for LayerNorm, RMSNorm FW/BW [shard prop] single-dim rules for conv, uniform, scatter, index ops Apr 2, 2026
[ghstack-poisoned]
pianpwk added a commit that referenced this pull request Apr 2, 2026
pianpwk added a commit that referenced this pull request Apr 8, 2026
@anshul-si
Copy link
Copy Markdown
Contributor

@claude review this pr. consider edge cases that could break these strategies

@claude
Copy link
Copy Markdown

claude Bot commented Apr 9, 2026

Claude finished @anshul-si's task in 7m 58s —— View job


Review of PR #179185

  • Gather context and read changed files
  • Understand the single_dim_strategy infrastructure
  • Analyze conv ops sharding rules for edge cases
  • Analyze random ops sharding rules for edge cases
  • Analyze scatter/index ops sharding rules for edge cases
  • Review tests for coverage gaps
  • Post final review

Overall this PR adds correct and well-structured single-dim sharding strategies. The strategies are conservative (batch-dim-only for conv, skip-active-dims for scatter ops), which is the right approach. I found one substantive issue and a few minor observations.

Issue: Conv strategy doesn't cover channel-dim sharding

_conv_ops.py:164-178convolution_single_dim_strategy only provides a single rule: shard input/output on dim 0 (batch), replicate weight and bias.

This is correct for data parallelism, but the existing convolution_rules prop_rule (line 69: output_dim_map = input_spec.dim_map) propagates the full dim_map from input to output, meaning it already supports sharding on any dimension — including the channel dimension (dim 1) where input and output channels are both sharded identically.

For the single-dim strategy, sharding on spatial dims (dims ≥ 2) would require halo exchanges, so excluding those is correct. But sharding on the channel dim (dim 1) with groups == in_channels (depthwise conv) or groups == 1 where weight.shape[1] == in_channels would be valid — though it would require Partial("sum") on the output (since the inner product across channels is split). The current batch-only approach is conservative and safe, but leaving this as a note for potential future expansion.

Verdict: Not a bug, just a coverage gap. Fine for now.

Observation: convolution_backward_single_dim_strategyargs_schema[3] is SymInt[]?

_conv_ops.py:190bias_sizes = args_schema[3] corresponds to SymInt[]? bias_sizes in the native schema, not a TensorMeta. The is not None check is correct for detecting whether bias exists. Just noting this is a non-tensor arg being accessed positionally — the RuntimeSchemaInfo(3) declaration correctly indicates args from index 3 onward should be hashed for caching.

Edge case: select_scatter with 1-D input

_tensor_ops.py:601-627 — For a 1-D tensor with dim=0, the loop skips d=0 (since d == dim), returning an empty strategy list. The infrastructure adds the all-Replicate fallback automatically, so this is handled correctly. No issue here.

Edge case: diagonal_scatter dimension mapping verification

_tensor_ops.py:630-662 — I traced through the dim mapping for a concrete example (self shape [A,B,C,D,E], dim1=1, dim2=3):

  • src = self.diagonal(0,1,3) has shape [A, C, E, min(B,D)] (dims 1,3 removed; diagonal appended)
  • min_d=1, max_d=3
  • d=0: removed=0, src_dim=0 → maps to A ✓
  • d=2: removed=1, src_dim=1 → maps to C ✓
  • d=4: removed=2, src_dim=2 → maps to E ✓

The diagonal dimension itself (appended at the end of src) is correctly excluded from sharding since it mixes elements from both dim1 and dim2. The mapping is correct.

Edge case: diagonal_scatter with non-default dim1/dim2 not provided as args

_tensor_ops.py:644-645 — The code uses len(args_schema) > 3 and len(args_schema) > 4 to fall back to defaults dim1=0, dim2=1. This relies on the args being positional. The ATen schema diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) has them as positional with defaults, so PyTorch may or may not include them in args depending on whether defaults were explicitly passed. If dim1/dim2 are passed as kwargs instead, they'd be in kwargs_schema, not args_schema. However, at the ATen level, the dispatcher typically normalizes these to positional args, so this should be fine in practice. Worth noting as a potential fragility.

Nit: uniform_single_dim_strategy — 0-d tensor returns empty strategies

_random_ops.py:54-67 — For a 0-d (scalar) tensor, len(input_meta.shape) is 0, so no strategies are returned. The all-Replicate fallback handles this. This is correct — a scalar tensor can only be replicated.

Test changes

The test diff removes xfail("diagonal_scatter") and xfail("select_scatter") from dtensor_fails_no_strategy, which is correct since these ops now have strategies.

The additions of xfail("exponential") and 22 xfail("special.*") entries appear to be rebase-related additions for ops that were added to the op_db between the original PR base and the updated base. They maintain alphabetical ordering in the list. This seems like a clean rebase artifact.

Summary

The strategies are mathematically sound and handle edge cases properly through the infrastructure's all-Replicate fallback. The diagonal_scatter dimension mapping logic is the trickiest part and is implemented correctly. No blocking issues found.


Copy link
Copy Markdown
Contributor

@anshul-si anshul-si left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ty!



@register_single_dim_strategy([aten.uniform.default])
def uniform_single_dim_strategy(
Copy link
Copy Markdown
Contributor

@anshul-si anshul-si Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uniform single dim strategy will be merged in #178457.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, will delete

pianpwk added a commit that referenced this pull request Apr 9, 2026
@pianpwk pianpwk requested review from anshul-si and zpcore April 9, 2026 18:56
@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Apr 9, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 9, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-auto-revert
Copy link
Copy Markdown

@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert

This PR is attributed to have caused regression in:

Note: This revert was accelerated by the AI advisor: AI advisor: revert (confidence=0.95)

Please investigate and fix the issues.
@claude Can you please read this revert comment, follow the links and read the errors, to then give a brief diagnostics on the cause of the error? If you judge the error to be legitimate reason for a revert, please provide brief guidance on how the author could fix it.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 10, 2026

Claude finished @pytorch-auto-revert[bot]'s task in 3m 6s —— View job


Diagnosing CI Revert for PR #179185

  • Gather context from revert comment
  • Read CI failure logs
  • Analyze the root cause
  • Provide diagnosis and fix guidance

Verdict: The revert is legitimate. The fix is straightforward.

Root Cause

The failing test is:

TestUnbackedDTensorOpsCPU::test_unbacked_dtensor_op_db_select_scatter_cpu_float32

The PR added a select_scatter single-dim sharding strategy and removed select_scatter from the dtensor_fails_no_strategy xfail set in test/distributed/tensor/test_dtensor_ops.py. This is correct for the standard DTensor tests. However, dtensor_fails_no_strategy is also used to skip the unbacked symint test (test_unbacked_dtensor_op_db) at line 997:

ops_dde_xfail
| ops_unbacked_dtensor_dde
| dtensor_fails_no_strategy   # <-- select_scatter was here, shielding it from unbacked tests
| ops_unbacked_skip,

By removing select_scatter from dtensor_fails_no_strategy, the unbacked symint test now runs select_scatter for the first time. During sharding propagation, is_tensor_shardable() (torch/distributed/tensor/_ops/utils.py:222) evaluates shape[shard_dim] < num_shards with unbacked symbolic dimensions, triggering:

GuardOnDataDependentSymNode: Could not guard on data-dependent expression u0 < 4

This is a data-dependent error (DDE) under torch.compile(fullgraph=True) — the sharding propagation cannot guard on whether an unbacked symbolic dimension is large enough to shard.

Fix

Add select_scatter to the ops_unbacked_dtensor_dde xfail set in test/distributed/tensor/test_dtensor_ops.py (around line 775):

ops_unbacked_dtensor_dde = {
    ...
    xfail("select_scatter"),  # sharding propagation triggers DDE on unbacked dims
    ...
}

Similarly, check if diagonal_scatter (the other op the PR removed from dtensor_fails_no_strategy) has the same issue with unbacked symints. If it does, add it to ops_unbacked_dtensor_dde as well.

The diagonal_scatter strategy has the same pattern — it produces _ShardingPlaceholder entries that eventually go through is_tensor_shardable, so it likely needs the same xfail treatment.


@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pianpwk your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Apr 10, 2026
…ex ops (#179185)"

This reverts commit 6279179.

Reverted #179185 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#179185 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 10, 2026
pianpwk added a commit that referenced this pull request Apr 10, 2026
pianpwk added a commit that referenced this pull request Apr 13, 2026
@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Apr 13, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/dtensor Run DTensor specific tests ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants