Skip to content

feat: Deterministic mode for filtered topK kernel#2759

Closed
Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Linda-Stadter:filtered_topk_deterministic
Closed

feat: Deterministic mode for filtered topK kernel#2759
Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Linda-Stadter:filtered_topk_deterministic

Conversation

@Linda-Stadter
Copy link
Copy Markdown
Contributor

@Linda-Stadter Linda-Stadter commented Mar 11, 2026

📌 Description

Adds deterministic=True parameter to flashinfer.top_k() for bitwise-reproducible output. Optional and defaults to False.

python
values, indices = flashinfer.top_k(logits, k, deterministic=True)

When deterministic=True:

  • Forces the FilteredTopK algorithm (k ≤ 2048, 128KB smem required). Raises RuntimeError if preconditions aren't met.
  • Boundary elements that equal the pivot value are collected via a deterministic ballot-scan instead of non-deterministic atomicAdd.
  • Post-kernel SortTopKByIndexKVKernel enforces deterministic order by sorting output indices ascending.
  • For sorted=True, an additional stable value sort produces (value desc, index asc) ordering. The sort through torch.sort in topk.py is then disabled
  • Further fixes uint32 overflow issue on large batches when batch_size * vocab_size > 2^32

Benchmark

B200

  batch_size=4096  vocab_size=200000  dtype=fp32

       |                --- Unsorted ---                 |                 --- Sorted ---
     k |     NonDet    GB/s        Det    GB/s  overhead  |     NonDet    GB/s        Det    GB/s  overhead
-----------------------------------------------------------------------------------------------------------
   128 |    1.838ms  1783.0    1.894ms  1730.3     +3.0%  |    1.885ms  1738.3    1.912ms  1713.6     +1.4%
   256 |    1.846ms  1774.9    1.913ms  1712.9     +3.6%  |    1.956ms  1674.8    1.940ms  1689.3     -0.9%
   512 |    1.974ms  1659.6    2.054ms  1595.7     +4.0%  |    2.096ms  1563.6    2.103ms  1558.4     +0.3%
  1024 |    1.994ms  1643.5    2.105ms  1556.8     +5.6%  |    2.145ms  1527.3    2.202ms  1488.0     +2.6%
  2048 |    2.379ms  1377.4    2.581ms  1269.6     +8.5%  |    2.658ms  1232.8    2.782ms  1178.0     +4.7%

Algorithm: filtered
  batch_size=8192  vocab_size=200000  dtype=fp32

       |                --- Unsorted ---                 |                 --- Sorted ---
     k |     NonDet    GB/s        Det    GB/s  overhead  |     NonDet    GB/s        Det    GB/s  overhead
-----------------------------------------------------------------------------------------------------------
   128 |    3.652ms  1794.8    3.765ms  1740.6     +3.1%  |    3.723ms  1760.2    3.797ms  1725.9     +2.0%
   256 |    3.622ms  1809.2    3.749ms  1748.0     +3.5%  |    3.813ms  1718.7    3.800ms  1724.4     -0.3%
   512 |    3.939ms  1663.6    4.088ms  1603.2     +3.8%  |    4.155ms  1577.1    4.181ms  1567.4     +0.6%
  1024 |    3.973ms  1649.7    4.195ms  1562.1     +5.6%  |    4.265ms  1536.5    4.382ms  1495.6     +2.7%
  2048 |    4.747ms  1380.5    5.159ms  1270.2     +8.7%  |    5.272ms  1243.2    5.556ms  1179.5     +5.4%

Algorithm: filtered
  batch_size=16384  vocab_size=200000  dtype=fp32

       |                --- Unsorted ---                 |                 --- Sorted ---
     k |     NonDet    GB/s        Det    GB/s  overhead  |     NonDet    GB/s        Det    GB/s  overhead
-----------------------------------------------------------------------------------------------------------
   128 |    7.250ms  1808.0    7.462ms  1756.6     +2.9%  |    7.366ms  1779.4    7.526ms  1741.7     +2.2%
   256 |    7.189ms  1823.3    7.430ms  1764.0     +3.4%  |    7.545ms  1737.2    7.524ms  1742.0     -0.3%
   512 |    7.806ms  1679.1    8.101ms  1617.9     +3.8%  |    8.232ms  1592.3    8.279ms  1583.1     +0.6%
  1024 |    7.879ms  1663.7    8.332ms  1573.1     +5.8%  |    8.432ms  1554.5    8.695ms  1507.4     +3.1%
  2048 |    9.416ms  1392.0   10.180ms  1287.5     +8.1%  |   10.421ms  1257.8   10.965ms  1195.4     +5.2%
Benchmark code

import numpy as np
import torch

import flashinfer
from flashinfer.testing.utils import bench_gpu_time

BATCH_SIZES = [4096, 8192, 16384]
VOCAB_SIZE = 200000
K_VALUES = [128, 256, 512, 1024, 2048]
ELEM_SIZE = 4  # fp32

ALGOS = ["filtered"]


def bench_one(fn):
    times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100)
    return float(np.median(times))


def input_gbps(time_ms, batch_size):
    return (batch_size * VOCAB_SIZE * ELEM_SIZE) / (time_ms * 1e-3) / 1e9


def fmt_overhead(test_ms, base_ms):
    if base_ms == 0:
        return "    N/A"
    pct = (test_ms / base_ms - 1) * 100
    return f"{pct:>+7.1f}%"


def print_row(
    k, nd_ms, nd_gbps, d_ms, d_gbps, oh, nd_s_ms, nd_s_gbps, d_s_ms, d_s_gbps, oh_s
):
    print(
        f"  {k:>4} |"
        f"  {nd_ms:>7.3f}ms {nd_gbps:>7.1f}"
        f"  {d_ms:>7.3f}ms {d_gbps:>7.1f}"
        f"  {oh:>8}"
        f"  |"
        f"  {nd_s_ms:>7.3f}ms {nd_s_gbps:>7.1f}"
        f"  {d_s_ms:>7.3f}ms {d_s_gbps:>7.1f}"
        f"  {oh_s:>8}"
    )


@torch.inference_mode()
def main():
    dtype = torch.float32

    for algo in ALGOS:
        if algo == "auto":
            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
        else:
            os.environ["FLASHINFER_TOPK_ALGO"] = algo

        for batch_size in BATCH_SIZES:
            scores = torch.randn(batch_size, VOCAB_SIZE, device="cuda", dtype=dtype)

            print(f"Algorithm: {algo}")
            print(f"  batch_size={batch_size}  vocab_size={VOCAB_SIZE}  dtype=fp32")
            print()

            half = 48
            print(f"       |{'--- Unsorted ---':^{half}} |{'--- Sorted ---':^{half}}")
            print(
                f"     k |"
                f"  {'NonDet':>9} {'GB/s':>7}"
                f"  {'Det':>9} {'GB/s':>7}"
                f"  {'overhead':>8}"
                f"  |"
                f"  {'NonDet':>9} {'GB/s':>7}"
                f"  {'Det':>9} {'GB/s':>7}"
                f"  {'overhead':>8}"
            )
            print("-" * (7 + 1 + half + 2 + 1 + half))

            for k in K_VALUES:
                nd_unsorted = bench_one(
                    lambda k=k, s=scores: flashinfer.top_k(s, k, sorted=False)
                )
                nd_sorted = bench_one(
                    lambda k=k, s=scores: flashinfer.top_k(s, k, sorted=True)
                )

                d_unsorted = bench_one(
                    lambda k=k, s=scores: flashinfer.top_k(
                        s, k, sorted=False, deterministic=True
                    )
                )
                d_sorted = bench_one(
                    lambda k=k, s=scores: flashinfer.top_k(
                        s, k, sorted=True, deterministic=True
                    )
                )

                print_row(
                    k,
                    nd_unsorted,
                    input_gbps(nd_unsorted, batch_size),
                    d_unsorted,
                    input_gbps(d_unsorted, batch_size),
                    fmt_overhead(d_unsorted, nd_unsorted),
                    nd_sorted,
                    input_gbps(nd_sorted, batch_size),
                    d_sorted,
                    input_gbps(d_sorted, batch_size),
                    fmt_overhead(d_sorted, nd_sorted),
                )

            del scores
            print()

    os.environ.pop("FLASHINFER_TOPK_ALGO", None)


if __name__ == "__main__":
    main()

Limitations

  • Only supported for FilteredTopK (k ≤ 2048). Not available for the multi-CTA radix path.
  • Forces FilteredTopK even when the heuristic would prefer multi-CTA, which may be slower for some configurations
  • Not optimized for bf16

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added deterministic mode for reproducible, bit-identical top-k results across runs.
    • Added sorted output option to control ordering of top-k selection results.
  • Tests

    • Extensive test coverage added for deterministic reproducibility and sorted output functionality.
    • Tests validating edge cases, tie-breaking behavior, and cross-algorithm correctness.

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

This PR extends the top-k implementation across Python bindings, CUDA source, and kernels to support deterministic and sorted-output modes. New parameters sorted_output and deterministic are propagated through the API stack, new CUDA kernels handle post-processing sorting, and deterministic collection paths are introduced to enable reproducible results.

Changes

Cohort / File(s) Summary
CUDA Bindings & Source
csrc/flashinfer_topk_binding.cu, csrc/topk.cu
Extended radix_topk function signatures to accept two new int64_t parameters (sorted_output, deterministic), propagating them to underlying TopKDispatch calls with boolean conversion.
Python Wrapper
flashinfer/topk.py
Added sorted_output and deterministic parameters to radix_topk and _fake_radix_topk functions. Updated top_k public API to accept deterministic parameter and conditionally apply in-kernel or post-processing sorting based on combined flags. Expanded documentation for deterministic mode requirements and constraints.
CUDA Kernel Headers
include/flashinfer/topk.cuh
Introduced SortTopKByIndexKVKernel and StableSortTopKByValueKernel for post-processing sort operations. Updated TopKDispatch and FilteredTopK signatures with sorted_output and deterministic parameters. Extended kernel dispatch logic with deterministic collection paths, per-bit pivot refinement, and size_t indexing safety improvements across radix and filtered top-k implementations.
Test Coverage
tests/utils/test_topk.py
Added comprehensive test suite validating deterministic reproducibility across FilteredTopK variants, correctness of sorted/deterministic modes, edge cases (k equals vocab_size), cross-algorithm consistency, transform operations (page-table, ragged), and SGLang-style reference comparisons.

Sequence Diagram(s)

sequenceDiagram
    participant App as Python Application
    participant TopK as top_k Function
    participant Dispatch as TopKDispatch<br/>(CUDA Kernel)
    participant MainKernel as RadixTopK Kernel
    participant SortKernel as Sort Kernel<br/>(Value or Index)
    participant GPU as GPU Memory

    App->>TopK: top_k(input, k, sorted=True, deterministic=True)
    activate TopK
    TopK->>TopK: Compute sorted_cuda flag<br/>(sorted AND deterministic)
    TopK->>Dispatch: TopKDispatch(..., sorted_output,<br/>deterministic, stream)
    activate Dispatch
    
    Dispatch->>MainKernel: Launch RadixTopK kernel
    activate MainKernel
    MainKernel->>GPU: Write output_indices,<br/>output_values
    deactivate MainKernel
    
    alt sorted_cuda is True (kernel-internal sort)
        Dispatch->>GPU: Skip post-processing
    else sorted_cuda is False (post-processing sort)
        Dispatch->>SortKernel: Launch StableSortTopKByValueKernel
        activate SortKernel
        SortKernel->>GPU: Reorder output by values
        deactivate SortKernel
    end
    
    deactivate Dispatch
    TopK->>App: Return sorted top-k values & indices
    deactivate TopK
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

Possibly related PRs

  • PR #2605: Modifies top-k implementation in include/flashinfer/topk.cuh with deterministic/sorted-output enhancements and FilteredTopK overflow fixes, directly overlapping with this PR's kernel changes.
  • PR #2215: Updates TopK/CUDA API signatures (radix_topk, TopKDispatch) and Python bindings in flashinfer/topk.py, sharing the same binding/export layer as this PR.
  • PR #2119: Extends topk codepaths and RadixTopK multi-CTA machinery in include/flashinfer/topk.cuh, related to the deterministic and sorted-output kernel dispatch refactors in this PR.

Suggested labels

v0.6.1

Suggested reviewers

  • jiahanc
  • kahyunnam
  • IwakuraRein
  • yzh119
  • nv-yunzheq
  • djmmoss

Poem

🐰 A fluffy topk tale:

With bits of radix and sorted arrays bright,
Deterministic spirits dance through the night,
New kernels sort values with careful, stable grace,
While indices waltz into their rightful place! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely summarizes the main feature being added: deterministic mode for the filtered topK kernel.
Description check ✅ Passed PR description comprehensively covers changes, includes benchmarks, limitations, and confirms pre-commit and testing checklist completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan for PR comments
  • Generate coding plan

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a deterministic mode for the filtered top-K kernel, addressing the need for bitwise-reproducible results in top-K selection operations. The implementation involves modifying existing CUDA kernels and adding new sorting kernels to ensure consistent output, especially when dealing with ties. This enhancement provides greater reliability and predictability for applications requiring strict reproducibility.

Highlights

  • Deterministic Top-K Mode: Introduced a new deterministic boolean parameter to the flashinfer.top_k function, allowing users to guarantee bitwise-reproducible output across calls with the same input. This mode forces the use of the FilteredTopK algorithm.
  • CUDA Kernel Enhancements: Modified the FilteredTopKUnifiedKernel to support deterministic tie-breaking logic. New CUDA kernels, SortTopKByIndexKVKernel and StableSortTopKByValueKernel, were added to perform post-processing sorts for deterministic output, ensuring consistent ordering of indices and values.
  • API and Dispatch Logic Updates: Updated the C++ radix_topk and TopKDispatch functions, along with their Python bindings, to propagate the deterministic and sorted_output flags. The TopKDispatch now explicitly checks for deterministic mode and applies the FilteredTopK algorithm followed by CUB-based sorting kernels.
  • Comprehensive Testing: Added extensive test cases to tests/utils/test_topk.py to verify the reproducibility and correctness of the deterministic FilteredTopK implementation, including scenarios with heavy ties and various data types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • No specific activity (comments, reviews, progress) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a deterministic mode for the filtered top-k kernel, which is a valuable feature for reproducibility. The implementation is well-structured, using if constexpr for compile-time dispatch between deterministic and non-deterministic paths, and adding static_cast<size_t> for safer pointer arithmetic. The changes also include refactoring to reduce code duplication in kernel launch sites. The accompanying Python bindings and extensive test suite are well-written. I have two suggestions for the CUDA implementation in include/flashinfer/topk.cuh: one to fix a correctness issue in the deterministic path that could lead to non-deterministic output in an edge case, and another to further improve maintainability by refactoring duplicated code.

Comment thread include/flashinfer/topk.cuh Outdated
Comment thread include/flashinfer/topk.cuh
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)

2872-2924: Verify RuntimeError behavior for unsatisfied preconditions.

The deterministic path validates preconditions using FLASHINFER_CHECK macros, which is good. However, the error messages could be clearer about what the user should do.

💡 Consider more actionable error messages
     FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K,
-                     "deterministic=True requires k <=", FILTERED_TOPK_MAX_K,
-                     "but got k =", top_k_val);
+                     "deterministic=True requires k <= ", FILTERED_TOPK_MAX_K,
+                     ", but got k = ", top_k_val,
+                     ". Use deterministic=False for larger k values.");
     FLASHINFER_CHECK(CanImplementFilteredTopK(),
-                     "deterministic=True requires GPU support for 128KB shared memory");
+                     "deterministic=True requires GPU support for 128KB shared memory. "
+                     "Current GPU does not meet this requirement.");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2872 - 2924, The checks in the
deterministic branch use FLASHINFER_CHECK but the messages are terse; update the
two checks (FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K, ...) and
FLASHINFER_CHECK(CanImplementFilteredTopK(), ...)) to provide actionable
guidance: include the invalid value (top_k_val), the allowed limit
(FILTERED_TOPK_MAX_K) and concrete remediation steps (e.g., "use
non-deterministic mode, reduce k, or compile with larger shared memory
support"), and for CanImplementFilteredTopK() include how to enable GPU support
(e.g., required GPU compute capability or driver/compile flags); keep the checks
at the same call sites so behavior unchanged but messages are more helpful to
users.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_topk.py`:
- Line 1641: The unpacked variable ref_values from the call ref_values,
ref_indices = torch.topk(logits, k, dim=-1, sorted=True) is unused; remove it by
only capturing the indices (e.g., assign ref_indices = torch.topk(logits, k,
dim=-1, sorted=True)[1] or use _ to ignore the first return like _, ref_indices
= torch.topk(...)) so only ref_indices is kept and the unused variable is
eliminated.

---

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2872-2924: The checks in the deterministic branch use
FLASHINFER_CHECK but the messages are terse; update the two checks
(FLASHINFER_CHECK(top_k_val <= FILTERED_TOPK_MAX_K, ...) and
FLASHINFER_CHECK(CanImplementFilteredTopK(), ...)) to provide actionable
guidance: include the invalid value (top_k_val), the allowed limit
(FILTERED_TOPK_MAX_K) and concrete remediation steps (e.g., "use
non-deterministic mode, reduce k, or compile with larger shared memory
support"), and for CanImplementFilteredTopK() include how to enable GPU support
(e.g., required GPU compute capability or driver/compile flags); keep the checks
at the same call sites so behavior unchanged but messages are more helpful to
users.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f36f1d97-5749-466c-9acf-bfce82f8b179

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and 4563b71.

📒 Files selected for processing (5)
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

Comment thread tests/utils/test_topk.py
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 14, 2026

Hi @Linda-Stadter is it ready for review?

btw, there is another PR from the community which seems to be working on the same thing: #2661, I haven't checked carefully but would you mind also taking a look?

@Linda-Stadter
Copy link
Copy Markdown
Contributor Author

This PR is now part of #2661

aleozlx pushed a commit that referenced this pull request Apr 1, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

> Part of the FilteredTopK implementation refers to or is adapted from
@Linda-Stadter's work in #2759

### Deterministic Mode for Top-K Kernels

#### FilteredTopK Kernel

FilteredTopKKernel implements deterministic mode as follows:

1. Build a coarse histogram.

- Build a coarse histogram on the top 8 bits to locate the coarse
threshold bin that contains the k-th largest element.
- Same as non-deterministic mode, elements with bin > threshold_bin are
appended to s_indices via **atomicAdd** (see
`collect_gt_and_nondet_eq_threshold`); their final order is determined
by the post-sort kernel.

2. Refine with 8-bit radix passes.

- Run multiple 8-bit refine passes to find the exact pivot.
- Deterministic == pivot selection is performed by
`collect_det_eq_pivot`, which writes the selected tie elements into
`s_indices` in deterministic **thread-strided order**.

> **Thread-strided order** means, for example, if `BLOCK_THREADS = 4`,
then the logical scan order is:
>
> - thread 0: `0, 4, 8, ...`
> - thread 1: `1, 5, 9, ...`
> - thread 2: `2, 6, 10, ...`
> - thread 3: `3, 7, 11, ...`
>
> If the `== pivot` positions are:
> - thread 0: `0, 8`
> - thread 1: `5`
> - thread 2: none
> - thread 3: `3, 7`
>
>  then the deterministic collection order is: [0, 8, 5, 3, 7].
> That is, we order elements first by thread ID, and then by each
thread's strided traversal order.

3. Post-sort kernels.

- After FilteredTopKKernel finishes, `SortTopKByIndexKernel` is applied
to produce index-ascending output and make the final ordering
deterministic (we use atomicAdd to collect > pivot at stage 1).
- If the Python API is called with sorted=True,
`StableSortTopKByValueKernel` is applied afterward to produce
value-descending output.

#### RadixTopK Kernel


1. RadixSelectFindPivot

- Finds `ordered_pivot`, which Stage 2 uses to determine whether an
element is >= `ordered_pivot`.
- Computes `cta_local_eq_count` and `cta_local_gt_count`, which Stage 2
uses to **determine** how many elements the current CTA may emit and
where each emitted element should be placed.

2. collect_indices (`RadixCollectIndicesDeterministic`)

RadixCollectIndicesDeterministic: after the pivot is known, assigns each
CTA a fixed output range, then writes all > pivot elements followed by
the required == pivot elements in a deterministic order.

Order definition:

- Emit > pivot elements first, then == pivot elements.
- For each category, earlier CTAs write to earlier output positions.
- Within each CTA, emit elements in thread-strided order.

### Benchmarks

machine: NVIDIA A100-PCIE-40GB

command: (fp32/fp16/bf16)
```bash
python -u benchmarks/bench_topk.py \
  --op all \
  --dtype fp32 \
  --deterministic \
  --compare-torch-deterministic \
  --input-pattern random
```

raw results:


[output.txt](https://github.com/user-attachments/files/26337712/output-group2-current-v1_5-20260330-101411.txt)
**Summary**


| dtype | geomean det slowdown vs non-det | geomean speedup vs torch.det
|
  | --- | ---: | ---: |
  | fp32 | 1.0992x | 1.7660x |
  | fp16 | 1.0777x | 1.3381x |
  | bf16 | 1.0745x | 1.3055x |


NOTE: FlashInfer deterministic **underperforms** PyTorch mainly on
short-sequence workloads. Importantly, this is not unique to the
deterministic path: FlashInfer non-deterministic top-k is also slower
than PyTorch in the same short-sequence regime. This suggests the gap is
primarily a short-sequence top-k issue rather than a
deterministic-specific regression. Optimizing short-sequence top-k, for
both non-deterministic and deterministic modes, is better treated as
future work.



## 🔍 Related Issues

close: #2584


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

```
unittest I ran:
test_topk.py
test_sampling.py
test_logits_processor.py
```
## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Deterministic mode for top‑k and fused transforms (stable, repeatable
tie ordering) with API flag to enable deterministic outputs and stable
sorting behavior.

* **Benchmarks**
* Expanded benchmarking to compare deterministic vs nondeterministic
runs, pre-generated input patterns, DSA workload cases, and richer CLI
output.

* **Tests**
* Large suite of determinism and correctness tests (ties, multi‑CTA,
streams, sorted behavior, cache transitions).

* **Bug Fixes**
  * Improved runtime-error labeling and benchmark cache handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants