Significant performance improvement on MoE block of SwitchTransformer#30490
Significant performance improvement on MoE block of SwitchTransformer#30490ranggihwang wants to merge 0 commit intohuggingface:mainfrom
Conversation
younesbelkada
left a comment
There was a problem hiding this comment.
Nice work and great investigation ! Thanks for this ! Can you confirm the slow SwitchTransformers test pass?
src/transformers/models/switch_transformers/modeling_switch_transformers.py
Outdated
Show resolved
Hide resolved
|
Is there anything else that I need to do for this PR? |
younesbelkada
left a comment
There was a problem hiding this comment.
Hi @ranggihwang
Yes, please make sure to run make fixup so that the styling checks pass in our CI. For running the slow tests, can you run:
RUN_SLOW=1 pytest tests/models/switch_transformers/test_modeling_switch_transformers.py|
@younesbelkada I've noticed the same error with the original switch transformer code, so I assume it's not due to my changes. How can I resolve this issue? If you need the error log, I can attach it here. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Hey! Though we do span on all expert, the way the model is trained emphasize an even loading, meaning on average the experts should all be used no?
Could you share a bit more which model you are using, are you pretraining?
Let's revert changes related to linting that are not supposed to be here as well
|
@ArthurZucker For linting, is it okay for me to revert it to the previous version? Actually, this is the first time for me to contribute codes to HuggingFace, so I'm a bit confused about what to do next to get an acceptance for merging. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Actually a lot better!
idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()
to simplify.
And indeed, since decoding is only gonna activate at most 2 x batch, this is better.
Make sure to rebase on main and simply revert the unrelated changes! 🤗 thanks for the contribution
There was a problem hiding this comment.
| idx_mask = router_mask.transpose(1, 2) | |
| idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) | |
| idx_mask = router.reshape(batch*seq_len,num_experts).transpose(0,1) |
equivalent and more understandable
There was a problem hiding this comment.
add a hint like # batch * seq, num_expert also helps!
There was a problem hiding this comment.
| idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens |
There was a problem hiding this comment.
| idx_mask = idx_mask.sum(dim=2) | |
| idx_mask = idx_mask.sum(dim=1) |
23cf9e8 to
6bd511a
Compare
|
Oh, I think I've done something wrong. I just wanted to rebase and revert other unrelated changes but somehow it close the PR :( Can I open a new PR reflecting all the comments? @ArthurZucker |
Shouldn't the activated expert at most be batch size as it top-1 gating? @ArthurZucker |
|
AH it's top1, I was thinking of Mixtral which is top2 😉 |
What does this PR do?
This PR includes a performant implementation of
SwitchTransformersSparseMLPin the Google SwitchTransformer.In the current implementation of the SwitchTransformer, it spans all possible experts, including the inactive ones.
This results in serious performance degradation of the SwitchTransformer.
However, in my custom implementation of
SwitchTransformersSparseMLP, it only accesses and computes the active experts.Advantages
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.