Skip to content

Commit fc80fbd

Browse files
Fix cuda graph capture dtype mismatch (#17)
* Fix dtype mismatch in rotary embedding with FP8 KV cache When using FP8 KV cache quantization (e.g., with ModelOpt FP8 models), the query and key tensors may have different dtypes during CUDA graph capture. The query tensor remains in bfloat16 for computation, while the key tensor might need to be in FP8 format for KV cache storage. The issue was in DeepseekScalingRotaryEmbedding.forward_native() which only captured query's dtype and then converted both query and key to that same dtype. This caused a dtype mismatch error during CUDA graph capture: "query and key must have the same dtype". The fix preserves the original dtypes of both query and key tensors separately, ensuring they maintain their intended dtypes after the rotary position embedding computation. This resolves the CUDA graph capture failure with Qwen3MoE and other models using FP8 KV cache quantization. * Fix FA4 dtype mismatch with FP8 KV cache When using FlashAttention 4 (FA4) with FP8 KV cache quantization, there was a dtype mismatch between the query tensor (bfloat16) and the cached key/value tensors (FP8). FA4 requires all input tensors (q, k, v) to have the same dtype. The previous code only converted the query to FP8 when NOT using FA4 (fa_impl_ver != 4). This was based on the assumption that FA4 doesn't support FP8, but actually FA4 CAN work with FP8 tensors as long as all tensors have matching dtypes. The key difference is that FA4 doesn't support descale parameters for on-the-fly dequantization (unlike FA3). So we: 1. Convert query to FP8 to match the KV cache dtype for both FA3 and FA4 2. Only set k_descale/v_descale for FA3 (FA4 doesn't support them) This resolves the "query and key must have the same dtype" error when using FP8 KV cache with FA4. --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent df08f34 commit fc80fbd

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -693,16 +693,18 @@ def forward_extend(
693693
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
694694
# has corresponding quantization method so that layer.k_scale is not None,
695695
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
696-
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
696+
# 4) fa_impl_ver != 4 since fa4 does not support descale parameters (but FA4 can work with FP8 if all tensors have matching dtypes).
697697
if (
698698
self.kv_cache_dtype_str != "auto"
699699
and layer.head_dim <= 256
700-
and self.fa_impl_ver != 4
701700
):
702-
if layer.k_scale is not None:
703-
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
704-
k_descale = layer.k_scale.expand(descale_shape)
705-
v_descale = layer.v_scale.expand(descale_shape)
701+
if self.fa_impl_ver != 4:
702+
# For FA3, use descale parameters for on-the-fly dequantization
703+
if layer.k_scale is not None:
704+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
705+
k_descale = layer.k_scale.expand(descale_shape)
706+
v_descale = layer.v_scale.expand(descale_shape)
707+
# Convert query to FP8 to match KV cache dtype (required for FA4, optional for FA3)
706708
q = q.to(self.kv_cache_dtype)
707709
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
708710
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,8 @@ def forward_native(
816816
offsets: Optional[torch.Tensor] = None,
817817
) -> Tuple[torch.Tensor, torch.Tensor]:
818818
"""PyTorch-native implementation equivalent to forward()."""
819-
dtype = query.dtype
819+
query_dtype = query.dtype
820+
key_dtype = key.dtype
820821
query_rot = query[..., : self.rotary_dim]
821822
key_rot = key[..., : self.rotary_dim]
822823
if self.rotary_dim < self.head_size:
@@ -847,7 +848,7 @@ def forward_native(
847848
else:
848849
query = query_rot
849850
key = key_rot
850-
return query.to(dtype), key.to(dtype)
851+
return query.to(query_dtype), key.to(key_dtype)
851852

852853
def forward_npu(
853854
self,

0 commit comments

Comments
 (0)