Skip to content

perf(mamba): use Triton conv1d for non-contiguous input to avoid .contiguous() copy#20469

Merged
Qiaolin-Yu merged 7 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/triton-conv1d-noncontiguous-fallback
Mar 20, 2026
Merged

perf(mamba): use Triton conv1d for non-contiguous input to avoid .contiguous() copy#20469
Qiaolin-Yu merged 7 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/triton-conv1d-noncontiguous-fallback

Conversation

@jasperjiaguo
Copy link
Copy Markdown
Contributor

@jasperjiaguo jasperjiaguo commented Mar 12, 2026

Summary

causal_conv1d_fn in sgl_kernel requires the input tensor to have stride(-1) == 1 (contiguous last dimension). When called during prefill, the input x is a transposed view of the GEMM output:

mixed_qkv = mixed_qkv.transpose(0, 1)          # [seq, dim] → [dim, seq], non-contiguous
mixed_qkv = causal_conv1d_fn(mixed_qkv, ...)    # triggers x.contiguous() internally → >0.6ms copy

The copy costs >0.6ms per layer on large prefill batches (e.g. 16K tokens × 6144 features = 188MB). The existing Triton conv1d kernel already accepts arbitrary strides by passing stride values directly to the kernel — no copy needed.

This PR adds a fallback: when x.stride(-1) \!= 1 and seq_lens_cpu is already pre-computed by the caller (to avoid introducing a GPU-CPU sync), dispatch to the Triton kernel instead of copying.

Also updates run_eval.py to support --dataset-path for GPQA, allowing local CSV files instead of the hardcoded URL.

Test

Embedding server with Qwen3.5-0.8B on H200, 16K token inputs:

python -m sglang.launch_server \
  --model /models/Qwen/Qwen3.5-0.8B/2fc06364715b967f1860aea9cf38778875588b17 \
  --is-embedding --dtype bfloat16 --tp 1 \
  --mem-fraction-static 0.88 --port 30003 \
  --trust-remote-code --disable-radix-cache \
  --chunked-prefill-size -1 --max-prefill-tokens 32768 \
  --context-length 32768 --attention-backend fa3 \
  --enforce-piecewise-cuda-graph \
  --piecewise-cuda-graph-compiler inductor \
  --piecewise-cuda-graph-max-tokens 32768 \
  --linear-attn-backend flashinfer

Prefill Throughput (embedding, product distribution):

Throughput (tok/s)
Before ~206K
After ~267K (+30%)

Decode throughput (generation, 64 requests, concurrency=16, input=10K, output=30K):

Using bench_serving.py --dataset-name random-ids --random-input-len 10000 --random-output-len 30000 --num-prompts 64 --max-concurrency 16:

Metric Before (baseline) After (patched) Diff
Output tok/s 5076 5083 ~0%
Total tok/s 6479 6489 ~0%
Mean TPOT (ms) 2.62 2.63 ~0%
Median TPOT (ms) 2.77 2.74 ~0%
Mean E2E latency (s) 36.2 36.2 ~0%

No decode regression — causal_conv1d_update is unaffected by this patch (decode path always uses the CUDA kernel when sgl_kernel is available).

Quality (GPQA Diamond, 198 questions):

Score
Before 4.5%
After 6.6%

Quality (GPQA Main, 448 questions):

Score
Before 11.6%
After 11.2%

Cosine similarity between before/after embeddings: 0.99993 (numerical parity).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/triton-conv1d-noncontiguous-fallback branch 2 times, most recently from a28e87d to 1d3048d Compare March 12, 2026 20:19
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, have you also tested decoding performance? just to verify decoding performance will not have regression when using this kernel

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/triton-conv1d-noncontiguous-fallback branch from 1d3048d to fd5ecad Compare March 12, 2026 22:51
@Qiaolin-Yu Qiaolin-Yu self-assigned this Mar 12, 2026
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

Good catch! The concern is valid for callers like lfm2.py and lfm2_moe.py which do not pass seq_lens_cpu.

Updated in fceca7a: the Triton fallback now only activates when seq_lens_cpu is already present in kwargs (pre-computed on CPU by the caller). When absent, we keep the original .contiguous() path to avoid any new GPU-CPU sync.

Callers that do pass seq_lens_cpu (e.g. gdn_backend.py via forward_batch.extend_seq_lens_cpu) are already using a CPU list — no sync there either.

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see slightly regression in decode performance, could you test a longer output len for decoding? e.g., bs 16 with input len 10000, output len 30000. btw, I'm a bit concerned that the decoding performance of other models might be affected. If this slight regression is consistently reproduced, could we only use this kernel for prefill?

Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py
Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py Outdated
Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py Outdated
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

I have cleaned up the code a bit I think now it only impacts prefill

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/triton-conv1d-noncontiguous-fallback branch from 876e2e4 to 2536dfd Compare March 13, 2026 17:33
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

several nits

Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py
Comment thread python/sglang/srt/layers/attention/mamba/causal_conv1d.py
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/triton-conv1d-noncontiguous-fallback branch from 6c11188 to 6d4589f Compare March 13, 2026 21:15
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

…tiguous() copy

On large prefill batches, causal_conv1d_fn receives a non-contiguous input
tensor because the GEMM output [seq, features] is transposed to [features, seq]
before being passed to the CUDA conv1d kernel. The CUDA kernel requires
stride(-1)==1, which forces a full tensor copy via .contiguous() costing
>0.6ms per layer.

The existing Triton conv1d kernel already accepts arbitrary strides by passing
stride values directly. This change falls back to the Triton path whenever the
input is non-contiguous, eliminating the copy entirely.

Tested on both embedding (Qwen3.5-0.8B) and generation workloads.

fix: avoid GPU-CPU sync when seq_lens_cpu not pre-computed

Only fall back to Triton conv1d when seq_lens_cpu is already available
in kwargs (pre-computed on CPU by the caller). When absent, keep the
original .contiguous() path to avoid introducing a GPU-CPU sync via
query_start_loc.cpu().tolist().

fix: correct fallback logic — use .contiguous() when seq_lens_cpu unavailable

Previous fix had dead code. Now clearly:
- no sgl_kernel: always use Triton (compute seq_lens_cpu if needed)
- sgl_kernel + contiguous: use CUDA kernel (fast path, unchanged)
- sgl_kernel + non-contiguous + seq_lens_cpu available: use Triton (no copy)
- sgl_kernel + non-contiguous + seq_lens_cpu absent: .contiguous() + CUDA (avoids GPU-CPU sync)

style: improve comments and docstring clarity

feat(eval): support --dataset-path for GPQA eval in run_eval.py

Allow users to pass a local CSV file path via --dataset-path instead of
downloading from the hardcoded OpenAI blob URL. Falls back to the
original URL when --dataset-path is not provided.

refactor: use use_triton variable for dispatch clarity

Address reviewer nit: unwrap the dispatch condition into a named
use_triton variable in both causal_conv1d_fn and causal_conv1d_update
for readability.

Also revert docstring format to match the original inline style,
adding a brief dispatch note at the end rather than a full rewrite.

fix: restore original docstrings, keep use_triton variable

Revert docstring changes in causal_conv1d_fn and causal_conv1d_update
to match upstream exactly. The dispatch logic explanation now lives
only in inline comments, not in the docstring.

fix: remove kwargs from causal_conv1d_update, restore clean state

- Remove **kwargs from causal_conv1d_update signature and triton call
  (no callers pass extra kwargs to the decode update function)
- Keep **kwargs only in causal_conv1d_fn where seq_lens_cpu is passed
- Docstrings now match upstream exactly
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/triton-conv1d-noncontiguous-fallback branch from 6d4589f to 2337425 Compare March 14, 2026 00:23
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

2 similar comments
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

3 similar comments
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Qiaolin-Yu Qiaolin-Yu enabled auto-merge (squash) March 19, 2026 22:24
@Qiaolin-Yu Qiaolin-Yu disabled auto-merge March 20, 2026 02:38
@Qiaolin-Yu Qiaolin-Yu merged commit 87549f8 into sgl-project:main Mar 20, 2026
573 of 621 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants