Skip to content

CUTLASS FP8 Blockwise GEMM improvement of SM120#20887

Merged
BBuf merged 1 commit intomainfrom
brayden/optimize-fp8-gemm-sm120
Mar 22, 2026
Merged

CUTLASS FP8 Blockwise GEMM improvement of SM120#20887
BBuf merged 1 commit intomainfrom
brayden/optimize-fp8-gemm-sm120

Conversation

@b8zhong
Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong commented Mar 18, 2026

Motivation

The SM120 fp8 blockwise GEMM kernel was using KernelScheduleAuto as the schedule, which on SM120 happens to select the cooperative kernel only. The single-kernel approach misses a performance opportunity, the pingpong schedule is about 2x faster than cooperative for small M. I adapted the example from the CUTLASS repo.

Modifications

sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu:

  • Replaced KernelScheduleAuto with KernelScheduleSm120Blockwise for the cooperative path, when M > 64, to avoid
    the specific CUTLASS issue (the refcheck output will explode in relative error, I'm not sure of the cause exactly, but it appears to be potentially a CUTLASS library issue)
  • Added a pingpong path using KernelTmaWarpSpecializedBlockwisePingpongSm120 with a 64x128x128 tile shape for M ≤ 64.
  • Added a m <= 64 runtime check to pick between the two paths.
  • Refactor the kernel setup a bit.

Accuracy and UT

Screenshot 2026-03-18 at 6 51 54 PM

RTX 5090 (SM120) against Flashinfer across Qwen/Qwen3.5-27B-FP8 shapes at M from 1 to 512:

Performance (RTX 5090, N=1536, K=5120)

image image
M this PR FlashInfer Triton
8 0.034 ms 0.063 ms 0.041 ms
64 0.034 ms 0.063 ms 0.041 ms
128 0.063 ms 0.063 ms 0.042 ms
512 0.063 ms 0.063 ms 0.043 ms

Profiles:

E2E Accuracy:

python -m sglang.test.run_eval --base-url http://localhost:30000 --eval-name gsm8k --num-examples 200 --max-tokens 16000 --repeat 5 --num-threads 48 --num-shots 5 --temperature 1.0 --top-p 0.95 --top-k 20 --min-p 0.0 --chat-template-kwargs '{"enable_thinking": true}'

Before:
{'score:std': np.float64(0.07053367989832945), 'scores': ['0.995', '0.975', '0.980', '0.990', '0.995'], 'mean_score': np.float64(0.9870000000000001)}

After:
{'score:std': np.float64(0.12155245781143219), 'scores': ['0.990', '0.985', '0.995', '0.995', '0.985'], 'mean_score': np.float64(0.99)}

BS = 1 speed:

Before:
(It will use the cooperative schedule)
Screenshot 2026-03-18 at 6 59 50 PM

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|   29.996    |  1024  |   1.000    |      34.14      |
+-------------+--------+------------+-----------------+

After:
Screenshot 2026-03-18 at 6 54 45 PM
(I only zoom in really far, to show the ping-pong schedule name)

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|   19.652    |  1024  |   1.000    |      52.11      |
+-------------+--------+------------+-----------------+

I think we can change the default GEMM backend on SM120 later.

Checklist

  • Format your code according to the Format code with pre-commit.
  • Add unit tests according to the Run and add unit tests.
  • Update documentation according to Write documentations.
  • Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@b8zhong
Copy link
Copy Markdown
Collaborator Author

b8zhong commented Mar 20, 2026

I further attached some reports using NCU:

Before (Cooperative Mainloop / 2-Stage)

--------------------------------------
Kernel: device_kernel<...MainloopSm120TmaWarpSpecializedBlockwiseScaling<2, 2...>>
Duration:               89.06 us
Compute (SM) Throughput: 3.24 %
Memory Throughput:      5.08 %
Memory Bandwidth:       89.61 Gbyte/s
L2 Hit Rate:            50.77 %
Local Memory Spilling:  0 requests

After (Pingpong Mainloop / 3-Stage)

--------------------------------------
Kernel: device_kernel<...MainloopSm120TmaWarpSpecializedBlockwiseScaling<3, 2...>>
Duration:               47.65 us
Compute (SM) Throughput: 3.01 %
Memory Throughput:      9.51 %
Memory Bandwidth:       167.55 Gbyte/s
L2 Hit Rate:            34.99 %
Local Memory Spilling:  252 requests

It can also be seen that the memory bandwidth increased by around 80%, for M = 16. Anyway, there is still a lot to improve in it's performance, but it should be an okay first step.

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Mar 20, 2026

/tag-and-rerun-ci again

@BBuf BBuf merged commit 009eee8 into main Mar 22, 2026
247 of 327 checks passed
@BBuf BBuf deleted the brayden/optimize-fp8-gemm-sm120 branch March 22, 2026 09:55
OrangeRedeng pushed a commit to OrangeRedeng/sglang that referenced this pull request Mar 22, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants