Skip to content

Significant performance improvement on MoE block of SwitchTransformer#30490

Closed
ranggihwang wants to merge 0 commit intohuggingface:mainfrom
ranggihwang:google_switch_transformer
Closed

Significant performance improvement on MoE block of SwitchTransformer#30490
ranggihwang wants to merge 0 commit intohuggingface:mainfrom
ranggihwang:google_switch_transformer

Conversation

@ranggihwang
Copy link
Contributor

What does this PR do?

This PR includes a performant implementation of SwitchTransformersSparseMLP in the Google SwitchTransformer.
In the current implementation of the SwitchTransformer, it spans all possible experts, including the inactive ones.

for idx, expert in enumerate(self.experts.values()):
            token_indices = router_mask[:, :, idx].bool()
            next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

This results in serious performance degradation of the SwitchTransformer.

스크린샷 2024-04-26 오전 2 16 44 As shown in this figure, the current implementation of the SwitchTransformer spans inactive experts, unnecessarily increasing latency. 스크린샷 2024-04-26 오전 2 17 37 This issue can be particularly severe in models with a larger number of experts, as it needlessly spans more experts.

However, in my custom implementation of SwitchTransformersSparseMLP, it only accesses and computes the active experts.

Advantages

  • This can significantly reduce the latency of the SwitchTransformer and make the model more accessible to a broader range of users.
  • This change achieves greater latency reductions when expert parameters are offloaded to the CPU or SSD.
  • This change addresses the problem of increasing latency proportional to the number of experts.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Contributor

cc @ArthurZucker @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work and great investigation ! Thanks for this ! Can you confirm the slow SwitchTransformers test pass?

@ranggihwang
Copy link
Contributor Author

ranggihwang commented Apr 26, 2024

Is there anything else that I need to do for this PR?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ranggihwang
Yes, please make sure to run make fixup so that the styling checks pass in our CI. For running the slow tests, can you run:

RUN_SLOW=1 pytest tests/models/switch_transformers/test_modeling_switch_transformers.py

@ranggihwang
Copy link
Contributor Author

ranggihwang commented Apr 27, 2024

@younesbelkada
I've changed the coding style using make fixup, but I'm encountering an error when running this command:

RUN_SLOW=1 pytest tests/models/switch_transformers/test_modeling_switch_transformers.py

I've noticed the same error with the original switch transformer code, so I assume it's not due to my changes. How can I resolve this issue?

If you need the error log, I can attach it here.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Though we do span on all expert, the way the model is trained emphasize an even loading, meaning on average the experts should all be used no?
Could you share a bit more which model you are using, are you pretraining?

Let's revert changes related to linting that are not supposed to be here as well

@ranggihwang
Copy link
Contributor Author

@ArthurZucker
That's a great point.
In the SwitchTransformer model, all experts can be used in the training phase, but in the inference, only some of them are utilized and that's the point where inefficiency is raised that the original code missed.
I'm mainly talking about inference, however, my custom code also can be used for training because it is mathematically equivalent to the original code.

For linting, is it okay for me to revert it to the previous version? Actually, this is the first time for me to contribute codes to HuggingFace, so I'm a bit confused about what to do next to get an acceptance for merging.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually a lot better!

idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1) 
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()

to simplify.
And indeed, since decoding is only gonna activate at most 2 x batch, this is better.

Make sure to rebase on main and simply revert the unrelated changes! 🤗 thanks for the contribution

Comment on lines 304 to 305
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = router_mask.transpose(1, 2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
idx_mask = router.reshape(batch*seq_len,num_experts).transpose(0,1)

equivalent and more understandable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a hint like # batch * seq, num_expert also helps!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.sum(dim=1)

@ranggihwang ranggihwang closed this Jun 1, 2024
@ranggihwang ranggihwang force-pushed the google_switch_transformer branch from 23cf9e8 to 6bd511a Compare June 1, 2024 05:36
@ranggihwang
Copy link
Contributor Author

Oh, I think I've done something wrong. I just wanted to rebase and revert other unrelated changes but somehow it close the PR :( Can I open a new PR reflecting all the comments? @ArthurZucker

@ranggihwang
Copy link
Contributor Author

Actually a lot better!

idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1) 
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()

to simplify. And indeed, since decoding is only gonna activate at most 2 x batch, this is better.

Make sure to rebase on main and simply revert the unrelated changes! 🤗 thanks for the contribution

Shouldn't the activated expert at most be batch size as it top-1 gating? @ArthurZucker

@ArthurZucker
Copy link
Collaborator

AH it's top1, I was thinking of Mixtral which is top2 😉
Anyways no worries for the old PR vs new it happens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants