Skip to content

Fuse writing KV buffer into rope kernel (part 1: sgl-kernel)#9077

Merged
zhyncs merged 4 commits intosgl-project:mainfrom
fzyzcjy:kv_rope_fuse_part1
Aug 12, 2025
Merged

Fuse writing KV buffer into rope kernel (part 1: sgl-kernel)#9077
zhyncs merged 4 commits intosgl-project:mainfrom
fzyzcjy:kv_rope_fuse_part1

Conversation

@fzyzcjy
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy commented Aug 11, 2025

Motivation

see #9014 for details

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fzyzcjy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization by fusing the Key-Value (KV) buffer writing operation directly into the Rotary Positional Embedding (RoPE) kernel within the sgl-kernel component. This initial part focuses on the sgl-kernel implementation, aiming to reduce overhead associated with separate kernel launches and memory transfers, thereby improving overall performance for large language models.

Highlights

  • Fused KV Buffer Writing in RoPE Kernel: The core change involves modifying the CUDA RoPE kernel (BatchQKApplyRotaryPosIdsCosSinCacheEnhanced) to directly write the computed Key and Value tensors into a provided KV buffer, eliminating the need for a subsequent separate memory copy or kernel.
  • New FusedSetKVBufferArg Dataclass: A Python dataclass FusedSetKVBufferArg is introduced to cleanly encapsulate the necessary arguments for this fused operation, including value tensor, key/value buffers, and cache location.
  • Updated Python API for RoPE: The apply_rope_with_cos_sin_cache_inplace Python function now accepts an optional fused_set_kv_buffer_arg, allowing users to leverage the new fused functionality.
  • Enhanced Benchmarking and Testing: New benchmarking utilities (bench_kineto) and a dedicated rotary embedding benchmark script have been added. The existing rotary embedding tests are extended to validate the correctness of the fused KV cache saving.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fuses the KV cache writing operation into the RoPE kernel, which is a solid optimization for performance. The changes are well-structured, introducing new CUDA kernels, C++ and Python bindings, along with corresponding benchmarks and tests. I've identified a few issues: a potential ZeroDivisionError in the new benchmark utility script, a redundant check in the C++ code, and a minor bug in the test logic. Overall, this is a good contribution.

Comment on lines +122 to +135
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code could raise a ZeroDivisionError on line 135 if a kernel name is not found in the profiler output, as total_num would be 0. This is only partially guarded by the assertion on lines 110-112, which does not run when with_multiple_kernels is True. It's safer to check if total_num is zero before the division and raise an informative error.

Suggested change
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
if total_num == 0:
raise ValueError(f"Kernel '{name}' not found in profiler output: {prof_lines}")
kernel_times.append(total_time / total_num)

x_flashinfer = getattr(pool_flashinfer, field)[0]
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
nonzero_ref = x_ref != 0
nonzero_flashinfer = x_ref != 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There seems to be a copy-paste error here. nonzero_flashinfer is being calculated based on x_ref instead of x_flashinfer. This makes the assertion on the next line less effective as it might compare nonzero_ref with a mask derived from x_ref again.

Suggested change
nonzero_flashinfer = x_ref != 0
nonzero_flashinfer = x_flashinfer != 0


const bool save_kv_cache = v.has_value();
if (save_kv_cache) {
TORCH_CHECK(v.has_value());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check TORCH_CHECK(v.has_value()); is redundant because it is inside an if (save_kv_cache) block, and save_kv_cache is defined as v.has_value() on line 38. This line can be safely removed.

@zhyncs zhyncs self-assigned this Aug 11, 2025
@fzyzcjy
Copy link
Copy Markdown
Collaborator Author

fzyzcjy commented Aug 12, 2025

ci pass now

image

@zhyncs zhyncs merged commit 9aea255 into sgl-project:main Aug 12, 2025
87 of 98 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants