Skip to content

[NVIDIA] Support TF32 matmul to improve MiniMax gate gemm performance#22744

Open
trevor-m wants to merge 1 commit intosgl-project:mainfrom
trevor-m:tf32
Open

[NVIDIA] Support TF32 matmul to improve MiniMax gate gemm performance#22744
trevor-m wants to merge 1 commit intosgl-project:mainfrom
trevor-m:tf32

Conversation

@trevor-m
Copy link
Copy Markdown
Collaborator

Motivation

Before this change, the fp32 gate gemm takes 9.1% of e2e decode time for MiniMax-M2.5 at bs 64. With --enable-tf32-matmul, it is reduced to 3.3%.

Modifications

Use torch.set_float32_matmul_precision('high') to use TF32 as internal computation type for FP32 matmuls when available. This improves performance without affecting accuracy as much.

See torch docs: https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html

Accuracy Tests

GPQA

sglang serve --model-path MiniMaxAI/MiniMax-M2.5 --tp 8 --trust-remote-code --mem-fraction-static=0.85 --reasoning-parser=minimax-append-think --kv-cache-dtype fp8_e4m3 --moe-runner-backend flashinfer_trtllm_routed --dtype bfloat16 --enable-flashinfer-allreduce-fusion  --enable-tf32-matmul
python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 64000 --repeat 4 --temperature 1.0 --top-k 40 --top-p 0.95
[METRIC] gpqa_mean_score=0.8017676767676768 labels={"model": "MiniMaxAI/MiniMax-M2.5", "eval": "gpqa", "repeat": 4}

Speed Tests and Profiling

# Before 
sglang serve --model-path MiniMaxAI/MiniMax-M2.5 --tp 8 --ep 8 --trust-remote-code --mem-fraction-static=0.85 --reasoning-parser=minimax-append-think
python3 -m sglang.bench_one_batch_server \
    --model MiniMaxAI/MiniMax-M2.5 \
    --base-url http://localhost:30000 \
    --batch-size 1 8 16 64 \
    --input-len 4096 \
    --output-len 512 \
    --show-report \
    --no-append-to-github-summary \
    --trust-remote-code \
    --server-args-for-metrics \
    --trust-remote-code --ep=8 --mem-fraction-static=0.85 --reasoning-parser=minimax-append-think

Input lens: [4096]. Output lens: [512].
|   batch size |   input len |   latency (s) |   input throughput (tok/s) |   output throughput (tok/s) | acc length   |   ITL (ms) |   input cost ($/1M) |   output cost ($/1M) | cache hit rate   |                                                                                                                                                       
|--------------|-------------|---------------|----------------------------|-----------------------------|--------------|------------|---------------------|----------------------|------------------|                                                                                                                                                       
|            1 |        4096 |          5.09 |                    35600.7 |                      102.9  | n/a          |       9.72 |                0.04 |                10.8  | n/a              |                                                                                                                                                       
|            8 |        4096 |          6.75 |                    60989.9 |                      659.3  | n/a          |      12.13 |                0.03 |                 1.69 | n/a              |                                                                                                                                                       
|           16 |        4096 |          7.96 |                    63056.5 |                     1183.73 | n/a          |      13.52 |                0.03 |                 0.94 | n/a              |                                                                                                                                                       
|           64 |        4096 |         14.68 |                    64979.8 |                     3076.99 | n/a          |      20.8  |                0.02 |                 0.36 | n/a              |

