Skip to content

Conversation

@yetiansh
Copy link
Contributor

Previously, using both Tutel optimization and top-2 gating in MoE model training would fail. If we enable both Tutel and top-2, MoELayer would try to unpack top-2 gate's output at here, which would fail because top-2 gate does not produce these number of outputs.

Fix by checking the gate's type when constructing MoELayer.

@awan-10
Copy link
Contributor

awan-10 commented Jun 29, 2022

Thank you for the PR @yetiansh :) It looks good to me. Alex had added tutel support so let me tag him and ask for a quick review.

@alexandremuzio - can you please review this real quick?

@awan-10
Copy link
Contributor

awan-10 commented Jun 29, 2022

@yetiansh - can you please follow the guide here and update your PR? I see its failing in format checks.

https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md

@alexandremuzio
Copy link
Contributor

Looks good to me. Thanks!

@yetiansh
Copy link
Contributor Author

Thanks @alexandremuzio @awan-10. I've run the pre-commit and it looks like running format checking workflow needs your approval.

@yetiansh
Copy link
Contributor Author

Hi, is this PR still active? @awan-10 @alexandremuzio

@awan-10 awan-10 requested a review from samadejacobs as a code owner July 22, 2022 22:40
@awan-10
Copy link
Contributor

awan-10 commented Jul 22, 2022

Sorry for the delay in getting back @yetiansh. I approved this PR so tests can run. Will merge it as soon as tests pass. Thank you!

logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this in a if torch.distributed.get_rank() ==0:?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is possible. But I wonder should we also wrap other warnings and infos? For example, L480 and L482-483?

@awan-10 awan-10 enabled auto-merge (squash) July 26, 2022 18:05
@awan-10 awan-10 merged commit 31582d7 into deepspeedai:master Jul 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants