Deterministic Mode: Add 1-stage triton kernel for prefill#11147
Deterministic Mode: Add 1-stage triton kernel for prefill#11147
Conversation
Summary of ChangesHello @hebiao064, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant architectural change to improve the determinism of prefill operations within the attention mechanism. It integrates a new, unified 1-stage Triton kernel that efficiently handles both prefix and extended Key-Value (KV) tokens. This modification is crucial for ensuring consistent model outputs, especially in scenarios requiring high reproducibility. The changes also include updates to the attention wrapper to leverage this new kernel under deterministic settings and minor adjustments to testing for better consistency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new 1-stage Triton kernel for prefill in deterministic mode, aiming to improve determinism. The core changes involve adding the new _fwd_kernel_unified Triton kernel and the corresponding Python logic to prepare data and launch it. While this is a good step, I've identified a critical correctness issue with the sliding window attention logic, a missing feature for sink token handling, and a significant performance bottleneck in the data preparation loop. Addressing these issues will be crucial for the correctness and performance of this new feature.
|
@hebiao064 Have you tested accuracy for this kernel? |
see pr description |
| final_mask = mask_m[:, None] & mask_n[None, :] | ||
|
|
||
| # Apply causal mask for extend part | ||
| if IS_CAUSAL: |
There was a problem hiding this comment.
custom mask for speculative decoding seems not considered?
There was a problem hiding this comment.
yes you are right @zminglei supported it in latest commit
Motivation
Part of this Issue: #10278
Related PR: #10721
Inspired from @ispobock
Co-authored with @zminglei @byjiang1996
Before this PR:
After this PR:
Modifications
Accuracy Tests
Deterministic Test:
GSM
MMLU
Spec Decoding
Benchmarking and Profiling
Checklist