Skip to content

feat: add trtllm-gen mha from direct call#8782

Merged
zhyncs merged 22 commits intosgl-project:mainfrom
yyihuang:flashinfer_trtllm_gen_mha_direct
Aug 5, 2025
Merged

feat: add trtllm-gen mha from direct call#8782
zhyncs merged 22 commits intosgl-project:mainfrom
yyihuang:flashinfer_trtllm_gen_mha_direct

Conversation

@yyihuang
Copy link
Copy Markdown
Collaborator

@yyihuang yyihuang commented Aug 5, 2025

Motivation

Modifications

Accuracy Test

Benchmark & Profiling

benchmark with
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1000 --parallel 1000

Results of the updated kernel:
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --trust-remote --attention-backend trtllm_mha --page-size 64

Accuracy: 0.797
Invalid: 0.001
Latency: 6.904 s
Output throughput: 14548.023 token/s

Results of prev flashinfer kernel:
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --trust-remote --attention-backend flashinfer --page-size 64

Accuracy: 0.796
Invalid: 0.000
Latency: 10.477 s
Output throughput: 9496.439 token/s

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new, highly optimized attention backend that utilizes TensorRT-LLM's Multi-Head Attention (MHA) kernels through FlashInfer. The primary goal is to accelerate large language model inference, particularly during the token generation phase, by leveraging hardware-specific optimizations on NVIDIA Blackwell GPUs. The changes involve a new attention backend class, a generalized Triton kernel for efficient KV cache indexing, and robust validation for backend configuration.

Highlights

  • New Attention Backend: Introduced TRTLLMHAAttnBackend to integrate TensorRT-LLM's Multi-Head Attention (MHA) kernels via FlashInfer, specifically for the decode (token generation) path. This aims to leverage highly optimized kernels for improved inference performance.
  • Enhanced KV Cache Indexing: The core Triton kernel create_flashinfer_kv_indices_triton has been significantly updated to support variable page sizes, which is essential for efficient memory management and performance with the new TRT-LLM MHA backend. It now correctly computes and stores page indices for the KV cache.
  • Hardware-Specific Optimization and Validation: The new trtllm_mha backend is explicitly enabled for NVIDIA Blackwell GPUs (SM100 architecture) and enforces specific page_size values (16, 32, or 64) to ensure optimal performance and compatibility with the underlying TRT-LLM kernels. Validation checks are added to prevent misconfiguration.
  • CUDA Graph Integration: The TRTLLMHAAttnBackend is designed to work seamlessly with CUDA graphs, allowing for further performance gains during the decode phase by capturing and replaying the attention computation graph.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a new trtllm-mha attention backend. The changes are a good starting point, but there are several critical issues that need to be addressed. The new backend implementation in trtllm_mha_backend.py has incorrect imports and calls the wrong Triton kernel with incorrect arguments, which will cause it to fail. The forward_extend method is also unimplemented. Additionally, there are some logical inconsistencies in the model runner and server arguments, and the tests for the new Triton kernel are incomplete. I've provided detailed comments and suggestions to fix these issues.

Comment on lines +19 to +20
create_flashinfer_kv_indices_triton,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The wrong Triton kernel is being imported and used throughout this file. The trtllm_mha backend uses fixed-size block tables (a 2D tensor), but create_flashinfer_kv_indices_triton is designed for ragged tensors (1D tensor with kv_indptr).

The correct kernel to use is create_flashmla_kv_indices_triton, which is designed for 2D block tables. This needs to be imported instead.

Consequently, all calls to create_flashinfer_kv_indices_triton in this file (in _create_block_kv_indices, init_forward_metadata_capture_cuda_graph, and init_forward_metadata_replay_cuda_graph) are incorrect and need to be updated to call create_flashmla_kv_indices_triton with the correct arguments. This is a critical issue that will cause the backend to fail.

from sglang.srt.layers.attention.utils import (
    TRITON_PAD_NUM_PAGE_PER_BLOCK,
    create_flashmla_kv_indices_triton,
)

Comment on lines +142 to +152
create_flashinfer_kv_indices_triton[(batch_size,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_blocks,
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE=self.page_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This call to create_flashinfer_kv_indices_triton is incorrect. As mentioned in another comment, you should be calling create_flashmla_kv_indices_triton.

Furthermore, the arguments are mismatched. The correct call should use create_flashmla_kv_indices_triton and pass the correct arguments for a 2D block table. The kv_indices_ptr_stride should be max_blocks.

Here is the corrected call:

        create_flashmla_kv_indices_triton[(batch_size,)]( 
            self.req_to_token, 
            req_pool_indices, 
            seq_lens, 
            None,  # kv_start_idx 
            block_kv_indices, 
            self.req_to_token.stride(0), 
            max_blocks,  # kv_indices_ptr_stride 
            PAGED_SIZE=self.page_size, 
        )

Comment on lines +200 to +208
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.shape[1],
PAGE_SIZE=self.page_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This call to create_flashinfer_kv_indices_triton is incorrect. It should be create_flashmla_kv_indices_triton.

The arguments are also incorrect. The call is missing the kv_indices_ptr_stride argument, and other arguments are shifted, which will lead to a runtime error. The kv_indices_ptr_stride should be block_kv_indices.stride(0).

        create_flashmla_kv_indices_triton[(bs,)]( 
            self.req_to_token, 
            req_pool_indices, 
            seq_lens, 
            None, 
            block_kv_indices, 
            self.req_to_token.shape[1], 
            block_kv_indices.stride(0), 
            PAGE_SIZE=self.page_size, 
        )

Comment on lines +243 to +251
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
metadata.block_kv_indices,
self.req_to_token.shape[1],
PAGE_SIZE=self.page_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This call to create_flashinfer_kv_indices_triton is incorrect. It should be create_flashmla_kv_indices_triton.

Similar to the other calls in this file, the arguments are incorrect. The call is missing the kv_indices_ptr_stride argument, and other arguments are shifted. The kv_indices_ptr_stride should be metadata.block_kv_indices.stride(0).

        create_flashmla_kv_indices_triton[(bs,)]( 
            self.req_to_token, 
            req_pool_indices, 
            seq_lens, 
            None, 
            metadata.block_kv_indices, 
            self.req_to_token.shape[1], 
            metadata.block_kv_indices.stride(0), 
            PAGE_SIZE=self.page_size, 
        )

@@ -0,0 +1,372 @@
from __future__ import annotations

from python.sglang.srt.layers.radix_attention import RadixAttention
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import path python.sglang.srt.layers.radix_attention is incorrect. The python. prefix is not standard in this project and will likely cause import errors. Please remove it to align with the project's standard import style.

Suggested change
from python.sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.radix_attention import RadixAttention

Comment on lines +371 to +372
# todo(Yingyi): implement this
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward_extend method is not implemented and just returns None. This will likely cause a TypeError or other unexpected behavior downstream when the return value is used. If this functionality is not yet supported, it's better to raise a NotImplementedError to fail explicitly and clearly.

        # todo(Yingyi): implement this
        raise NotImplementedError("forward_extend is not implemented for TRTLLMHAAttnBackend yet.")

Comment on lines +1453 to +1454
if not self.use_mla_backend:
raise ValueError("trtllm_mha backend can only be used with MLA models.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if not self.use_mla_backend: seems incorrect for the trtllm_mha backend. MHA (Multi-Head Attention) is the standard attention mechanism, while MLA is Multi-Layer Attention. This check implies that the MHA backend can only be used with MLA models, which is contradictory.

This seems like a copy-paste error from the trtllm_mla case above. This check should likely be removed or corrected to ensure the MHA backend can be used with standard MHA models.

Suggested change
if not self.use_mla_backend:
raise ValueError("trtllm_mha backend can only be used with MLA models.")
# if not self.use_mla_backend:
# raise ValueError("trtllm_mha backend can only be used with MLA models.")

Comment thread test/srt/test_create_kvindices.py Outdated
Comment on lines +75 to +89
max_pages = max_context_len // PAGE_SIZE
kv_indices_flashmla = torch.empty(
batch, max_pages, dtype=torch.int32, device="cuda"
)

create_flashmla_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
seq_lens,
None,
kv_indices_flashmla,
req_to_token.size(1),
max_pages,
PAGE_SIZE,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test for create_flashmla_kv_indices_triton is incomplete. The kernel is called and kv_indices_flashmla is populated, but there is no assertion to verify its correctness. Please add a reference implementation and an assertion for kv_indices_flashmla to ensure the kernel works as expected.

Comment thread python/sglang/srt/server_args.py Outdated
Comment on lines +455 to +458
# if self.speculative_algorithm is not None:
# raise ValueError(
# "trtllm_mla backend does not support speculative decoding yet."
# )
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message in this commented-out block refers to trtllm_mla backend, but this code is within the trtllm_mha block. This is likely a copy-paste error and could be confusing. Please update the message to refer to trtllm_mha for consistency, even if it's commented out.

Suggested change
# if self.speculative_algorithm is not None:
# raise ValueError(
# "trtllm_mla backend does not support speculative decoding yet."
# )
# if self.speculative_algorithm is not None:
# raise ValueError(
# "trtllm_mha backend does not support speculative decoding yet."
# )

Comment thread test/srt/test_create_kvindices.py Outdated
Comment on lines +98 to +113
# for debug
# BATCH = [4]
# MAX_BATCH = 4
# MAX_CONTEXT_LEN = 10
# Test for small batch size
for page_size in PAGE_SIZE[:1]:
print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)

# Test for larger batch size
for batch in BATCH[1:]:
for page_size in PAGE_SIZE:
print(
f"Running test for batch size: {batch} and page size: {page_size}"
)
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test contains commented-out debug code and print statements. These should be removed before merging to keep the test suite clean.

Suggested change
# for debug
# BATCH = [4]
# MAX_BATCH = 4
# MAX_CONTEXT_LEN = 10
# Test for small batch size
for page_size in PAGE_SIZE[:1]:
print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)
# Test for larger batch size
for batch in BATCH[1:]:
for page_size in PAGE_SIZE:
print(
f"Running test for batch size: {batch} and page size: {page_size}"
)
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
# Test for small batch size
for page_size in PAGE_SIZE[:1]:
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)
# Test for larger batch size
for batch in BATCH[1:]:
for page_size in PAGE_SIZE:
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)

@zhyncs zhyncs marked this pull request as ready for review August 5, 2025 08:05
@zhyncs zhyncs self-assigned this Aug 5, 2025
@yyihuang yyihuang added the ready-to-merge The PR is ready to merge after the CI is green. label Aug 5, 2025
@zhyncs zhyncs merged commit 40e3b2b into sgl-project:main Aug 5, 2025
22 of 60 checks passed
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

collaboration high priority ready-to-merge The PR is ready to merge after the CI is green.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants