feat: add trtllm moe_allreduce_fusion#1108
Conversation
yzh119
left a comment
There was a problem hiding this comment.
Please remove all usage of packed/unpacked data type and use vec_t instead.
<!-- .github/pull_request_template.md --> ## 📌 Description Update the create_ipc_buffer implementation. Add unit tests for create_ipc_buffer. ## 🔍 Related Issues To help debug #1108. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
|
Next step: uncomment and complete the fused quantization. Maybe dependent on #1142 |
| "hidden_dim * sizeof(T) must be a multiple of kBytesPerAccess"); | ||
| if (params.residual_out && not params.norm_out && params.quant_out) { | ||
| // pattern1: AR+Add_RMS+Quant | ||
| // [m, 7168] bf16 allreduce_in, [m, 7168] bf16 residual_in |
There was a problem hiding this comment.
Do we have shape check somewhere?
| torch.cuda.synchronize() | ||
|
|
||
| # 6. Check correctness | ||
| tolerance = 8e-2 if dtype == torch.float16 else 8e-1 |
There was a problem hiding this comment.
8e-1 seems too large for me, can you give an example about the distribution of all_reduce_out?
| // [m, d] bf16 allreduce_in, [m, d] bf16 residual_in | ||
| // [m, d] bf16 residual_out, [m, d] bf16 norm_out, [m, d] fp4 quant_out | ||
|
|
||
| if (params.allreduce_out && params.residual_out && !params.norm_out && params.quant_out) { |
There was a problem hiding this comment.
the remaining part can still be dispatched:
DISPATCH_MOEREDUCTION_KERNEL(T, params, launch_with_pdl, ar, res, rms, quant)
…into trtllm-moear
yzh119
left a comment
There was a problem hiding this comment.
I'm good with the PR, thanks so much for your contribution!
Please refer to
9c229c9 on how to simplify the macro.
Some naming conventions (in flashinfer we usually write both runtime variable and constexpr in the macro definition, to make it easier to developer to track what are the new constexpr introduced in the macro):
#define DISPATCH_*(var, CONST_EXPR)
and we capitalize the CONST_EXPR.
📌 Description
We try to add moe_all_reduce_fusion kernels from trt-llm.
🔍 Related Issues
We split this PR into multiple ones. #1061
And all_reduce_fusion will be the next.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes