Skip to content

[VLM][LLM] Optimize fused_moe triton kernel tma#18782

Merged
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:optimize_fused_moe_tma
Feb 14, 2026
Merged

[VLM][LLM] Optimize fused_moe triton kernel tma#18782
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:optimize_fused_moe_tma

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Feb 13, 2026

Motivation

During profiling Qwen3-VL-30B-A3B, we found that fused_moe had significant GPU bubbles in prefill phase. One reason is CUDA Graph is not enabled by default, many small kernels launch introduces overhead. Some other reason can be scheduler's CPU participates some small ops' calculation, which makes kernel launch delayed. This PR focus on fused_moe optimzation.

image

Summary

This PR reduces host-side overhead in the Triton TMA path for fused_moe by:

  • setting the Triton allocator only once per process,
  • caching the weight TensorDescriptor (b_desc) with a bounded LRU to avoid re-creating it on every call.

Key improvement details

  1. Set triton.set_allocator(...) once (process-wide)
    Previously, triton.set_allocator(alloc_fn) was executed on every invoke_fused_moe_kernel() call whenever a_use_tma or b_use_tma was enabled.
    This PR introduces a set-once guard so the allocator is configured only once per process (global within the process), reducing repeated Python overhead on the hot path.

    Note: this only affects the TMA path; non-TMA execution is unchanged.

  2. Cache b_desc (weight TensorDescriptor) with a bounded LRU
    B is a persistent weight tensor. Previously we created a new TensorDescriptor(B, ...) every invocation when b_use_tma=True.
    This PR adds an LRU cache keyed by (weight storage identity / layout, tile shape), i.e. (B.data_ptr, B.shape, B.stride, B.dtype, BLOCK_SIZE_N, BLOCK_SIZE_K) to reuse descriptors across calls.
    The cache is bounded by SGLANG_TMA_BDESC_CACHE_MAX (default: 64), so it does not grow unbounded. Setting it to 0 disables caching.

Design considerations

We intentionally do not cache a_desc:

A is activation and typically changes per call (shape/stride/layout), so cache hit rate is low.
Caching activation descriptors can also inadvertently extend the lifetime of temporary tensors.
Locks are not used on the cache to avoid overhead on the hot path. Under Python’s GIL this is safe for correctness; in the worst case under concurrency we may create duplicate descriptors, but results remain correct.

Results

No quality regression observed (gsm8k unchanged).

TTFT improvement: ~5–10% in our serving benchmark.

Server:

root@c7e9bb6a6789:/sgl-workspace/sglang# FLASHINFER_DISABLE_VERSION_CHECK=1 python -m sglang.launch_server --model-path Qwen/Qwen3-VL-30B-A3B-Instruct --host 0.0.0.0 --port 30000 --trust-remote-code --tp-size 4 --enable-cache-report --max-running-requests 128 --mem-fraction-static 0.7 --chunked-prefill-size 8192 --attention-backend fa3 --mm-attention-backend fa3 --log-level debug --log-requests --log-requests-level 1

Client:

python3 -m sglang.bench_serving \
  --backend sglang-oai-chat \
  --dataset-name image \
  --num-prompts 256 \
  --apply-chat-template \
  --random-input-len 128 \
  --random-output-len 32 \
  --image-resolution 560x560 \
  --image-format jpeg \
  --image-count 1 \
  --image-content random \
  --random-range-ratio 0.1 \
  --port 30000 \
  --max-concurrency 32

Baseline:
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 32        
Successful requests:                     256       
Benchmark duration (s):                  19.55     
Total input tokens:                      104007    
Total input text tokens:                 20551     
Total input vision tokens:               83456     
Total generated tokens:                  4541      
Total generated tokens (retokenized):    4533      
Request throughput (req/s):              13.09     
Input token throughput (tok/s):          5319.10   
Output token throughput (tok/s):         232.23    
Peak output token throughput (tok/s):    448.00    
Peak concurrent requests:                54        
Total token throughput (tok/s):          5551.33   
Concurrency:                             31.68     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2419.69   
Median E2E Latency (ms):                 2417.90   
P90 E2E Latency (ms):                    3745.81   
P99 E2E Latency (ms):                    5511.77   
---------------Time to First Token----------------
Mean TTFT (ms):                          951.86    
Median TTFT (ms):                        786.24    
P99 TTFT (ms):                           2403.74   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          77.12     
Median TPOT (ms):                        87.69     
P99 TPOT (ms):                           140.49    
---------------Inter-Token Latency----------------
Mean ITL (ms):                           87.86     
Median ITL (ms):                         7.86      
P95 ITL (ms):                            669.69    
P99 ITL (ms):                            814.92    
Max ITL (ms):                            1368.38   
==================================================

PR:
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 32        
Successful requests:                     256       
Benchmark duration (s):                  17.85     
Total input tokens:                      104035    
Total input text tokens:                 20579     
Total input vision tokens:               83456     
Total generated tokens:                  4541      
Total generated tokens (retokenized):    4537      
Request throughput (req/s):              14.34     
Input token throughput (tok/s):          5827.77   
Output token throughput (tok/s):         254.38    
Peak output token throughput (tok/s):    461.00    
Peak concurrent requests:                54        
Total token throughput (tok/s):          6082.15   
Concurrency:                             31.76     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2214.52   
Median E2E Latency (ms):                 2204.39   
P90 E2E Latency (ms):                    3490.61   
P99 E2E Latency (ms):                    4691.24   
---------------Time to First Token----------------
Mean TTFT (ms):                          858.76    
Median TTFT (ms):                        697.06    
P99 TTFT (ms):                           2183.01   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          71.99     
Median TPOT (ms):                        80.33     
P99 TPOT (ms):                           126.40    
---------------Inter-Token Latency----------------
Mean ITL (ms):                           81.09     
Median ITL (ms):                         9.05      
P95 ITL (ms):                            533.15    
P99 ITL (ms):                            758.23    
Max ITL (ms):                            1581.80   
==================================================

There are still some other parts needs to be emphasized such as jit kernel will always be compiled on each layer. We will keep on optimizing the VL MoE model's performance.

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yuan-luo yuan-luo force-pushed the optimize_fused_moe_tma branch from 8b58110 to e6b489e Compare February 13, 2026 07:05
@yuan-luo yuan-luo changed the title [VLM][LLM] Optimize fused_moe TMA [WIP][VLM][LLM] Optimize fused_moe TMA Feb 13, 2026
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo yuan-luo changed the title [WIP][VLM][LLM] Optimize fused_moe TMA [VLM][LLM] Optimize fused_moe triton kernel tma Feb 13, 2026
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py Outdated
Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job.

@yuan-luo yuan-luo force-pushed the optimize_fused_moe_tma branch from e6b489e to 4cdc472 Compare February 13, 2026 08:37
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

7 similar comments
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@BBuf BBuf merged commit fa0ef6e into sgl-project:main Feb 14, 2026
378 of 414 checks passed
@yuan-luo yuan-luo deleted the optimize_fused_moe_tma branch February 14, 2026 15:44
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants