fix: In-place Residual Update for add_rmsnorm_fp4quant#2385
fix: In-place Residual Update for add_rmsnorm_fp4quant#2385yzh119 merged 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThe pull request modifies the Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly implements an in-place update for the add_rmsnorm_fp4quant kernel, which improves API consistency with other fused operations in FlashInfer. The changes are well-executed across the codebase, including necessary updates to the core kernel logic, documentation, benchmarks, and tests.
The kernel implementation for the cluster-mode path has been notably simplified by leveraging the new in-place update, which is a great improvement. The addition of the TestResidualInPlaceUpdate test suite is excellent, as it provides comprehensive validation for the new in-place behavior under various conditions, ensuring the correctness and robustness of this change. The overall quality of the code and tests is very high. I have no further comments.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2304-2313: Add missing decorator and fix in-place semantics for 3D non-contiguous residuals.Two issues:
Missing
@functools.cachedecorator: Per coding guidelines,flashinfer/**/*.pyfunctions need@functools.cachefor module-level caching. The@flashinfer_apidecorator is present but@functools.cacheis missing.In-place semantics broken for non-contiguous 3D residuals: For 3D inputs,
residual_2d = residual.view(B * S, H).contiguous()followed bytensor_api(..., residual_2d.contiguous(), ...)creates a copy when the original residual is non-contiguous. The kernel then modifies the copy instead of the original tensor, violating the documented in-place update behavior. The existing tests only cover contiguous tensors and don't catch this case.
🧹 Nitpick comments (1)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)
1315-1351: Address unused variable warnings from static analysis.The
y_fp4andblock_scalevariables are unpacked but never used in tests that only verify the residual in-place update. Per static analysis hints, prefix unused variables with underscores.♻️ Suggested fix for unused variable warnings
# Call kernel - residual should be modified in-place - y_fp4, block_scale = add_rmsnorm_fp4quant( + _y_fp4, _block_scale = add_rmsnorm_fp4quant( x, r, weight, eps=eps, block_size=block_size )Apply similar changes to other tests in
TestResidualInPlaceUpdatethat don't use the FP4 outputs (lines 1374, 1407, 1440, 1565, 1602).
|
[FAILED] Pipeline #42139148: 3/20 passed |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
The failed UT doesn't look relevant, let's go ahead and merge it.
…2395) <!-- .github/pull_request_template.md --> ## 📌 Description **Must be merged after #2385** This PR extends the `add_rmsnorm_fp4quant` API to support outputting both swizzled and unswizzled scale factors simultaneously. This is useful for scenarios where the quantized output needs to be consumed by both GEMMs (experts) and All-to-All without requiring a separate layout conversion pass. When `output_both_sf_layouts=True`, the function returns a 3-tuple `(y_fp4, block_scale_swizzled, block_scale_unswizzled)` instead of the standard 2-tuple. This flag overrides `is_sf_swizzled_layout` when set. **Changes Summary** File | Change -- | -- flashinfer/cute_dsl/add_rmsnorm_fp4quant.py | Added `output_both_sf_layouts` and `block_scale_unswizzled` parameters; updated kernel to write both SF layouts tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py | Added `TestOutputBothSFLayouts` test class with 10 test methods covering NVFP4/MXFP4, 2D/3D inputs, auto/pre-allocation, and large hidden sizes benchmarks/routines/norm.py | Added `--output_both_sf_layouts` flag; adjusted bandwidth calculation to account for 2× SF writes benchmarks/samples/sample_testlist.txt | Added example benchmark commands for dual SF output <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added --output_both_sf_layouts to emit both swizzled and unswizzled scale-factor layouts and to enable in-place residual updates for the FP4 quantization add+RMSNorm flow. * API now returns an extra unswizzled scale tensor when requested. * **Bug Fixes / UX** * rmsnorm_fp4quant surfaces a user warning when the new flag is unsupported and ignores it. * Verbose output and result reporting now include the new flag state. * **Tests** * Added extensive tests validating dual-layout outputs, dtypes, shapes, in-place residual semantics, and cross-path consistency. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
In-place Residual Update for add_rmsnorm_fp4quant
This PR modifies the add_rmsnorm_fp4quant CuTe-DSL kernel to update the residual tensor in-place, matching the behavior of other fused add+rmsnorm APIs in FlashInfer.
Before:
h = input + residual(computed internally, residual unchanged)y = RMSNorm(h) * weightyto FP4After:
residual = input + residual(in-place update)y = RMSNorm(residual) * weightyto FP4Changes
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyh = input + residualback to residual tensor; fix cluster mode paths to read from updated residualtests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pyhbefore kernel call; add tests for in-place residual verificationbenchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.pybenchmarks/routines/norm.py🔍 Related Issues
#2259
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
Enhancements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.