Commit fc80fbd
Fix cuda graph capture dtype mismatch (#17)
* Fix dtype mismatch in rotary embedding with FP8 KV cache
When using FP8 KV cache quantization (e.g., with ModelOpt FP8 models),
the query and key tensors may have different dtypes during CUDA graph
capture. The query tensor remains in bfloat16 for computation, while
the key tensor might need to be in FP8 format for KV cache storage.
The issue was in DeepseekScalingRotaryEmbedding.forward_native() which
only captured query's dtype and then converted both query and key to
that same dtype. This caused a dtype mismatch error during CUDA graph
capture: "query and key must have the same dtype".
The fix preserves the original dtypes of both query and key tensors
separately, ensuring they maintain their intended dtypes after the
rotary position embedding computation.
This resolves the CUDA graph capture failure with Qwen3MoE and other
models using FP8 KV cache quantization.
* Fix FA4 dtype mismatch with FP8 KV cache
When using FlashAttention 4 (FA4) with FP8 KV cache quantization,
there was a dtype mismatch between the query tensor (bfloat16) and
the cached key/value tensors (FP8). FA4 requires all input tensors
(q, k, v) to have the same dtype.
The previous code only converted the query to FP8 when NOT using FA4
(fa_impl_ver != 4). This was based on the assumption that FA4 doesn't
support FP8, but actually FA4 CAN work with FP8 tensors as long as
all tensors have matching dtypes.
The key difference is that FA4 doesn't support descale parameters for
on-the-fly dequantization (unlike FA3). So we:
1. Convert query to FP8 to match the KV cache dtype for both FA3 and FA4
2. Only set k_descale/v_descale for FA3 (FA4 doesn't support them)
This resolves the "query and key must have the same dtype" error when
using FP8 KV cache with FA4.
---------
Co-authored-by: Cursor Agent <cursoragent@cursor.com>1 parent df08f34 commit fc80fbd
2 files changed
Lines changed: 11 additions & 8 deletions
Lines changed: 8 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
693 | 693 | | |
694 | 694 | | |
695 | 695 | | |
696 | | - | |
| 696 | + | |
697 | 697 | | |
698 | 698 | | |
699 | 699 | | |
700 | | - | |
701 | 700 | | |
702 | | - | |
703 | | - | |
704 | | - | |
705 | | - | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
706 | 708 | | |
707 | 709 | | |
708 | 710 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
816 | 816 | | |
817 | 817 | | |
818 | 818 | | |
819 | | - | |
| 819 | + | |
| 820 | + | |
820 | 821 | | |
821 | 822 | | |
822 | 823 | | |
| |||
847 | 848 | | |
848 | 849 | | |
849 | 850 | | |
850 | | - | |
| 851 | + | |
851 | 852 | | |
852 | 853 | | |
853 | 854 | | |
| |||
0 commit comments