feat: add trtllm-gen mha from direct call#8782
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates a new, highly optimized attention backend that utilizes TensorRT-LLM's Multi-Head Attention (MHA) kernels through FlashInfer. The primary goal is to accelerate large language model inference, particularly during the token generation phase, by leveraging hardware-specific optimizations on NVIDIA Blackwell GPUs. The changes involve a new attention backend class, a generalized Triton kernel for efficient KV cache indexing, and robust validation for backend configuration.
Highlights
- New Attention Backend: Introduced
TRTLLMHAAttnBackendto integrate TensorRT-LLM's Multi-Head Attention (MHA) kernels via FlashInfer, specifically for the decode (token generation) path. This aims to leverage highly optimized kernels for improved inference performance. - Enhanced KV Cache Indexing: The core Triton kernel
create_flashinfer_kv_indices_tritonhas been significantly updated to support variable page sizes, which is essential for efficient memory management and performance with the new TRT-LLM MHA backend. It now correctly computes and stores page indices for the KV cache. - Hardware-Specific Optimization and Validation: The new
trtllm_mhabackend is explicitly enabled for NVIDIA Blackwell GPUs (SM100 architecture) and enforces specificpage_sizevalues (16, 32, or 64) to ensure optimal performance and compatibility with the underlying TRT-LLM kernels. Validation checks are added to prevent misconfiguration. - CUDA Graph Integration: The
TRTLLMHAAttnBackendis designed to work seamlessly with CUDA graphs, allowing for further performance gains during the decode phase by capturing and replaying the attention computation graph.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds a new trtllm-mha attention backend. The changes are a good starting point, but there are several critical issues that need to be addressed. The new backend implementation in trtllm_mha_backend.py has incorrect imports and calls the wrong Triton kernel with incorrect arguments, which will cause it to fail. The forward_extend method is also unimplemented. Additionally, there are some logical inconsistencies in the model runner and server arguments, and the tests for the new Triton kernel are incomplete. I've provided detailed comments and suggestions to fix these issues.
| create_flashinfer_kv_indices_triton, | ||
| ) |
There was a problem hiding this comment.
The wrong Triton kernel is being imported and used throughout this file. The trtllm_mha backend uses fixed-size block tables (a 2D tensor), but create_flashinfer_kv_indices_triton is designed for ragged tensors (1D tensor with kv_indptr).
The correct kernel to use is create_flashmla_kv_indices_triton, which is designed for 2D block tables. This needs to be imported instead.
Consequently, all calls to create_flashinfer_kv_indices_triton in this file (in _create_block_kv_indices, init_forward_metadata_capture_cuda_graph, and init_forward_metadata_replay_cuda_graph) are incorrect and need to be updated to call create_flashmla_kv_indices_triton with the correct arguments. This is a critical issue that will cause the backend to fail.
from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
)| create_flashinfer_kv_indices_triton[(batch_size,)]( | ||
| self.req_to_token, | ||
| req_pool_indices, | ||
| seq_lens, | ||
| None, | ||
| block_kv_indices, | ||
| self.req_to_token.stride(0), | ||
| max_blocks, | ||
| NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK, | ||
| PAGED_SIZE=self.page_size, | ||
| ) |
There was a problem hiding this comment.
This call to create_flashinfer_kv_indices_triton is incorrect. As mentioned in another comment, you should be calling create_flashmla_kv_indices_triton.
Furthermore, the arguments are mismatched. The correct call should use create_flashmla_kv_indices_triton and pass the correct arguments for a 2D block table. The kv_indices_ptr_stride should be max_blocks.
Here is the corrected call:
create_flashmla_kv_indices_triton[(batch_size,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None, # kv_start_idx
block_kv_indices,
self.req_to_token.stride(0),
max_blocks, # kv_indices_ptr_stride
PAGED_SIZE=self.page_size,
)| create_flashinfer_kv_indices_triton[(bs,)]( | ||
| self.req_to_token, | ||
| req_pool_indices, | ||
| seq_lens, | ||
| None, | ||
| block_kv_indices, | ||
| self.req_to_token.shape[1], | ||
| PAGE_SIZE=self.page_size, | ||
| ) |
There was a problem hiding this comment.
This call to create_flashinfer_kv_indices_triton is incorrect. It should be create_flashmla_kv_indices_triton.
The arguments are also incorrect. The call is missing the kv_indices_ptr_stride argument, and other arguments are shifted, which will lead to a runtime error. The kv_indices_ptr_stride should be block_kv_indices.stride(0).
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.shape[1],
block_kv_indices.stride(0),
PAGE_SIZE=self.page_size,
)| create_flashinfer_kv_indices_triton[(bs,)]( | ||
| self.req_to_token, | ||
| req_pool_indices, | ||
| seq_lens, | ||
| None, | ||
| metadata.block_kv_indices, | ||
| self.req_to_token.shape[1], | ||
| PAGE_SIZE=self.page_size, | ||
| ) |
There was a problem hiding this comment.
This call to create_flashinfer_kv_indices_triton is incorrect. It should be create_flashmla_kv_indices_triton.
Similar to the other calls in this file, the arguments are incorrect. The call is missing the kv_indices_ptr_stride argument, and other arguments are shifted. The kv_indices_ptr_stride should be metadata.block_kv_indices.stride(0).
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
metadata.block_kv_indices,
self.req_to_token.shape[1],
metadata.block_kv_indices.stride(0),
PAGE_SIZE=self.page_size,
)| @@ -0,0 +1,372 @@ | |||
| from __future__ import annotations | |||
|
|
|||
| from python.sglang.srt.layers.radix_attention import RadixAttention | |||
There was a problem hiding this comment.
The import path python.sglang.srt.layers.radix_attention is incorrect. The python. prefix is not standard in this project and will likely cause import errors. Please remove it to align with the project's standard import style.
| from python.sglang.srt.layers.radix_attention import RadixAttention | |
| from sglang.srt.layers.radix_attention import RadixAttention |
| # todo(Yingyi): implement this | ||
| return None |
There was a problem hiding this comment.
The forward_extend method is not implemented and just returns None. This will likely cause a TypeError or other unexpected behavior downstream when the return value is used. If this functionality is not yet supported, it's better to raise a NotImplementedError to fail explicitly and clearly.
# todo(Yingyi): implement this
raise NotImplementedError("forward_extend is not implemented for TRTLLMHAAttnBackend yet.")| if not self.use_mla_backend: | ||
| raise ValueError("trtllm_mha backend can only be used with MLA models.") |
There was a problem hiding this comment.
The check if not self.use_mla_backend: seems incorrect for the trtllm_mha backend. MHA (Multi-Head Attention) is the standard attention mechanism, while MLA is Multi-Layer Attention. This check implies that the MHA backend can only be used with MLA models, which is contradictory.
This seems like a copy-paste error from the trtllm_mla case above. This check should likely be removed or corrected to ensure the MHA backend can be used with standard MHA models.
| if not self.use_mla_backend: | |
| raise ValueError("trtllm_mha backend can only be used with MLA models.") | |
| # if not self.use_mla_backend: | |
| # raise ValueError("trtllm_mha backend can only be used with MLA models.") |
| max_pages = max_context_len // PAGE_SIZE | ||
| kv_indices_flashmla = torch.empty( | ||
| batch, max_pages, dtype=torch.int32, device="cuda" | ||
| ) | ||
|
|
||
| create_flashmla_kv_indices_triton[(batch,)]( | ||
| req_to_token, | ||
| req_pool_indices, | ||
| seq_lens, | ||
| None, | ||
| kv_indices_flashmla, | ||
| req_to_token.size(1), | ||
| max_pages, | ||
| PAGE_SIZE, | ||
| ) |
There was a problem hiding this comment.
| # if self.speculative_algorithm is not None: | ||
| # raise ValueError( | ||
| # "trtllm_mla backend does not support speculative decoding yet." | ||
| # ) |
There was a problem hiding this comment.
The error message in this commented-out block refers to trtllm_mla backend, but this code is within the trtllm_mha block. This is likely a copy-paste error and could be confusing. Please update the message to refer to trtllm_mha for consistency, even if it's commented out.
| # if self.speculative_algorithm is not None: | |
| # raise ValueError( | |
| # "trtllm_mla backend does not support speculative decoding yet." | |
| # ) | |
| # if self.speculative_algorithm is not None: | |
| # raise ValueError( | |
| # "trtllm_mha backend does not support speculative decoding yet." | |
| # ) |
| # for debug | ||
| # BATCH = [4] | ||
| # MAX_BATCH = 4 | ||
| # MAX_CONTEXT_LEN = 10 | ||
| # Test for small batch size | ||
| for page_size in PAGE_SIZE[:1]: | ||
| print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}") | ||
| self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size) | ||
|
|
||
| # Test for larger batch size | ||
| for batch in BATCH[1:]: | ||
| for page_size in PAGE_SIZE: | ||
| print( | ||
| f"Running test for batch size: {batch} and page size: {page_size}" | ||
| ) | ||
| self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size) |
There was a problem hiding this comment.
This test contains commented-out debug code and print statements. These should be removed before merging to keep the test suite clean.
| # for debug | |
| # BATCH = [4] | |
| # MAX_BATCH = 4 | |
| # MAX_CONTEXT_LEN = 10 | |
| # Test for small batch size | |
| for page_size in PAGE_SIZE[:1]: | |
| print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}") | |
| self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size) | |
| # Test for larger batch size | |
| for batch in BATCH[1:]: | |
| for page_size in PAGE_SIZE: | |
| print( | |
| f"Running test for batch size: {batch} and page size: {page_size}" | |
| ) | |
| self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size) | |
| # Test for small batch size | |
| for page_size in PAGE_SIZE[:1]: | |
| self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size) | |
| # Test for larger batch size | |
| for batch in BATCH[1:]: | |
| for page_size in PAGE_SIZE: | |
| self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size) |
…g/sglang into flashinfer_trtllm_gen_mha_direct
This reverts commit d19c10f.
…g/sglang into flashinfer_trtllm_gen_mha_direct
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Motivation
Modifications
Accuracy Test
Benchmark & Profiling
benchmark with
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1000 --parallel 1000Results of the updated kernel:
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --trust-remote --attention-backend trtllm_mha --page-size 64Results of prev flashinfer kernel:
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --trust-remote --attention-backend flashinfer --page-size 64Checklist