perf: skip KV cache in FA backend for embedding mode#21971
Merged
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom Apr 13, 2026
Merged
perf: skip KV cache in FA backend for embedding mode#21971Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
5bbfa1a to
74a7315
Compare
Contributor
Author
|
/tag-and-rerun-ci |
47962ff to
74a7315
Compare
9a45729 to
1247d45
Compare
Contributor
Author
|
/tag-and-rerun-ci |
dc606bb to
020412a
Compare
Contributor
Author
|
/rerun-failed-ci |
3 similar comments
Contributor
Author
|
/rerun-failed-ci |
Contributor
Author
|
/rerun-failed-ci |
Contributor
Author
|
/rerun-failed-ci |
5 tasks
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
pyc96
pushed a commit
to pyc96/sglang
that referenced
this pull request
Apr 14, 2026
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 18, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 19, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to eliminate the 4-byte memset that nvjet requires before each GEMM launch. This memset creates ~20us pipeline bubbles between triton fusion kernels and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs). Changes: - Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes (adapted from vLLM) that handle per-tensor scalar scales natively via runtime bool flag, eliminating expand+contiguous overhead - Add out= parameter to fp8_scaled_mm for zero-copy GEMM output - Add runtime wrapper that replaces extern_kernels._scaled_mm with CUTLASS fp8_scaled_mm, preserving inductor triton fusion - Update fake tensor implementation for torch.compile compatibility Profile results (7k token FP8 embedding, H200): - Memset: 112 -> 0 - nvjet GEMM: 112 -> 0 CUTLASS - Total GPU kernels: 357 (unchanged, fusion preserved) Benchmark (Qwen3-0.6B FP8, production traffic distribution): - Baseline (main): 30.77 items/sec - With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec - + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
yhyang201
pushed a commit
to yhyang201/sglang
that referenced
this pull request
Apr 22, 2026
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Skip KV cache read/write in the FlashAttention backend when serving embedding models, eliminating
store_kvcacheandflash::prepare_varlen_num_blockskernel overhead per decoder layer.Motivation
In embedding mode with
--chunked-prefill-size -1and--disable-radix-cache, every request is a single prefill with no decode step. The KV cache is written to and read from but never reused. This wastes ~19µs per layer (store_kvcache~15µs +prepare_varlen~4µs).Changes
flashattention_backend.py:set_kv_buffer(KV cache write) whenlayer.fa_skip_kv_cacheis trueflash_attn_varlen_funcwith raw K/V tensors instead offlash_attn_with_kvcache(bypasses KV cache read + paged attention)cu_seqlens_qfor both Q and K sequence lengths (no prefix cache in this mode)Why
fa_skip_kv_cacherequiresdisable_radix_cacheIf radix cache is enabled, prefix KV entries could be shared across requests. Skipping the KV cache write would leave the cache unpopulated, causing radix cache lookups to read garbage. The
disable_radix_cacheguard ensures we only skip when caching is fully disabled.Why other backends are unaffected
The
save_kv_cache=Falseoverride was previously applied inradix_attention.forward(), which broketorch_nativeandtritonbackends — they skip the write but still read K/V from the (now empty) cache. The fix moves the skip logic entirely intoflashattention_backend.py, which has its own varlen path that uses raw K/V directly.Test plan
save_kv_cacheoverride)🤖 Generated with Claude Code