Skip to content

[5/n] Lora support cuda graph#21647

Merged
Fridge003 merged 23 commits intosgl-project:mainfrom
yushengsu-thu:lora-support-cuda-graph
Apr 4, 2026
Merged

[5/n] Lora support cuda graph#21647
Fridge003 merged 23 commits intosgl-project:mainfrom
yushengsu-thu:lora-support-cuda-graph

Conversation

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

Motivation

MoE LoRA inference does not support CUDA graph because the forward path dynamically allocates intermediate tensors (torch.empty()) on every call. CUDA graph requires fixed tensor addresses between capture and replay, so these dynamic allocations break graph replay.

Modifications

  • Pre-allocate MoE CG buffers (Phase 1): Add init_cuda_graph_moe_buffers() to BaseLoRABackend, implemented in TritonLoRABackend and ChunkedSgmvLoRABackend. Called from ModelRunner before init_memory_pool() so memory profiling accounts for the buffers. All MoE LoRA layers share one buffer set since they execute sequentially.
  • Reuse buffers in runner (Phase 2): TritonRunnerCoreWithLoRA.run() slices pre-allocated buffers instead of calling torch.empty() when in CG mode. Covers intermediate caches, alignment tensors, and output buffers.
  • CG-aware LoRA path selection: Force the LoRA kernel path during capture (so kernels are recorded in the graph); skip LoRA path entirely when no adapter is active in the batch.
  • In-place adapter_enabled update: get_lora_info() reuses CG buffer with .zero() + .index_fill_() instead of allocating a new tensor.
  • Kernel dtype fix: Cast operand in fused_moe_lora_kernel to avoid dtype mismatch under mixed precision.

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

When adapter_config.json uses PEFT shorthands like "all-linear" or "all",
SGLang previously required users to explicitly specify --lora-target-modules
on the CLI. This change adds a model-scanning approach that inspects the
loaded base model to discover all LoRA-compatible linear modules automatically.

Changes:
- utils.py: add auto_detect_lora_target_modules() that walks the model graph,
  collects LinearBase/FusedMoE/ParallelLMHead module suffixes, normalizes
  them, and filters to the set supported by get_hidden_dim and init_buffers.
- lora_manager.py: in init_lora_shapes(), resolve "all-linear"/"all" via
  model scanning instead of raising ValueError when CLI target modules are
  not provided. In init_lora_modules(), guard against modules outside decoder
  layers (layer_id is None) to prevent TypeError on non-layer modules.

Made-with: Cursor
…fallbacks

1. layers.py: fix RowParallelLinearWithLoRA bias handling to pass bias
   into quant_method.apply(), matching base RowParallelLinear behavior;
   add interleaved gate/up layout support in FusedMoEWithLoRA for models
   using gemm1_alpha (e.g. gpt-oss-20b)

2. mem_pool.py: zero-initialize all LoRA buffers (torch.empty ->
   torch.zeros) to prevent garbage values in unused slots

3. utils.py: fall back to config.intermediate_size when
   moe_intermediate_size is not available in get_hidden_dim (supports
   GptOss, Mixtral, OLMoE, PhiMoE, GraniteMoE, Grok, etc.); accept
   PEFT shorthand "all-linear" in get_normalized_target_modules; fix
   isinstance order in auto_detect_lora_target_modules so ParallelLMHead
   is checked before VocabParallelEmbedding

4. gpt_oss.py: add should_apply_lora() to GptOssForCausalLM for
   explicit LoRA module filtering, consistent with Qwen3VLMoe

Made-with: Cursor
Regression test comparing SGLang LoRA logprobs against reference
training logprobs (KL threshold 1e-2). Uses 8-GPU H200 suite with
triton MoE runner and shared outer LoRA mode.

Adapter checkpoint: yushengsu/lora-diff-gpt-oss-20b

Made-with: Cursor
Pre-allocate MoE intermediate buffers before memory profiling so
KV cache sizing accounts for them. Reuse fixed buffers during
capture/replay instead of dynamic torch.empty() allocations.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for shared outer LoRAs in MoE models and optimizes MoE LoRA performance within CUDA graphs by pre-allocating intermediate buffers. Key changes include auto-detection of shared LoRA formats, normalization of expert weight names, and kernel-level fixes for type safety and out-of-bounds access. Review feedback recommends reducing code duplication by moving shared buffer initialization logic to the base backend class and suggests restoring type hints for weight slicing functions to improve maintainability.

Comment thread python/sglang/srt/lora/backend/triton_backend.py Outdated
Comment thread python/sglang/srt/lora/layers.py Outdated
Comment thread python/sglang/srt/lora/layers.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends SGLang’s LoRA inference to better support CUDA Graphs for MoE models by eliminating capture-time dynamic allocations, while also adding support for “shared outer LoRA” MoE adapters and improving LoRA target-module inference.

Changes:

  • Add a two-phase LoRA CUDA-graph initialization flow, including pre-allocation of shared MoE intermediate buffers and reuse of those buffers during capture.
  • Add support for MoE “shared outer LoRA” weight formats (expert_dim=1) and broaden LoRA allow-lists for MoE models (incl. experts/embed/lm_head).
  • Add new registered regression tests comparing LoRA logprobs against precomputed reference datasets for several large models.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py New registered logprob regression test using a downloaded reference dataset.
test/registered/lora/test_lora_qwen3_8b_logprob_diff.py New registered logprob regression test using a downloaded reference dataset.
test/registered/lora/test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff.py New registered logprob regression test using a downloaded reference dataset.
test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py New registered logprob regression test using a downloaded reference dataset.
test/manual/lora/test_lora_qwen3_vl.py Updates manual regex expectations for expanded LoRA module coverage.
python/sglang/srt/server_args.py Adds --experts-shared-outer-loras CLI/server arg (override for shared-outer mode).
python/sglang/srt/models/qwen3_vl_moe.py Expands should_apply_lora allow-list to cover MoE experts + embeddings/head.
python/sglang/srt/models/gpt_oss.py Adds should_apply_lora allow-list for MoE experts + embeddings/head.
python/sglang/srt/model_executor/model_runner.py Adds phase-1 LoRA CUDA-graph MoE buffer preallocation during runner init.
python/sglang/srt/model_executor/cuda_graph_runner.py Documents phase-2 LoRA CUDA-graph init (dense batch metadata).
python/sglang/srt/lora/utils.py Adds base-model scanning to resolve PEFT shorthand target modules (auto-detect).
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py Fixes Triton masking for safe loads/stores on non-multiple-of-block shapes.
python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py Casts operands to avoid mixed-precision dtype mismatches in tl.dot.
python/sglang/srt/lora/mem_pool.py Adds shared-outer MoE buffer shaping + loading logic; adjusts expert detection; changes buffer init behavior.
python/sglang/srt/lora/lora_moe_runners.py Makes MoE LoRA runner CUDA-graph capture-aware and reuses preallocated buffers.
python/sglang/srt/lora/lora_manager.py Adds shared-outer mode detection/override; adds MoE CUDA-graph init hooks; resolves PEFT shorthand targets via auto-detect.
python/sglang/srt/lora/lora.py Normalizes MoE expert weight naming and fixes stacking logic for higher-rank tensors.
python/sglang/srt/lora/layers.py Reuses CUDA-graph buffers for adapter-enabled mask; enhances MoE slicing for shared/per-expert formats; aligns RowParallel bias handling with base layer.
python/sglang/srt/lora/backend/triton_backend.py Implements MoE CUDA-graph buffer preallocation for Triton LoRA backend.
python/sglang/srt/lora/backend/chunked_backend.py Implements MoE CUDA-graph buffer preallocation for ChunkedSGMV LoRA backend.
python/sglang/srt/lora/backend/base_backend.py Adds backend hook init_cuda_graph_moe_buffers() (phase 1).
python/sglang/jit_kernel/moe_lora_align.py Allows passing/reusing preallocated work buffers for CUDA-graph friendliness.
Comments suppressed due to low confidence (1)

python/sglang/srt/lora/mem_pool.py:341

  • init_buffers() now uses torch.zeros() for large LoRA A/B buffers (per-layer, per-adapter, and potentially per-expert). This will eagerly touch/zero-fill potentially huge GPU allocations and can significantly slow startup and increase memory bandwidth pressure compared to torch.empty(). If zero-init is only needed for specific safety cases (e.g., cleared slots on eviction / missing weights), consider reverting most allocations to torch.empty() and explicitly zeroing only the required slices when (re)loading adapters.
                        buffer[module_name] = [
                            torch.zeros(
                                get_lora_shape_fn(
                                    module_name, base_model, self.max_lora_rank, idx
                                ),
                                dtype=self.dtype,
                                device=device,
                            )
                            for idx in range(self.num_layer)
                        ]

                    # MoE expert version (4D)
                    moe_key = f"{module_name}_moe"
                    buffer[moe_key] = [
                        torch.zeros(
                            get_lora_shape_fn(
                                moe_key, base_model, self.max_lora_rank, idx
                            ),
                            dtype=self.dtype,
                            device=device,
                        )
                        for idx in range(self.num_layer)
                    ]
                else:
                    # Standard allocation for unambiguous modules
                    buffer[module_name] = [
                        torch.zeros(
                            get_lora_shape_fn(
                                module_name,
                                base_model,
                                self.max_lora_rank,
                                idx,
                            ),
                            dtype=self.dtype,
                            device=device,
                        )
                        for idx in range(self.num_layer)
                    ]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/sglang/srt/lora/mem_pool.py
Comment thread test/registered/lora/test_lora_qwen3_8b_logprob_diff.py
Comment thread python/sglang/srt/lora/layers.py Outdated
Comment thread python/sglang/srt/lora/lora_moe_runners.py Outdated
Comment thread python/sglang/srt/server_args.py
Comment thread test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py
Comment thread test/registered/lora/test_lora_qwen3_8b_logprob_diff.py
Comment thread python/sglang/srt/lora/lora_manager.py
Comment thread python/sglang/srt/lora/mem_pool.py
@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/tag-run-ci-label

@github-actions github-actions Bot added the run-ci label Apr 2, 2026
@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/tag-run-ci-label

@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@Fridge003
Copy link
Copy Markdown
Collaborator

Please fix lint, thx

@Fridge003 Fridge003 reopened this Apr 3, 2026
raw_names.add("down_proj")
elif isinstance(module, ParallelLMHead):
raw_names.add("lm_head")
elif isinstance(module, VocabParallelEmbedding):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation is risky, since ParallelLMHead is a subclass of VocabParallelEmbedding. If a set of lora modules only include lm_head, embed_tokens will also be added mistakenly.

Can be checked later



@dataclass
class LoRAInfo:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had better rename LoRAInfo to MoELoRAInfo. For future PR

@Fridge003 Fridge003 merged commit ff8e47e into sgl-project:main Apr 4, 2026
205 of 221 checks passed
sundar24295s pushed a commit to sundar24295s/sglang that referenced this pull request Apr 4, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Fridge003 pushed a commit that referenced this pull request Apr 7, 2026
xiezhq-hermann pushed a commit to antgroup/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants