Skip to content

Commit 6da3aba

Browse files
jasperjiaguoclaude
andauthored
perf: optimize PCG inductor path for FP8 models (sgl-project#21734)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3cb3f7c commit 6da3aba

2 files changed

Lines changed: 35 additions & 10 deletions

File tree

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sglang.srt.layers.quantization.fp8_kernel import (
1919
fp8_dtype,
2020
fp8_max,
21+
fp8_min,
2122
is_fp8_fnuz,
2223
mxfp8_block_scaled_matmul_triton,
2324
per_token_group_quant_fp8,
@@ -28,6 +29,7 @@
2829
w8a8_block_fp8_matmul_deepgemm,
2930
w8a8_block_fp8_matmul_triton,
3031
)
32+
from sglang.srt.server_args import get_global_server_args
3133
from sglang.srt.utils import (
3234
ceil_align,
3335
ceil_div,
@@ -1455,12 +1457,32 @@ def apply_fp8_linear(
14551457
num_token_padding = output_padding
14561458
if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
14571459
num_token_padding = None
1458-
qinput, x_scale = scaled_fp8_quant(
1459-
input_2d,
1460-
input_scale,
1461-
num_token_padding=num_token_padding,
1462-
use_per_token_if_dynamic=use_per_token_if_dynamic,
1463-
)
1460+
# For static per-tensor activation scales when using inductor compiler,
1461+
# use pure PyTorch ops instead of the opaque sgl_kernel quant kernel.
1462+
# Inductor fuses these with surrounding ops (RMSNorm, residual add),
1463+
# eliminating a separate kernel launch per linear layer.
1464+
# weight_scale shape does not matter here -- it is only used in the
1465+
# GEMM epilogue, not in the activation quant fusion. Only activates when
1466+
# piecewise_cuda_graph_compiler=inductor; eager PCG and decode both
1467+
# use the faster custom kernel.
1468+
if (
1469+
input_scale is not None
1470+
and input_scale.numel() == 1
1471+
and get_global_server_args().piecewise_cuda_graph_compiler == "inductor"
1472+
):
1473+
qinput = (
1474+
(input_2d * input_scale.reciprocal())
1475+
.clamp(min=fp8_min, max=fp8_max)
1476+
.to(fp8_dtype)
1477+
)
1478+
x_scale = input_scale
1479+
else:
1480+
qinput, x_scale = scaled_fp8_quant(
1481+
input_2d,
1482+
input_scale,
1483+
num_token_padding=num_token_padding,
1484+
use_per_token_if_dynamic=use_per_token_if_dynamic,
1485+
)
14641486
else:
14651487
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
14661488
if input_scale is not None:

python/sglang/srt/models/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
3131
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
3232
from sglang.srt.model_loader.weight_utils import default_weight_loader
33+
from sglang.srt.server_args import get_global_server_args
3334
from sglang.srt.utils import get_current_device_stream_fast, is_cuda, is_hip
3435
from sglang.srt.utils.custom_op import register_custom_op
3536

@@ -422,6 +423,8 @@ def apply_qk_norm(
422423
and allow_inplace # TODO(dark): this can be relaxed if needed
423424
and (q_eps == k_eps) # TODO(dark): this can also be relaxed
424425
and not envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()
426+
and get_global_server_args().piecewise_cuda_graph_compiler
427+
!= "inductor" # let inductor fuse QK norm
425428
and can_use_fused_inplace_qknorm(head_dim, q.dtype)
426429
):
427430
fused_inplace_qknorm(
@@ -437,16 +440,16 @@ def apply_qk_norm(
437440
if alt_stream is not None and get_is_capture_mode():
438441
current_stream = get_current_device_stream_fast()
439442
alt_stream.wait_stream(current_stream)
440-
q_by_head = q.reshape(-1, head_dim)
443+
q_by_head = q.view(*q.shape[:-1], -1, head_dim)
441444
q_by_head = q_norm(q_by_head)
442445
with torch.cuda.stream(alt_stream):
443-
k_by_head = k.reshape(-1, head_dim)
446+
k_by_head = k.view(*k.shape[:-1], -1, head_dim)
444447
k_by_head = k_norm(k_by_head)
445448
current_stream.wait_stream(alt_stream)
446449
else:
447-
q_by_head = q.reshape(-1, head_dim)
450+
q_by_head = q.view(*q.shape[:-1], -1, head_dim)
448451
q_by_head = q_norm(q_by_head)
449-
k_by_head = k.reshape(-1, head_dim)
452+
k_by_head = k.view(*k.shape[:-1], -1, head_dim)
450453
k_by_head = k_norm(k_by_head)
451454
q = q_by_head.view(q.shape)
452455
k = k_by_head.view(k.shape)

0 commit comments

Comments
 (0)