[Lora] Lora kimi support#22381
Conversation
When adapter_config.json uses PEFT shorthands like "all-linear" or "all", SGLang previously required users to explicitly specify --lora-target-modules on the CLI. This change adds a model-scanning approach that inspects the loaded base model to discover all LoRA-compatible linear modules automatically. Changes: - utils.py: add auto_detect_lora_target_modules() that walks the model graph, collects LinearBase/FusedMoE/ParallelLMHead module suffixes, normalizes them, and filters to the set supported by get_hidden_dim and init_buffers. - lora_manager.py: in init_lora_shapes(), resolve "all-linear"/"all" via model scanning instead of raising ValueError when CLI target modules are not provided. In init_lora_modules(), guard against modules outside decoder layers (layer_id is None) to prevent TypeError on non-layer modules. Made-with: Cursor
…fallbacks 1. layers.py: fix RowParallelLinearWithLoRA bias handling to pass bias into quant_method.apply(), matching base RowParallelLinear behavior; add interleaved gate/up layout support in FusedMoEWithLoRA for models using gemm1_alpha (e.g. gpt-oss-20b) 2. mem_pool.py: zero-initialize all LoRA buffers (torch.empty -> torch.zeros) to prevent garbage values in unused slots 3. utils.py: fall back to config.intermediate_size when moe_intermediate_size is not available in get_hidden_dim (supports GptOss, Mixtral, OLMoE, PhiMoE, GraniteMoE, Grok, etc.); accept PEFT shorthand "all-linear" in get_normalized_target_modules; fix isinstance order in auto_detect_lora_target_modules so ParallelLMHead is checked before VocabParallelEmbedding 4. gpt_oss.py: add should_apply_lora() to GptOssForCausalLM for explicit LoRA module filtering, consistent with Qwen3VLMoe Made-with: Cursor
Regression test comparing SGLang LoRA logprobs against reference training logprobs (KL threshold 1e-2). Uses 8-GPU H200 suite with triton MoE runner and shared outer LoRA mode. Adapter checkpoint: yushengsu/lora-diff-gpt-oss-20b Made-with: Cursor
Pre-allocate MoE intermediate buffers before memory profiling so KV cache sizing accounts for them. Reuse fixed buffers during capture/replay instead of dynamic torch.empty() allocations.
Extract get_triton_quant_info() into FusedMoEMethodBase and each quant method (Fp8, W8A8Fp8, W8A8Int8, BlockInt8, MoeWNA16, Unquantized) so FusedMoEWithLoRA uses the polymorphic method instead of hardcoding TritonMoeQuantInfo. Enables LoRA on quantized MoE models. Made-with: Cursor
- Add ReplicatedLinearWithLoRA for fused_qkv_a_proj_with_mqa, applying LoRA B via two separate sgemm calls for unequal output partitions (q_a_proj=1536 vs kv_a_proj_with_mqa=576). B slices are precomputed in set_lora_info to avoid per-forward allocation. - Add normalize_fused_qkv_a_proj to fuse q_a_proj + kv_a_proj_with_mqa adapter weights into a single stacked entry. - Add stack_num parameter to run_lora_a_sgemm across all 3 backends. - Fix o_proj hidden dim to use v_head_dim for MLA models. - Fix gate_up_proj/down_proj hidden dim to use per-layer shared expert intermediate size on MoE layers. - Exclude ReplicatedLinear from TP sharding in memory pool allocation. Made-with: Cursor
Made-with: Cursor
- Force triton-compatible MoE weights when LoRA is enabled for compressed tensors quantized models (avoid Marlin path which is incompatible) - Refactor get_triton_quant_info into a reusable method for LoRA MoE runner - Make MoE LoRA runner backend detection robust with hasattr fallback - Handle multi-modal model configs via get_text_config() in LoRAManager - Add CI test for Kimi-K2.5 LoRA logprob accuracy Made-with: Cursor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
Pull request overview
This PR extends SGLang’s LoRA support to cover Kimi-K2.5 / DeepSeek-style MLA fused projections and improves MoE+LoRA compatibility across multiple quantization backends, with new CUDA-registered regression tests validating LoRA logprob accuracy.
Changes:
- Add LoRA handling for DeepSeek MLA fused projection (
fused_qkv_a_proj_with_mqa) including target-module normalization, buffer sizing, weight normalization/fusion, and a newReplicatedLinearWithLoRAwrapper. - Refactor MoE quantization info plumbing via a new
get_triton_quant_info()hook so LoRA MoE runner can consume correct quant metadata across quant methods. - Add registered regression tests for Kimi-K2.5 and DeepSeek-V3.1-Base LoRA logprob accuracy vs reference datasets.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/registered/lora/test_lora_kimi_k25_logprob_diff.py | New CI-registered Kimi-K2.5 LoRA logprob regression test using HF dataset reference. |
| test/registered/lora/test_lora_deepseek_v3_base_logprob_diff.py | New CI-registered DeepSeek-V3.1-Base LoRA logprob regression test with input/output normalization guards. |
| python/sglang/srt/lora/utils.py | Extend target-module normalization and hidden-dim logic for MLA fused projections and MoE shared-expert dims. |
| python/sglang/srt/lora/mem_pool.py | Adjust TP sharding rules for replicated fused projections; improve shared-outer MoE buffer zeroing behavior. |
| python/sglang/srt/lora/lora.py | Add fusion/normalization step to combine q_a + kv_a LoRA weights into fused MLA layout. |
| python/sglang/srt/lora/lora_manager.py | Use get_text_config() for VLM configs; initialize fused MLA LoRA modules with partition boundary metadata. |
| python/sglang/srt/lora/layers.py | Add ReplicatedLinearWithLoRA; refactor MoE LoRA runner init to use get_triton_quant_info() and handle missing runner fields. |
| python/sglang/srt/lora/backend/triton_backend.py | Add stack_num parameter passthrough for LoRA-A SGEMM. |
| python/sglang/srt/lora/backend/torch_backend.py | Add stack_num support wired into num_slices for LoRA-A ops. |
| python/sglang/srt/lora/backend/chunked_backend.py | Add stack_num support wired into chunked shrink op num_slices. |
| python/sglang/srt/layers/quantization/base_config.py | Introduce default FusedMoEMethodBase.get_triton_quant_info() API. |
| python/sglang/srt/layers/quantization/w8a8_int8.py | Factor Triton quant-info construction into get_triton_quant_info(). |
| python/sglang/srt/layers/quantization/w8a8_fp8.py | Factor Triton quant-info construction into get_triton_quant_info(). |
| python/sglang/srt/layers/quantization/unquant.py | Use get_triton_quant_info() in Triton MoE path (XPU). |
| python/sglang/srt/layers/quantization/moe_wna16.py | Add get_triton_quant_info() and reuse it in apply(). |
| python/sglang/srt/layers/quantization/fp8.py | Add get_triton_quant_info() and reuse it in Triton MoE path. |
| python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py | Add get_triton_quant_info() and reuse it in apply_weights(). |
| python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py | When LoRA is enabled, force Triton-compatible WNA16 MoE scheme; expose get_triton_quant_info() passthrough. |
| python/sglang/srt/layers/quantization/blockwise_int8.py | Add get_triton_quant_info() and reuse it in apply(). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| kv_a_weight = ( | ||
| weights[kv_a_name] | ||
| if kv_a_name in weights | ||
| else torch.zeros_like(weights[q_a_name]) | ||
| ) | ||
|
|
||
| weights[fused_name] = torch.cat((weights[q_a_name], kv_a_weight), dim=0) |
There was a problem hiding this comment.
In normalize_fused_qkv_a_proj, the fallback for missing kv_a_proj_with_mqa uses torch.zeros_like(weights[q_a_name]). This is only safe for LoRA A (where q/kv LoRA-A shapes match), but for LoRA B the q_a and kv_a output dims differ, so zeros_like will produce the wrong shape and the subsequent torch.cat will create a fused B with an incorrect output dimension (leading to buffer shape mismatches or silent misalignment). Consider handling lora_A vs lora_B separately: for lora_B, either (a) require kv_a_name to exist and raise a clear error if missing, or (b) allocate zeros with the correct kv output dim derived from base_hf_config (kv_lora_rank + qk_rope_head_dim) and the adapter rank.
| qm = base_layer.quant_method | ||
| if hasattr(qm, "runner") and qm.runner is not None: | ||
| runner_backend = qm.runner.runner_backend | ||
| else: | ||
| runner_backend = MoeRunnerBackend.TRITON | ||
|
|
||
| self._lora_runner = MoeRunner( | ||
| base_layer.quant_method.runner.runner_backend, | ||
| runner_backend, | ||
| base_layer.moe_runner_config, | ||
| lora_enabled=True, | ||
| ) | ||
|
|
||
| # Pre-compute quant info for efficiency (weights don't change during inference) | ||
| self._quant_info = TritonMoeQuantInfo( | ||
| w13_weight=base_layer.w13_weight, | ||
| w2_weight=base_layer.w2_weight, | ||
| b13=getattr(base_layer, "w13_weight_bias", None), | ||
| b2=getattr(base_layer, "w2_weight_bias", None), | ||
| ) | ||
| self._quant_info = base_layer.quant_method.get_triton_quant_info(base_layer) | ||
|
|
There was a problem hiding this comment.
FusedMoEWithLoRA currently falls back to MoeRunnerBackend.TRITON when the quant method has no runner, and then always builds _quant_info via quant_method.get_triton_quant_info(). For quant methods whose MoE weights are not Triton-compatible (e.g., BitsAndBytesMoEMethod stores packed uint8 weights and doesn’t create a runner), this change can silently route execution into the Triton MoE runner with an invalid TritonMoeQuantInfo, likely producing incorrect results or runtime failures. Please add an explicit compatibility check here (e.g., require qm to expose a supported runner backend or a dedicated flag indicating triton-kernel compatibility) and raise a clear error when LoRA+MoE is requested with an unsupported quant method.
|
/tag-run-ci-label |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |

Motivation
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci