Skip to content

Deterministic Mode: Add 1-stage triton kernel for prefill#11147

Merged
ispobock merged 13 commits intomainfrom
bhe/1_stage_triton_kernel
Oct 19, 2025
Merged

Deterministic Mode: Add 1-stage triton kernel for prefill#11147
ispobock merged 13 commits intomainfrom
bhe/1_stage_triton_kernel

Conversation

@hebiao064
Copy link
Copy Markdown
Collaborator

@hebiao064 hebiao064 commented Oct 1, 2025

Motivation

Part of this Issue: #10278
Related PR: #10721

Inspired from @ispobock

Co-authored with @zminglei @byjiang1996

Before this PR:

              K (Keys, total 9 tokens)
        ╔═══════════════════╦═════════════════╗
        ║   Prefix (0-4)    ║  Extend (5-8)   ║
        ║   [from cache]    ║  [new tokens]   ║
    ════╬═══════════════════╬═════════════════╣
    Q 5 ║  ✓   ✓   ✓   ✓   ✓║  ✓   ✗   ✗   ✗ ║
      6 ║  ✓   ✓   ✓   ✓   ✓║  ✓   ✓   ✗   ✗ ║
      7 ║  ✓   ✓   ✓   ✓   ✓║  ✓   ✓   ✓   ✗ ║
      8 ║  ✓   ✓   ✓   ✓   ✓║  ✓   ✓   ✓   ✓ ║
    ════╩═══════════════════╩═════════════════╝
         ╚═════Stage 1═════╝ ╚═══Stage 2════╝


 # prefix length 4097: 16,
Prompt 0 with prefix length 1: total samples: 306, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 339, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 320, Unique samples: 5
Prompt 3 with prefix length 4097: total samples: 310, Unique samples: 1

100%|█████████████████████████████████████████████████████| 1319/1319 [01:44<00:00, 12.68it/s]
Accuracy: 0.911
Invalid: 0.000
Latency: 104.082 s
Output throughput: 1555.579 token/s

After this PR:


                    Keys (all 8 tokens)
        ┌──────────────────────────────────────┐
        │    Prefix (0-4)    │  Extend (5-7)   │
    ────┼────────────────────┼─────────────────┤
    Q   │                                      │
    u   │   Unified Stage           │
    e 5 │  1  1  1  1  1     │  1  0  0        │
    r   │                    │                 │
    i 6 │  1  1  1  1  1     │  1  1  0        │
    e   │                    │                 │
    s 7 │  1  1  1  1  1     │  1  1  1        │
    ────┴────────────────────┴─────────────────┘

 # prefix length 4097: 15,
Prompt 0 with prefix length 1: total samples: 348, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 283, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 326, Unique samples: 1
Prompt 3 with prefix length 4097: total samples: 318, Unique samples: 1

Accuracy: 0.909
Invalid: 0.000
Latency: 111.468 s
Output throughput: 1446.923 token/s

Modifications

Accuracy Tests

Deterministic Test:

Prompt 0 with prefix length 1: total samples: 84, Unique samples: 1
Prompt 1 with prefix length 8000: total samples: 72, Unique samples: 1
Prompt 2 with prefix length 10000: total samples: 95, Unique samples: 1
Prompt 3 with prefix length 12500: total samples: 74, Unique samples: 1

GSM

100%|███████████████████████████████████████████████████████████████████████| 200/200 [00:25<00:00,  7.71it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 26.031 s
Output throughput: 915.098 token/s

MMLU

subject: abstract_algebra, #q:100, acc: 0.560
subject: anatomy, #q:135, acc: 0.704
subject: astronomy, #q:152, acc: 0.895
subject: business_ethics, #q:100, acc: 0.770
subject: clinical_knowledge, #q:265, acc: 0.785
subject: college_biology, #q:144, acc: 0.875
subject: college_chemistry, #q:100, acc: 0.540
subject: college_computer_science, #q:100, acc: 0.740
subject: college_mathematics, #q:100, acc: 0.570
subject: college_medicine, #q:173, acc: 0.815
Total latency: 31.060
Average accuracy: 0.748

Spec Decoding

python3 -m sglang.launch_server --model /shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693  --speculative-algorithm EAGLE3 --speculative-draft-model-path /shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B/e5ed08d66f528a95ce89f5d4fd136a28f6def714 \
--speculative-num-steps 3         --speculative-eagle-topk 1 --speculative-num-draft-tokens 4  --trust-remote-code --dtype float16 --enable-torch-compile --attention-backend triton --cuda-graph-max-bs 2 --enable-deterministic-inference

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:18<00:00, 10.70it/s]
Accuracy: 0.770
Invalid: 0.000
Latency: 18.762 s
Output throughput: 909.155 token/s

