Skip to content

[lora] Speedup triton backend sgemm calls with better grid#22386

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
klshuster:lora-sgemm-sorted-by-adapter
Apr 15, 2026
Merged

[lora] Speedup triton backend sgemm calls with better grid#22386
Fridge003 merged 5 commits intosgl-project:mainfrom
klshuster:lora-sgemm-sorted-by-adapter

Conversation

@klshuster
Copy link
Copy Markdown
Contributor

Motivation

During multi-LoRA decode, each sequence gets its own segment in the Triton sgemm grid — even when many sequences share the same adapter. This means the grid scales with batch_size instead of num_adapters, launching excessive blocks and wasting GPU cycles.

This PR sorts tokens by adapter and merges per-sequence segments into per-adapter segments, so the kernel grid scales with adapter count instead.

Modifications

  • kernel_utils.py (new): _resolve_token_positions() Triton JIT helper — gathers/scatters through a permutation when sorted, passes through otherwise.
  • All four sgemm kernels (sgemm_lora_a, sgemm_lora_b, qkv_lora_b, gate_up_lora_b): added SORTED_BY_ADAPTER constexpr path with indirection via
    _resolve_token_positions, plus early-exit for empty segments and excess grid blocks.
  • triton_backend.py: compute_sgemm_routing() builds merged per-adapter batch info using argsort + searchsorted; called during decode only. CUDA graph buffers pre-allocated in init_cuda_graph_batch_info().
  • test_sgemm_sorted_by_adapter.py (new): verifies numerical equivalence (bf16, atol=1e-4) between per-sequence and sorted-by-adapter paths for all four kernels, plus mixed-rank and single-adapter edge cases.

Accuracy Tests

Unit test compares original per-sequence output against sorted-by-adapter output across all kernels.

Speed Tests and Profiling

Checklist

Sort tokens by adapter during decode to merge per-sequence segments into
per-adapter segments. This reduces the number of kernel grid blocks and
improves GPU utilization for multi-LoRA batches.

Key changes:
- Add _resolve_token_positions() helper for indirection in all sgemm kernels
- Add SORTED_BY_ADAPTER constexpr and early-exit for empty/OOB segments
- Add compute_sgemm_routing() in TritonLoRABackend to build merged batch info
- Pre-allocate sgemm CUDA graph buffers in init_cuda_graph_batch_info()
- Add test_sgemm_sorted_by_adapter.py verifying correctness across all kernels
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Fridge003
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 9, 2026
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

6 similar comments
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

4 similar comments
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

2 similar comments
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

1 similar comment
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu yushengsu-thu enabled auto-merge (squash) April 15, 2026 08:47
@yushengsu-thu yushengsu-thu disabled auto-merge April 15, 2026 16:47
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yushengsu-thu yushengsu-thu enabled auto-merge (squash) April 15, 2026 18:06
@Fridge003 Fridge003 disabled auto-merge April 15, 2026 20:47
@Fridge003 Fridge003 merged commit 32d9fe5 into sgl-project:main Apr 15, 2026
258 of 342 checks passed
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants