[DTensor] Strategy Validation (1/3): placement utilities and data structures#174798
[DTensor] Strategy Validation (1/3): placement utilities and data structures#174798wconstab wants to merge 7 commits intogh/wconstab/528/basefrom
Conversation
…tructures [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174798
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (11 Unrelated Failures)As of commit 81f2bd6 with merge base c24b6a2 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… and data structures" [ghstack-poisoned]
… and data structures" [ghstack-poisoned]
… and data structures" Adds the foundation layer for a DTensor sharding rule validation tool. This commit provides placement parsing/normalization (handling trivial shards on size-1 dims), equivalence checking, placement enumeration for inputs and outputs, and pytree-based tensor extraction from SampleInput. Read this commit first to understand the core data types (PlacementCombination, Discrepancy, ComparisonStats) that the rest of the stack builds on. Authored with Claude. [ghstack-poisoned]
| Check if two placements are equivalent for a given tensor shape. | ||
|
|
||
| Shard(dim) is equivalent to Replicate() when shape[dim] == 1, because | ||
| sharding a size-1 dimension produces the same result as replicating. |
There was a problem hiding this comment.
probably this is contextual from later PRs usages, but why is this true? and could we add a comment explaining
There was a problem hiding this comment.
this is total nonsense. haha. let me see if claude can figure out its own code
There was a problem hiding this comment.
claude's explanation:
The equivalence is a validator noise reduction heuristic: when brute-force ground truth testing finds that Shard(0) on a [1, 4] tensor validates as
correct (because tensor_split produces a [1,4] chunk and a [0,4] chunk, the op runs on each, and concatenation recovers the original), this is
"technically correct" but uninteresting. If DTensor's rule says Replicate() for that input, we don't want to report it as a "missing rule."
but also, this function appears to be no longer needed after having added normalize_placement and is_trivial_shard later on, so i am trying to remove it
…nd data structures" Adds the foundation layer for a DTensor sharding rule validation tool. This commit provides placement parsing/normalization (handling trivial shards on size-1 dims), equivalence checking, placement enumeration for inputs and outputs, and pytree-based tensor extraction from SampleInput. Read this commit first to understand the core data types (PlacementCombination, Discrepancy, ComparisonStats) that the rest of the stack builds on. Authored with Claude. [ghstack-poisoned]
|
|
||
|
|
||
| def has_equivalent_rule( | ||
| combo_key: tuple, |
There was a problem hiding this comment.
This one is tuple[str, str]. Shall we complete those type hints, including other functions for easy review?
There was a problem hiding this comment.
deleted has_equivalent_rule. checking for any other missing hints.
| 1. Running operators on full tensors to get ground truth | ||
| 2. Simulating sharding with various placement combinations | ||
| 3. Comparing redistributed outputs against ground truth | ||
| 4. Reporting incorrect rules (DTensor claims valid but wrong) and |
There was a problem hiding this comment.
Does this mean the DTensor OP test can pass the xfail, but in fact the op test is incomplete and the strategy is incorrect?
There was a problem hiding this comment.
i don't know if the dtensor op test was even run on this case, for example, it may not exhaustively cover all placement combos. All this means is, this util has a way find incorrect dtensor rules whether they are tested or not tested in main.
There was a problem hiding this comment.
What does "DTensor claims valid" mean here?
There was a problem hiding this comment.
any strategy dtensor has registered. so if we run the dtensor strategy_fn with these inputs, it gives us back a list of possible strategies. Then we check if we can prove any of those strategies are numerically incorreect.
|
|
||
| def is_trivial_shard(p, shape: tuple[int, ...]) -> bool: | ||
| """Check if placement is a Shard on a size-1 dimension (equivalent to Replicate).""" | ||
| return isinstance(p, Shard) and p.dim < len(shape) and shape[p.dim] == 1 |
There was a problem hiding this comment.
I may miss it somewhere, but you also normalize P(max) --> R right?
There was a problem hiding this comment.
No, we don't normalize Partial placements. The trivial-shard normalization is based on static shape information: Shard(dim) on a size-1 dim always
produces the same validation result as Replicate regardless of tensor values (rank 0 gets all the data, rank 1's empty computation is vacuous). There's
no analogous static condition for Partial — a Partial placement's validity depends on the op and the actual tensor values, not just the shape.
…nd data structures" Adds the foundation layer for a DTensor sharding rule validation tool. This commit provides placement parsing/normalization (handling trivial shards on size-1 dims), equivalence checking, placement enumeration for inputs and outputs, and pytree-based tensor extraction from SampleInput. Read this commit first to understand the core data types (PlacementCombination, Discrepancy, ComparisonStats) that the rest of the stack builds on. Authored with Claude. [ghstack-poisoned]
…nd data structures" Adds the foundation layer for a DTensor sharding rule validation tool. This commit provides placement parsing/normalization (handling trivial shards on size-1 dims), equivalence checking, placement enumeration for inputs and outputs, and pytree-based tensor extraction from SampleInput. Read this commit first to understand the core data types (PlacementCombination, Discrepancy, ComparisonStats) that the rest of the stack builds on. Authored with Claude. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ation engine (#174799) Adds the validation engine that tests whether a sharding rule is correct by simulating distributed execution on a single machine using LocalTensor. For each placement combination, it creates local tensors that would reduce to the original (e.g., for P(sum), splits values across ranks so they sum back), runs the op on those local tensors, wraps the output as a DTensor, redistributes to Replicate, and compares against ground truth. The main challenge is avoiding false positives where a rule appears valid on a specific input but is actually incorrect. Several techniques are used: - Asymmetric splits for P(sum)/P(avg): instead of splitting evenly (tensor/2 per rank), uses a 60/40 ratio (varied by tensor index) so that ops which are not truly linear don't accidentally produce matching outputs. - Sign-varying offsets for P(sum)/P(avg): adds an offset that alternates sign across elements, so local tensors have mixed positive and negative values. Without this, proportional splits preserve the sign pattern of the original tensor, causing non-linear ops like abs to falsely validate P(sum)->P(sum). - Distinct magnitudes for P(min) vs P(max): P(min) offsets non-holding ranks by +0.7 while P(max) offsets by -1.3. Using different magnitudes prevents accidental cancellation when min and max placements appear in the same combination. - Alternating rank ownership for P(min)/P(max): a mask alternating by element (shifted by tensor index) controls which rank holds the true value vs the offset value. This prevents the degenerate case where one rank always holds all true values. Authored with Claude. Pull Request resolved: #174799 Approved by: https://github.com/weifengpy, https://github.com/pianpwk, https://github.com/zpcore ghstack dependencies: #174798
…tructures Adds the foundation layer for a DTensor sharding rule validation tool. This commit provides placement parsing/normalization (handling trivial shards on size-1 dims), equivalence checking, placement enumeration for inputs and outputs, and pytree-based tensor extraction from SampleInput. Read this commit first to understand the core data types (PlacementCombination, Discrepancy, ComparisonStats) that the rest of the stack builds on. Authored with Claude. ghstack-source-id: 19e4191 Pull Request resolved: pytorch/pytorch#174798
Stack from ghstack (oldest at bottom):
Adds the foundation layer for a DTensor sharding rule validation tool.
This commit provides placement parsing/normalization (handling trivial
shards on size-1 dims), equivalence checking, placement enumeration for
inputs and outputs, and pytree-based tensor extraction from SampleInput.
Read this commit first to understand the core data types
(PlacementCombination, Discrepancy, ComparisonStats) that the rest of
the stack builds on.
Authored with Claude.