Skip to content

piecewise cuda graph support qwen3-moe#11845

Merged
ispobock merged 5 commits intomainfrom
qwen3_moe_support_piecewise_cuda_gprah
Oct 21, 2025
Merged

piecewise cuda graph support qwen3-moe#11845
ispobock merged 5 commits intomainfrom
qwen3_moe_support_piecewise_cuda_gprah

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Oct 20, 2025

CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-FP8 --tp 4 --host 0.0.0.0 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager

➜  sglang python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
100%|████████████████████████████████████████████████| 1319/1319 [00:39<00:00, 33.29it/s]
Accuracy: 0.923
Invalid: 0.002
Latency: 39.746 s
Output throughput: 5340.152 token/s
CUDA_VISIBLE_DEVICES=4,5,6,7  python3 -m sglang.launch_server --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct --tp 4 --host 0.0.0.0 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager

python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8

100%|████████████████████████████████████████████████| 1319/1319 [00:16<00:00, 82.05it/s]
Accuracy: 0.928
Invalid: 0.001
Latency: 16.142 s
Output throughput: 10959.170 token/s

Fix two bugs:

    raise UnsupportedOperatorException(func)
RuntimeError: Dynamo failed to run FX node with fake tensors: call_function sgl_kernel.sgl_per_token_group_quant_fp8.default(*(FakeTensor(..., device='cuda:1', size=(s72, 4096), dtype=torch.bfloat16), FakeTensor(..., device='cuda:1', size=(s72, 4096), dtype=torch.float8_e4m3fn), FakeTensor(..., device='cuda:1', size=(s72, 32)), 128, 1e-10, -448.0, 448.0, False), **{}): got UnsupportedOperatorException(func=<OpOverload(op='sgl_kernel.sgl_per_token_group_quant_fp8', overload='default')>)

During handling of the above exception, another exception occurred:

torch._dynamo.exc.Unsupported: Operator does not support running with fake tensors
  Explanation: 
  Hint: see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix

  Developer debug context: unsupported operator: sgl_kernel.sgl_per_token_group_quant_fp8.default
2025-10-19 23:35:24 TP2] Scheduler hit an exception: torch._dynamo.exc.Unsupported: Unsupported method call
  Explanation: Dynamo does not know how to trace method `__call__` of class `_lru_cache_wrapper`
  Hint: Avoid calling `_lru_cache_wrapper.__call__` in your code.
  Hint: Please report an issue to PyTorch.

  Developer debug context: call_method UserDefinedObjectVariable(_lru_cache_wrapper) __call__ [LazyVariableTracker()] {}

TODO

Fix IMA when:

python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 4096 --random-output-len 20 --random-range-ratio 1 --num-prompts 10 --max-concurrency 1 --warmup-requests 3
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.399", line 1421, in forward
    submod_0 = self.submod_0(l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_scale_inv_, l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_, l_forward_batch_token_to_kv_pool_k_buffer_0_, l_forward_batch_token_to_kv_pool_v_buffer_0_, l_positions_, s80, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_forward_batch_out_cache_loc, s67);  l_input_ids_ = l_self_modules_embed_tokens_parameters_weight_ = l_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_scale_inv_ = l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_ = l_forward_batch_token_to_kv_pool_k_buffer_0_ = l_forward_batch_token_to_kv_pool_v_buffer_0_ = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhyncs/bbuf/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 227, in __call__
    entry.cudagraph.replay()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/graphs.py", line 117, in replay
    super().replay()
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


terminate called after throwing an instance of 'c10::AcceleratorError'
[2025-10-20 01:50:48 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/zhyncs/bbuf/sglang/python/sglang/srt/managers/scheduler.py", line 3065, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhyncs/bbuf/sglang/python/sglang/srt/managers/scheduler.py", line 1044, in event_loop_overlap
    batch_result = self.run_batch(batch)

@BBuf BBuf marked this pull request as ready for review October 20, 2025 09:41
Copy link
Copy Markdown
Collaborator

@ispobock ispobock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a small model in the unit test.

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Oct 20, 2025

We can add a small model in the unit test.

done.

@ispobock ispobock merged commit 8374a96 into main Oct 21, 2025
100 of 107 checks passed
@ispobock ispobock deleted the qwen3_moe_support_piecewise_cuda_gprah branch October 21, 2025 02:55
@sleepcoo
Copy link
Copy Markdown
Collaborator

sleepcoo commented Oct 22, 2025

I have tested the performance improvement of Qwen-30B-A3B using piece-wise on h20, and the performance data as follows:

input 7k,output 0.2k TTFT TPOT
with piecewise 289 5.99
without piecewise 310 6.06

@yansiyu550
Copy link
Copy Markdown

yansiyu550 commented Nov 27, 2025

CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-FP8 --tp 4 --host 0.0.0.0 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager

Hi, I am testing the --enable-piecewise-cuda-graph feature on different Qwen3 models.
I observed that:

  1. Qwen3-4B shows no performance improvement at all, even though capture succeeds.
  2. Qwen3-30B-A3B-FP8 fails to run when CUDA graph is enabled (with eager compiler).

Environment:

  • SGLang version: 0.5.5.post3
  • GPUs:4090 24G*8
  1. No Performance Gain on Qwen3-4B
    Test Commands:
# baseline
CUDA_VISIBLE_DEVICES=4 python3 -m sglang.launch_server --model-path /work/Qwen3-4B --served-model-name qwen3 --tp 1 --port 39992 --host 0.0.0.0 --log-level info --enable-metrics --enable-p2p-check

# enable-piecewise-cuda-graph
CUDA_VISIBLE_DEVICES=6 python3 -m sglang.launch_server --model-path /work/Qwen3-4B --served-model-name qwen3 --tp 1 --port 39992 --host 0.0.0.0 --log-level info --enable-metrics --enable-p2p-check --enable-piecewise-cuda-graph

# Test
python3 -m sglang.bench_serving --backend sglang --model /work/Qwen3-4B --tokenizer /work/Qwen3-4B --num-prompts 128 --random-input-len 4096 --random-output-len 128 --dataset-name random --dataset-path  /work/ShareGPT_V3_unfiltered_cleaned_split.json --seed 42 --host 0.0.0.0 --port 39992 --random-range-ratio 1 --request-rate 3

python3 -m sglang.bench_serving --backend sglang --model /work/Qwen3-4B --tokenizer /work/Qwen3-4B --num-prompts 128 --random-input-len 2048 --random-output-len 128 --dataset-name random --dataset-path  /work/ShareGPT_V3_unfiltered_cleaned_split.json --seed 42 --host 0.0.0.0 --port 39992 --random-range-ratio 1 --request-rate 3

python3 -m sglang.bench_serving --backend sglang --model /work/Qwen3-4B --tokenizer /work/Qwen3-4B --num-prompts 128 --random-input-len 512 --random-output-len 128 --dataset-name random --dataset-path  /work/ShareGPT_V3_unfiltered_cleaned_split.json --seed 42 --host 0.0.0.0 --port 39992 --random-range-ratio 1 --request-rate 3
  1. Runtime Error on Qwen3-30B-A3B-FP8 with CUDA Graph Enabled

Commands:

CUDA_VISIBLE_DEVICES=6,7 python3 -m sglang.launch_server --model-path /work/Qwen3-30B-A3B-FP8 --tp 2 --host 0.0.0.0 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager

Error:

[2025-11-27 09:23:37 TP0] Capture cuda graph end. Time elapsed: 3.94 s. mem usage=0.21 GB. avail mem=3.37 GB.
[2025-11-27 09:23:37 TP0] Capture cuda graph num tokens [4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096]
[2025-11-27 09:23:37 TP0] install_torch_compiled
[2025-11-27 09:23:37 TP1] Capture cuda graph end. Time elapsed: 3.98 s. mem usage=0.21 GB. avail mem=3.37 GB.
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1692: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1692: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
[2025-11-27 09:23:37 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/work/yansiyu/sglang-main/python/sglang/srt/managers/scheduler.py", line 2628, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/managers/scheduler.py", line 314, in __init__
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/managers/tp_worker.py", line 245, in __init__
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/model_executor/model_runner.py", line 387, in __init__
    self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 235, in __init__
    self.warmup_torch_compile()
  File "/work/yansiyu/sglang-main/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 315, in warmup_torch_compile
    _ = self.model_runner.model.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen3_moe.py", line 755, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/yansiyu/sglang-main/python/sglang/srt/compilation/compile.py", line 206, in trampoline
    _ensure_compiled(self, *args, **kwargs)
  File "/work/yansiyu/sglang-main/python/sglang/srt/compilation/compile.py", line 197, in _ensure_compiled
    compiled_callable(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor
  Explanation: torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output


  Developer debug context: example_value type: str; op: call_function; target: <function get_device_name at 0x7f8bcbdf04a0>

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0208.html

from user code:
   File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen2_moe.py", line 608, in forward
    hidden_states, residual = layer(
  File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen3_moe.py", line 607, in forward
    hidden_states = self.self_attn(
  File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen3_moe.py", line 496, in forward
    s = self.forward_prepare(
  File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen3_moe.py", line 464, in forward_prepare
    return self.forward_prepare_native(
  File "/work/yansiyu/sglang-main/python/sglang/srt/models/qwen3_moe.py", line 434, in forward_prepare_native
    qkv, _ = self.qkv_proj(hidden_states)
  File "/work/yansiyu/sglang-main/python/sglang/srt/layers/linear.py", line 429, in forward
    output_parallel = self.quant_method.apply(self, input_, bias)
  File "/work/yansiyu/sglang-main/python/sglang/srt/layers/quantization/fp8.py", line 510, in apply
    return self.w8a8_block_fp8_linear(
  File "/work/yansiyu/sglang-main/python/sglang/srt/layers/quantization/fp8_utils.py", line 321, in triton_w8a8_block_fp8_linear
    output = w8a8_block_fp8_matmul_triton(
  File "/work/yansiyu/sglang-main/python/sglang/srt/layers/quantization/fp8_kernel.py", line 1101, in w8a8_block_fp8_matmul_triton
    configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/polyfills/__init__.py", line 264, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/work/yansiyu/sglang-main/python/sglang/srt/layers/quantization/fp8_kernel.py", line 954, in get_w8a8_block_fp8_configs
    device_name = get_device_name().replace(" ", "_")
  File "/work/yansiyu/sglang-main/python/sglang/srt/utils/common.py", line 1809, in get_device_name
    return torch.cuda.get_device_name(device_id)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


[2025-11-27 09:23:37] Received sigquit from a child process. It usually means the child failed.
Killed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants