Fix: several bugs/issues with trtllm-gen attention kernels. #2062
Fix: several bugs/issues with trtllm-gen attention kernels. #2062yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request delivers crucial fixes for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
WalkthroughUpdates TRTLLM FMHA artifact path and checksum constants; extends FMHA kernel hash encoding to include a new sparseMla flag with adjusted bit-field layout and stricter head-dimension checks; and adds paged-KV / sparse-related fields to KernelParams with zero-initialization and log2 computation for numTokensPerPage. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Runner as Runner / Dispatch
participant Selector as Kernel Selector
participant Meta as KernelMeta
participant Loader as Kernel Loader
Note over Runner,Selector: Build selection key from runtime params
Runner->>Selector: hashFromRunnerParams(params, /* sparseMla */ false)
Selector->>Meta: select candidate KernelMeta
Note right of Meta: KernelMeta includes mSparseMla
Selector->>Loader: hashID(kernelMeta, sparseMla=Meta.mSparseMla)
Loader->>Loader: assemble 64-bit hash (includes sparseMla bit, log2(numTokensPerPage))
Loader->>Runner: return selected kernel / load artifacts (uses updated artifact checksum)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates artifact hashes and refines the kernel selection logic for trtllm-gen attention kernels. Key changes include adding a sparseMla parameter to the hashID function, adjusting bit shifts for head dimensions, and enforcing that numTokensPerPage must be a power of 2. New members have been added to the KernelParams struct to support these changes, and the struct is now explicitly zero-initialized using memset for improved safety. These modifications appear to address the reported CUDA launch errors and masking bugs, enhancing the robustness and correctness of the attention kernels.
|
@PerkzZheng would you mind rebasing to main branch? Seems there are some merge conflicts. |
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
8dc0a1b to
e4d7f46
Compare
it was rebased to a wrong remote. It should be good now. Thanks |
nvmbreughe
left a comment
There was a problem hiding this comment.
LGTM. Just wondering: for what config did we get failures without this fix? I think it would be good to have a test. I can add it after this PR.
| ) | ||
| TRTLLM_GEN_BMM: str = ( | ||
| "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" | ||
| "1ebace613389a4f2e10b14315da5d522642c5dcaae23f01213d56c59068f148b" |
There was a problem hiding this comment.
Why do we need to update the BMM hash in this PR?
|
/bot run |
|
[FAILED] Pipeline #38107936: 7/17 passed |
|
/bot run |
|
[FAILED] Pipeline #38135771: 14/17 passed |
…er-ai#2062) <!-- .github/pull_request_template.md --> ## 📌 Description This MR fixes: 1. unspecified cuda launch errors with 2CTA MLA kernels 2. masking bug of SWA decode kernels. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Sparse MLA support and propagated its flag through kernel selection and dispatch. * **Bug Fixes / Improvements** * Enforced power-of-two page sizing for paged KV caches and tightened head-dimension limits for broader hardware compatibility. * Updated kernel trait encoding and hash construction to include the sparse MLA flag and revised bit-field layout. * **Chores** * Updated runtime kernel artifact identifiers and checksums. * Extended kernel parameter fields, zero-initialized params on setup, and populated tokens-per-page log2 for paged KV. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Co-authored-by: yzh119 <zihaoy@nvidia.com> Co-authored-by: Zihao Ye <expye@outlook.com>
📌 Description
This MR fixes:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Improvements
Chores