[DTensor] Optimize redistribute to use flattened mesh dims for consecutive reductions#171913
[DTensor] Optimize redistribute to use flattened mesh dims for consecutive reductions#171913ezyang wants to merge 1 commit intogh/ezyang/3233/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/171913
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 6590dfa with merge base 3d2e7de ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…utive reductions Authored with claude code When there are multiple reductions that need to occur on multiple mesh dims, we will issue multiple collectives per mesh dim. When we want to do a reduction on multiple contiguous mesh dims, and a flattened dim of those contiguous dims exists (e.g., we have already paid for initializing PGs for the flattened dim), then it would be better to do the reduction all in one go on the flattened mesh dim. The redistribute algorithm currently operates by proposing a sequence of collectives to perform. This change looks for multiple consecutive reductions, and greedily tests if they have a flattened mesh dim/PG. If they do, it replaces this plan with one that does the reduction all in one step. ghstack-source-id: 7c4441a Pull-Request: #171913
This PR needs a
|
| ) -> list[_TransformInfo]: | ||
| """ | ||
| Optimize transform_infos by merging consecutive all-reduce operations on | ||
| contiguous mesh dimensions when a flattened mesh/PG exists for those dimensions. |
There was a problem hiding this comment.
is it only possible to merge reductions on contiguous mesh dims?
further, if we had a mesh like [dp, pp, tp] and we sliced out spmd_mesh = parent[dp, tp] - the indices of dp, tp would be 0,1 and appear 'contiguous' to this code, but not actually be. is that a problem?
There was a problem hiding this comment.
if we can leverage the coalesce of the layout, that will solve this case @wconstab mentioned here.
There was a problem hiding this comment.
Trying to find a corner case where the contiguous assumption does not provide identity guarantee:
If we use graph-based DDP + FSDP + TP, for RMSNorm.weight we would have
- param (Replicate, Shard, Replicate)
- grad before reduction (Partial, Partial, Partial)
I guess reduction will still happen in order AR, RS, AR and result in DDP / TP ranks not having the same results.
There was a problem hiding this comment.
I am pretty sure contiguous only is sound, but not complete (as wconstab is mentioning above). Are you worried about unsoundness here too?
There was a problem hiding this comment.
I do worry about completeness. If a moderately complicated solution doesn't solve the problem, I would prefer we error out.
|
Also, I don't care AT ALL about the code here (entirely claude coded), so if someone wants to redo it from scratch or commandeer, I am not trying to lick the cookie. |
|
I did take over this PR and replace it with #172121 FYI. closing this one |
Stack from ghstack (oldest at bottom):
Authored with claude code
When there are multiple reductions that need to occur on multiple mesh
dims, we will issue multiple collectives per mesh dim. When we want to
do a reduction on multiple contiguous mesh dims, and a flattened dim of
those contiguous dims exists (e.g., we have already paid for
initializing PGs for the flattened dim), then it would be better to do
the reduction all in one go on the flattened mesh dim.
The redistribute algorithm currently operates by proposing a sequence of
collectives to perform. This change looks for multiple consecutive
reductions, and greedily tests if they have a flattened mesh dim/PG. If
they do, it replaces this plan with one that does the reduction all in
one step.