Skip to content

[4/4] Introduce CachedKernel to reduce CSGMV kernel launch overheads #10704

Closed
lifuhuang wants to merge 61 commits intomainfrom
lifu/cached-kernel
Closed

[4/4] Introduce CachedKernel to reduce CSGMV kernel launch overheads #10704
lifuhuang wants to merge 61 commits intomainfrom
lifu/cached-kernel

Conversation

@lifuhuang
Copy link
Copy Markdown
Collaborator

Motivation

TODO

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

lifuhuang and others added 30 commits September 18, 2025 03:50
…end (#10273)

Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: fujianhao.fjh <fujianhao.fjh@antgroup.com>
Signed-off-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
…#10319)

Co-authored-by: Teng Ma <sima.mt@alibaba-inc.com>
Co-authored-by: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com>
Co-authored-by: shicang <shicang@shicang>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
@lifuhuang lifuhuang changed the title Introduce CachedKernel to reduce LoRA kernel launch overheads [4/4] Introduce CachedKernel to reduce CSGMV kernel launch overheads Sep 21, 2025
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @lifuhuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance and modularity of the inference engine, particularly for LoRA and disaggregated setups. It introduces a new speculative decoding algorithm and expands multimodal model support, while also ensuring more predictable inference behavior through deterministic settings. The changes aim to improve efficiency, flexibility, and the range of supported models within the system.

Highlights

  • LoRA Performance Optimization: Introduced a CachedKernel mechanism to reduce LoRA kernel launch overheads, particularly within the chunked SGMV backend. This includes a new heuristic for determining optimal chunk sizes, aiming to improve performance for LoRA operations.
  • Disaggregation System Refactoring: Undertook a significant refactoring of the disaggregation connection management. Common logic for KV managers, senders, receivers, and bootstrap servers has been centralized into new Common classes. Bootstrap server communication now includes more detailed parallelization information (TP, DP, PP sizes and ranks) and improved handling of multiple prefill responses, enhancing multi-node inference capabilities.
  • New Multimodal Model Integration: Added support for the Sarashina2VisionForCausalLM multimodal model. This includes its specific multimodal data processing, a new Jinja chat template, and integration into the model configuration, enabling multimodal capabilities.
  • Lookahead Speculative Decoding: Implemented a new speculative decoding algorithm called 'Lookahead'. This features a dedicated C++ backend for efficient tree-based token generation and verification, along with new environment variables and related utility functions.
  • Attention Backend Modularity: Refactored the attention backend selection into a new attention_registry module. This change centralizes the registration and creation of various attention backends (e.g., FlashInfer, Triton, Flex Attention), making the system more modular and extensible.
  • Deterministic Inference Support: Introduced new environment variables and logic to enable deterministic inference, specifically for the FlashInfer attention backend. This ensures consistent outputs by controlling split tile sizes for prefill and decode operations and disabling KV split for CUDA graphs.
  • Core Dependency Updates: Updated several core dependencies, including DeepEP, sgl-kernel (to 0.3.11), mooncake-transfer-engine (to 0.3.6.post1), and flashinfer_python (to 0.4.0rc1). These updates bring the project in line with the latest versions and incorporate new features and bug fixes from these libraries.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@lifuhuang lifuhuang closed this Sep 21, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a large pull request that introduces several significant features and refactorings. The main changes include the introduction of CachedKernel to reduce kernel launch overheads, support for lookahead speculative decoding, and deterministic inference. There are also major refactorings in the disaggregation connection logic, attention backend selection, and the Rust router's worker management. Additionally, new models and attention backends are added. The code quality is generally high, with good use of design patterns like registries and builders to improve modularity and maintainability. My review focuses on potential issues and areas for improvement in the new and refactored code.

try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
response = requests.get(url)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This requests.get call is missing a timeout. Other network requests in this file have a timeout set (e.g., line 139, 372), and adding one here would prevent the thread from hanging indefinitely if the bootstrap server is unresponsive.

Suggested change
response = requests.get(url)
response = requests.get(url, timeout=5)

Comment on lines 166 to +167
else:
from vllm._custom_ops import awq_dequantize
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the vllm._custom_ops import for awq_dequantize on non-CUDA platforms means that AWQ quantization for this model will no longer work on platforms like ROCm/HIP, unless sgl-kernel provides an equivalent. This appears to be a regression in functionality for non-CUDA backends.

def init_engine(self):
# TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote()
local_ip = get_local_ip_auto()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new get_local_ip_auto function can raise a ValueError if it fails to determine the IP address and no fallback is provided. In a server context, this could prevent the service from starting in certain network environments. It would be safer to provide a fallback to bind to all interfaces.

Suggested change
local_ip = get_local_ip_auto()
local_ip = get_local_ip_auto(fallback="0.0.0.0")

Comment on lines 503 to 509
def grid(META):
max_seq_len = max(seq_lens_cpu)
return (
num_program(META, args),
len(seq_lens_cpu), # batch_size
(max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
triton.cdiv(dim, META["BLOCK_N"]),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new grid calculation launches kernels for every sequence up to the max_seq_len in the batch. For batches with varying sequence lengths, this will launch many empty kernels for shorter sequences, which then exit early due to the segment_len <= 0 check. This can introduce significant kernel launch overhead. It might be more efficient to compute a tighter grid, similar to the previous implementation's num_program function, to only launch kernels that will do actual work.

Comment on lines 588 to +601
# Compute common indices once to avoid duplication
last_steps_all = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
last_steps = last_steps_all[valid_mask].to(torch.int64)

if valid_state_indices.numel() > 0:
chunk = 256
num_valid = valid_state_indices.numel()

# SSM state updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
# per (cache line, step)
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
ssm_states[:, ci, :].copy_(
intermediate_state_cache[:, ci, st].to(
ssm_states.dtype, copy=False
)
)

# Conv window updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
conv_states[:, ci, :, :].copy_(
intermediate_conv_window_cache[:, ci, st].to(
conv_states.dtype, copy=False
)
)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]

# scatter into ssm_states at the chosen cache lines
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype, copy=False)

# Scatter into conv_states at the chosen cache lines
conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
:, valid_state_indices, last_steps
].to(conv_states.dtype, copy=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This refactoring from Python loops to vectorized tensor operations is a significant performance improvement. The use of advanced indexing is much more efficient and readable.

@zhyncs zhyncs deleted the lifu/cached-kernel branch September 21, 2025 07:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.