# After (with --enable-tf32-matmul)
sglang serve --model-path MiniMaxAI/MiniMax-M2.5 --tp 8 --ep 8 --trust-remote-code --mem-fraction-static=0.85 --reasoning-parser=minimax-append-think --enable-tf32-matmul
Input lens: [4096]. Output lens: [512].
|   batch size |   input len |   latency (s) |   input throughput (tok/s) |   output throughput (tok/s) | acc length   |   ITL (ms) |   input cost ($/1M) |   output cost ($/1M) | cache hit rate   |                                                                                                                                                       
|--------------|-------------|---------------|----------------------------|-----------------------------|--------------|------------|---------------------|----------------------|------------------|                                                                                                                                                       
|            1 |        4096 |          5.03 |                    35126.1 |                      104.3  | n/a          |       9.59 |                0.05 |                10.65 | n/a              |                                                                                                                                                       
|            8 |        4096 |          6.6  |                    61127.5 |                      675.34 | n/a          |      11.85 |                0.03 |                 1.65 | n/a              |                                                                                                                                                       
|           16 |        4096 |          7.69 |                    70905.7 |                     1209.99 | n/a          |      13.22 |                0.02 |                 0.92 | n/a              |                                                                                                                                                       
|           64 |        4096 |         13.5  |                    73254.5 |                     3302.03 | n/a          |      19.38 |                0.02 |                 0.34 | n/a              |

Before

Screenshot 2026-04-13 at 4 51 48 PM

After

Screenshot 2026-04-13 at 4 50 39 PM

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 14, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

@trevor-m Instead of adding a new argument that globally enables this feature, can we enable the tf32 matmul flag only for minimax model, since it has only been verified on this model?

@trevor-m
Copy link
Copy Markdown
Collaborator Author

Hi @Fridge003 This feature can benefit any model with fp32 gemms so I think it's nice to have it as an optional flag, it's currently off by default. Users can enable it if they test the accuracy is good. For minimax, do you want it enabled by default?

thanhhao98 pushed a commit to thanhhao98/sglang that referenced this pull request Apr 25, 2026
Stacks on top of the previous commit (default routed-MoE). Calls
torch.set_float32_matmul_precision("high") in the same Glm4MoeForCausalLM
sm100 auto-default block, so any GLM-4.7-NVFP4 launch on Blackwell
gets TF32 tensor-core path for the FP32 router gemm (5120 -> 160).

This is a port of the pending sgl-project#22744 ("[NVIDIA]
Support TF32 matmul to improve MiniMax gate gemm performance"). On
Minimax-M2.5: +7% output throughput, -8% latency at batch=64;
FP32 router gemm 9.1% -> 3.3% of decode time; GPQA accuracy preserved.

GLM-4.7 has the same router topology (5120 -> N_experts FP32 cast)
so the same gain should apply. Bench data to follow as optimal_v2.

The FP32 cast from PR sgl-project#21660 still happens upstream of the matmul;
this changes only the matmul kernel to use TF32 tensor cores. Gate
with the existing GSM8K accuracy CI before merging.
thanhhao98 pushed a commit to thanhhao98/sglang that referenced this pull request Apr 25, 2026
The previous attempt (commit d6a435b) called
torch.set_float32_matmul_precision("high") in ServerArgs.__post_init__,
which runs in the parent process only. SGLang's worker processes are
spawned (not forked), so PyTorch state set in the parent does not
propagate. Bench data confirmed v2 was a no-op:

  v1 vs v2 nvfp4_tp8 throughput at all 10 cc points: |delta| < 1%

Microbench inside the v2 image confirmed TF32 is functional (3.18x
speedup on the gate matmul shape) but only after explicit setting in
the same process. Hence each TP rank's worker must call
set_float32_matmul_precision itself.

This commit moves the call into Glm4MoeGate.__init__, gated by a
class-level _tf32_set flag so it fires exactly once per worker. The
parent-side call from d6a435b is left in place for redundancy
(harmless, idempotent).

Glm4MoeGate is the only class that performs an FP32 matmul in the
GLM-4.7 forward graph (see lines 372-375 — the FP32-cast gate
projection per PR sgl-project#21660). The TF32 setting therefore has zero scope
beyond what's needed; no other FP32 matmuls get accelerated.

Expected gain on Blackwell sm100: ~5-7% throughput at high cc per
PR sgl-project#22744's Minimax-M2.5 measurement, mediated by what fraction of
decode is the gate gemm. To be measured as optimal_v3.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants