Skip to content

[Deterministic] Blockscale FP8 kernel#11491

Closed
b8zhong wants to merge 2 commits intosgl-project:mainfrom
bzhng-development:deterministic-blockscale-fp8
Closed

[Deterministic] Blockscale FP8 kernel#11491
b8zhong wants to merge 2 commits intosgl-project:mainfrom
bzhng-development:deterministic-blockscale-fp8

Conversation

@b8zhong
Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong commented Oct 12, 2025

Motivation

Part of #11402
Enable deterministic inference on R1/V3. According to @Fridge003 If FP8 fused MoE is deterministic, then this should be one of the remaining kernel (maybe for Hopper only)

After rebasing, I encountered #11529 on triton, but I don't know how to solve this. So I used flashinfer but still not deterministic

Modifications

Make 2 changes to blockwise FP8 for deterministic inference.

  1. Avoid StreamK (which is another way to do Split-K) in CUTLASS
  2. Avoid ReductionMode::Nondeterministic. To my understanding, it is also part of StreamK scheduler, because the reduction modes can occur in any order, and FP addition is non-deterministic.

I also disabled DeepGEMM through env vars.

Accuracy Tests

Before cmd (100% not deterministic):

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1   --tp 8   --trust-remote-code

Results:

python3 -m sglang.test.test_deterministic --test-mode single
python3 -m sglang.test.test_deterministic --test-mode mixed
python3 -m sglang.test.test_deterministic --test-mode prefix
...
Total samples: 50, Unique samples: 3

Prompt 1: total samples: 424, Unique samples: 6
Prompt 2: total samples: 629, Unique samples: 6
Long prompt: total samples: 222, Unique samples: 14

Prompt 0 with prefix length 1: total samples: 340, Unique samples: 58
Prompt 1 with prefix length 511: total samples: 312, Unique samples: 163
Prompt 2 with prefix length 2048: total samples: 305, Unique samples: 115
Prompt 3 with prefix length 4097: total samples: 318, Unique samples: 115

Try to enable deterministic before this PR (default blockwise fp8 is flashinfer)

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1 --tp 8 --trust-remote-code --enable-deterministic-inference --disable-radix-cache --attention-backend flashinfer --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' --flashinfer-mla-disable-ragged
python3 -m sglang.test.test_deterministic --test-mode single
python3 -m sglang.test.test_deterministic --test-mode mixed
python3 -m sglang.test.test_deterministic --test-mode prefix
...
Total samples: 50, Unique samples: 6

Prompt 1: total samples: 559, Unique samples: 4
Prompt 2: total samples: 509, Unique samples: 7
Long prompt: total samples: 207, Unique samples: 3

Prompt 0 with prefix length 1: total samples: 310, Unique samples: 7
Prompt 1 with prefix length 511: total samples: 321, Unique samples: 23
Prompt 2 with prefix length 2048: total samples: 316, Unique samples: 9
Prompt 3 with prefix length 4097: total samples: 328, Unique samples: 20

After cmd:

Note by default, it seems to use the flashinfer blockwise FP8, however I am unsure/did not look into whether how to make it deterministic.

SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=0 SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1 python3 -m sglang.launch_server   --model-path deepseek-ai/DeepSeek-R1  --tp 8   --trust-remote-code --enable-deterministic-inference --disable-radix-cache --attention-backend triton --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' --attention-backend flashinfer --flashinfer-mla-disable-ragged

Results (TODO redo these):

...
Total samples: 50, Unique samples: 14
...
Prompt 1: total samples: 609, Unique samples: 5
Prompt 2: total samples: 485, Unique samples: 5
Long prompt: total samples: 181, Unique samples: 7
...
Prompt 0 with prefix length 1: total samples: 322, Unique samples: 5
Prompt 1 with prefix length 511: total samples: 322, Unique samples: 11
Prompt 2 with prefix length 2048: total samples: 334, Unique samples: 7
Prompt 3 with prefix length 4097: total samples: 297, Unique samples: 20
...

The script to test the determinism of the op itself.

# sgl-kernel/tests/test_fp8_blockwise_gemm_deterministic.py
import os
import torch

from sgl_kernel.gemm import fp8_blockwise_scaled_mm


def _m_major_view_2d(t: torch.Tensor) -> torch.Tensor:
    # M major
    base = t.t().contiguous()
    return base.t()


def _col_major_view_k_n(mat_k_n: torch.Tensor) -> torch.Tensor:
    # Col major
    # by materializing a transposed contiguous base and returning the transpose view.
    base = mat_k_n.t().contiguous()     # (n, k), row-major contiguous
    return base.t()                     # (k, n), column-major view (stride(0) == 1)


def _make_fp8_with_scales(m, k, n, device, dtype_out):
    torch.manual_seed(1234)
    assert k % 128 == 0 and n % 128 == 0

    fp8_dtype = torch.float8_e4m3fn
    fp8_max = torch.finfo(fp8_dtype).max

    # A: (m, k), per-128 block scales along K
    a_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
    a_scales_blocks = (
        a_bf16.abs().view(m, k // 128, 128).amax(dim=2) / fp8_max
    ).clamp_min(1e-8).to(torch.float32).contiguous()
    a_scale_full = a_scales_blocks.repeat_interleave(128, dim=1)
    a_fp8 = torch.clamp(
        (a_bf16 / a_scale_full).to(torch.float32), min=-fp8_max, max=fp8_max
    ).to(fp8_dtype).contiguous()
    # M major
    a_scales = _m_major_view_2d(a_scales_blocks)
    assert a_scales.size() == (m, k // 128) and a_scales.stride(0) == 1

    # B: (k, n), per-128x128 block scales
    b_bf16 = torch.randn(k, n, device=device, dtype=torch.bfloat16)
    b_scales_blocks = (
        b_bf16.abs().view(k // 128, 128, n // 128, 128).amax(dim=(1, 3))
    ).clamp_min(1e-8).to(torch.float32).contiguous()
    b_scale_full = b_scales_blocks.repeat_interleave(128, dim=0).repeat_interleave(128, dim=1)
    b_fp8_k_n = torch.clamp(
        (b_bf16 / b_scale_full).to(torch.float32), min=-fp8_max, max=fp8_max
    ).to(fp8_dtype)
    # Col major
    b_fp8 = _col_major_view_k_n(b_fp8_k_n)
    assert b_fp8.size() == (k, n) and b_fp8.stride(0) == 1

    b_scales = _m_major_view_2d(b_scales_blocks)
    assert b_scales.size() == (k // 128, n // 128) and b_scales.stride(0) == 1

    return a_fp8, a_scales, b_fp8, b_scales, dtype_out


def _run_once(m, k, n, device, dtype_out):
    a_fp8, a_scales, b_fp8, b_scales, out_dtype = _make_fp8_with_scales(
        m, k, n, device, dtype_out
    )
    # print(a_fp8.shape, b_fp8.shape, a_scales.shape, b_scales.shape)
    return fp8_blockwise_scaled_mm(a_fp8, b_fp8, a_scales, b_scales, out_dtype)


def test_fp8_blockwise_gemm_deterministic_sm90():
    if not torch.cuda.is_available():
        return
    dev = torch.device("cuda")
    os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"

    shapes = [(1024, 4096, 1024), (2048, 2048, 2048)]
    dtypes = [torch.bfloat16, torch.half]

    for (m, k, n) in shapes:
        for out_dtype in dtypes:
            y0 = _run_once(m, k, n, dev, out_dtype)
            for _ in range(5):
                y = _run_once(m, k, n, dev, out_dtype)
                assert torch.equal(y0, y), f"Non-deterministic result for shape {(m,k,n)} dtype {out_dtype}"

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses the need for deterministic inference in Blockscale FP8 kernels, particularly for R1/V3 hardware. It achieves this by making targeted modifications to the underlying CUTLASS GEMM operations, ensuring that computations yield consistent results by controlling scheduler behavior and reduction modes. The changes prevent sources of non-determinism, such as the use of StreamK and non-deterministic reduction strategies, when deterministic inference is explicitly requested.

Highlights

  • Deterministic Inference: Introduced a deterministic flag to control the behavior of FP8 blockwise GEMM operations, enabling consistent results across runs.
  • CUTLASS Scheduler Modification: Modified the CUTLASS scheduler to avoid StreamK decomposition and explicitly set ReductionMode::Deterministic when the deterministic flag is active, preventing non-deterministic behavior from floating-point additions.
  • Conditional StreamK Usage: The StreamKScheduler is now only used when deterministic inference is not required and a specific condition (k > 3 * n) is met, otherwise the PersistentScheduler is used.
  • Environment Variable Integration: The deterministic flag is dynamically controlled by the SGLANG_ENABLE_DETERMINISTIC_INFERENCE environment variable, allowing users to enable or disable deterministic behavior at runtime.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces changes to enable deterministic inference for the Blockscale FP8 kernel. This is achieved by adding a deterministic flag which is propagated down from an environment variable. This flag is used to avoid the non-deterministic StreamKScheduler and ReductionMode::Nondeterministic in CUTLASS. The changes are logical and correctly implement the intended behavior. I have one suggestion to improve code maintainability by reducing code duplication.

Comment on lines 418 to 435
torch::Tensor scales_b_contiguous = scales_b.contiguous();
if (out_dtype == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
out_padded,
mat_a_padded,
mat_b,
scales_a_padded,
scales_b_contiguous,
getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"));
} else {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
out_padded,
mat_a_padded,
mat_b,
scales_a_padded,
scales_b_contiguous,
getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid redundant calls to getBoolEnv, you can fetch the deterministic flag once before the if block and store it in a const bool variable. This also reduces code duplication within the if-else branches.

    torch::Tensor scales_b_contiguous = scales_b.contiguous();
    const bool deterministic = getBoolEnv("SGLANG_ENABLE_DETERMINISTIC_INFERENCE");
    if (out_dtype == torch::kBFloat16) {
      cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
          out_padded,
          mat_a_padded,
          mat_b,
          scales_a_padded,
          scales_b_contiguous,
          deterministic);
    } else {
      cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
          out_padded,
          mat_a_padded,
          mat_b,
          scales_a_padded,
          scales_b_contiguous,
          deterministic);
    }

@Fridge003 Fridge003 self-assigned this Oct 12, 2025
@hebiao064 hebiao064 self-assigned this Oct 12, 2025
@hebiao064
Copy link
Copy Markdown
Collaborator

hebiao064 commented Oct 12, 2025

Nice PR!

Have you verified deepseek was not deterministic before this change?

And would you pls provide the benchmark for non-deterministic vs deterministic mode?

@b8zhong
Copy link
Copy Markdown
Collaborator Author

b8zhong commented Oct 14, 2025

@hebiao064 Sorry, acc I rebased some changes with main, but I swapped to flashinfer backend bc I encountered a strange triton crash (issue is at) #11529.

I think some mla decode might still not be deterministic atm

@b8zhong
Copy link
Copy Markdown
Collaborator Author

b8zhong commented Oct 31, 2025

Complete in #11491

@b8zhong b8zhong closed this Oct 31, 2025
@b8zhong b8zhong deleted the deterministic-blockscale-fp8 branch October 31, 2025 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants