Skip to content

try to support vl model#13034

Closed
BBuf wants to merge 1 commit intomainfrom
try_to_support_vl_model
Closed

try to support vl model#13034
BBuf wants to merge 1 commit intomainfrom
try_to_support_vl_model

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Nov 11, 2025

Fixing bugs when using piecewise CUDA graph serving for VL models, such as the "assert positions.ndim == 1 or positions.ndim == 2" error caused by mrope_positions, and the "torch._dynamo.exc.Unsupported: Skip calling torch.compiler.disable()" error caused by the @torch._dynamo.disable() decorator in triton_mrope_wrapper. However, there is currently a new bug that remains unresolved, preventing the serving of VL models.

Capturing num tokens (num_tokens=3712 avail_mem=59.57 GB):   5%|| 3/58 [00:06<02:04,  2.26s/it]
[2025-11-11 01:58:02] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/yineng/bbuf/sglang/python/sglang/srt/managers/scheduler.py", line 2679, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/managers/scheduler.py", line 312, in __init__
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/model_executor/model_runner.py", line 360, in __init__
    self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 217, in __init__
    self.capture()
  File "/home/yineng/bbuf/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 324, in capture
    self.capture_one_batch_size(num_tokens)
  File "/home/yineng/bbuf/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 424, in capture_one_batch_size
    run_once()
  File "/home/yineng/bbuf/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 413, in run_once
    self.model_runner.model.forward(
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/models/qwen2_5_vl.py", line 580, in forward
    hidden_states = general_mm_embed_routine(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/managers/mm_utils.py", line 701, in general_mm_embed_routine
    hidden_states = language_model(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/compilation/compile.py", line 207, in trampoline
    return compiled_callable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/models/qwen2.py", line 335, in forward
    def forward(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.191", line 205, in forward
    submod_0 = self.submod_0(l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_input_embeds_, s47, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_bias_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_, s18, s7, s80);  l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_bias_ = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yineng/bbuf/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 222, in __call__
    assert new_input_addresses == entry.input_addresses, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Input addresses for cudagraphs are different during replay. Expected [124452267358208, 124355008790528, 124445100343296, 124452267348992, 124321687142400, 124330174447616], got [124452267358208, 124320646955008, 124445100343296, 124452267348992, 124321687142400, 124330174447616]

[2025-11-11 01:58:02] Received sigquit from a child process. It usually means the child failed.
[1]    1566408 killed     SGLANG_USE_CUDA_IPC_TRANSPORT=1 SGLANG_VLM_CACHE_SIZE_MB=0 python -m    --hos

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Nov 11, 2025

I encountered the "assert positions.ndim == 1 or positions.ndim == 2" error as well and tried to locate the root cause. The following finding is for your reference. From debugging, we make sure that during cuda graph capture, general_mm_embed_routine() enters the following branch.

    else:
        inputs_embeds = embed_tokens(input_ids)

I added an assertion “positions must be set for Qwen2/2.5-VL (MRoPE)” in def general_mm_embed_routine, which pinpoints the root cause:

    pos = kwargs.get("positions", getattr(forward_batch, "positions", None))
    assert pos is not None, "positions must be set for Qwen2/2.5-VL (MRoPE)" <<<<<<<<<<<
    hidden_states = language_model(
        input_ids=None,
        forward_batch=forward_batch,
        input_embeds=inputs_embeds,
        **kwargs,
    )
    return hidden_states
[2025-11-11 10:28:00 TP6] Scheduler hit an exception: Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2672, in run_scheduler_process
    scheduler = Scheduler(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 311, in __init__
    self.tp_worker = TpModelWorker(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 360, in __init__
    self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 210, in __init__
    self.warmup_and_capture()
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 267, in warmup_and_capture
    _ = self.model_runner.model.forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen2_5_vl.py", line 580, in forward
    hidden_states = general_mm_embed_routine(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/mm_utils.py", line 733, in general_mm_embed_routine
    assert pos is not None, "positions must be set for Qwen2/2.5-VL (MRoPE)"
AssertionError: positions must be set for Qwen2/2.5-VL (MRoPE)

When entering language_model.forward(...), the positions required by Qwen2/2.5-VL’s rotary positional encoding (MRoPE) aren’t being passed in, so rotary_embedding.forward() receives None. Then, when TorchDynamo captures it and tries to evaluate positions.ndim == 1 or 2, it crashes (None.ndim).

I'm fixing it in progress. FYI. @BBuf @ispobock

@yhyang201
Copy link
Copy Markdown
Collaborator

I encountered the "assert positions.ndim == 1 or positions.ndim == 2" error as well and tried to locate the root cause. The following finding is for your reference. From debugging, we make sure that during cuda graph capture, general_mm_embed_routine() enters the following branch.

The issue you encountered was likely caused by the piecewise CUDA graph not preparing the mrope data. This PR should have resolved the problem.

@yhyang201
Copy link
Copy Markdown
Collaborator

I believe this new error might be caused by the fact that the general_mm_embed_routine function in the multimodal model creates a new variable inputs_embeds on each call.

inputs_embeds, other_info = embed_mm_inputs(

Because embed_mm_inputs allocates a fresh inputs_embeds tensor every time, its device address cannot remain fixed, which likely prevents us from capturing the entire VLM with a single CUDA Graph.

I have two potential ideas:

  1. Use a separate CUDA Graph for the encoder, keep the construction of inputs_embeds in eager mode, and let the LLM backbone continue to run under the standard piecewise CUDA Graph.

  2. Only capture the piecewise CUDA Graph for the LLM.

I'm not entirely sure which approach would be better, but it might make sense to start with option (2) and later build upon it to explore (1).

I’m not deeply familiar with CUDA Graph internals, so I’d really appreciate your thoughts and feedback. @BBuf @yuan-luo

Thank you!

@BBuf BBuf closed this Nov 19, 2025
@BBuf BBuf deleted the try_to_support_vl_model branch November 19, 2025 13:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants