refactor: refactor trtllm-gen attention kernel integration code#1289
refactor: refactor trtllm-gen attention kernel integration code#1289yzh119 merged 42 commits intoflashinfer-ai:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on refactoring the integration of TRTLLM-GEN attention kernels within the FlashInfer library. The primary goal is to streamline the interface between Python and C++/CUDA code by transitioning from at::Tensor objects to raw pointers and explicit dimension parameters in the low-level kernel launchers. This change aims to improve performance by reducing PyTorch tensor overheads and enables more direct control over kernel parameters, such as the Streaming Multiprocessor count, for better hardware utilization. Additionally, the C++ code has been cleaned up with improved data type dispatching and parameter structure initialization.
Highlights
- C++ Kernel Interface Refactoring: Refactored C++ kernel launcher signatures (
trtllm_paged_attention_decode_launcher,trtllm_paged_attention_context_launcher,trtllm_paged_attention_mla_launcher) to accept raw pointers and explicit dimension/count parameters (e.g.,batch_size,head_dim,sm_count) instead ofat::Tensorobjects. This aims to reduce PyTorch tensor overhead and provide more direct control over kernel execution. - Data Type Dispatching Improvement: Replaced verbose
if-else ifblocks for data type dispatching in C++ attention functions (trtllm_paged_attention_decode,trtllm_paged_attention_context) with a cleanerDISPATCH_PYTORCH_DTYPE_TO_CTYPEmacro, improving code readability and maintainability. - Hardware-Aware Optimization: Introduced
sm_count(Streaming Multiprocessor count) as a direct parameter to CUDA kernel launchers, allowing the kernels to potentially make more informed decisions based on the target GPU's architecture. - Parameter Structure Initialization: Added a default constructor to
TllmGenFmhaRunnerParamsthat performs zero-initialization usingmemset, eliminating the need for explicitmemsetcalls in the kernel launcher functions. - Python Frontend Updates: Updated Python frontend calls in
flashinfer/decode.py,flashinfer/prefill.py, andflashinfer/jit/attention/pytorch.pyto retrieve the device'ssm_countusing a new utility function and pass it down to the refactored C++ kernels.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request refactors the TensorRT-LLM attention kernel integration, primarily by decoupling the CUDA/C++ code from PyTorch tensors and improving parameter handling. The changes are a good step towards better modularity and clarity.
I've identified a critical bug related to incorrect batch_size handling that needs to be addressed, along with a few other medium to high severity issues concerning potential regressions and code correctness. Please see the detailed comments for specifics.
There was a problem hiding this comment.
The local variable batch_size_ is initialized with query.size(0), which corresponds to the total number of query tokens (sum_seq_q), not the batch size. This variable shadows the batch_size function parameter, which holds the correct number of sequences. This is a critical bug.
Please remove this line and use the batch_size function parameter in the call to trtllm_paged_attention_context_launcher on line 260.
There was a problem hiding this comment.
The batch_size_ variable used here is incorrect as it holds the total number of query tokens instead of the batch size. Please use the batch_size function parameter instead.
After this change, please also review the initialization of runner_params.mMaxSeqLenQ inside trtllm_paged_attention_context_launcher. It is currently set to batch_size, which would become the number of sequences after the fix above. This is likely incorrect, as it should be the maximum query sequence length in the batch.
batch_size, max_seq_len, num_qo_heads, num_kv_heads, head_dim, page_size,
There was a problem hiding this comment.
The parameter runner_params.mNumPagesInMemPool is hardcoded to 0. The previous implementation calculated this value based on the total device memory. Setting it to 0 could cause issues if the underlying kernel relies on this value for resource allocation or performance tuning.
Since key_value_cache is now a raw pointer, its size is not available here. Consider passing the total number of pages in the cache pool as an argument from the Python side to correctly calculate this value. For example, in csrc/trtllm_mla_kernel_launcher.cu, this was calculated as key_value_cache.size(0) * 2.
There was a problem hiding this comment.
The function signature of trtllm_paged_attention_mla_launcher still uses at::Tensor arguments, while trtllm_paged_attention_decode_launcher and trtllm_paged_attention_context_launcher in trtllm_fmha_kernel_launcher.cu have been refactored to use raw pointers.
For consistency and to decouple the CUDA kernels from PyTorch's at::Tensor, consider applying the same refactoring to this function as well.
bbf2db0 to
b759e9e
Compare
| int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t max_num_blocks_per_seq, | ||
| double bmm1_scale, double bmm2_scale, int64_t window_left, int64_t sum_seq_q, int64_t sm_count, | ||
| cudaStream_t stream, int* cum_seq_lens_q = nullptr, int* cum_seq_lens_kv = nullptr) { | ||
| if (num_qo_heads % num_kv_heads != 0) { |
There was a problem hiding this comment.
If there are other restrictions on group size (i.e., num_qo_heads // num_kv_heads) we should throw that error as well
…v into refactor-trtllm-gen
| runner_params.outputScale = bmm2_scale; | ||
| runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; | ||
| runner_params.mChunkedAttentionSize = INT_MAX; | ||
| runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; |
There was a problem hiding this comment.
it is added by one because flashinfer assumes that the sliding window attention should consider the extra token during masking, right ? probably we can add a comment here.
There was a problem hiding this comment.
| runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; | |
| // Add one to include the extra token during masking. | |
| runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; |
| use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; | ||
| runner_params.mMultiCtasKvMode = use_multi_block; | ||
|
|
||
| size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); |
There was a problem hiding this comment.
is the rounding for 16B alignment ? then what about the workspace_buffer
There was a problem hiding this comment.
| size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); | |
| // Round up num_semaphores to a mulitple of 8 since `multiCtasKvScratchPtr` requires 16B alignment. | |
| size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); |
There was a problem hiding this comment.
@PerkzZheng workspace_buffer is allocated by users. We should ask users to make sure workspace_buffer is 16B aligned. I have added comments for this.
| Context, | ||
| ForGen, | ||
| }; | ||
|
|
There was a problem hiding this comment.
nit (P1): add documentation:
//! \brief Helper function to launch a trtllm paged attention kernel.
//! \note This function should not be called directly from another file. Use `trtllm_paged_attention_decode` and `trtllm_paged_attention_context` instead.
//!
//! \param out Device pointer to the output tensor.
//! \param query Device pointer to the input query tensor.
//! \param key_cache Device pointer to the input paged key cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param value_cache Device pointer to the input paged value cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param workspace_buffer Device pointer to the workspace. Must be at least 16-byte aligned. Recommended to allocate at least 128MB for workspace.
//! \param block_tables Device pointer to the block tables. The table shape is [batch_size, max_num_blocks_per_seq].
//! \param seq_lens Device pointer to the sequeunce lengths. The shape is [batch_size].
//! \param batch_size Batch size, i.e. the number of sequences in the batch.
//! \param max_q_len Maximum number of query tokens per sequence in the batch.
//! \param max_kv_len Maximum number of key/value tokens per sequence in the batch.
//! \param num_pages Maximum number of pages of the kv-cache.
//! \param num_qo_heads Number of query heads.
//! \param num_kv_heads Number of key/value heads.
//! \param head_dim_qk Head dimension of query/key.
//! \param head_dim_vo Head dimension of value/output.
//! \param page_size Number of tokens per page.
//! \param kv_stride_0 Stride of the "page_size" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_1 Stride of the "num_kv_heads" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_2 Stride of the "2" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param max_num_blocks_per_seq Maximum number of blocks that can be allocated for a sequence.
//! \param bmm1_scale The scaling factor applied between BMM1 and Softmax, not including the LOG2E factor.
//! \param bmm2_scale The scaling factor applied after BMM2, not including the scaling factor for Softmax output.
//! \param window_left The window size for sliding window attention. Set to -1 to disable sliding window attention.
//! \param sum_seq_q Total number of query tokens within the batch.
//! \param sm_count The number of SMs on the GPU.
//! \param stream The cuda stream to launch the kernel on.
//! \param cum_seq_lens_q Device pointer to the tensor storing the accumulated sequence lengths of the query tokens in the batch. Not used in ForGen mode.
//! \param cum_seq_lens_kv Device pointer to the tensor storing the accumulated sequence lengths of the key/value tokens in the batch. Not used in ForGen mode.
…v into refactor-trtllm-gen
| } else { | ||
| auto runner = std::make_shared<TllmGenFmhaRunner>(q_data_type, kv_data_type, o_data_type); | ||
| cache.emplace(key, runner); | ||
| return runner; |
There was a problem hiding this comment.
Since std::unordered_map never invalidates references and cache outlives everything, we can do
static std::unordered_map<Key, TllmGenFmhaRunner, KeyHash> cache;and return TllmGenFmhaRunner& instead
auto runner = TllmGenFmhaRunner(q_data_type, kv_data_type, o_data_type);
auto [it, ok] = cache.emplace(key, std::move(runner));
return it->second;
📌 Description
Simplify and unify the interface for trtllm-gen decode/prefill/mla kernels, and add support for shared-kv (in MLA, #1273).
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes