Skip to content

DTensor must generate flattened PGs to avoid allreduce result inconsistency across Replicate when reducing over multiple mesh dims #171916

@ezyang

Description

@ezyang

🐛 Describe the bug

Suppose you have a mesh (ddp=8, fsdp=5, tp=1) (yes, you read that right, not (5, 8)) where your NVLink domain is size 8. This results in a jagged assignment of nodes across NVLink domains:

Image

Let's say you have a DTensor which is (Partial, Partial, Replicate) and you want to redistribute it to (Replicate, Replicate, Replicate). We will do this in two allreduces (e.g., first on fsdp and then on ddp). In principle, it would be better to do it in one allreduce over fsdp+ddp, but our historical contract from DTensor is that we will never generate new flattened device mesh dims on the fly, as this has setup cost, and without the flattened ddp+fsdp mesh dim, we don't actually have the PGs that can do the all ranks all reduce.

However, there is a very big problem with doing it in two steps: you won't get the same result on all ranks. Specifically, NCCL doesn't guarantee reductions inside nvlink domains are bitwise identical to reductions across nvlink domains. When we do a reduction on FSDP first, there are only some FSDP groups that are entirely in one NVLink domain, and others that aren't. You end up with different results for the first reduction here.

Although it is nice to have DTensor not generate new PGs, we get incorrect results if we do not generate new PGs. Thus, this seems to decisively favor that DTensor should generate new PGs so that it can always arrange that all reductions happen in one go.

#171913 is a bandaid fix that gets rid of simple situations, but it only applies when the flattened mesh already exists. I concretely propose that we SHOULD generate the flattened mesh as well.

Versions

main

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions