Skip to content

[Bug] --enable-mixed-chunk breaks Qwen 2.5 VL #10179

@AlienKevin

Description

@AlienKevin

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Enabling the --enable-mixed-chunk option throws Dynamo error and breaks Qwen 2.5 VL.

Reproduction

Server command:

python -m sglang.launch_server \
    --model-path Qwen/Qwen2.5-VL-7B-Instruct \
    --mem-fraction-static 0.8 \
    --tp 1 \
    --disable-radix-cache \
    --cuda-graph-bs 256 \
    --cuda-graph-max-bs 256 \
    --chunked-prefill-size 8192 \
    --max-prefill-tokens 8192 \
    --max-running-requests 256 \
    --enable-mixed-chunk

Client command:

python -m sglang.bench_serving     --backend sglang-oai-chat     --dataset-name random-image     --num-prompts 100     --random-image-num-images 1     --random-image-resolution 1080p     --random-input-len 512     --random-output-len 512 --port 30000
Full log, error only appeared after enabling --enable-mixed-chunk
user@sglang-bench:/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang$ ./qwen_server.sh 
All deep_gemm operations loaded successfully!
W0908 20:08:10.294000 119663 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0908 20:08:10.294000 119663 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-09-08 20:08:10] server_args=ServerArgs(model_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.8, max_running_requests=256, max_queued_requests=9223372036854775807, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=8192, schedule_policy='fcfs', schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=845617692, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, gc_warning_threshold_secs=0.0, api_key=None, served_model_name='Qwen/Qwen2.5-VL-7B-Instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_backend='prefill', ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', disable_radix_cache=True, cuda_graph_max_bs=256, cuda_graph_bs=[256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=True, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False)
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-08 20:08:11] MOE_RUNNER_BACKEND is not initialized, using triton backend
[2025-09-08 20:08:13] Using default HuggingFace chat template with detected content format: openai
All deep_gemm operations loaded successfully!
W0908 20:08:16.337000 119942 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0908 20:08:16.337000 119942 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
All deep_gemm operations loaded successfully!
W0908 20:08:16.872000 119943 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0908 20:08:16.872000 119943 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-08 20:08:18] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-09-08 20:08:18] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-08 20:08:18] Init torch distributed ends. mem usage=0.00 GB
[2025-09-08 20:08:19] MOE_RUNNER_BACKEND is not initialized, using triton backend
[2025-09-08 20:08:19] Load weight begin. avail mem=139.28 GB
[2025-09-08 20:08:19] Multimodal attention backend not set. Use fa3.
[2025-09-08 20:08:19] Using fa3 as multimodal attention backend.
[2025-09-08 20:08:19] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:02,  1.83it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:01,  1.73it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:01<00:01,  1.72it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.65it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  2.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  1.94it/s]

