Skip to content

misc: Add XQA decode to microbenchmark for sm90 and sm120#2055

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
bkryu:bench_sm90_120_xqa_decode
Nov 7, 2025
Merged

misc: Add XQA decode to microbenchmark for sm90 and sm120#2055
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
bkryu:bench_sm90_120_xqa_decode

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Nov 6, 2025

📌 Description

In #2001 , XQA decode kernels became available through trtllm_batch_decode_with_kv_cache on SM90 and SM120.

Current PR adds the ability to benchmark through the microbenchmark.

Example microbenchmark command and outputs before and after:

### Before current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping.
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping.
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec

### After current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec
[PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Standardized backend identifier to "trtllm-native" and expanded its support across benchmark routines and utilities.
    • Argument parsing now canonicalizes deprecated backend aliases and emits a deprecation warning when encountered.
  • Documentation
    • README and tool-facing messages updated to use the canonical backend name and include contextual notes about the change.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 6, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR renames and normalizes a deprecated backend identifier: it replaces occurrences of trtllm-gen-native with trtllm-native in docs and backend mappings, adds a runtime normalization helper to canonicalize backend names during CLI parsing, and updates allowed-backend lists plus adjacent NOTE comments in flashinfer benchmark utilities. No behavioral logic changes beyond name normalization and list updates.

Changes

Cohort / File(s) Summary
Backend mapping updates
benchmarks/routines/flashinfer_benchmark_utils.py
Replaced/expanded backend entries to include trtllm-native (previously referenced as trtllm-gen-native) across multiple wrapper-version mappings (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper). Added adjacent NOTE comments documenting which underlying trtllm functions each backend invokes.
Backend normalization & CLI handling
benchmarks/routines/attention.py
Added normalize_backends(backends) to canonicalize backend names (replacing trtllm-gen-nativetrtllm-native with a warning). parse_attention_args now calls this normalization after parsing; updated docstrings, help messages, and backend references to use the canonical name.
Documentation updates
benchmarks/README.md
Replaced occurrences of trtllm-gen-native with trtllm-native in the general backends line, the four routine rows noted above, and the Backend Legend.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant CLI as parse_attention_args
    participant Norm as normalize_backends
    participant Runner as BenchmarkRunner

    User->>CLI: invoke CLI with --backends [user list]
    CLI->>Norm: normalize_backends(backends)
    Note right of Norm `#DDFFDD`: replace deprecated\n"trtllm-gen-native" → "trtllm-native"\nand emit warning
    Norm-->>CLI: normalized backends
    CLI->>Runner: start benchmark with normalized backends
    Runner-->>User: run results
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Changes are cohesive (docs, CLI normalization, backend maps) but touch multiple responsibilities.
  • Review focus:
    • Ensure no remaining references to trtllm-gen-native.
    • Verify NOTE comments in flashinfer_benchmark_utils.py reference correct underlying function names.
    • Confirm normalize_backends warning format and that CLI flow uses normalized list.

Suggested reviewers

  • Anerudhan
  • cyx-6
  • nvmbreughe

Poem

🐇 I hopped through mappings, quick and bright,
swapped an old name for a new delight.
A gentle warning, a tidy chain —
benchmarks hum, and I nibble a grain. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 2 inconclusive)
Check name Status Explanation Resolution
Title check ⚠️ Warning The title 'misc: Add XQA decode to microbenchmark for sm90 and sm120' is vague and misleading. It describes adding XQA decode support but the changes actually rename backends from trtllm-gen-native to trtllm-native and update backend support mappings. Revise the title to accurately reflect the main change, such as 'Rename trtllm-gen-native backend to trtllm-native and add support for SM90/SM120 microbenchmarks' or 'Update backend naming and add XQA decode support to benchmarks'.
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The description is incomplete and missing required sections. While it provides context and before/after examples, it lacks a clear summary of what changes were made (backend renaming), related issue link, and adequate explanation of why the changes are needed. Add a clear explanation of the backend naming change (trtllm-gen-native to trtllm-native), link to related issue #2001, and clarify how backend support mappings were updated for SM90 and SM120.
Linked Issues check ❓ Inconclusive The PR description mentions issue #2001 where XQA decode kernels became available, but does not formally link it in the 'Related Issues' section of the pull request. Explicitly link the related issue (#2001) in the PR description's 'Related Issues' section for better traceability.
✅ Passed checks (1 passed)
Check name Status Explanation
Out of Scope Changes check ✅ Passed The changes appear to be in scope as they consistently update backend naming and support mappings across benchmark utilities, README, and attention routines. All modifications are related to enabling XQA decode benchmarking on SM90/SM120.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9392306 and 8c125b8.

📒 Files selected for processing (1)
  • benchmarks/routines/attention.py (13 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/routines/attention.py (3)

86-95: LGTM: Good backward compatibility approach.

The choices list correctly includes both the new "trtllm-native" and the deprecated "trtllm-gen-native" with a clear comment indicating deprecation.


179-181: LGTM: Normalization applied at the right time.

The normalization is correctly applied early in the argument parsing flow, ensuring all subsequent code uses the canonical backend name.


217-217: LGTM: Consistent usage of the canonical backend name.

All docstrings, backend conditionals, and error messages throughout the file consistently use "trtllm-native" instead of the deprecated "trtllm-gen-native". The implementation is thorough and well-integrated.

Also applies to: 522-522, 646-646, 729-735, 987-987, 1210-1224, 1436-1436, 1570-1570, 1666-1674, 1839-1839


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request extends the microbenchmarking framework to support performance evaluation of XQA decode kernels, specifically trtllm_batch_decode_with_kv_cache, on NVIDIA's SM90 and SM120 architectures. By integrating trtllm-gen-native into the benchmark utility, developers can now measure the performance of these optimized kernels alongside existing backends, ensuring comprehensive performance analysis for advanced GPU hardware.

Highlights

  • XQA Decode Benchmarking: Enabled benchmarking of trtllm-gen-native (XQA decode kernels) for BatchDecodeWithPagedKVCacheWrapper on NVIDIA compute capabilities SM90 (H200) and SM120 (RTX 5090).
  • Microbenchmark Configuration Update: Modified benchmarks/routines/flashinfer_benchmark_utils.py to include trtllm-gen-native in the list of supported backends for the relevant decode routine on SM90 and SM120.
  • Code Clarity: Added inline comments to flashinfer_benchmark_utils.py to explicitly link trtllm-gen-native backend calls to their corresponding trtllm functions for different attention wrappers.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for benchmarking the XQA decode kernel (trtllm-gen-native backend) on SM90 and SM120 architectures. The changes correctly update the routine_cc_to_supported_backends dictionary to include this new backend for the specified compute capabilities. The addition of comments explaining which underlying functions are called by the trtllm-gen-native backend is a good step towards improving code clarity. I have one suggestion to enhance maintainability by applying this commenting practice consistently for other backends as well.

routine_cc_to_supported_backends = {
# ATTENTION
"BatchDecodeWithPagedKVCacheWrapper": {
# NOTE: trtllm-gen-native calls trtllm_batch_decode_with_kv_cache
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These new comments explaining which function the trtllm-gen-native backend calls are very helpful for understanding the code. However, this information is only provided for the trtllm-gen-native backend, while other backends like fa2, cudnn, and trtllm-gen lack similar explanations.

For consistency and better maintainability, I suggest adding similar comments for the other backends across all routines in this dictionary. This will make it easier for future contributors to understand the purpose of each backend.

For example, for BatchDecodeWithPagedKVCacheWrapper:

"BatchDecodeWithPagedKVCacheWrapper": {
    # NOTE: 'fa2' uses FlashAttention-2, 'cudnn' uses cuDNN SDPA, etc.
    # NOTE: 'trtllm-gen-native' calls trtllm_batch_decode_with_kv_cache
    "7.5": ["fa2"],
    # ...
},

Alternatively, if this dictionary is expected to grow, consider refactoring this information into a more structured format, like a separate dictionary mapping routines and backends to the functions they call. This would be more scalable.

@bkryu bkryu self-assigned this Nov 6, 2025
"8.6": ["fa2", "fa2_tc", "cudnn"],
"8.9": ["fa2", "fa2_tc", "cudnn"],
"9.0": ["fa2", "fa2_tc", "cudnn"],
"9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen-native"],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to remove "gen" from trtllm, trtllm-gen is the codegen framework designed specifically for sm_100 and sm_103, and for other backends we are not going through trtllm-gen.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with your sentiment for the trtllm-gen-native that call trtllm_batch_... prefill/decode/MLA APIs. For the trtllm kernel called through wrappers, I still would like to call them trtllm-gen since it is the actual backend for them as stated in our documentation. For example:

backend (str) – The implementation backend, could be auto/fa2 or trtllm-gen. Defaults to auto. If set to auto, the wrapper will automatically choose the backend based on the device architecture and kernel availability.

In the latest commit, I updated so that:

  1. The microbenchmark backend trtllm-gen-native is renamed to trtllm-native.
  2. Added a check that replaces user-provided trtllm-gen-native to trtllm-native and prints a warning message about future deprecation.
  3. Replaced pretty much all instances of trtllm-gen-native to trtllm-native

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then I suppose we should update the unified attention wrapper as well (in a future PR), thanks for spotting this issue!

Comment thread benchmarks/routines/flashinfer_benchmark_utils.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 23a0d31 and 9392306.

📒 Files selected for processing (3)
  • benchmarks/README.md (3 hunks)
  • benchmarks/routines/attention.py (13 hunks)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/routines/flashinfer_benchmark_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/routines/attention.py (2)

93-94: LGTM! Well-designed backward compatibility.

The approach of accepting both the canonical and deprecated backend names in the choices list (lines 93-94), followed by runtime normalization (lines 179-181), provides smooth backward compatibility while guiding users toward the new naming convention.

Also applies to: 179-181


217-217: LGTM! Consistent backend name usage.

All references to the TensorRT-LLM native backend have been consistently updated to use "trtllm-native" throughout the file, including docstrings, condition checks, and backend validation logic.

Also applies to: 522-522, 646-646, 729-735, 987-987, 1210-1224, 1436-1436, 1570-1570, 1666-1674, 1839-1839

benchmarks/README.md (1)

120-120: LGTM! Documentation properly updated.

The documentation has been consistently updated to reflect the canonical backend name trtllm-native across the general backends list (line 120), the routine support matrix (lines 220-223), and the backend legend (line 241). This aligns with the code changes and provides clear guidance to users.

Also applies to: 220-223, 241-241

Comment thread benchmarks/routines/attention.py
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit f566d49 into flashinfer-ai:main Nov 7, 2025
4 checks passed
@bkryu bkryu deleted the bench_sm90_120_xqa_decode branch November 7, 2025 23:45
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…-ai#2055)

<!-- .github/pull_request_template.md -->

## 📌 Description

In flashinfer-ai#2001 , XQA decode kernels became available through
`trtllm_batch_decode_with_kv_cache` on SM90 and SM120.

Current PR adds the ability to benchmark through the microbenchmark.

Example microbenchmark command and outputs before and after:
```
### Before current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping.
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping.
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec

### After current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec
[PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec
```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Chores**
* Standardized backend identifier to "trtllm-native" and expanded its
support across benchmark routines and utilities.
* Argument parsing now canonicalizes deprecated backend aliases and
emits a deprecation warning when encountered.
* **Documentation**
* README and tool-facing messages updated to use the canonical backend
name and include contextual notes about the change.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants