[DTensor] make debugmode print optimized transforminfos#173436
[DTensor] make debugmode print optimized transforminfos#173436wconstab wants to merge 9 commits intogh/wconstab/506/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173436
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 6 Unrelated FailuresAs of commit 1100b4f with merge base 7754b55 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
| shard_order_dict[src_dim].pop() | ||
| # Remove mesh dims in order (from shard_order_dict perspective) | ||
| for _ in mesh_dims_to_update: |
There was a problem hiding this comment.
Can we assert check for x = shard_order_dict[src_dim].pop(), we must have x in mesh_dims_to_update? Same to dst_dim_placement. Just to be safe that the optimized transforminfo is correct.
| cur_placement[transform_info.mesh_dim] = dst_dim_placement | ||
| # Add mesh dims in order | ||
| for mesh_dim in mesh_dims_to_update: | ||
| shard_order_dict[dst_dim].append(mesh_dim) |
There was a problem hiding this comment.
The order here is related to #172610 (comment). Let's address that one first.
There was a problem hiding this comment.
Nit: Suggest using something like '-->' (double dash) to show this is an optimized transforms.
[ghstack-poisoned]
[ghstack-poisoned]
| self.assertExpectedInline( | ||
| trace_str, | ||
| """S(0)[0]S(0)[1]_S(0, 3)->S(0)[0]S(0)[1]R->S(0)[0]RR->RRR->RS(0)[1]R->RS(0)[1]S(0)[2]""", # noqa: B950 | ||
| """S(0)[0]S(0)[1]_S(0, 3)->S(0)[0]S(0)[1]R->S(0)RR->RRR->RS(0)R->RS(0)[0]S(0)[1]""", # noqa: B950 |
There was a problem hiding this comment.
@zpcore do you buy this explanation?
● I see the issue. The old code had a bug where it didn't handle _StridedShard in the pop logic (since _StridedShard.is_shard() returns
False). This caused it to produce an incorrect shard_order, and the test's expected traces were written to match that incorrect
behavior.
Now that I've fixed the code to properly handle _StridedShard, the correct output differs from the expected. I need to update the
test's expected traces to match the correct behavior.
Let me run the test with EXPECTTEST_ACCEPT=1 to see what all the correct traces should be.
There was a problem hiding this comment.
I see the issue that I didn't add the string of order for StridedShard in the output...
The fix from Claude still have some missing order for _StridedShard.
There was a problem hiding this comment.
The fix from Claude still have some missing order for _StridedShard.
can you say more about this?
There was a problem hiding this comment.
oh- just that we don't have the [i] after _S in the string repr. fixing that.
[ghstack-poisoned]
[ghstack-poisoned]
zpcore
left a comment
There was a problem hiding this comment.
I think we can fix the _StridedShard order later, since _StridedShard is not that important now.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / win-vs2022-cpu-py3 / build Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#173436 Approved by: https://github.com/zpcore ghstack dependencies: pytorch#173593, pytorch#172610
|
@pytorchbot revert -m="Diff reverted internally" -c="ghfirst" This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).) |
|
@pytorchbot successfully started a revert job. Check the current status here. |
)" This reverts commit 47260be. Reverted #173436 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#173436 (comment)))
|
@wconstab your PR has been successfully reverted. |
|
relanding via #174630 |
Reland of #172610: same code as previous land except: - includes #173873 (credit @bdhirsh) - includes #173790 (credit @IvanKobzarev) - includes #173436 - adds disable contextmanager + test Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. Differential Revision: D92540256 Pull Request resolved: #174630 Approved by: https://github.com/zpcore
…rch#173436)" This reverts commit 47260be. Reverted pytorch#173436 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#173436 (comment)))
…rch#173436)" This reverts commit 47260be. Reverted pytorch#173436 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#173436 (comment)))
Stack from ghstack (oldest at bottom):