[2025-09-08 20:08:22] Load weight end. type=Qwen2_5_VLForConditionalGeneration, dtype=torch.bfloat16, avail mem=123.49 GB, mem usage=15.79 GB.
[2025-09-08 20:08:22] KV Cache is allocated. #tokens: 1790758, K size: 47.82 GB, V size: 47.82 GB
[2025-09-08 20:08:22] Memory pool end. avail mem=27.59 GB
[2025-09-08 20:08:22] Capture cuda graph begin. This can take up to several minutes. avail mem=27.01 GB
[2025-09-08 20:08:22] Capture cuda graph bs [256]
Capturing batches (bs=256 avail_mem=26.71 GB): 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.25s/it]
[2025-09-08 20:08:25] Capture cuda graph end. Time elapsed: 2.45 s. mem usage=0.50 GB. avail mem=26.51 GB.
[2025-09-08 20:08:27] max_total_num_tokens=1790758, chunked_prefill_size=8192, max_prefill_tokens=8192, max_running_requests=256, context_len=128000, available_gpu_mem=26.51 GB
[2025-09-08 20:08:27] INFO:     Started server process [119663]
[2025-09-08 20:08:27] INFO:     Waiting for application startup.
[2025-09-08 20:08:27] INFO:     Application startup complete.
[2025-09-08 20:08:27] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-09-08 20:08:28] INFO:     127.0.0.1:57254 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-09-08 20:08:28] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-08 20:08:31] INFO:     127.0.0.1:57266 - "POST /generate HTTP/1.1" 200 OK
[2025-09-08 20:08:31] INFO:     127.0.0.1:57276 - "GET /v1/models HTTP/1.1" 200 OK
[2025-09-08 20:08:31] The server is fired up and ready to roll!
[2025-09-08 20:08:43] INFO:     127.0.0.1:45760 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:43] Prefill batch. #new-seq: 1, #new-token: 2757, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-08 20:08:44] Decode batch. #running-req: 1, #token: 0, token usage: 0.00, cuda graph: True, gen throughput (token/s): 2.29, #queue-req: 0, 
[2025-09-08 20:08:46] INFO:     127.0.0.1:45788 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:46] INFO:     127.0.0.1:45776 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:46] INFO:     127.0.0.1:45782 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:46] INFO:     127.0.0.1:45794 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:46] INFO:     127.0.0.1:45804 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:47] Prefill batch. #new-seq: 1, #new-token: 3128, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-08 20:08:47] INFO:     127.0.0.1:45822 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:47] INFO:     127.0.0.1:45824 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:47] Prefill batch. #new-seq: 2, #new-token: 5726, #cached-token: 0, token usage: 0.00, #running-req: 1, #queue-req: 1, 
[2025-09-08 20:08:47] INFO:     127.0.0.1:45836 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:47] INFO:     127.0.0.1:45838 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:48] INFO:     127.0.0.1:45846 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:48] INFO:     127.0.0.1:45854 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:48] INFO:     127.0.0.1:45880 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2] failed while attempting to run meta for aten.view.default
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2] Traceback (most recent call last):
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     r = func(*args, **kwargs)
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]         ^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     return self._op(*args, **kwargs)
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 364, in _view_meta
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 3823, in _reshape_view_helper
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     shape = utils.infer_size(shape, a.numel())
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/__init__.py", line 1018, in infer_size
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     torch._check(
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1684, in _check
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     _check_with(RuntimeError, cond, message)
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1666, in _check_with
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2]     raise error_type(message_evaluated)
[rank0]:E0908 20:08:48.141000 119942 torch/_subclasses/fake_tensor.py:2721] [0/2] RuntimeError: shape '[s18, -1, 128]' is invalid for input of size s5*s83
[2025-09-08 20:08:48] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 141, in forward_thread_func
    self.forward_thread_func_()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 176, in forward_thread_func_
    self.worker.forward_batch_generation(
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker.py", line 244, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1785, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1836, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1730, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2_5_vl.py", line 586, in forward
    hidden_states = general_mm_embed_routine(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/mm_utils.py", line 664, in general_mm_embed_routine
    hidden_states = language_model(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 340, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 244, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 182, in forward
    q, k = self.rotary_emb(positions, q, k)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 753, in transform
    tracer.run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
    self._call(inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/misc.py", line 1111, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 712, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2559, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2625, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2723, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3355, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3253, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 2753, in wrap_fake_exception
    return fn()
           ^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3254, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3462, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3432, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 364, in _view_meta
    return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 3823, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/__init__.py", line 1018, in infer_size
    torch._check(
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method view(*(FakeTensor(..., device='cuda:0', size=(s5, s83), dtype=torch.bfloat16), s18, -1, 128), **{}): got RuntimeError("shape '[s18, -1, 128]' is invalid for input of size s5*s83")

from user code:
   File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/layers/rotary_embedding.py", line 1066, in forward
    query = query.view(num_tokens, -1, self.head_size)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


[2025-09-08 20:08:48] INFO:     127.0.0.1:45864 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:48] INFO:     127.0.0.1:45874 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-08 20:08:48] Received sigquit from a child process. It usually means the child failed.
./qwen_server.sh: line 16: 119663 Killed                  python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --mem-fraction-static 0.8 --tp 1 --disable-radix-cache --cuda-graph-bs 256 --cuda-graph-max-bs 256 --chunked-prefill-size 8192 --max-prefill-tokens 8192 --max-running-requests 256 --enable-mixed-chunk

Environment

Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H200
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.8, V12.8.93
CUDA Driver Version: 575.57.08
PyTorch: 2.8.0+cu128
sglang: 0.5.2rc2
sgl_kernel: 0.3.8
flashinfer_python: 0.3.1
triton: 3.4.0
transformers: 4.56.1
torchao: 0.9.0
numpy: 2.3.2
aiohttp: 3.12.15
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.34.4
interegular: 0.3.3
modelscope: 1.29.0
orjson: 3.11.2
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.2
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.24
openai: 1.99.1
tiktoken: 0.11.0
anthropic: 0.66.0
litellm: Module Not Found
decord: 0.6.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    NIC9NIC10   NIC11   CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    PXB     NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS     0-55,112-167    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS     0-55,112-167    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYSSYS      SYS     0-55,112-167    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYSSYS      SYS     0-55,112-167    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     PXB     NODE    NODE    NODENODE    NODE    56-111,168-223  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    PXBNODE     NODE    56-111,168-223  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODEPXB     NODE    56-111,168-223  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODENODE    PXB     56-111,168-223  1               N/A
NIC0    PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS
NIC1    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS
NIC2    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    PIX      X      NODE    NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS
NIC3    NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      NODE    NODE    SYS     SYS     SYS     SYSSYS      SYS
NIC4    NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE     X      NODE    SYS     SYS     SYS     SYSSYS      SYS
NIC5    NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE     X      SYS     SYS     SYS     SYSSYS      SYS
NIC6    SYS     SYS     SYS     SYS     PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS      X      NODE    NODE    NODENODE    NODE
NIC7    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE     X      PIX     NODENODE    NODE
NIC8    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE    PIX      X      NODENODE    NODE
NIC9    SYS     SYS     SYS     SYS     NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X NODE     NODE
NIC10   SYS     SYS     SYS     SYS     NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE X      NODE
NIC11   SYS     SYS     SYS     SYS     NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODENODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9
  NIC10: mlx5_10
  NIC11: mlx5_11


ulimit soft: 1048576

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions