Skip to content

[DTensor] Strategy Validation (1/3): placement utilities and data structures#174798

Closed
wconstab wants to merge 7 commits intogh/wconstab/528/basefrom
gh/wconstab/528/head
Closed

[DTensor] Strategy Validation (1/3): placement utilities and data structures#174798
wconstab wants to merge 7 commits intogh/wconstab/528/basefrom
gh/wconstab/528/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Feb 11, 2026

Stack from ghstack (oldest at bottom):

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174798

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (11 Unrelated Failures)

As of commit 81f2bd6 with merge base c24b6a2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

… and data structures"

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

[ghstack-poisoned]
@wconstab wconstab changed the title [DTensor] Add sharding rule validator: placement utilities and data structures [DTensor] Strategy Validation (1/3): placement utilities and data structures Feb 11, 2026
@wconstab wconstab requested a review from pianpwk February 12, 2026 00:33
Check if two placements are equivalent for a given tensor shape.

Shard(dim) is equivalent to Replicate() when shape[dim] == 1, because
sharding a size-1 dimension produces the same result as replicating.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this is contextual from later PRs usages, but why is this true? and could we add a comment explaining

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is total nonsense. haha. let me see if claude can figure out its own code

Copy link
Copy Markdown
Contributor Author

@wconstab wconstab Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude's explanation:
The equivalence is a validator noise reduction heuristic: when brute-force ground truth testing finds that Shard(0) on a [1, 4] tensor validates as
correct (because tensor_split produces a [1,4] chunk and a [0,4] chunk, the op runs on each, and concatenation recovers the original), this is
"technically correct" but uninteresting. If DTensor's rule says Replicate() for that input, we don't want to report it as a "missing rule."

but also, this function appears to be no longer needed after having added normalize_placement and is_trivial_shard later on, so i am trying to remove it

Comment thread torch/distributed/tensor/_ops/strategy_validation.py Outdated
Comment thread test/distributed/tensor/test_strategy_validation.py Outdated
…nd data structures"

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

[ghstack-poisoned]


def has_equivalent_rule(
combo_key: tuple,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is tuple[str, str]. Shall we complete those type hints, including other functions for easy review?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted has_equivalent_rule. checking for any other missing hints.

1. Running operators on full tensors to get ground truth
2. Simulating sharding with various placement combinations
3. Comparing redistributed outputs against ground truth
4. Reporting incorrect rules (DTensor claims valid but wrong) and
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the DTensor OP test can pass the xfail, but in fact the op test is incomplete and the strategy is incorrect?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know if the dtensor op test was even run on this case, for example, it may not exhaustively cover all placement combos. All this means is, this util has a way find incorrect dtensor rules whether they are tested or not tested in main.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "DTensor claims valid" mean here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any strategy dtensor has registered. so if we run the dtensor strategy_fn with these inputs, it gives us back a list of possible strategies. Then we check if we can prove any of those strategies are numerically incorreect.

Comment thread torch/distributed/tensor/_ops/strategy_validation.py
Comment thread torch/distributed/tensor/_ops/strategy_validation.py Outdated
Comment thread torch/distributed/tensor/_ops/strategy_validation.py
Comment thread torch/distributed/tensor/_ops/strategy_validation.py Outdated

def is_trivial_shard(p, shape: tuple[int, ...]) -> bool:
"""Check if placement is a Shard on a size-1 dimension (equivalent to Replicate)."""
return isinstance(p, Shard) and p.dim < len(shape) and shape[p.dim] == 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may miss it somewhere, but you also normalize P(max) --> R right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't normalize Partial placements. The trivial-shard normalization is based on static shape information: Shard(dim) on a size-1 dim always
produces the same validation result as Replicate regardless of tensor values (rank 0 gets all the data, rank 1's empty computation is vacuous). There's
no analogous static condition for Partial — a Partial placement's validity depends on the op and the actual tensor values, not just the shape.

…nd data structures"

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

[ghstack-poisoned]
…nd data structures"

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

[ghstack-poisoned]
@wconstab
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 13, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 13, 2026
…ation engine (#174799)

Adds the validation engine that tests whether a sharding rule is correct
by simulating distributed execution on a single machine using LocalTensor.

For each placement combination, it creates local tensors that would
reduce to the original (e.g., for P(sum), splits values across ranks so
they sum back), runs the op on those local tensors, wraps the output as
a DTensor, redistributes to Replicate, and compares against ground
truth.

The main challenge is avoiding false positives where a rule appears
valid on a specific input but is actually incorrect. Several techniques
are used:

- Asymmetric splits for P(sum)/P(avg): instead of splitting evenly
  (tensor/2 per rank), uses a 60/40 ratio (varied by tensor index) so
  that ops which are not truly linear don't accidentally produce
  matching outputs.

- Sign-varying offsets for P(sum)/P(avg): adds an offset that
  alternates sign across elements, so local tensors have mixed positive
  and negative values. Without this, proportional splits preserve the
  sign pattern of the original tensor, causing non-linear ops like abs
  to falsely validate P(sum)->P(sum).

- Distinct magnitudes for P(min) vs P(max): P(min) offsets non-holding
  ranks by +0.7 while P(max) offsets by -1.3. Using different
  magnitudes prevents accidental cancellation when min and max
  placements appear in the same combination.

- Alternating rank ownership for P(min)/P(max): a mask alternating by
  element (shifted by tensor index) controls which rank holds the true
  value vs the offset value. This prevents the degenerate case where
  one rank always holds all true values.

Authored with Claude.
Pull Request resolved: #174799
Approved by: https://github.com/weifengpy, https://github.com/pianpwk, https://github.com/zpcore
ghstack dependencies: #174798
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 13, 2026
…tructures

Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.

Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.

Authored with Claude.

ghstack-source-id: 19e4191
Pull Request resolved: pytorch/pytorch#174798
@github-actions github-actions Bot deleted the gh/wconstab/528/head branch March 16, 2026 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants