Skip to content

fix: In-place Residual Update for add_rmsnorm_fp4quant#2385

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
bkryu:add_rmsn_f4q_residual
Jan 22, 2026
Merged

fix: In-place Residual Update for add_rmsnorm_fp4quant#2385
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
bkryu:add_rmsn_f4q_residual

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Jan 20, 2026

📌 Description

In-place Residual Update for add_rmsnorm_fp4quant
This PR modifies the add_rmsnorm_fp4quant CuTe-DSL kernel to update the residual tensor in-place, matching the behavior of other fused add+rmsnorm APIs in FlashInfer.
Before:

  • h = input + residual (computed internally, residual unchanged)
  • y = RMSNorm(h) * weight
  • Quantize y to FP4

After:

  • residual = input + residual (in-place update)
  • y = RMSNorm(residual) * weight
  • Quantize y to FP4

Changes

File Change
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py Write h = input + residual back to residual tensor; fix cluster mode paths to read from updated residual
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py Compute reference h before kernel call; add tests for in-place residual verification
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py Add residual write to bandwidth calculation; fix sanity check
benchmarks/routines/norm.py Update docstring and comments for clarity

🔍 Related Issues

#2259

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Enhancements

    • In-place residual updates: Add + RMSNorm + FP4 Quant now modifies residual in-place for improved memory efficiency.
    • Input tensors remain read-only and unmodified.
    • Updated memory bandwidth reporting to reflect in-place behavior.
  • Tests

    • Expanded test coverage with comprehensive validation of in-place residual updates across multiple configurations and data layouts.

✏️ Tip: You can customize this high-level summary in your review settings.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the add_rmsnorm_fp4quant CuTe-DSL kernel to perform an in-place update of the residual tensor, aligning its behavior with other fused add+rmsnorm APIs in FlashInfer. This change optimizes memory usage by directly modifying the residual tensor with the sum of input and residual, then applying RMSNorm and FP4 quantization to this updated residual. The modifications span the core kernel logic, testing infrastructure, and performance benchmarks to ensure correctness and accurate representation of the new operational flow.

Highlights

  • In-place Residual Update: The add_rmsnorm_fp4quant CuTe-DSL kernel now updates the residual tensor in-place (residual = input + residual), aligning its behavior with other fused add+rmsnorm APIs in FlashInfer.
  • Optimized Memory Usage: This change avoids creating a new temporary tensor for h = input + residual, potentially reducing memory overhead and improving performance by directly modifying the existing residual tensor.
  • Updated Logic for RMSNorm: The RMSNorm calculation now directly uses the modified residual tensor, ensuring the correct sequence of operations: residual = input + residual, then y = RMSNorm(residual) * weight, followed by FP4 quantization.
  • Comprehensive Testing: New and modified tests have been added to rigorously verify the correctness of the in-place update across various tensor shapes, data types, and quantization formats (NVFP4, MXFP4), including scenarios with large hidden sizes, pre-allocated outputs, and swizzled layouts, and confirming that the input tensor remains unchanged.
  • Documentation and Benchmarking Updates: Relevant documentation, comments, and benchmark calculations have been adjusted to accurately reflect the new in-place behavior and its impact on memory bandwidth.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 20, 2026

📝 Walkthrough

Walkthrough

The pull request modifies the add_rmsnorm_fp4quant fused kernel to perform in-place residual updates (residual := residual + input) before applying RMSNorm and FP4 quantization. Bandwidth computations, kernel implementation across multiple code paths, test validation logic, and documentation are updated to reflect this behavioral change.

Changes

Cohort / File(s) Summary
Core Kernel Implementation
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
Refactored to treat residual as in-place container; added store mechanisms (copy_atom_store) to write intermediate h = x + r back to residual before RMSNorm and quantization. Updated both cluster and non-cluster code paths (NVFP4/MXFP4, swizzled/unswizzled) to consistently propagate in-place residual updates. Revised docstrings to reflect residual modification semantics.
Benchmark & Bandwidth Calculations
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py, benchmarks/routines/norm.py
Updated write_bytes computation to include in-place residual write-back alongside y_fp4 and block_scale. Modified sanity checks to precompute h_ref before kernel invocation to avoid in-place interference. Updated comments to document residual mutation and bandwidth estimates.
Test Suite Expansion
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
Introduced new TestResidualInPlaceUpdate suite with extensive parametrized tests covering 2D/3D inputs, MXFP4/NVFP4 formats, swizzled/unswizzled layouts, and large hidden sizes. Updated existing tests to compute reference values before kernel invocation and to clone residuals when multiple runs are needed, verifying in-place residual correctness and output fidelity across broad parameter space.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • kaixih
  • aleozlx
  • jimmyzho
  • jiahanc
  • cyx-6
  • kahyunnam

Poem

🐰 A residual hops in place,
No extra copies slow the race—
In-place updates make kernels quick,
RMSNorm and FP4 do the trick!
✨ Fused and clever, fast and neat! 🎉

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: converting the add_rmsnorm_fp4quant kernel to perform in-place residual updates, which matches the primary objective across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The PR description is comprehensive and follows the template structure with clear sections for description, related issues, and checklist.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu self-assigned this Jan 20, 2026
@bkryu bkryu moved this from Todo to In Progress in FlashInfer Roadmap Jan 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly implements an in-place update for the add_rmsnorm_fp4quant kernel, which improves API consistency with other fused operations in FlashInfer. The changes are well-executed across the codebase, including necessary updates to the core kernel logic, documentation, benchmarks, and tests.

The kernel implementation for the cluster-mode path has been notably simplified by leveraging the new in-place update, which is a great improvement. The addition of the TestResidualInPlaceUpdate test suite is excellent, as it provides comprehensive validation for the new in-place behavior under various conditions, ensuring the correctness and robustness of this change. The overall quality of the code and tests is very high. I have no further comments.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 20, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !251 has been created, and the CI pipeline #42139148 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2304-2313: Add missing decorator and fix in-place semantics for 3D non-contiguous residuals.

Two issues:

  1. Missing @functools.cache decorator: Per coding guidelines, flashinfer/**/*.py functions need @functools.cache for module-level caching. The @flashinfer_api decorator is present but @functools.cache is missing.

  2. In-place semantics broken for non-contiguous 3D residuals: For 3D inputs, residual_2d = residual.view(B * S, H).contiguous() followed by tensor_api(..., residual_2d.contiguous(), ...) creates a copy when the original residual is non-contiguous. The kernel then modifies the copy instead of the original tensor, violating the documented in-place update behavior. The existing tests only cover contiguous tensors and don't catch this case.

🧹 Nitpick comments (1)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)

1315-1351: Address unused variable warnings from static analysis.

The y_fp4 and block_scale variables are unpacked but never used in tests that only verify the residual in-place update. Per static analysis hints, prefix unused variables with underscores.

♻️ Suggested fix for unused variable warnings
         # Call kernel - residual should be modified in-place
-        y_fp4, block_scale = add_rmsnorm_fp4quant(
+        _y_fp4, _block_scale = add_rmsnorm_fp4quant(
             x, r, weight, eps=eps, block_size=block_size
         )

Apply similar changes to other tests in TestResidualInPlaceUpdate that don't use the FP4 outputs (lines 1374, 1407, 1440, 1565, 1602).

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #42139148: 3/20 passed

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 21, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !251 has been updated with latest changes, and the CI pipeline #42206776 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed UT doesn't look relevant, let's go ahead and merge it.

@yzh119 yzh119 merged commit 54afc1e into flashinfer-ai:main Jan 22, 2026
61 of 69 checks passed
@github-project-automation github-project-automation Bot moved this from In Progress to Done in FlashInfer Roadmap Jan 22, 2026
yzh119 pushed a commit that referenced this pull request Jan 23, 2026
…2395)

<!-- .github/pull_request_template.md -->

## 📌 Description

**Must be merged after #2385**

This PR extends the `add_rmsnorm_fp4quant` API to support outputting
both swizzled and unswizzled scale factors simultaneously. This is
useful for scenarios where the quantized output needs to be consumed by
both GEMMs (experts) and All-to-All without requiring a separate layout
conversion pass.

When `output_both_sf_layouts=True`, the function returns a 3-tuple
`(y_fp4, block_scale_swizzled, block_scale_unswizzled)` instead of the
standard 2-tuple. This flag overrides `is_sf_swizzled_layout` when set.

**Changes Summary**
File | Change
-- | --
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py |
Added `output_both_sf_layouts` and `block_scale_unswizzled` parameters;
updated kernel to write both SF layouts
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py |
Added `TestOutputBothSFLayouts` test class with 10 test methods covering
NVFP4/MXFP4, 2D/3D inputs, auto/pre-allocation, and large hidden sizes
benchmarks/routines/norm.py | Added `--output_both_sf_layouts` flag;
adjusted bandwidth calculation to account for 2× SF writes
benchmarks/samples/sample_testlist.txt | Added example benchmark
commands for dual SF output



<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added --output_both_sf_layouts to emit both swizzled and unswizzled
scale-factor layouts and to enable in-place residual updates for the FP4
quantization add+RMSNorm flow.
  * API now returns an extra unswizzled scale tensor when requested.

* **Bug Fixes / UX**
* rmsnorm_fp4quant surfaces a user warning when the new flag is
unsupported and ignores it.
  * Verbose output and result reporting now include the new flag state.

* **Tests**
* Added extensive tests validating dual-layout outputs, dtypes, shapes,
in-place residual semantics, and cross-path consistency.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@claude claude Bot mentioned this pull request Jan 23, 2026
5 tasks
@coderabbitai coderabbitai Bot mentioned this pull request Jan 27, 2026
5 tasks
@bkryu bkryu deleted the add_rmsn_f4q_residual branch January 29, 2026 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants