Skip to content

[nvidia] Gemma4 nvfp4 fix#22079

Merged
ispobock merged 6 commits intosgl-project:mainfrom
wenscarl:gemma4-nvfp4-fix
Apr 10, 2026
Merged

[nvidia] Gemma4 nvfp4 fix#22079
ispobock merged 6 commits intosgl-project:mainfrom
wenscarl:gemma4-nvfp4-fix

Conversation

@wenscarl
Copy link
Copy Markdown
Collaborator

@wenscarl wenscarl commented Apr 3, 2026

Based on #21952 and depends on flashinfer-ai/flashinfer#2959

Motivation

Gemma 4 NVFP4 checkpoints does not work on GB200 for the following reasons:

Triton attention kernel — PTX register exhaustion

When running Gemma4 with the triton attention backend on GB200, the engine crashes during prefill:

triton.runtime.errors.PTXASError: PTXAS error: Internal Triton PTX codegen error
ptxas fatal: Register allocation failed with register count of '255'.

Root cause: _get_block_sizes_for_extend_attention had no dedicated branch for CUDA_CAPABILITY[0] == 10 (GB200/B200/sm_100a). sm_100a fell into the >= 9 Hopper catch-all, selecting BLOCK_M=32, BLOCK_N=64, num_warps=8 for Lq > 256. Gemma4 uses a global head dim of 512, so this config is always hit for global attention layers.

The crash is specifically triggered when the KV cache dtype is fp8 — which Gemma4-NVFP4 enables automatically via quant_config.kv_cache_quant_algo = "FP8". The fp8 dequantization instructions in the kernel body increase register pressure enough to push over sm_100a's ptxas allocation limit. The same crash reproduces with any bf16 model that explicitly sets kv_cache_dtype=fp8_e4m3 on GB200.

Modifications

In extend_attention.py: Add a dedicated CUDA_CAPABILITY[0] == 10 branch before the >= 9 Hopper catch-all with smaller tile sizes (BLOCK_M=16, BLOCK_N=64 for Lq > 256) to stay within the sm_100a register budget.

Accuracy Tests

Tested on GB200 with nvidia/Gemma-4-31B-IT-NVFP4 + triton attention backend. Script completes without exception and produces correct output.

Speed Tests and Profiling

cc. @nvpohanh

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

… large head dims and default to trtllm_mha on sm100.
@kpham-sgl
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 7, 2026
@kpham-sgl kpham-sgl self-assigned this Apr 7, 2026
@kpham-sgl
Copy link
Copy Markdown
Collaborator

kpham-sgl commented Apr 8, 2026

/tag-and-rerun-ci again

Comment thread python/sglang/srt/layers/attention/triton_ops/extend_attention.py
@kpham-sgl
Copy link
Copy Markdown
Collaborator

kpham-sgl commented Apr 8, 2026

/rerun-failed-ci again

@kpham-sgl
Copy link
Copy Markdown
Collaborator

kpham-sgl commented Apr 9, 2026

/rerun-failed-ci one

@jeremylea
Copy link
Copy Markdown

Any reason this isn't handling sm_120a (RTX 6000)

@ispobock ispobock merged commit 5638d40 into sgl-project:main Apr 10, 2026
194 of 233 checks passed
Fridge003 pushed a commit that referenced this pull request Apr 11, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
@baoskee
Copy link
Copy Markdown

baoskee commented May 7, 2026

hey did you guys test this with the docker images because it does not work for:

docker pull lmsysorg/sglang:cu13-gemma4 # CUDA 13

I'm getting:

File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 115, in forward return self.forward_extend( ^^^^^^^^^^^^^^^^^^^^ File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/triton_backend.py", line 936, in forward_extend self.extend_attention_fwd( File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/triton_ops/extend_attention.py", line 609, in extend_attention_fwd _fwd_kernel[grid]( File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 419, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 733, in run kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 861, in _do_compile kernel = self.compile(src, target=target, options=options.__dict__) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 320, in compile next_module = compile_ir(module, metadata) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 520, in <lambda> stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 503, in make_cubin raise PTXASError(error) triton.runtime.errors.PTXASError: PTXAS error: Internal Triton PTX codegen error ptxas stderr: ptxas fatal : (C7600) Register allocation failed with register count of '255'. Compile the program with a higher register target ptxas fatal : Ptx assembly aborted due to errors Repro command: /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas -lineinfo -v --gpu-name=sm_100a /tmp/tmp8ca_n_mf.ptx -o /tmp/tmp8ca_n_mf.ptx.o

@nvpohanh
Copy link
Copy Markdown
Collaborator

nvpohanh commented May 7, 2026

@baoskee that container is quite old. could you try the latest dev-cu13 container?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority Multi-modal multi-modal language model quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants