Skip to content

(1/n - prefill optimize)feat(lora): enable csgmv backend with virtual experts for MoE LoRA#24007

Merged
yushengsu-thu merged 5 commits intosgl-project:mainfrom
yushengsu-thu:lora-perf-optimize-0
May 4, 2026
Merged

(1/n - prefill optimize)feat(lora): enable csgmv backend with virtual experts for MoE LoRA#24007
yushengsu-thu merged 5 commits intosgl-project:mainfrom
yushengsu-thu:lora-perf-optimize-0

Conversation

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

Summary

  • Remove the assertion that restricted --lora-use-virtual-experts to triton-only. The virtual experts path uses its own triton kernel (merged_experts_fused_moe_lora_add) independent of the dense-LoRA backend, so csgmv is fully compatible.
  • Fix CUDA graph crash when using --lora-backend csgmv --lora-use-virtual-experts by adding per-request segment info (req_seg_indptr, req_weight_indices) to LoRABatchInfo. The csgmv backend's chunked-segment semantics differ from what MoE virtual experts expect (per-request), causing garbage token-to-adapter mappings and device-side asserts.

Changes

File Change
server_args.py Remove triton-only assertion for --lora-use-virtual-experts
lora/utils.py Add req_seg_indptr and req_weight_indices fields to LoRABatchInfo
lora/backend/chunked_backend.py csgmv backend computes and stores per-request segment info (incl. CUDA graph pre-allocation)
lora/layers.py MoE LoRA _get_lora_info uses per-request fields when available

Benchmark Results (Qwen3-30B-A3B, TP=4, GB300, in=1024, out=2048)

BS triton (e2e tps) csgmv (e2e tps) csgmv vs triton
1 201.0/s 205.5/s 102.2%
128 2753.3/s 2800.3/s 101.7%
512 2758.7/s 2807.6/s 101.8%

csgmv matches triton throughput (~1.7-2.2% faster) since MoE LoRA uses the same virtual experts triton kernel regardless of the dense-LoRA backend.

Accuracy

Metric Value
KL(sglang, trainer) 0.00392
KL(orig_sampler, trainer) 0.00424
KL(sglang, orig_sampler) 0.00497

Accuracy verified — sglang csgmv+virtual_experts output closely matches the training reference (lower KL than the original sampler).

Test plan

  • Perf benchmark: triton vs csgmv with virtual experts (BS=1,128,512)
  • Accuracy test: KL divergence against training logprobs
  • CUDA graph capture succeeds with csgmv + virtual experts

Made with Cursor

Replace sgemm.seg_indptr[0] = 0 with sgemm.seg_indptr[0:1].zero_()
to prevent a CPU-GPU sync point that breaks CUDA graph capture. The
scalar assignment triggers a host-to-device copy which is incompatible
with CUDA graph recording, while the slice-based .zero_() stays
entirely on-device.

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 29, 2026 05:58
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables using the csgmv LoRA backend together with MoE “virtual experts” by removing the triton-only restriction and by plumbing per-request routing metadata through the LoRA batching and MoE LoRA info path to avoid incorrect token→adapter mappings (and CUDA graph crashes).

Changes:

  • Remove the --lora-use-virtual-experts--lora-backend triton assertion to allow csgmv + virtual experts.
  • Extend LoRABatchInfo with per-request routing fields (req_seg_indptr, req_weight_indices) and populate them in the chunked/csgmv backend.
  • Update MoE LoRA _get_lora_info() to prefer per-request routing fields when provided; adjust a CUDA-graph-sensitive write in the triton backend.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
python/sglang/srt/server_args.py Removes backend restriction for virtual experts.
python/sglang/srt/lora/utils.py Adds per-request routing fields to LoRABatchInfo.
python/sglang/srt/lora/backend/chunked_backend.py Computes/stores per-request routing info for csgmv (+ CUDA graph prealloc).
python/sglang/srt/lora/layers.py MoE LoRA uses per-request routing fields when available.
python/sglang/srt/lora/backend/triton_backend.py Makes seg_indptr initialization CUDA-graph-safe.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""Build per-request cumulative token boundaries on CPU (pinned)."""
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode():
indptr = torch.arange(bs + 1, dtype=torch.int32, pin_memory=True)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.arange(..., pin_memory=True) is not a valid signature in PyTorch (arange doesn't accept pin_memory), so this will raise a TypeError at runtime. Create the CPU tensor first and then call .pin_memory() (or allocate via torch.empty(..., pin_memory=True) and fill) to keep the intended pinned-memory behavior.

Suggested change
indptr = torch.arange(bs + 1, dtype=torch.int32, pin_memory=True)
indptr = torch.arange(bs + 1, dtype=torch.int32).pin_memory()

Copilot uses AI. Check for mistakes.
Comment on lines 6907 to 6909
if self.lora_use_virtual_experts:
assert self.lora_backend == "triton", (
"--lora-use-virtual-experts requires --lora-backend triton. "
f"Got: {self.lora_backend}"
)
logger.info("Virtual expert computation enabled.")

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the triton-only restriction removed, --lora-use-virtual-experts can now be combined with --lora-backend csgmv, but there doesn't appear to be any automated test coverage for this flag/combination in python/sglang/test (no references found). Adding at least one regression test that boots an MoE model with --lora-backend csgmv --lora-use-virtual-experts (ideally with CUDA graph enabled) would help prevent reintroducing the crash this PR is fixing.

Copilot uses AI. Check for mistakes.
@yushengsu-thu yushengsu-thu enabled auto-merge (squash) April 29, 2026 06:02
@jybsuper jybsuper self-assigned this Apr 29, 2026
Remove the assertion that restricted --lora-use-virtual-experts to
triton-only. The virtual experts path uses its own triton kernel
(merged_experts_fused_moe_lora_add) independent of the dense-LoRA
backend, so csgmv is fully compatible.

The root cause of the CUDA graph crash was that the csgmv backend's
seg_indptr/weight_indices use chunked-segment semantics (grouped by
adapter), while the MoE virtual experts code in
_compute_token_lora_mapping expects per-request semantics. This
mismatch produced garbage token-to-adapter mappings, triggering
device-side asserts during CUDA graph capture.

Fix: add req_seg_indptr and req_weight_indices fields to LoRABatchInfo
so MoE layers always get per-request segment info regardless of which
dense-LoRA backend is active. The csgmv backend now computes and stores
these alongside its internal chunked segments.

Benchmark (Qwen3-30B-A3B, TP=4, GB300, in=1024, out=2048):
  BS=1:   triton 201.0 tps vs csgmv 205.5 tps (102.2%)
  BS=128: triton 2753.3 tps vs csgmv 2800.3 tps (101.7%)
  BS=512: triton 2758.7 tps vs csgmv 2807.6 tps (101.8%)

Accuracy: KL(sglang, trainer)=0.00392 — matches triton baseline.
Made-with: Cursor
@yushengsu-thu yushengsu-thu self-assigned this Apr 29, 2026
RowParallelLinearWithLoRA.forward() previously issued two separate
NCCL all-reduce calls — one for the base output and one for the
LoRA A output. Profiling showed this accounted for ~610ms (76%) of
total LoRA overhead on Qwen3-30B-A3B with TP=4.

Fuse them: cat both tensors, single all-reduce, then split. This
halves the NCCL round-trips per RowParallel layer per forward pass.

Made-with: Cursor
@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yushengsu-thu yushengsu-thu changed the title feat(lora): enable csgmv backend with virtual experts for MoE LoRA (1/n - prefill optimize)feat(lora): enable csgmv backend with virtual experts for MoE LoRA May 2, 2026
@yushengsu-thu yushengsu-thu merged commit b7fefc0 into sgl-project:main May 4, 2026
125 of 163 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants