[VLM] Optimize Gemma4 VLM with PCG and fuse RMSNorm + residual add + scalar#24048
[VLM] Optimize Gemma4 VLM with PCG and fuse RMSNorm + residual add + scalar#24048Kangyan-Zhou merged 1 commit intosgl-project:mainfrom
Conversation
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
Code Review
This pull request introduces a fused Triton kernel, gemma_dual_rmsnorm_residual_scalar, to optimize the Gemma4 model's forward pass by combining multiple RMSNorm and residual operations. It also refactors model resolution logic in model_runner.py and piecewise_cuda_graph_runner.py to better handle different model architectures. Feedback highlights the need for stricter input validation in the new Triton wrapper to prevent potential memory issues and identifies a logic error in resolve_language_model that would lead to a guaranteed AttributeError.
kpham-sgl
left a comment
There was a problem hiding this comment.
Nice! Can you run MMMU and verify against the score here https://docs.sglang.io/cookbook/autoregressive/Google/Gemma4#mmmu
Updated. |
85d7f4b to
a1f46b9
Compare
8cdbdbf to
e35d303
Compare
Motivation
Optimize Gemma4 26B-A4B prefill performance through two complementary approaches:
Main:

PR:

Modifications
Fused Triton Kernels (sglang/srt/layers/gemma4_fused_ops.py)
Gemma4 Model Integration (sglang/srt/models/gemma4_causal.py)
PCG VLM Compatibility (sglang/srt/model_executor/model_runner.py)
PCG Runner Fix (sglang/srt/model_executor/piecewise_cuda_graph_runner.py)
Accuracy Tests
MMMU match official 54.9%.
Speed Tests and Profiling
PCG eliminates CPU dispatch overhead by capturing per-layer CUDA graphs for prefill. The benefit is most significant at small-to-medium token counts where kernel launch latency dominates compute time.
Compatibility
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci