Fuse writing KV buffer into rope kernel (part 1: sgl-kernel)#9077
Fuse writing KV buffer into rope kernel (part 1: sgl-kernel)#9077zhyncs merged 4 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @fzyzcjy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant optimization by fusing the Key-Value (KV) buffer writing operation directly into the Rotary Positional Embedding (RoPE) kernel within the sgl-kernel component. This initial part focuses on the sgl-kernel implementation, aiming to reduce overhead associated with separate kernel launches and memory transfers, thereby improving overall performance for large language models.
Highlights
- Fused KV Buffer Writing in RoPE Kernel: The core change involves modifying the CUDA RoPE kernel (
BatchQKApplyRotaryPosIdsCosSinCacheEnhanced) to directly write the computed Key and Value tensors into a provided KV buffer, eliminating the need for a subsequent separate memory copy or kernel. - New
FusedSetKVBufferArgDataclass: A Python dataclassFusedSetKVBufferArgis introduced to cleanly encapsulate the necessary arguments for this fused operation, including value tensor, key/value buffers, and cache location. - Updated Python API for RoPE: The
apply_rope_with_cos_sin_cache_inplacePython function now accepts an optionalfused_set_kv_buffer_arg, allowing users to leverage the new fused functionality. - Enhanced Benchmarking and Testing: New benchmarking utilities (
bench_kineto) and a dedicated rotary embedding benchmark script have been added. The existing rotary embedding tests are extended to validate the correctness of the fused KV cache saving.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request fuses the KV cache writing operation into the RoPE kernel, which is a solid optimization for performance. The changes are well-structured, introducing new CUDA kernels, C++ and Python bindings, along with corresponding benchmarks and tests. I've identified a few issues: a potential ZeroDivisionError in the new benchmark utility script, a redundant check in the C++ code, and a minor bug in the test logic. Overall, this is a good contribution.
| total_time = 0 | ||
| total_num = 0 | ||
| for line in prof_lines: | ||
| if name in line: | ||
| time_str = line.split()[-2] | ||
| num_str = line.split()[-1] | ||
| for unit, scale in units.items(): | ||
| if unit in time_str: | ||
| total_time += ( | ||
| float(time_str.replace(unit, "")) / scale * int(num_str) | ||
| ) | ||
| total_num += int(num_str) | ||
| break | ||
| kernel_times.append(total_time / total_num) |
There was a problem hiding this comment.
The code could raise a ZeroDivisionError on line 135 if a kernel name is not found in the profiler output, as total_num would be 0. This is only partially guarded by the assertion on lines 110-112, which does not run when with_multiple_kernels is True. It's safer to check if total_num is zero before the division and raise an informative error.
| total_time = 0 | |
| total_num = 0 | |
| for line in prof_lines: | |
| if name in line: | |
| time_str = line.split()[-2] | |
| num_str = line.split()[-1] | |
| for unit, scale in units.items(): | |
| if unit in time_str: | |
| total_time += ( | |
| float(time_str.replace(unit, "")) / scale * int(num_str) | |
| ) | |
| total_num += int(num_str) | |
| break | |
| kernel_times.append(total_time / total_num) | |
| total_time = 0 | |
| total_num = 0 | |
| for line in prof_lines: | |
| if name in line: | |
| time_str = line.split()[-2] | |
| num_str = line.split()[-1] | |
| for unit, scale in units.items(): | |
| if unit in time_str: | |
| total_time += ( | |
| float(time_str.replace(unit, "")) / scale * int(num_str) | |
| ) | |
| total_num += int(num_str) | |
| break | |
| if total_num == 0: | |
| raise ValueError(f"Kernel '{name}' not found in profiler output: {prof_lines}") | |
| kernel_times.append(total_time / total_num) |
| x_flashinfer = getattr(pool_flashinfer, field)[0] | ||
| torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2) | ||
| nonzero_ref = x_ref != 0 | ||
| nonzero_flashinfer = x_ref != 0 |
There was a problem hiding this comment.
There seems to be a copy-paste error here. nonzero_flashinfer is being calculated based on x_ref instead of x_flashinfer. This makes the assertion on the next line less effective as it might compare nonzero_ref with a mask derived from x_ref again.
| nonzero_flashinfer = x_ref != 0 | |
| nonzero_flashinfer = x_flashinfer != 0 |
|
|
||
| const bool save_kv_cache = v.has_value(); | ||
| if (save_kv_cache) { | ||
| TORCH_CHECK(v.has_value()); |

Motivation
see #9014 for details
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist