Skip to content

Support Qwen3 MoE context parallel#18233

Merged
Fridge003 merged 13 commits intosgl-project:mainfrom
Shunkangz:support_qwen3
Mar 22, 2026
Merged

Support Qwen3 MoE context parallel#18233
Fridge003 merged 13 commits intosgl-project:mainfrom
Shunkangz:support_qwen3

Conversation

@Shunkangz
Copy link
Copy Markdown
Contributor

@Shunkangz Shunkangz commented Feb 4, 2026

Motivation

Context parallelism is essential in long context LLM inference. It splits a long input sequence across multiple GPUs so attention can be computed in parallel, drastically reducing latency, which enables practical million-token context windows.
In this PR, we add support for the context-parallel form of Qwen3-MoE. With this update, context parallelism can now be enabled during the prefill phase under various parallel configurations.
As for attention layer, users can use CP, TP and CP + TP.
As for moe layer, users can use CP, TP, CP + TP, CP + EP.

Modifications

In this implementation, we allocate a full-sequence KV cache on each CP rank. This approach simplifies both KV cache management and reuse by replicating the KV cache across all CP ranks. Before performing the attention computation, we use an allgather operation to collect the KV cache from all ranks, and then apply the FlashAttention backend for the calculation.

Accuracy Tests

Command
sglang serve --model-path /home/scratch.trt_llm_data/llm-models/Qwen3/Qwen3-30B-A3B-FP8/ --trust-remote-code --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 64}' --tp=4 --moe-dp-size=2 --ep-size=2 --attn-cp-size=2 --enable-prefill-context-parallel --cuda-graph-max-bs=32 --max-running-requests=32
Results
Accuracy: 0.785 Invalid: 0.000 Latency: 43.704 s Output throughput: 1027.630 token/s

H200 Qwen3-235B
Screenshot 2026-03-10 at 16 05 11

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added documentation Improvements or additions to documentation deepseek labels Feb 4, 2026
@Shunkangz Shunkangz force-pushed the support_qwen3 branch 4 times, most recently from 50c7181 to 84572ca Compare February 24, 2026 05:18
@Shunkangz Shunkangz marked this pull request as ready for review February 24, 2026 05:19
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut registered/spec/eagle/test_eagle_dp_attention.py

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered /rerun-ut on 4-gpu-h100 runner:

cd test/ && python3 registered/spec/eagle/test_eagle_dp_attention.py

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@Fridge003 Fridge003 merged commit bb737d7 into sgl-project:main Mar 22, 2026
35 of 63 checks passed
OrangeRedeng pushed a commit to OrangeRedeng/sglang that referenced this pull request Mar 22, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
@yuan-luo
Copy link
Copy Markdown
Collaborator

This PR introduced test_qwen3_30b.py which made CI broken. @Fridge003 @alisonshao @Kangyan-Zhou

root@e1448ef40573:/sgl-workspace/sglang_dev# python ./test/registered/4-gpu-models/test_qwen3_30b.py


command=sglang serve --model-path Qwen/Qwen3-30B-A3B-FP8 --tp-size 4 --moe-dp-size 2 --ep-size 2 --attn-cp-size 2 --enable-prefill-context-parallel --cuda-graph-max-bs 32 --max-running-requests 32 --trust-remote-code --disable-piecewise-cuda-graph --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 64}' --device cuda --host 127.0.0.1 --port 21000
CI_OFFLINE: Launching server HF_HUB_OFFLINE=0 model=Qwen/Qwen3-30B-A3B-FP8
usage: sglang serve [-h] --model-path MODEL_PATH [--tokenizer-path TOKENIZER_PATH] [--tokenizer-mode {auto,slow}] [--tokenizer-worker-num TOKENIZER_WORKER_NUM] [--skip-tokenizer-init]
                    [--load-format {auto,pt,safetensors,npcache,dummy,sharded_state,gguf,bitsandbytes,mistral,layered,flash_rl,remote,remote_instance,fastsafetensors,private}]
                    [--model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG] [--trust-remote-code] [--context-length CONTEXT_LENGTH] [--is-embedding] [--enable-multimodal] [--revision REVISION]
                    [--model-impl MODEL_IMPL] [--host HOST] [--port PORT] [--fastapi-root-path FASTAPI_ROOT_PATH] [--grpc-mode] [--skip-server-warmup] [--warmups WARMUPS] [--nccl-port NCCL_PORT]
                    [--checkpoint-engine-wait-weights-before-ready] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--ssl-ca-certs SSL_CA_CERTS] [--ssl-keyfile-password SSL_KEYFILE_PASSWORD]
                    [--enable-ssl-refresh] [--dtype {auto,half,float16,bfloat16,float,float32}]
                    [--quantization {awq,fp8,mxfp8,gptq,marlin,gptq_marlin,awq_marlin,bitsandbytes,gguf,modelopt,modelopt_fp8,modelopt_fp4,modelopt_mixed,petit_nvfp4,w8a8_int8,w8a8_fp8,moe_wna16,qoq,w4afp8,mxfp4,auto-round,compressed-tensors,modelslim,quark_int4fp8_moe}]
                    [--quantization-param-path QUANTIZATION_PARAM_PATH] [--kv-cache-dtype {auto,fp8_e5m2,fp8_e4m3,bf16,bfloat16,fp4_e2m1}] [--enable-fp32-lm-head] [--modelopt-quant MODELOPT_QUANT]
                    [--modelopt-checkpoint-restore-path MODELOPT_CHECKPOINT_RESTORE_PATH] [--modelopt-checkpoint-save-path MODELOPT_CHECKPOINT_SAVE_PATH] [--modelopt-export-path MODELOPT_EXPORT_PATH]
                    [--quantize-and-serve] [--rl-quant-profile RL_QUANT_PROFILE] [--mem-fraction-static MEM_FRACTION_STATIC] [--max-running-requests MAX_RUNNING_REQUESTS]
                    [--max-queued-requests MAX_QUEUED_REQUESTS] [--max-total-tokens MAX_TOTAL_TOKENS] [--chunked-prefill-size CHUNKED_PREFILL_SIZE] [--prefill-max-requests PREFILL_MAX_REQUESTS]
                    [--enable-dynamic-chunking] [--max-prefill-tokens MAX_PREFILL_TOKENS] [--schedule-policy {lpm,random,fcfs,dfs-weight,lof,priority,routing-key}] [--enable-priority-scheduling]
                    [--disable-priority-preemption] [--default-priority-value DEFAULT_PRIORITY_VALUE] [--abort-on-priority-when-disabled] [--schedule-low-priority-values-first]
                    [--priority-scheduling-preemption-threshold PRIORITY_SCHEDULING_PREEMPTION_THRESHOLD] [--schedule-conservativeness SCHEDULE_CONSERVATIVENESS] [--page-size PAGE_SIZE]
                    [--hybrid-kvcache-ratio] [--swa-full-tokens-ratio SWA_FULL_TOKENS_RATIO] [--disable-hybrid-swa-memory] [--radix-eviction-policy {lru,lfu,slru}] [--enable-prefill-delayer]
                    [--prefill-delayer-max-delay-passes PREFILL_DELAYER_MAX_DELAY_PASSES] [--prefill-delayer-token-usage-low-watermark PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK]
                    [--prefill-delayer-forward-passes-buckets PREFILL_DELAYER_FORWARD_PASSES_BUCKETS [PREFILL_DELAYER_FORWARD_PASSES_BUCKETS ...]]
                    [--prefill-delayer-wait-seconds-buckets PREFILL_DELAYER_WAIT_SECONDS_BUCKETS [PREFILL_DELAYER_WAIT_SECONDS_BUCKETS ...]] [--device DEVICE] [--tensor-parallel-size TENSOR_PARALLEL_SIZE]
                    [--attention-context-parallel-size ATTENTION_CONTEXT_PARALLEL_SIZE] [--moe-data-parallel-size MOE_DATA_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
                    [--pp-max-micro-batch-size PP_MAX_MICRO_BATCH_SIZE] [--pp-async-batch-depth PP_ASYNC_BATCH_DEPTH] [--stream-interval STREAM_INTERVAL] [--incremental-streaming-output] [--stream-output]
                    [--enable-streaming-session] [--random-seed RANDOM_SEED] [--constrained-json-whitespace-pattern CONSTRAINED_JSON_WHITESPACE_PATTERN] [--constrained-json-disable-any-whitespace]
                    [--watchdog-timeout WATCHDOG_TIMEOUT] [--soft-watchdog-timeout SOFT_WATCHDOG_TIMEOUT] [--dist-timeout DIST_TIMEOUT] [--download-dir DOWNLOAD_DIR] [--model-checksum [MODEL_CHECKSUM]]
                    [--base-gpu-id BASE_GPU_ID] [--gpu-id-step GPU_ID_STEP] [--sleep-on-idle] [--use-ray] [--custom-sigquit-handler CUSTOM_SIGQUIT_HANDLER] [--log-level LOG_LEVEL]
                    [--log-level-http LOG_LEVEL_HTTP] [--log-requests] [--log-requests-level {0,1,2,3}] [--log-requests-format {text,json}]
                    [--log-requests-target LOG_REQUESTS_TARGET [LOG_REQUESTS_TARGET ...]] [--uvicorn-access-log-exclude-prefixes [UVICORN_ACCESS_LOG_EXCLUDE_PREFIXES ...]]
                    [--crash-dump-folder CRASH_DUMP_FOLDER] [--show-time-cost] [--enable-metrics] [--enable-metrics-for-all-schedulers]
                    [--tokenizer-metrics-custom-labels-header TOKENIZER_METRICS_CUSTOM_LABELS_HEADER]
                    [--tokenizer-metrics-allowed-custom-labels TOKENIZER_METRICS_ALLOWED_CUSTOM_LABELS [TOKENIZER_METRICS_ALLOWED_CUSTOM_LABELS ...]] [--extra-metric-labels EXTRA_METRIC_LABELS]
                    [--bucket-time-to-first-token BUCKET_TIME_TO_FIRST_TOKEN [BUCKET_TIME_TO_FIRST_TOKEN ...]] [--bucket-inter-token-latency BUCKET_INTER_TOKEN_LATENCY [BUCKET_INTER_TOKEN_LATENCY ...]]
                    [--bucket-e2e-request-latency BUCKET_E2E_REQUEST_LATENCY [BUCKET_E2E_REQUEST_LATENCY ...]] [--collect-tokens-histogram]
                    [--prompt-tokens-buckets PROMPT_TOKENS_BUCKETS [PROMPT_TOKENS_BUCKETS ...]] [--generation-tokens-buckets GENERATION_TOKENS_BUCKETS [GENERATION_TOKENS_BUCKETS ...]]
                    [--gc-warning-threshold-secs GC_WARNING_THRESHOLD_SECS] [--decode-log-interval DECODE_LOG_INTERVAL] [--enable-request-time-stats-logging] [--kv-events-config KV_EVENTS_CONFIG]
                    [--enable-trace] [--otlp-traces-endpoint OTLP_TRACES_ENDPOINT] [--export-metrics-to-file] [--export-metrics-to-file-dir EXPORT_METRICS_TO_FILE_DIR] [--api-key API_KEY]
                    [--admin-api-key ADMIN_API_KEY] [--served-model-name SERVED_MODEL_NAME] [--weight-version WEIGHT_VERSION] [--chat-template CHAT_TEMPLATE] [--hf-chat-template-name HF_CHAT_TEMPLATE_NAME]
                    [--completion-template COMPLETION_TEMPLATE] [--file-storage-path FILE_STORAGE_PATH] [--enable-cache-report]
                    [--reasoning-parser {deepseek-r1,deepseek-v3,glm45,gpt-oss,kimi,kimi_k2,qwen3,qwen3-thinking,minimax,minimax-append-think,step3,step3p5,mistral,nemotron_3,interns1}]
                    [--tool-call-parser {deepseekv3,deepseekv31,deepseekv32,glm,glm45,glm47,gpt-oss,kimi_k2,lfm2,llama3,mimo,mistral,pythonic,qwen,qwen25,qwen3_coder,step3,step3p5,minimax-m2,trinity,interns1,hermes,gigachat3}]
                    [--tool-server TOOL_SERVER] [--sampling-defaults {openai,model}] [--data-parallel-size DATA_PARALLEL_SIZE]
                    [--load-balance-method {auto,round_robin,follow_bootstrap_room,total_requests,total_tokens}] [--prefill-round-robin-balance] [--dist-init-addr DIST_INIT_ADDR] [--nnodes NNODES]
                    [--node-rank NODE_RANK] [--json-model-override-args JSON_MODEL_OVERRIDE_ARGS] [--preferred-sampling-params PREFERRED_SAMPLING_PARAMS] [--enable-lora] [--enable-lora-overlap-loading]
                    [--max-lora-rank MAX_LORA_RANK] [--lora-target-modules [{q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,qkv_proj,gate_up_proj,embed_tokens,lm_head,all} ...]]
                    [--lora-paths [LORA_PATHS ...]] [--max-loras-per-batch MAX_LORAS_PER_BATCH] [--max-loaded-loras MAX_LOADED_LORAS] [--lora-eviction-policy {lru,fifo}]
                    [--lora-backend {triton,csgmv,ascend,torch_native}] [--max-lora-chunk-size {16,32,64,128}]
                    [--attention-backend {triton,torch_native,flex_attention,nsa,cutlass_mla,fa3,fa4,flashinfer,flashmla,trtllm_mla,trtllm_mha,dual_chunk_flash_attn,aiter,wave,intel_amx,ascend,intel_xpu}]
                    [--prefill-attention-backend {triton,torch_native,flex_attention,nsa,cutlass_mla,fa3,fa4,flashinfer,flashmla,trtllm_mla,trtllm_mha,dual_chunk_flash_attn,aiter,wave,intel_amx,ascend,intel_xpu}]
                    [--decode-attention-backend {triton,torch_native,flex_attention,nsa,cutlass_mla,fa3,fa4,flashinfer,flashmla,trtllm_mla,trtllm_mha,dual_chunk_flash_attn,aiter,wave,intel_amx,ascend,intel_xpu}]
                    [--sampling-backend {pytorch,ascend,flashinfer}] [--grammar-backend {xgrammar,outlines,llguidance,none}]
                    [--mm-attention-backend {sdpa,fa3,fa4,triton_attn,ascend_attn,aiter_attn,flashinfer_cudnn}] [--nsa-prefill-backend {flashmla_sparse,flashmla_kv,flashmla_auto,fa3,tilelang,aiter,trtllm}]
                    [--nsa-decode-backend {flashmla_sparse,flashmla_kv,flashmla_auto,fa3,tilelang,aiter,trtllm}]
                    [--fp8-gemm-backend {auto,deep_gemm,flashinfer_trtllm,flashinfer_cutlass,flashinfer_deepgemm,cutlass,triton,aiter}]
                    [--fp4-gemm-backend {auto,flashinfer_cudnn,flashinfer_cutlass,flashinfer_trtllm}] [--disable-flashinfer-autotune] [--speculative-algorithm {EAGLE,EAGLE3,NEXTN,STANDALONE,NGRAM}]
                    [--speculative-draft-model-path SPECULATIVE_DRAFT_MODEL_PATH] [--speculative-draft-model-revision SPECULATIVE_DRAFT_MODEL_REVISION]
                    [--speculative-draft-load-format {auto,pt,safetensors,npcache,dummy,sharded_state,gguf,bitsandbytes,mistral,layered,flash_rl,remote,remote_instance,fastsafetensors,private}]
                    [--speculative-num-steps SPECULATIVE_NUM_STEPS] [--speculative-eagle-topk SPECULATIVE_EAGLE_TOPK] [--speculative-num-draft-tokens SPECULATIVE_NUM_DRAFT_TOKENS]
                    [--speculative-accept-threshold-single SPECULATIVE_ACCEPT_THRESHOLD_SINGLE] [--speculative-accept-threshold-acc SPECULATIVE_ACCEPT_THRESHOLD_ACC]
                    [--speculative-token-map SPECULATIVE_TOKEN_MAP] [--speculative-attention-mode {prefill,decode}] [--speculative-draft-attention-backend SPECULATIVE_DRAFT_ATTENTION_BACKEND]
                    [--speculative-moe-runner-backend {auto,deep_gemm,triton,triton_kernel,flashinfer_trtllm,flashinfer_trtllm_routed,flashinfer_cutlass,flashinfer_mxfp4,flashinfer_cutedsl,cutlass}]
                    [--speculative-moe-a2a-backend {none,deepep,mooncake,nixl,mori,ascend_fuseep,flashinfer}]
                    [--speculative-draft-model-quantization {awq,fp8,mxfp8,gptq,marlin,gptq_marlin,awq_marlin,bitsandbytes,gguf,modelopt,modelopt_fp8,modelopt_fp4,modelopt_mixed,petit_nvfp4,w8a8_int8,w8a8_fp8,moe_wna16,qoq,w4afp8,mxfp4,auto-round,compressed-tensors,modelslim,quark_int4fp8_moe,unquant}]
                    [--speculative-ngram-min-match-window-size SPECULATIVE_NGRAM_MIN_MATCH_WINDOW_SIZE] [--speculative-ngram-max-match-window-size SPECULATIVE_NGRAM_MAX_MATCH_WINDOW_SIZE]
                    [--speculative-ngram-min-bfs-breadth SPECULATIVE_NGRAM_MIN_BFS_BREADTH] [--speculative-ngram-max-bfs-breadth SPECULATIVE_NGRAM_MAX_BFS_BREADTH]
                    [--speculative-ngram-match-type {BFS,PROB}] [--speculative-ngram-branch-length SPECULATIVE_NGRAM_BRANCH_LENGTH] [--speculative-ngram-capacity SPECULATIVE_NGRAM_CAPACITY]
                    [--enable-multi-layer-eagle] [--expert-parallel-size EXPERT_PARALLEL_SIZE] [--moe-a2a-backend {none,deepep,mooncake,nixl,mori,ascend_fuseep,flashinfer}]
                    [--moe-runner-backend {auto,deep_gemm,triton,triton_kernel,flashinfer_trtllm,flashinfer_trtllm_routed,flashinfer_cutlass,flashinfer_mxfp4,flashinfer_cutedsl,cutlass}]
                    [--flashinfer-mxfp4-moe-precision {default,bf16}] [--enable-flashinfer-allreduce-fusion] [--enable-aiter-allreduce-fusion] [--deepep-mode {normal,low_latency,auto}]
                    [--ep-num-redundant-experts EP_NUM_REDUNDANT_EXPERTS] [--ep-dispatch-algorithm EP_DISPATCH_ALGORITHM] [--init-expert-location INIT_EXPERT_LOCATION] [--enable-eplb]
                    [--eplb-algorithm EPLB_ALGORITHM] [--eplb-rebalance-num-iterations EPLB_REBALANCE_NUM_ITERATIONS] [--eplb-rebalance-layers-per-chunk EPLB_REBALANCE_LAYERS_PER_CHUNK]
                    [--eplb-min-rebalancing-utilization-threshold EPLB_MIN_REBALANCING_UTILIZATION_THRESHOLD] [--expert-distribution-recorder-mode EXPERT_DISTRIBUTION_RECORDER_MODE]
                    [--expert-distribution-recorder-buffer-size EXPERT_DISTRIBUTION_RECORDER_BUFFER_SIZE] [--enable-expert-distribution-metrics] [--deepep-config DEEPEP_CONFIG]
                    [--moe-dense-tp-size MOE_DENSE_TP_SIZE] [--elastic-ep-backend {none,mooncake,nixl}] [--enable-elastic-expert-backup] [--mooncake-ib-device MOONCAKE_IB_DEVICE]
                    [--max-mamba-cache-size MAX_MAMBA_CACHE_SIZE] [--mamba-ssm-dtype {float32,bfloat16,float16}] [--mamba-full-memory-ratio MAMBA_FULL_MEMORY_RATIO]
                    [--mamba-scheduler-strategy {auto,no_buffer,extra_buffer}] [--mamba-track-interval MAMBA_TRACK_INTERVAL] [--mamba-backend {triton,flashinfer}]
                    [--linear-attn-backend {triton,cutedsl,flashinfer}] [--linear-attn-decode-backend {triton,cutedsl,flashinfer}] [--linear-attn-prefill-backend {triton,cutedsl,flashinfer}]
                    [--enable-hierarchical-cache] [--hicache-ratio HICACHE_RATIO] [--hicache-size HICACHE_SIZE] [--hicache-write-policy {write_back,write_through,write_through_selective}]
                    [--hicache-io-backend {direct,kernel,kernel_ascend}] [--hicache-mem-layout {layer_first,page_first,page_first_direct,page_first_kv_split,page_head}] [--disable-hicache-numa-detect]
                    [--hicache-storage-backend {file,mooncake,hf3fs,nixl,aibrix,dynamic,eic}] [--hicache-storage-prefetch-policy {best_effort,wait_complete,timeout}]
                    [--hicache-storage-backend-extra-config HICACHE_STORAGE_BACKEND_EXTRA_CONFIG] [--hierarchical-sparse-attention-extra-config HIERARCHICAL_SPARSE_ATTENTION_EXTRA_CONFIG]
                    [--enable-lmcache] [--kt-weight-path KT_WEIGHT_PATH] [--kt-method KT_METHOD] [--kt-cpuinfer KT_CPUINFER] [--kt-threadpool-count KT_THREADPOOL_COUNT]
                    [--kt-num-gpu-experts KT_NUM_GPU_EXPERTS] [--kt-max-deferred-experts-per-token KT_MAX_DEFERRED_EXPERTS_PER_TOKEN] [--dllm-algorithm DLLM_ALGORITHM]
                    [--dllm-algorithm-config DLLM_ALGORITHM_CONFIG] [--enable-double-sparsity] [--ds-channel-config-path DS_CHANNEL_CONFIG_PATH] [--ds-heavy-channel-num DS_HEAVY_CHANNEL_NUM]
                    [--ds-heavy-token-num DS_HEAVY_TOKEN_NUM] [--ds-heavy-channel-type DS_HEAVY_CHANNEL_TYPE] [--ds-sparse-decode-threshold DS_SPARSE_DECODE_THRESHOLD] [--cpu-offload-gb CPU_OFFLOAD_GB]
                    [--offload-group-size OFFLOAD_GROUP_SIZE] [--offload-num-in-group OFFLOAD_NUM_IN_GROUP] [--offload-prefetch-step OFFLOAD_PREFETCH_STEP] [--offload-mode OFFLOAD_MODE]
                    [--multi-item-scoring-delimiter MULTI_ITEM_SCORING_DELIMITER] [--disable-radix-cache] [--cuda-graph-max-bs CUDA_GRAPH_MAX_BS] [--cuda-graph-bs CUDA_GRAPH_BS [CUDA_GRAPH_BS ...]]
                    [--disable-cuda-graph] [--disable-cuda-graph-padding] [--enable-profile-cuda-graph] [--enable-cudagraph-gc] [--enable-layerwise-nvtx-marker] [--enable-nccl-nvls] [--enable-symm-mem]
                    [--disable-flashinfer-cutlass-moe-fp4-allgather] [--enable-tokenizer-batch-encode] [--disable-tokenizer-batch-decode] [--disable-outlines-disk-cache] [--disable-custom-all-reduce]
                    [--enable-mscclpp] [--enable-torch-symm-mem] [--pre-warm-nccl] [--disable-overlap-schedule] [--enable-mixed-chunk] [--enable-dp-attention] [--enable-dp-lm-head]
                    [--enable-two-batch-overlap] [--enable-single-batch-overlap] [--tbo-token-distribution-threshold TBO_TOKEN_DISTRIBUTION_THRESHOLD] [--enable-torch-compile]
                    [--enable-torch-compile-debug-mode] [--disable-piecewise-cuda-graph] [--enable-piecewise-cuda-graph] [--enforce-piecewise-cuda-graph]
                    [--piecewise-cuda-graph-tokens PIECEWISE_CUDA_GRAPH_TOKENS [PIECEWISE_CUDA_GRAPH_TOKENS ...]] [--piecewise-cuda-graph-compiler {eager,inductor}]
                    [--torch-compile-max-bs TORCH_COMPILE_MAX_BS] [--piecewise-cuda-graph-max-tokens PIECEWISE_CUDA_GRAPH_MAX_TOKENS] [--torchao-config TORCHAO_CONFIG] [--enable-nan-detection]
                    [--enable-p2p-check] [--triton-attention-reduce-in-fp32] [--triton-attention-num-kv-splits TRITON_ATTENTION_NUM_KV_SPLITS]
                    [--triton-attention-split-tile-size TRITON_ATTENTION_SPLIT_TILE_SIZE] [--num-continuous-decode-steps NUM_CONTINUOUS_DECODE_STEPS] [--delete-ckpt-after-loading] [--enable-memory-saver]
                    [--enable-weights-cpu-backup] [--enable-draft-weights-cpu-backup] [--allow-auto-truncate] [--enable-custom-logit-processor] [--flashinfer-mla-disable-ragged]
                    [--disable-shared-experts-fusion] [--disable-chunked-prefix-cache] [--disable-fast-image-processor] [--keep-mm-feature-on-device] [--enable-return-hidden-states]
                    [--enable-return-routed-experts] [--scheduler-recv-interval SCHEDULER_RECV_INTERVAL] [--numa-node NUMA_NODE [NUMA_NODE ...]] [--enable-deterministic-inference]
                    [--rl-on-policy-target {fsdp}] [--enable-attn-tp-input-scattered] [--enable-nsa-prefill-context-parallel] [--nsa-prefill-cp-mode {in-seq-split,round-robin-split}]
                    [--enable-fused-qk-norm-rope] [--enable-precise-embedding-interpolation] [--enable-fused-moe-sum-all-reduce] [--enable-dynamic-batch-tokenizer]
                    [--dynamic-batch-tokenizer-batch-size DYNAMIC_BATCH_TOKENIZER_BATCH_SIZE] [--dynamic-batch-tokenizer-batch-timeout DYNAMIC_BATCH_TOKENIZER_BATCH_TIMEOUT]
                    [--debug-tensor-dump-output-folder DEBUG_TENSOR_DUMP_OUTPUT_FOLDER] [--debug-tensor-dump-layers DEBUG_TENSOR_DUMP_LAYERS [DEBUG_TENSOR_DUMP_LAYERS ...]]
                    [--debug-tensor-dump-input-file DEBUG_TENSOR_DUMP_INPUT_FILE] [--debug-tensor-dump-inject DEBUG_TENSOR_DUMP_INJECT] [--disaggregation-mode {null,prefill,decode}]
                    [--disaggregation-transfer-backend {mooncake,nixl,ascend,fake,mori}] [--disaggregation-bootstrap-port DISAGGREGATION_BOOTSTRAP_PORT]
                    [--disaggregation-ib-device DISAGGREGATION_IB_DEVICE] [--disaggregation-decode-enable-offload-kvcache] [--num-reserved-decode-tokens NUM_RESERVED_DECODE_TOKENS]
                    [--disaggregation-decode-polling-interval DISAGGREGATION_DECODE_POLLING_INTERVAL] [--encoder-only] [--language-only]
                    [--encoder-transfer-backend {zmq_to_scheduler,zmq_to_tokenizer,mooncake}] [--encoder-urls ENCODER_URLS [ENCODER_URLS ...]] [--enable-adaptive-dispatch-to-encoder]
                    [--custom-weight-loader [CUSTOM_WEIGHT_LOADER ...]] [--weight-loader-disable-mmap] [--remote-instance-weight-loader-seed-instance-ip REMOTE_INSTANCE_WEIGHT_LOADER_SEED_INSTANCE_IP]
                    [--remote-instance-weight-loader-seed-instance-service-port REMOTE_INSTANCE_WEIGHT_LOADER_SEED_INSTANCE_SERVICE_PORT]
                    [--remote-instance-weight-loader-send-weights-group-ports REMOTE_INSTANCE_WEIGHT_LOADER_SEND_WEIGHTS_GROUP_PORTS]
                    [--remote-instance-weight-loader-backend {transfer_engine,nccl,modelexpress}] [--remote-instance-weight-loader-start-seed-via-transfer-engine]
                    [--modelexpress-config MODELEXPRESS_CONFIG] [--enable-pdmux] [--pdmux-config-path PDMUX_CONFIG_PATH] [--sm-group-num SM_GROUP_NUM] [--config CONFIG]
                    [--mm-max-concurrent-calls MM_MAX_CONCURRENT_CALLS] [--mm-per-request-timeout MM_PER_REQUEST_TIMEOUT] [--enable-broadcast-mm-inputs-process] [--mm-process-config MM_PROCESS_CONFIG]
                    [--mm-enable-dp-encoder] [--limit-mm-data-per-request LIMIT_MM_DATA_PER_REQUEST] [--decrypted-config-file DECRYPTED_CONFIG_FILE]
                    [--decrypted-draft-config-file DECRYPTED_DRAFT_CONFIG_FILE] [--enable-prefix-mm-cache] [--enable-mm-global-cache] [--forward-hooks FORWARD_HOOKS]
sglang serve: error: unrecognized arguments: --enable-prefill-context-parallel

Impacted CI:
https://github.com/sgl-project/sglang/actions/runs/23399815739/job/68121364614?pr=21019

0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
Kangyan-Zhou added a commit to Kangyan-Zhou/sglang that referenced this pull request Apr 2, 2026
PR sgl-project#18233 (bb737d7) switched MoE allreduce from _TP group to the
dedicated _MOE_TP group but did not add _MOE_TP to graph_capture().
During CUDA graph replay the custom-allreduce kernel dereferences
unregistered IPC handles, causing illegal-memory-access crashes on
every MoE model launched with 1 < ep < tp (e.g. Qwen3-235B-FP8
--tp=8 --ep=2).

Nightly CI confirms the breakpoint:
  • Mar 20 (before sgl-project#18233): model loads, different test-framework error
  • Mar 23 (after sgl-project#18233): exit code -9 (OOM-killed / segfault)

Validated on 8xH200 with Qwen3-235B-A22B-Instruct-2507-FP8 --tp=8
--ep=2 (gsm8k accuracy 96.4%, 3124 tok/s).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Fridge003 added a commit that referenced this pull request Apr 2, 2026
When ep_size > 1 and ep_size < tp_size, the _MOE_TP group is distinct
from _TP. PR #18233 switched MoE allreduce to use _MOE_TP but forgot
to register it in graph_capture(). This causes illegal memory access
during CUDA graph replay because custom allreduce IPC handles from
_MOE_TP are never registered.

Use ExitStack to register both _MOE_EP and _MOE_TP groups (when they
differ from _TP) during graph capture.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Fridge003 added a commit that referenced this pull request Apr 2, 2026
When ep_size > 1 and ep_size < tp_size, the _MOE_TP group is distinct
from _TP. PR #18233 switched MoE allreduce to use _MOE_TP but forgot
to register it in graph_capture(). This causes illegal memory access
during CUDA graph replay because custom allreduce IPC handles from
_MOE_TP are never registered.

Use ExitStack to register both _MOE_EP and _MOE_TP groups (when they
differ from _TP) during graph capture.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants