🐛 Describe the bug
Suppose you have a mesh (ddp=8, fsdp=5, tp=1) (yes, you read that right, not (5, 8)) where your NVLink domain is size 8. This results in a jagged assignment of nodes across NVLink domains:
Let's say you have a DTensor which is (Partial, Partial, Replicate) and you want to redistribute it to (Replicate, Replicate, Replicate). We will do this in two allreduces (e.g., first on fsdp and then on ddp). In principle, it would be better to do it in one allreduce over fsdp+ddp, but our historical contract from DTensor is that we will never generate new flattened device mesh dims on the fly, as this has setup cost, and without the flattened ddp+fsdp mesh dim, we don't actually have the PGs that can do the all ranks all reduce.
However, there is a very big problem with doing it in two steps: you won't get the same result on all ranks. Specifically, NCCL doesn't guarantee reductions inside nvlink domains are bitwise identical to reductions across nvlink domains. When we do a reduction on FSDP first, there are only some FSDP groups that are entirely in one NVLink domain, and others that aren't. You end up with different results for the first reduction here.
Although it is nice to have DTensor not generate new PGs, we get incorrect results if we do not generate new PGs. Thus, this seems to decisively favor that DTensor should generate new PGs so that it can always arrange that all reductions happen in one go.
#171913 is a bandaid fix that gets rid of simple situations, but it only applies when the flattened mesh already exists. I concretely propose that we SHOULD generate the flattened mesh as well.
Versions
main
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci
🐛 Describe the bug
Suppose you have a mesh
(ddp=8, fsdp=5, tp=1)(yes, you read that right, not (5, 8)) where your NVLink domain is size 8. This results in a jagged assignment of nodes across NVLink domains:Let's say you have a DTensor which is
(Partial, Partial, Replicate)and you want to redistribute it to(Replicate, Replicate, Replicate). We will do this in two allreduces (e.g., first on fsdp and then on ddp). In principle, it would be better to do it in one allreduce over fsdp+ddp, but our historical contract from DTensor is that we will never generate new flattened device mesh dims on the fly, as this has setup cost, and without the flattened ddp+fsdp mesh dim, we don't actually have the PGs that can do the all ranks all reduce.However, there is a very big problem with doing it in two steps: you won't get the same result on all ranks. Specifically, NCCL doesn't guarantee reductions inside nvlink domains are bitwise identical to reductions across nvlink domains. When we do a reduction on FSDP first, there are only some FSDP groups that are entirely in one NVLink domain, and others that aren't. You end up with different results for the first reduction here.
Although it is nice to have DTensor not generate new PGs, we get incorrect results if we do not generate new PGs. Thus, this seems to decisively favor that DTensor should generate new PGs so that it can always arrange that all reductions happen in one go.
#171913 is a bandaid fix that gets rid of simple situations, but it only applies when the flattened mesh already exists. I concretely propose that we SHOULD generate the flattened mesh as well.
Versions
main
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci