Skip to content

[lora][moe] Decoupled LoRA MoE backend with Marlin support#21858

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
klshuster:kurt/lora-virtual-experts-decoupled-20260331
Apr 11, 2026
Merged

[lora][moe] Decoupled LoRA MoE backend with Marlin support#21858
Fridge003 merged 5 commits intosgl-project:mainfrom
klshuster:kurt/lora-virtual-experts-decoupled-20260331

Conversation

@klshuster
Copy link
Copy Markdown
Contributor

@klshuster klshuster commented Apr 1, 2026

Motivation

LoRA adapters applied to MoE layers currently couple LoRA injection logic into each backend-specific runner subclass (e.g., TritonRunnerCoreWithLoRA), making it difficult to add new backends. This PR:

  1. Refactors the LoRA MoE runner architecture from per-backend subclasses to a generic hook-based injection pattern, decoupling LoRA logic from backend-specific code.
  2. Adds a Marlin int4/int8 MoE backend with LoRA support, enabling quantized base model inference with LoRA adapters.

Modifications

Decoupled LoRA/MoE backends:

  • Refactored lora/lora_moe_runners.py from a class-based TritonRunnerCoreWithLoRA (which replaced the runner) to a hooks-based architecture (LoRAHooks, build_lora_hooks). Hooks are injected into the MoE runner's RunnerInput, decoupling LoRA from the base MoE backend.
  • MoeRunner now accepts lora_enabled flag and pre_run_hook/post_run_hook in RunnerInput.
  • compressed_tensors quantization scheme updated to expose both get_triton_quant_info and get_marlin_quant_info, with backend selection via get_moe_runner_backend().
  • New MarlinLoraRunnerCore (lora/lora_moe_runner_marlin.py) enables Marlin wNa16 GEMM for base experts when LoRA is active.
  • Naive CPU-side alignment fallback for small batches where CUDA kernel launch overhead dominates.
  • CUDA graph buffer support for LoRA alignment tensors.

Bug fix:

  • Changed moe-align kernel fallback from <= to < strict inequality (fixes IMA for Qwen3).

Accuracy Tests

All 24 parametrized tests pass (test_lora_moe_runner.py):

  • test_lora_moe_runner_multi_expert (16 configs) — verifies LoRA delta matches between hook-based and baseline implementations
  • test_lora_moe_runner_marlin (8 configs) — verifies Marlin backend matches Triton backend for base expert computation with LoRA

test_marlin_lora_correctness.py — end-to-end correctness comparing Marlin vs Triton LoRA backends.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA support for Mixture-of-Experts (MoE) models by implementing a hook-based injection mechanism. It adds a new Marlin-based MoE runner for LoRA, updates existing runners to support LoRA hooks, and introduces virtual expert computation for improved efficiency. I have identified critical issues regarding dimension calculations in the CUDA buffer allocation logic that could lead to runtime errors, as well as opportunities to improve type safety by replacing 'Any' with 'LoRAHooks' in method signatures.

Comment thread python/sglang/srt/lora/layers.py Outdated
Comment on lines +789 to +790
_, N, _ = qinfo.w13_qweight.shape
hidden_dim = qinfo.w2_qweight.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be an issue with how dimensions are calculated for the Marlin backend, which could lead to incorrect CUDA buffer allocations and potential runtime errors.

  1. N is being assigned the packed dimension 2*N_packed from qinfo.w13_qweight.shape. However, it's used to allocate intermediate_cache1, which requires the full gate_up_dim. It should be multiplied by 16 to get the unpacked dimension.
  2. hidden_dim is being assigned qinfo.w2_qweight.shape[1], which is the packed intermediate dimension (N_packed), not the hidden dimension.

To fix this, you should use the unpacked dimensions for allocation.

Suggested change
_, N, _ = qinfo.w13_qweight.shape
hidden_dim = qinfo.w2_qweight.shape[1]
N = qinfo.w13_qweight.shape[1] * 16
hidden_dim = qinfo.w13_qweight.shape[2]


M, K = hidden_states.shape
E = quant_info.w13_qweight.shape[0]
N = quant_info.w2_qweight.shape[1] * 16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for N (intermediate dimension) seems incorrect. quant_info.w2_qweight has a shape of [E, N_packed, K], where E is number of experts, N_packed is the packed intermediate dimension, and K is the hidden dimension. Therefore, shape[1] corresponds to N_packed. The correct way to get the intermediate dimension N would be to use shape[2] from w13_qweight or w2_qweight and unpack it, or use w13_qweight.shape[1] and unpack it.

Given the shapes, quant_info.w2_qweight.shape[2] would give the hidden dimension, which is also not correct for N. The intermediate dimension N should be derived from w13_qweight's second dimension, unpacked.

Suggested change
N = quant_info.w2_qweight.shape[1] * 16
N = quant_info.w13_qweight.shape[1] // 2 * 16

runner_input: RunnerInput,
quant_info: MoeQuantInfo,
running_state: dict,
hooks: Optional[Any] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and maintainability, consider using the specific LoRAHooks type instead of Any for the hooks parameter. This makes the expected type explicit and improves code clarity.

You would need to add a forward reference import within the TYPE_CHECKING block:

if TYPE_CHECKING:
    from sglang.srt.lora.lora_moe_runners import LoRAHooks
    ...

And then change the signature to:

    def run(
        self,
        runner_input: RunnerInput,
        quant_info: MoeQuantInfo,
        running_state: dict,
        hooks: Optional["LoRAHooks"] = None,
    ) -> RunnerOutput:
Suggested change
hooks: Optional[Any] = None,
hooks: Optional["LoRAHooks"] = None,

runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
hooks: Optional[Any] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and maintainability, you can replace Any with a more specific type hint for the hooks parameter. Using "LoRAHooks" as a forward reference would make the code clearer about the expected object type.

This would involve adding an import inside the TYPE_CHECKING block:

if TYPE_CHECKING:
    from sglang.srt.lora.lora_moe_runners import LoRAHooks
    ...

And updating the method signature accordingly.

Suggested change
hooks: Optional[Any] = None,
hooks: Optional["LoRAHooks"] = None,

runner_input: TritonRunnerInput,
quant_info: TritonMoeQuantInfo,
running_state: dict,
hooks: Optional[Any] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve type safety, consider replacing Any with the specific "LoRAHooks" type hint for the hooks parameter. This will make the expected interface for hooks explicit.

You'll need to add from sglang.srt.lora.lora_moe_runners import LoRAHooks inside the TYPE_CHECKING block at the top of the file.

Suggested change
hooks: Optional[Any] = None,
hooks: Optional["LoRAHooks"] = None,

runner_input: TritonKernelsRunnerInput,
quant_info: TritonKernelsQuantInfo,
running_state: dict,
hooks: Optional[Any] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and code clarity, it's recommended to use the specific "LoRAHooks" type hint instead of Any for the hooks parameter, even if it's not used in this particular implementation. This maintains API consistency and improves readability.

You would need to add from sglang.srt.lora.lora_moe_runners import LoRAHooks to the TYPE_CHECKING block.

Suggested change
hooks: Optional[Any] = None,
hooks: Optional["LoRAHooks"] = None,

@yushengsu-thu yushengsu-thu self-assigned this Apr 1, 2026
@klshuster klshuster force-pushed the kurt/lora-virtual-experts-decoupled-20260331 branch from 8bec729 to bc9a320 Compare April 4, 2026 20:26
@klshuster klshuster changed the title [lora][moe] LoRA virtual experts, decoupled (marlin) moe backend [lora][moe] Decoupled LoRA MoE backend with Marlin support Apr 4, 2026
@klshuster klshuster force-pushed the kurt/lora-virtual-experts-decoupled-20260331 branch 2 times, most recently from 484c581 to 4291c5a Compare April 4, 2026 20:53
Refactor LoRA MoE runner from per-backend subclass (TritonRunnerCoreWithLoRA)
to a generic hook-based injection pattern, decoupling LoRA logic from the
base MoE backend. Add Marlin int4/int8 MoE backend for LoRA.
@klshuster klshuster force-pushed the kurt/lora-virtual-experts-decoupled-20260331 branch from 4291c5a to 220d2d5 Compare April 4, 2026 20:54
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@github-actions github-actions Bot added the run-ci label Apr 9, 2026
@yushengsu-thu yushengsu-thu mentioned this pull request Apr 9, 2026
5 tasks
@yushengsu-thu yushengsu-thu force-pushed the kurt/lora-virtual-experts-decoupled-20260331 branch from 709045c to a3ddc91 Compare April 10, 2026 07:15
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu yushengsu-thu enabled auto-merge (squash) April 11, 2026 06:57
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

Test Suite: stage-b-test-4-gpu-b200

Test File / Suite Test Class Test Method Key Metric Value Threshold Status
test_deepseek_v3_fp4_mtp_small.py TestDeepseekV3FP4MTP test_a_gsm8k GSM8K score 0.970 > 0.94 PASS
test_deepseek_v3_fp4_mtp_small.py TestDeepseekV3FP4MTP test_a_gsm8k avg_spec_accept_length 3.025 > 2.7 PASS
test_deepseek_v3_fp4_mtp_small.py TestDeepseekV3FP4MTP test_bs_1_speed acc_length 2.94 > 2.65 PASS
test_deepseek_v3_fp4_mtp_small.py TestDeepseekV3FP4MTP test_bs_1_speed speed (token/s) 299.24 > 150 PASS
test_flash_attention_4.py TestFlashAttention4 test_gsm8k GSM8K score ~0.91 > 0.89 PASS
test_flash_attention_4.py TestFlashAttention4SpeculativeDecodeTopk test_gsm8k GSM8K score 0.945 > 0.89 PASS
test_flash_attention_4.py TestFlashAttention4SpeculativeDecodeTopk test_gsm8k avg_spec_accept_length 3.388 > 1.5 PASS
test_flash_attention_4.py (jit_kernel) 312 test cases passed 312 PASS

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-H100

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@github-actions
Copy link
Copy Markdown
Contributor

❌ Stage stage-c-test-4-gpu-H100 doesn't support isolated runs yet.

NVIDIA stages:

  • stage-a-test-1-gpu-small
  • stage-a-test-cpu
  • stage-b-test-1-gpu-small
  • stage-b-test-1-gpu-large
  • stage-b-test-2-gpu-large
  • stage-b-test-4-gpu-b200
  • stage-c-test-4-gpu-h100
  • stage-c-test-8-gpu-h200
  • stage-c-test-8-gpu-h20
  • stage-c-test-4-gpu-b200
  • stage-c-test-4-gpu-gb200
  • stage-c-test-deepep-4-gpu-h100
  • stage-c-test-deepep-8-gpu-h200
  • multimodal-gen-test-1-gpu
  • multimodal-gen-test-2-gpu
  • multimodal-gen-component-accuracy-1-gpu
  • multimodal-gen-component-accuracy-2-gpu
  • multimodal-gen-test-1-b200

AMD stages:

  • sgl-kernel-unit-test-amd
  • sgl-kernel-unit-test-2-gpu-amd
  • stage-a-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd-nondeterministic
  • stage-b-test-1-gpu-small-amd-mi35x
  • stage-b-test-1-gpu-large-amd
  • stage-b-test-2-gpu-large-amd
  • multimodal-gen-test-1-gpu-amd
  • multimodal-gen-test-2-gpu-amd
  • stage-c-test-large-8-gpu-amd
  • stage-c-test-large-8-gpu-amd-mi35x

Other stages will be added soon. For now, use /rerun-failed-ci for those stages.

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@yushengsu-thu yushengsu-thu disabled auto-merge April 11, 2026 16:35
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@Fridge003 Fridge003 merged commit 8da1cfb into sgl-project:main Apr 11, 2026
621 of 923 checks passed

assert self.runner_core is not None

def _maybe_build_lora_hooks(_runner_input: Any) -> LoRAHooks:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@klshuster @yushengsu-thu Do not define a function in the forward / run critical path. Clean this up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants