(1/n - prefill optimize)feat(lora): enable csgmv backend with virtual experts for MoE LoRA#24007
Conversation
Replace sgemm.seg_indptr[0] = 0 with sgemm.seg_indptr[0:1].zero_() to prevent a CPU-GPU sync point that breaks CUDA graph capture. The scalar assignment triggers a host-to-device copy which is incompatible with CUDA graph recording, while the slice-based .zero_() stays entirely on-device. Made-with: Cursor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
Pull request overview
Enables using the csgmv LoRA backend together with MoE “virtual experts” by removing the triton-only restriction and by plumbing per-request routing metadata through the LoRA batching and MoE LoRA info path to avoid incorrect token→adapter mappings (and CUDA graph crashes).
Changes:
- Remove the
--lora-use-virtual-experts⇒--lora-backend tritonassertion to allowcsgmv+ virtual experts. - Extend
LoRABatchInfowith per-request routing fields (req_seg_indptr,req_weight_indices) and populate them in the chunked/csgmv backend. - Update MoE LoRA
_get_lora_info()to prefer per-request routing fields when provided; adjust a CUDA-graph-sensitive write in the triton backend.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
python/sglang/srt/server_args.py |
Removes backend restriction for virtual experts. |
python/sglang/srt/lora/utils.py |
Adds per-request routing fields to LoRABatchInfo. |
python/sglang/srt/lora/backend/chunked_backend.py |
Computes/stores per-request routing info for csgmv (+ CUDA graph prealloc). |
python/sglang/srt/lora/layers.py |
MoE LoRA uses per-request routing fields when available. |
python/sglang/srt/lora/backend/triton_backend.py |
Makes seg_indptr initialization CUDA-graph-safe. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Build per-request cumulative token boundaries on CPU (pinned).""" | ||
| bs = forward_batch.batch_size | ||
| if forward_batch.forward_mode.is_decode(): | ||
| indptr = torch.arange(bs + 1, dtype=torch.int32, pin_memory=True) |
There was a problem hiding this comment.
torch.arange(..., pin_memory=True) is not a valid signature in PyTorch (arange doesn't accept pin_memory), so this will raise a TypeError at runtime. Create the CPU tensor first and then call .pin_memory() (or allocate via torch.empty(..., pin_memory=True) and fill) to keep the intended pinned-memory behavior.
| indptr = torch.arange(bs + 1, dtype=torch.int32, pin_memory=True) | |
| indptr = torch.arange(bs + 1, dtype=torch.int32).pin_memory() |
| if self.lora_use_virtual_experts: | ||
| assert self.lora_backend == "triton", ( | ||
| "--lora-use-virtual-experts requires --lora-backend triton. " | ||
| f"Got: {self.lora_backend}" | ||
| ) | ||
| logger.info("Virtual expert computation enabled.") | ||
|
|
There was a problem hiding this comment.
With the triton-only restriction removed, --lora-use-virtual-experts can now be combined with --lora-backend csgmv, but there doesn't appear to be any automated test coverage for this flag/combination in python/sglang/test (no references found). Adding at least one regression test that boots an MoE model with --lora-backend csgmv --lora-use-virtual-experts (ideally with CUDA graph enabled) would help prevent reintroducing the crash this PR is fixing.
Remove the assertion that restricted --lora-use-virtual-experts to triton-only. The virtual experts path uses its own triton kernel (merged_experts_fused_moe_lora_add) independent of the dense-LoRA backend, so csgmv is fully compatible. The root cause of the CUDA graph crash was that the csgmv backend's seg_indptr/weight_indices use chunked-segment semantics (grouped by adapter), while the MoE virtual experts code in _compute_token_lora_mapping expects per-request semantics. This mismatch produced garbage token-to-adapter mappings, triggering device-side asserts during CUDA graph capture. Fix: add req_seg_indptr and req_weight_indices fields to LoRABatchInfo so MoE layers always get per-request segment info regardless of which dense-LoRA backend is active. The csgmv backend now computes and stores these alongside its internal chunked segments. Benchmark (Qwen3-30B-A3B, TP=4, GB300, in=1024, out=2048): BS=1: triton 201.0 tps vs csgmv 205.5 tps (102.2%) BS=128: triton 2753.3 tps vs csgmv 2800.3 tps (101.7%) BS=512: triton 2758.7 tps vs csgmv 2807.6 tps (101.8%) Accuracy: KL(sglang, trainer)=0.00392 — matches triton baseline. Made-with: Cursor
e96b722 to
5b65ea5
Compare
RowParallelLinearWithLoRA.forward() previously issued two separate NCCL all-reduce calls — one for the base output and one for the LoRA A output. Profiling showed this accounted for ~610ms (76%) of total LoRA overhead on Qwen3-30B-A3B with TP=4. Fuse them: cat both tensors, single all-reduce, then split. This halves the NCCL round-trips per RowParallel layer per forward pass. Made-with: Cursor
This reverts commit cf94518.
|
/rerun-failed-ci |
Summary
--lora-use-virtual-expertsto triton-only. The virtual experts path uses its own triton kernel (merged_experts_fused_moe_lora_add) independent of the dense-LoRA backend, so csgmv is fully compatible.--lora-backend csgmv --lora-use-virtual-expertsby adding per-request segment info (req_seg_indptr,req_weight_indices) toLoRABatchInfo. The csgmv backend's chunked-segment semantics differ from what MoE virtual experts expect (per-request), causing garbage token-to-adapter mappings and device-side asserts.Changes
server_args.py--lora-use-virtual-expertslora/utils.pyreq_seg_indptrandreq_weight_indicesfields toLoRABatchInfolora/backend/chunked_backend.pylora/layers.py_get_lora_infouses per-request fields when availableBenchmark Results (Qwen3-30B-A3B, TP=4, GB300, in=1024, out=2048)
csgmv matches triton throughput (~1.7-2.2% faster) since MoE LoRA uses the same virtual experts triton kernel regardless of the dense-LoRA backend.
Accuracy
Accuracy verified — sglang csgmv+virtual_experts output closely matches the training reference (lower KL than the original sampler).
Test plan
Made with Cursor