Skip to content

feat: add trtllm moe_allreduce_fusion#1108

Merged
yzh119 merged 52 commits intoflashinfer-ai:mainfrom
yyihuang:trtllm-moear
Jun 17, 2025
Merged

feat: add trtllm moe_allreduce_fusion#1108
yzh119 merged 52 commits intoflashinfer-ai:mainfrom
yyihuang:trtllm-moear

Conversation

@yyihuang
Copy link
Copy Markdown
Collaborator

@yyihuang yyihuang commented Jun 2, 2025

📌 Description

We try to add moe_all_reduce_fusion kernels from trt-llm.

🔍 Related Issues

We split this PR into multiple ones. #1061
And all_reduce_fusion will be the next.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all usage of packed/unpacked data type and use vec_t instead.

Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
yyihuang added a commit that referenced this pull request Jun 9, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

Update the create_ipc_buffer implementation. Add unit tests for
create_ipc_buffer.

## 🔍 Related Issues

To help debug #1108.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
@yyihuang yyihuang requested a review from yzh119 June 16, 2025 02:31
@yyihuang
Copy link
Copy Markdown
Collaborator Author

Next step: uncomment and complete the fused quantization. Maybe dependent on #1142

Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
Comment thread include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh Outdated
"hidden_dim * sizeof(T) must be a multiple of kBytesPerAccess");
if (params.residual_out && not params.norm_out && params.quant_out) {
// pattern1: AR+Add_RMS+Quant
// [m, 7168] bf16 allreduce_in, [m, 7168] bf16 residual_in
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have shape check somewhere?

torch.cuda.synchronize()

# 6. Check correctness
tolerance = 8e-2 if dtype == torch.float16 else 8e-1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8e-1 seems too large for me, can you give an example about the distribution of all_reduce_out?

Comment thread tests/test_trtllm_moe_allreduce_fusion.py Outdated
@yyihuang yyihuang requested a review from yzh119 June 16, 2025 08:20
// [m, d] bf16 allreduce_in, [m, d] bf16 residual_in
// [m, d] bf16 residual_out, [m, d] bf16 norm_out, [m, d] fp4 quant_out

if (params.allreduce_out && params.residual_out && !params.norm_out && params.quant_out) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the remaining part can still be dispatched:

DISPATCH_MOEREDUCTION_KERNEL(T, params, launch_with_pdl, ar, res, rms, quant)

@yyihuang yyihuang requested a review from yzh119 June 17, 2025 03:39
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the PR, thanks so much for your contribution!

Please refer to
9c229c9 on how to simplify the macro.

Some naming conventions (in flashinfer we usually write both runtime variable and constexpr in the macro definition, to make it easier to developer to track what are the new constexpr introduced in the macro):

#define DISPATCH_*(var, CONST_EXPR)

and we capitalize the CONST_EXPR.

@yzh119 yzh119 merged commit 0a754ce into flashinfer-ai:main Jun 17, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants