Skip to content

perf: skip KV cache in FA backend for embedding mode#21971

Merged
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/embedding-skip-kvcache
Apr 13, 2026
Merged

perf: skip KV cache in FA backend for embedding mode#21971
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/embedding-skip-kvcache

Conversation

@jasperjiaguo
Copy link
Copy Markdown
Contributor

@jasperjiaguo jasperjiaguo commented Apr 2, 2026

Summary

Skip KV cache read/write in the FlashAttention backend when serving embedding models, eliminating store_kvcache and flash::prepare_varlen_num_blocks kernel overhead per decoder layer.

Motivation

In embedding mode with --chunked-prefill-size -1 and --disable-radix-cache, every request is a single prefill with no decode step. The KV cache is written to and read from but never reused. This wastes ~19µs per layer (store_kvcache ~15µs + prepare_varlen ~4µs).

Changes

flashattention_backend.py:

  • Skip set_kv_buffer (KV cache write) when layer.fa_skip_kv_cache is true
  • Use flash_attn_varlen_func with raw K/V tensors instead of flash_attn_with_kvcache (bypasses KV cache read + paged attention)
  • Uses cu_seqlens_q for both Q and K sequence lengths (no prefix cache in this mode)

Why fa_skip_kv_cache requires disable_radix_cache

If radix cache is enabled, prefix KV entries could be shared across requests. Skipping the KV cache write would leave the cache unpopulated, causing radix cache lookups to read garbage. The disable_radix_cache guard ensures we only skip when caching is fully disabled.

Why other backends are unaffected

The save_kv_cache=False override was previously applied in radix_attention.forward(), which broke torch_native and triton backends — they skip the write but still read K/V from the (now empty) cache. The fix moves the skip logic entirely into flashattention_backend.py, which has its own varlen path that uses raw K/V directly.

Test plan

  • Embedding correctness: cosine similarity > 0.99 vs baseline
  • Cross-encoder test: torch_native/triton backends unaffected (no save_kv_cache override)
  • CI tests pass

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/embedding-skip-kvcache branch 2 times, most recently from 5bbfa1a to 74a7315 Compare April 2, 2026 23:05
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 2, 2026
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/embedding-skip-kvcache branch 2 times, most recently from 47962ff to 74a7315 Compare April 3, 2026 00:50
@jasperjiaguo jasperjiaguo changed the title [WIP] perf: skip KV cache and use varlen FA in embedding mode perf: skip KV cache and use varlen FA in embedding mode Apr 3, 2026
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/embedding-skip-kvcache branch 4 times, most recently from 9a45729 to 1247d45 Compare April 3, 2026 19:07
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/embedding-skip-kvcache branch 2 times, most recently from dc606bb to 020412a Compare April 4, 2026 20:58
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

3 similar comments
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 18, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 19, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants