Skip to content

perf: enable inductor combo_kernels for horizontal fusion#21977

Merged
ispobock merged 1 commit intosgl-project:mainfrom
jasperjiaguo:jiaguo/enable-combo-kernels
Apr 10, 2026
Merged

perf: enable inductor combo_kernels for horizontal fusion#21977
ispobock merged 1 commit intosgl-project:mainfrom
jasperjiaguo:jiaguo/enable-combo-kernels

Conversation

@jasperjiaguo
Copy link
Copy Markdown
Contributor

@jasperjiaguo jasperjiaguo commented Apr 2, 2026

Enable combo_kernels and benchmark_combo_kernel in inductor config to allow horizontal fusion of sibling ops with different shapes. This fuses operations like q_norm + k_norm (QK normalization) into a single triton kernel instead of generating separate kernels for each.

Requires torch >= 2.9.0.
Screenshot 2026-04-02 at 3 49 15 PM
Screenshot 2026-04-02 at 3 47 59 PM

Profile Results

Qwen3-0.6B FP8 embeddings on H200, PCG inductor, 7k tokens:

Metric Before After
GPU kernels per forward 413 357 (-14%)
QK norm kernels per layer 4 2
split_with_sizes / clone in kernel names Present Gone

The QK norm reduction + pointwise kernels for q and k are now horizontally fused into single kernels.

Throughput impact is neutral at 60 RPS (kernel launch overhead is not the bottleneck at this load), but the reduced kernel count should help at higher concurrency or with smaller models where launch overhead is proportionally larger.

@jasperjiaguo jasperjiaguo requested a review from hebiao064 as a code owner April 2, 2026 22:42
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 2, 2026
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-checks

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

2 similar comments
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/enable-combo-kernels branch from 28c6263 to 5e8404d Compare April 7, 2026 08:01
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/enable-combo-kernels branch 2 times, most recently from 2d4da8e to 3fd92ee Compare April 7, 2026 18:22
@Qiaolin-Yu Qiaolin-Yu requested a review from ispobock April 7, 2026 19:51
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

8 similar comments
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 8, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 14, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 15, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 16, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 17, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 18, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 19, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 23, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 23, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 23, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 23, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 23, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 24, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 24, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request May 3, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request May 4, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request May 9, 2026
Replace nvjet (cooperative-algorithm) FP8 GEMMs with CUTLASS kernels to
eliminate the 4-byte memset that nvjet requires before each GEMM launch.
This memset creates ~20us pipeline bubbles between triton fusion kernels
and GEMM kernels, totaling ~2.2ms per forward pass (112 GEMMs).

Changes:
- Add Sm90ColOrScalarBroadcast/Sm90RowOrScalarBroadcast custom EVT nodes
  (adapted from vLLM) that handle per-tensor scalar scales natively via
  runtime bool flag, eliminating expand+contiguous overhead
- Add out= parameter to fp8_scaled_mm for zero-copy GEMM output
- Add runtime wrapper that replaces extern_kernels._scaled_mm with
  CUTLASS fp8_scaled_mm, preserving inductor triton fusion
- Update fake tensor implementation for torch.compile compatibility

Profile results (7k token FP8 embedding, H200):
- Memset: 112 -> 0
- nvjet GEMM: 112 -> 0 CUTLASS
- Total GPU kernels: 357 (unchanged, fusion preserved)

Benchmark (Qwen3-0.6B FP8, production traffic distribution):
- Baseline (main): 30.77 items/sec
- With PRs sgl-project#21734+sgl-project#21971+sgl-project#21977: 37.77 items/sec
- + This PR (CUTLASS): 38.77 items/sec (+26% vs baseline)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants