[Deterministic] Blockscale FP8 kernel#11491
[Deterministic] Blockscale FP8 kernel#11491b8zhong wants to merge 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses the need for deterministic inference in Blockscale FP8 kernels, particularly for R1/V3 hardware. It achieves this by making targeted modifications to the underlying CUTLASS GEMM operations, ensuring that computations yield consistent results by controlling scheduler behavior and reduction modes. The changes prevent sources of non-determinism, such as the use of StreamK and non-deterministic reduction strategies, when deterministic inference is explicitly requested. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces changes to enable deterministic inference for the Blockscale FP8 kernel. This is achieved by adding a deterministic flag which is propagated down from an environment variable. This flag is used to avoid the non-deterministic StreamKScheduler and ReductionMode::Nondeterministic in CUTLASS. The changes are logical and correctly implement the intended behavior. I have one suggestion to improve code maintainability by reducing code duplication.
| torch::Tensor scales_b_contiguous = scales_b.contiguous(); | ||
| if (out_dtype == torch::kBFloat16) { | ||
| cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>( | ||
| out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); | ||
| out_padded, | ||
| mat_a_padded, | ||
| mat_b, | ||
| scales_a_padded, | ||
| scales_b_contiguous, | ||
| getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")); | ||
| } else { | ||
| cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>( | ||
| out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); | ||
| out_padded, | ||
| mat_a_padded, | ||
| mat_b, | ||
| scales_a_padded, | ||
| scales_b_contiguous, | ||
| getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")); | ||
| } |
There was a problem hiding this comment.
To improve readability and avoid redundant calls to getBoolEnv, you can fetch the deterministic flag once before the if block and store it in a const bool variable. This also reduces code duplication within the if-else branches.
torch::Tensor scales_b_contiguous = scales_b.contiguous();
const bool deterministic = getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE");
if (out_dtype == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out_padded,
mat_a_padded,
mat_b,
scales_a_padded,
scales_b_contiguous,
deterministic);
} else {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out_padded,
mat_a_padded,
mat_b,
scales_a_padded,
scales_b_contiguous,
deterministic);
}
|
Nice PR! Have you verified deepseek was not deterministic before this change? And would you pls provide the benchmark for non-deterministic vs deterministic mode? |
|
@hebiao064 Sorry, acc I rebased some changes with main, but I swapped to flashinfer backend bc I encountered a strange triton crash (issue is at) #11529. I think some mla decode might still not be deterministic atm |
|
Complete in #11491 |
Motivation
Part of #11402
Enable deterministic inference on R1/V3. According to @Fridge003 If FP8 fused MoE is deterministic, then this should be one of the remaining kernel (maybe for Hopper only)
Modifications
Make 2 changes to blockwise FP8 for deterministic inference.
ReductionMode::Nondeterministic. To my understanding, it is also part of StreamK scheduler, because the reduction modes can occur in any order, and FP addition is non-deterministic.I also disabled DeepGEMM through env vars.
Accuracy Tests
Before cmd (100% not deterministic):
Results:
Try to enable deterministic before this PR (default blockwise fp8 is flashinfer)
After cmd:
Results (TODO redo these):
The script to test the determinism of the op itself.