Skip to content

[lora][moe] Virtual experts for LoRA MoE#22122

Merged
yushengsu-thu merged 8 commits intosgl-project:mainfrom
klshuster:kurt/lora-vexp-20260404
Apr 13, 2026
Merged

[lora][moe] Virtual experts for LoRA MoE#22122
yushengsu-thu merged 8 commits intosgl-project:mainfrom
klshuster:kurt/lora-vexp-20260404

Conversation

@klshuster
Copy link
Copy Markdown
Contributor

Motivation

NOTE: depends on via the hooks-based architecture in #21858

This PR introduces virtual expert computation for LoRA+MoE: instead of iterating over each LoRA adapter separately (one alignment + kernel call per adapter), we treat [num_loras, num_experts] weight combinations as a flat [virtual_num_experts] space. This allows LoRA deltas to be computed in a single fused MoE kernel call by reusing the existing invoke_fused_moe_kernel infrastructure, significantly reducing kernel launch overhead for multi-adapter serving.

Enabled via --lora-use-virtual-experts.

Modifications

  • New Triton kernel (lora/triton_ops/virtual_experts.py) that maps (lora_adapter, expert) pairs into virtual expert IDs, flattens LoRA weights from [max_loras, num_experts, ...] to [max_loras * num_experts, ...], and runs fused MoE for LoRA A and B in a single pass.
  • Split-K support for the virtual experts kernel for better GPU utilization.
  • _compute_token_lora_mapping maps each token to its adapter index for the virtual routing.
  • fused_moe_triton_kernels.py: added lora_num_experts_override to allow virtual experts to override the expert count in the align kernel, and fuse_add_to_output / add_output_mask for masked in-place addition (tokens with no LoRA adapter are skipped).
  • --lora-use-virtual-experts flag in server_args.py, propagated through lora_manager.py and layers.py to LoRAInfo.
  • Registered as a custom op via direct_register_custom_op; the routing_cache dict (not supported by torch.library.infer_schema) is handled by a thin wrapper.

Accuracy Tests

All 16 test_lora_moe_runner_virtual_experts parametrized configs pass — each verifies that the virtual experts path produces the same LoRA delta as the per-adapter baseline.

Checklist

Refactor LoRA MoE runner from per-backend subclass (TritonRunnerCoreWithLoRA)
to a generic hook-based injection pattern, decoupling LoRA logic from the
base MoE backend. Add Marlin int4/int8 MoE backend for LoRA.
Refactor LoRA MoE runner from per-backend subclass (TritonRunnerCoreWithLoRA)
to a generic hook-based injection pattern, decoupling LoRA logic from the
base MoE backend. Add Marlin int4/int8 MoE backend for LoRA.
Add virtual expert computation for LoRA+MoE: treats (adapter, expert)
pairs as a flat virtual_num_experts space, allowing LoRA deltas to be
computed by reusing existing fused MoE kernels. Includes split-K support
for better GPU utilization.

Enabled via --lora-use-virtual-experts flag.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for LoRA injection in MoE models using virtual experts, enabling more efficient LoRA integration across different backends including Triton and Marlin. It adds hook-based injection points in the MoE pipeline, updates the runner infrastructure to support these hooks, and includes a new Marlin-based runner core. The changes also introduce a virtual expert routing mechanism to handle LoRA adapters and provide comprehensive tests for correctness.

Comment on lines +804 to +809
assert (
not fuse_sum_all_reduce
), "fuse_add_to_output and fuse_sum_all_reduce are mutually exclusive"
assert (
add_output_mask is not None
), "add_output_mask required when fuse_add_to_output=True"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert add_output_mask is not None is redundant because add_output_mask is already type-hinted as Optional[torch.Tensor] and the function signature defaults it to None. If fuse_add_to_output is True, it is better to handle the missing mask gracefully or raise a more descriptive error if the tensor is required for the kernel logic.

Comment on lines +511 to +512
if lora_info is None or lora_info.max_lora_rank == 0:
return LoRAHooks()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Returning an empty LoRAHooks() object when lora_info is None or rank is 0 is correct, but consider if the caller expects None instead to avoid unnecessary object creation in hot paths.

global _MARLIN_WORKSPACE
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput

assert hooks is not None, "hooks must be provided for MarlinLoraRunnerCore"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert hooks is not None is good, but consider providing a more informative error message that explains why hooks are required for this specific runner core.

@yushengsu-thu yushengsu-thu self-assigned this Apr 4, 2026
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

1 similar comment
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

2 similar comments
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu yushengsu-thu enabled auto-merge (squash) April 13, 2026 08:36
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

3 similar comments
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu yushengsu-thu merged commit ff13dfe into sgl-project:main Apr 13, 2026
575 of 663 checks passed
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
yushengsu-thu added a commit that referenced this pull request Apr 17, 2026
Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
bingxche added a commit that referenced this pull request Apr 18, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants