Skip to content

Conversation

@siddharth9820
Copy link
Contributor

@siddharth9820 siddharth9820 commented Jul 6, 2022

Note: This PR is in conjunction with this PR on the Megatron-Deepspeed repo.

This PR adds tensor parallelism for non-experts. This combined with ZeRO-2 allows us to scale to roughly 2x larger base models than ZeRO-2. When tensor parallelism is enabled only for non-experts, there are duplicate tokens at each gate. It is important to drop the duplicates before they reach the experts, otherwise we run into convergence issues. In the current implementation, we drop tokens right before the AlltoAll and gather them right after the AlltoAll. These calls are done in sharded_moe.py

Update: This PR now supports tensor parallelism for experts as well.

@siddharth9820
Copy link
Contributor Author

siddharth9820 commented Jul 6, 2022

Comparing loss curves with no tensor parallelism

image

@jeffra
Copy link
Collaborator

jeffra commented Jul 26, 2022

Let's add some functional unit tests to ensure the new code paths are triggered in our tests. This will ensure things are at least functionally working in the future. Would be great to have basic correctness unit tests as well but we can discuss that offline.

@siddharth9820
Copy link
Contributor Author

@jeffra I have added some tests in tests/unit/test_moe_tp.py. Can you please check if they are good enough for now?

@siddharth9820 siddharth9820 enabled auto-merge (squash) July 29, 2022 16:40
@siddharth9820 siddharth9820 disabled auto-merge July 29, 2022 16:41
@siddharth9820 siddharth9820 enabled auto-merge (squash) July 31, 2022 23:24
@siddharth9820 siddharth9820 disabled auto-merge July 31, 2022 23:24
@siddharth9820 siddharth9820 enabled auto-merge (squash) July 31, 2022 23:24
@siddharth9820 siddharth9820 merged commit 5fe9d61 into master Aug 1, 2022
@siddharth9820 siddharth9820 deleted the moe-tensor-parallelism branch August 2, 2022 18:10
@jerryli1981
Copy link

Comparing loss curves with no tensor parallelism

Hi, I tested the tensor parallelism for MoE, the loss curves still higher and then NAN after 600 millions of tokens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants