[SwitchTransformer] Significant performance improvement on MoE blocks#31173
[SwitchTransformer] Significant performance improvement on MoE blocks#31173ArthurZucker merged 4 commits intohuggingface:mainfrom
SwitchTransformer] Significant performance improvement on MoE blocks#31173Conversation
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot ! Looks very good ! Can you make sure the styling checks pas make fixup && make fix-copies
|
Thanks @younesbelkada How do I need to this correctly? Would you please let me know how to do this? |
|
Hi @ranggihwang |
|
Shouldn't the styling check be done for the |
|
since gpt_san_japanese uses blocks that are copied from switch transformers, running |
|
@younesbelkada @ranggihwang gpt san has been deprecated, so we don't really want these changes to be propogated. I've just merged in #31153 which removes the |
|
Perfect thanks for the heads up @amyeroberts ! |
52b6c57 to
4965da6
Compare
|
@amyeroberts @younesbelkada I've just rebase it to main and commit it. Would you please check if it is correct? |
|
Thanks @ranggihwang ! Now styling checks are failing, can you run |
|
Okay, now |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I'll review this as I reviewed the previous PR, want to make sure the suggestions are all applied! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Could you apply the suggestion I did in the previous PR
| router_mask = router_mask.bool() | ||
| idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens | ||
| idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens) | ||
| idx_mask = idx_mask.sum(dim=2) | ||
| idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens | ||
| idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ | ||
| 0 | ||
| ].tolist() # length: number of "activated" expert / value: index |
There was a problem hiding this comment.
| router_mask = router_mask.bool() | |
| idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens | |
| idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens) | |
| idx_mask = idx_mask.sum(dim=2) | |
| idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens | |
| idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ | |
| 0 | |
| ].tolist() # length: number of "activated" expert / value: index | |
| idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1) | |
| idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist() |
- the comment about shapes! 🤗
There was a problem hiding this comment.
The batch_size, seq_len, and num_experts are not defined in the funciton.
So, I've defined it with the router_mask and reflected your suggestions.
Thank you @ArthurZucker !
|
@ArthurZucker @younesbelkada |
SwitchTransformer] Significant performance improvement on MoE blocks of SwitchTransformer
SwitchTransformer] Significant performance improvement on MoE blocks of SwitchTransformerSwitchTransformer] Significant performance improvement on MoE blocks
younesbelkada
left a comment
There was a problem hiding this comment.
Still LGTM ! Let's wait for @ArthurZucker 's final review!
|
Could this be propagated to the qwen code @ranggihwang ? I know that they have some variants with lots of experts! |
|
@ArthurZucker I think it can be adopted for many MoE models in HuggingFace not only qwen-moe but also for NLLB-MoE, Mixtral, etc. |
|
awesome! Then if you are interested feel free to open a PR and ping me! 🤗 |
What does this PR do?
This is an edited version of the previously closed PR (#30490)
This PR includes a performant implementation of
SwitchTransformersSparseMLPin the Google SwitchTransformer.In the current implementation of the SwitchTransformer, it spans all possible experts, including the inactive ones.
This results in serious performance degradation of the SwitchTransformer.
However, in my custom implementation of
SwitchTransformersSparseMLP, it only accesses and computes the active experts.Advantages
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker and @younesbelkada