feat: Deterministic mode for filtered topK kernel#2759
feat: Deterministic mode for filtered topK kernel#2759Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR extends the top-k implementation across Python bindings, CUDA source, and kernels to support deterministic and sorted-output modes. New parameters Changes
Sequence Diagram(s)sequenceDiagram
participant App as Python Application
participant TopK as top_k Function
participant Dispatch as TopKDispatch<br/>(CUDA Kernel)
participant MainKernel as RadixTopK Kernel
participant SortKernel as Sort Kernel<br/>(Value or Index)
participant GPU as GPU Memory
App->>TopK: top_k(input, k, sorted=True, deterministic=True)
activate TopK
TopK->>TopK: Compute sorted_cuda flag<br/>(sorted AND deterministic)
TopK->>Dispatch: TopKDispatch(..., sorted_output,<br/>deterministic, stream)
activate Dispatch
Dispatch->>MainKernel: Launch RadixTopK kernel
activate MainKernel
MainKernel->>GPU: Write output_indices,<br/>output_values
deactivate MainKernel
alt sorted_cuda is True (kernel-internal sort)
Dispatch->>GPU: Skip post-processing
else sorted_cuda is False (post-processing sort)
Dispatch->>SortKernel: Launch StableSortTopKByValueKernel
activate SortKernel
SortKernel->>GPU: Reorder output by values
deactivate SortKernel
end
deactivate Dispatch
TopK->>App: Return sorted top-k values & indices
deactivate TopK
Estimated code review effort🎯 4 (Complex) | ⏱️ ~65 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan for PR comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a deterministic mode for the filtered top-K kernel, addressing the need for bitwise-reproducible results in top-K selection operations. The implementation involves modifying existing CUDA kernels and adding new sorting kernels to ensure consistent output, especially when dealing with ties. This enhancement provides greater reliability and predictability for applications requiring strict reproducibility. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a deterministic mode for the filtered top-k kernel, which is a valuable feature for reproducibility. The implementation is well-structured, using if constexpr for compile-time dispatch between deterministic and non-deterministic paths, and adding static_cast<size_t> for safer pointer arithmetic. The changes also include refactoring to reduce code duplication in kernel launch sites. The accompanying Python bindings and extensive test suite are well-written. I have two suggestions for the CUDA implementation in include/flashinfer/topk.cuh: one to fix a correctness issue in the deterministic path that could lead to non-deterministic output in an edge case, and another to further improve maintainability by refactoring duplicated code.
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)
2872-2924: Verify RuntimeError behavior for unsatisfied preconditions.The deterministic path validates preconditions using
FLASHINFER_CHECKmacros, which is good. However, the error messages could be clearer about what the user should do.💡 Consider more actionable error messages
FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K, - "deterministic=True requires k <=", FILTERED_TOPK_MAX_K, - "but got k =", top_k_val); + "deterministic=True requires k <= ", FILTERED_TOPK_MAX_K, + ", but got k = ", top_k_val, + ". Use deterministic=False for larger k values."); FLASHINFER_CHECK(CanImplementFilteredTopK(), - "deterministic=True requires GPU support for 128KB shared memory"); + "deterministic=True requires GPU support for 128KB shared memory. " + "Current GPU does not meet this requirement.");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2872 - 2924, The checks in the deterministic branch use FLASHINFER_CHECK but the messages are terse; update the two checks (FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K, ...) and FLASHINFER_CHECK(CanImplementFilteredTopK(), ...)) to provide actionable guidance: include the invalid value (top_k_val), the allowed limit (FILTERED_TOPK_MAX_K) and concrete remediation steps (e.g., "use non-deterministic mode, reduce k, or compile with larger shared memory support"), and for CanImplementFilteredTopK() include how to enable GPU support (e.g., required GPU compute capability or driver/compile flags); keep the checks at the same call sites so behavior unchanged but messages are more helpful to users.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_topk.py`:
- Line 1641: The unpacked variable ref_values from the call ref_values,
ref_indices = torch.topk(logits, k, dim=-1, sorted=True) is unused; remove it by
only capturing the indices (e.g., assign ref_indices = torch.topk(logits, k,
dim=-1, sorted=True)[1] or use _ to ignore the first return like _, ref_indices
= torch.topk(...)) so only ref_indices is kept and the unused variable is
eliminated.
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2872-2924: The checks in the deterministic branch use
FLASHINFER_CHECK but the messages are terse; update the two checks
(FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K, ...) and
FLASHINFER_CHECK(CanImplementFilteredTopK(), ...)) to provide actionable
guidance: include the invalid value (top_k_val), the allowed limit
(FILTERED_TOPK_MAX_K) and concrete remediation steps (e.g., "use
non-deterministic mode, reduce k, or compile with larger shared memory
support"), and for CanImplementFilteredTopK() include how to enable GPU support
(e.g., required GPU compute capability or driver/compile flags); keep the checks
at the same call sites so behavior unchanged but messages are more helpful to
users.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f36f1d97-5749-466c-9acf-bfce82f8b179
📒 Files selected for processing (5)
csrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
Hi @Linda-Stadter is it ready for review? btw, there is another PR from the community which seems to be working on the same thing: #2661, I haven't checked carefully but would you mind also taking a look? |
|
This PR is now part of #2661 |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> > Part of the FilteredTopK implementation refers to or is adapted from @Linda-Stadter's work in #2759 ### Deterministic Mode for Top-K Kernels #### FilteredTopK Kernel FilteredTopKKernel implements deterministic mode as follows: 1. Build a coarse histogram. - Build a coarse histogram on the top 8 bits to locate the coarse threshold bin that contains the k-th largest element. - Same as non-deterministic mode, elements with bin > threshold_bin are appended to s_indices via **atomicAdd** (see `collect_gt_and_nondet_eq_threshold`); their final order is determined by the post-sort kernel. 2. Refine with 8-bit radix passes. - Run multiple 8-bit refine passes to find the exact pivot. - Deterministic == pivot selection is performed by `collect_det_eq_pivot`, which writes the selected tie elements into `s_indices` in deterministic **thread-strided order**. > **Thread-strided order** means, for example, if `BLOCK_THREADS = 4`, then the logical scan order is: > > - thread 0: `0, 4, 8, ...` > - thread 1: `1, 5, 9, ...` > - thread 2: `2, 6, 10, ...` > - thread 3: `3, 7, 11, ...` > > If the `== pivot` positions are: > - thread 0: `0, 8` > - thread 1: `5` > - thread 2: none > - thread 3: `3, 7` > > then the deterministic collection order is: [0, 8, 5, 3, 7]. > That is, we order elements first by thread ID, and then by each thread's strided traversal order. 3. Post-sort kernels. - After FilteredTopKKernel finishes, `SortTopKByIndexKernel` is applied to produce index-ascending output and make the final ordering deterministic (we use atomicAdd to collect > pivot at stage 1). - If the Python API is called with sorted=True, `StableSortTopKByValueKernel` is applied afterward to produce value-descending output. #### RadixTopK Kernel 1. RadixSelectFindPivot - Finds `ordered_pivot`, which Stage 2 uses to determine whether an element is >= `ordered_pivot`. - Computes `cta_local_eq_count` and `cta_local_gt_count`, which Stage 2 uses to **determine** how many elements the current CTA may emit and where each emitted element should be placed. 2. collect_indices (`RadixCollectIndicesDeterministic`) RadixCollectIndicesDeterministic: after the pivot is known, assigns each CTA a fixed output range, then writes all > pivot elements followed by the required == pivot elements in a deterministic order. Order definition: - Emit > pivot elements first, then == pivot elements. - For each category, earlier CTAs write to earlier output positions. - Within each CTA, emit elements in thread-strided order. ### Benchmarks machine: NVIDIA A100-PCIE-40GB command: (fp32/fp16/bf16) ```bash python -u benchmarks/bench_topk.py \ --op all \ --dtype fp32 \ --deterministic \ --compare-torch-deterministic \ --input-pattern random ``` raw results: [output.txt](https://github.com/user-attachments/files/26337712/output-group2-current-v1_5-20260330-101411.txt) **Summary** | dtype | geomean det slowdown vs non-det | geomean speedup vs torch.det | | --- | ---: | ---: | | fp32 | 1.0992x | 1.7660x | | fp16 | 1.0777x | 1.3381x | | bf16 | 1.0745x | 1.3055x | NOTE: FlashInfer deterministic **underperforms** PyTorch mainly on short-sequence workloads. Importantly, this is not unique to the deterministic path: FlashInfer non-deterministic top-k is also slower than PyTorch in the same short-sequence regime. This suggests the gap is primarily a short-sequence top-k issue rather than a deterministic-specific regression. Optimizing short-sequence top-k, for both non-deterministic and deterministic modes, is better treated as future work. ## 🔍 Related Issues close: #2584 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ``` unittest I ran: test_topk.py test_sampling.py test_logits_processor.py ``` ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Deterministic mode for top‑k and fused transforms (stable, repeatable tie ordering) with API flag to enable deterministic outputs and stable sorting behavior. * **Benchmarks** * Expanded benchmarking to compare deterministic vs nondeterministic runs, pre-generated input patterns, DSA workload cases, and richer CLI output. * **Tests** * Large suite of determinism and correctness tests (ties, multi‑CTA, streams, sorted behavior, cache transitions). * **Bug Fixes** * Improved runtime-error labeling and benchmark cache handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com> Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
📌 Description
Adds
deterministic=Trueparameter toflashinfer.top_k()for bitwise-reproducible output. Optional and defaults toFalse.python
values, indices = flashinfer.top_k(logits, k, deterministic=True)When
deterministic=True:RuntimeErrorif preconditions aren't met.atomicAdd.SortTopKByIndexKVKernelenforces deterministic order by sorting output indices ascending.sorted=True, an additional stable value sort produces (value desc, index asc) ordering. The sort throughtorch.sortintopk.pyis then disabledbatch_size * vocab_size > 2^32Benchmark
B200
Benchmark code
Limitations
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Tests