python3 -m sglang.test.test_deterministic --test-mode prefix --n-trials 50
Prompt 0 with prefix length 1: total samples: 286, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 333, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 311, Unique samples: 1
Prompt 3 with prefix length 4097: total samples: 345, Unique samples: 1

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @hebiao064, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant architectural change to improve the determinism of prefill operations within the attention mechanism. It integrates a new, unified 1-stage Triton kernel that efficiently handles both prefix and extended Key-Value (KV) tokens. This modification is crucial for ensuring consistent model outputs, especially in scenarios requiring high reproducibility. The changes also include updates to the attention wrapper to leverage this new kernel under deterministic settings and minor adjustments to testing for better consistency.

Highlights

  • New Unified Triton Kernel: Introduced a new 1-stage Triton kernel (_fwd_kernel_unified) and its Python wrapper (extend_attention_fwd_unified) for prefill operations in deterministic mode. This kernel processes both prefix and extend KV tokens in a single pass.
  • Deterministic Mode Integration: The AttentionWrapper now conditionally uses the new unified kernel via a new _forward_extend_unified method when enable_deterministic is active, streamlining the prefill attention logic.
  • Improved Determinism: The changes aim to enhance the determinism of prefill operations, as demonstrated by the 'After this PR' results showing a reduction in 'Unique samples' for various prefix lengths, indicating more consistent output.
  • Radix Cache Compatibility Note: A line disabling radix cache for non-FA3 attention backends in deterministic mode was commented out, potentially indicating ongoing work or a change in compatibility strategy.
  • Test Consistency: A default sampling_seed of 42 has been added to the deterministic tests to ensure consistent and reproducible test runs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 1-stage Triton kernel for prefill in deterministic mode, aiming to improve determinism. The core changes involve adding the new _fwd_kernel_unified Triton kernel and the corresponding Python logic to prepare data and launch it. While this is a good step, I've identified a critical correctness issue with the sliding window attention logic, a missing feature for sink token handling, and a significant performance bottleneck in the data preparation loop. Addressing these issues will be crucial for the correctness and performance of this new feature.

Comment thread python/sglang/srt/layers/attention/triton_ops/extend_attention.py
Comment thread python/sglang/srt/layers/attention/triton_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/triton_ops/extend_attention.py
@hebiao064 hebiao064 changed the title [Not Ready for Review] Deterministic Mode: Add 1-stage triton kernel for prefill Deterministic Mode: Add 1-stage triton kernel for prefill Oct 13, 2025
@Fridge003 Fridge003 self-assigned this Oct 15, 2025
@Fridge003
Copy link
Copy Markdown
Collaborator

@hebiao064 Have you tested accuracy for this kernel?

@hebiao064
Copy link
Copy Markdown
Collaborator Author

Accuracy Tests

see pr description

Comment thread python/sglang/srt/layers/attention/triton_ops/extend_attention.py Outdated
Comment thread python/sglang/srt/layers/attention/triton_ops/extend_attention.py
final_mask = mask_m[:, None] & mask_n[None, :]

# Apply causal mask for extend part
if IS_CAUSAL:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom mask for speculative decoding seems not considered?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you are right @zminglei supported it in latest commit

Copy link
Copy Markdown
Collaborator

@ispobock ispobock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ispobock ispobock merged commit 4fff1ec into main Oct 19, 2025
132 of 140 checks passed
@ispobock ispobock deleted the bhe/1_stage_triton_kernel branch October 19, 2025 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants