Skip to content

feat: Support Qwen 2.5 vl#3258

Merged
zhaochenyang20 merged 1 commit intosgl-project:mainfrom
mickqian:qwen-2.5-vl
Feb 16, 2025
Merged

feat: Support Qwen 2.5 vl#3258
zhaochenyang20 merged 1 commit intosgl-project:mainfrom
mickqian:qwen-2.5-vl

Conversation

@mickqian
Copy link
Copy Markdown
Collaborator

@mickqian mickqian commented Feb 2, 2025

Motivation

Address #3247

Modifications

  1. Qwen2.5 vl modeling

Checklist

@mickqian
Copy link
Copy Markdown
Collaborator Author

mickqian commented Feb 2, 2025

To use the default processor class, either transformers should be updated, or the processor codes should be copied from it.

@grahama1970 grahama1970 mentioned this pull request Feb 5, 2025
5 tasks
@halexan
Copy link
Copy Markdown

halexan commented Feb 8, 2025

Any updates?

Comment thread python/sglang/srt/models/qwen2.py Outdated
Comment thread test/srt/test_vision_openai_server.py Outdated
Comment thread python/sglang/srt/configs/qwen2_5_vl.py Outdated
Comment thread python/sglang/srt/models/qwen2_5_vl.py Outdated
@yizhang2077 yizhang2077 self-assigned this Feb 11, 2025
@mickqian mickqian force-pushed the qwen-2.5-vl branch 3 times, most recently from 6c76d0d to 802a401 Compare February 13, 2025 13:41
@yizhang2077
Copy link
Copy Markdown
Collaborator

yizhang2077 commented Feb 14, 2025

LGTM,I think we can run a local benchmark like mmmu and paste result here, and then we can merge it. (here is benchmark result #3258 (comment)) cc @zhaochenyang20 @mickqian. BTW, I think maybe we need add a VLM benchmark?

Copy link
Copy Markdown
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update supported_models.md

@mickqian
Copy link
Copy Markdown
Collaborator Author

mickqian commented Feb 14, 2025

benchmark result posted here, and mmmu benchmark will be submitted here

@e1ijah1
Copy link
Copy Markdown

e1ijah1 commented Feb 14, 2025

I tried to install sglang from source code using this branch and run the sglang server with qwen 2.5 vl model specified, but encountered a KeyError from transformers AutoConfig.

root@cm02-187:~/projects/sglang# python3.10 -m pip freeze|grep transformers
transformers==4.48.3
root@cm02-187:~/projects/sglang# python3.10 -m sglang.launch_server --model-path /data1/Qwen2.5-VL-72B-Instruct --tp 2
INFO 02-14 06:33:58 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 06:34:02] server_args=ServerArgs(model_path='/data1/Qwen2.5-VL-72B-Instruct', tokenizer_path='/data1/Qwen2.5-VL-72B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='/data1/Qwen2.5-VL-72B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30000, mem_fraction_static=0.87, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=2, stream_interval=1, stream_output=False, random_seed=852285615, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 1071, in from_pretrained
    config_class = CONFIG_MAPPING[config_dict["model_type"]]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 773, in __getitem__
    raise KeyError(key)
KeyError: 'qwen2_5_vl'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/projects/sglang/python/sglang/launch_server.py", line 14, in <module>
    launch_server(server_args)
  File "/root/projects/sglang/python/sglang/srt/entrypoints/http_server.py", line 491, in launch_server
    tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
  File "/root/projects/sglang/python/sglang/srt/entrypoints/engine.py", line 426, in _launch_subprocesses
    tokenizer_manager = TokenizerManager(server_args, port_args)
  File "/root/projects/sglang/python/sglang/srt/managers/tokenizer_manager.py", line 134, in __init__
    self.model_config = ModelConfig(
  File "/root/projects/sglang/python/sglang/srt/configs/model_config.py", line 53, in __init__
    self.hf_config = get_config(
  File "/root/projects/sglang/python/sglang/srt/hf_transformers_utils.py", line 66, in get_config
    config = AutoConfig.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 1073, in from_pretrained
    raise ValueError(
ValueError: The checkpoint you are trying to load has model type `qwen2_5_vl` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`

@yizhang2077
Copy link
Copy Markdown
Collaborator

@e1ijah1 I think maybe you need specify chat template like using --chat-template qwen2-vl?

@e1ijah1
Copy link
Copy Markdown

e1ijah1 commented Feb 14, 2025

I tried to install sglang from source code using this branch and run the sglang server with qwen 2.5 vl model specified, but encountered a KeyError from transformers AutoConfig.

root@cm02-187:~/projects/sglang# python3.10 -m pip freeze|grep transformers
transformers==4.48.3
root@cm02-187:~/projects/sglang# python3.10 -m sglang.launch_server --model-path /data1/Qwen2.5-VL-72B-Instruct --tp 2
INFO 02-14 06:33:58 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 06:34:02] server_args=ServerArgs(model_path='/data1/Qwen2.5-VL-72B-Instruct', tokenizer_path='/data1/Qwen2.5-VL-72B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='/data1/Qwen2.5-VL-72B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30000, mem_fraction_static=0.87, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=2, stream_interval=1, stream_output=False, random_seed=852285615, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 1071, in from_pretrained
    config_class = CONFIG_MAPPING[config_dict["model_type"]]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 773, in __getitem__
    raise KeyError(key)
KeyError: 'qwen2_5_vl'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/projects/sglang/python/sglang/launch_server.py", line 14, in <module>
    launch_server(server_args)
  File "/root/projects/sglang/python/sglang/srt/entrypoints/http_server.py", line 491, in launch_server
    tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
  File "/root/projects/sglang/python/sglang/srt/entrypoints/engine.py", line 426, in _launch_subprocesses
    tokenizer_manager = TokenizerManager(server_args, port_args)
  File "/root/projects/sglang/python/sglang/srt/managers/tokenizer_manager.py", line 134, in __init__
    self.model_config = ModelConfig(
  File "/root/projects/sglang/python/sglang/srt/configs/model_config.py", line 53, in __init__
    self.hf_config = get_config(
  File "/root/projects/sglang/python/sglang/srt/hf_transformers_utils.py", line 66, in get_config
    config = AutoConfig.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py", line 1073, in from_pretrained
    raise ValueError(
ValueError: The checkpoint you are trying to load has model type `qwen2_5_vl` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`

My bad, I noticed that the code for qwen2.5 vl seems to be still in the main branch and hasn't been released yet. After installing transformers from the main branch, I was able to serve normally, but there still seems to be an issue with awq.

@e1ijah1
Copy link
Copy Markdown

e1ijah1 commented Feb 14, 2025

I tried to serve this unofficial awq model https://huggingface.co/PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ , and got the following error:

 $ python3.10 -m sglang.launch_server --model-path /data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/ --tp 2 --dtype float16
 

INFO 02-14 07:01:15 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 07:01:19] server_args=ServerArgs(model_path='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', tokenizer_path='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30000, mem_fraction_static=0.87, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=2, stream_interval=1, stream_output=False, random_seed=738640062, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
[2025-02-14 07:01:19] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:19 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 02-14 07:01:21 __init__.py:190] Automatically detected platform cuda.
INFO 02-14 07:01:22 __init__.py:190] Automatically detected platform cuda.
INFO 02-14 07:01:22 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 07:01:25 TP0] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP1] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP0] Overlap scheduler is disabled for multimodal models.
[2025-02-14 07:01:25 TP0] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP0] Automatically reduce --mem-fraction-static to 0.827 because this is a multimodal model.
[2025-02-14 07:01:25 TP0] Init torch distributed begin.
[2025-02-14 07:01:26 TP1] Overlap scheduler is disabled for multimodal models.
[2025-02-14 07:01:26 TP1] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:26 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:26 TP1] Automatically reduce --mem-fraction-static to 0.827 because this is a multimodal model.
[2025-02-14 07:01:26 TP1] Init torch distributed begin.
[2025-02-14 07:01:26 TP1] sglang is using nccl==2.21.5
[2025-02-14 07:01:26 TP0] sglang is using nccl==2.21.5
[2025-02-14 07:01:26 TP1] Load weight begin. avail mem=43.65 GB
[2025-02-14 07:01:26 TP0] Load weight begin. avail mem=43.65 GB
[2025-02-14 07:01:26 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/root/projects/sglang/python/sglang/srt/managers/scheduler.py", line 1816, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/root/projects/sglang/python/sglang/srt/managers/scheduler.py", line 240, in __init__
    self.tp_worker = TpWorkerClass(
  File "/root/projects/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in __init__
    self.model_runner = ModelRunner(
  File "/root/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 194, in __init__
    self.load_model()
  File "/root/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 317, in load_model
    self.model = get_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/loader.py", line 357, in load_model
    model = _initialize_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/loader.py", line 138, in _initialize_model
    return model_class(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2_5_vl.py", line 513, in __init__
    self.model = Qwen2Model(config, quant_config)
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 241, in __init__
    self.layers = make_layers(
  File "/root/projects/sglang/python/sglang/srt/utils.py", line 313, in make_layers
    [
  File "/root/projects/sglang/python/sglang/srt/utils.py", line 314, in <listcomp>
    maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 243, in <lambda>
    lambda idx, prefix: Qwen2DecoderLayer(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 190, in __init__
    self.mlp = Qwen2MLP(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 69, in __init__
    self.down_proj = RowParallelLinear(
  File "/root/projects/sglang/python/sglang/srt/layers/linear.py", line 1159, in __init__
    self.quant_method.create_weights(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 188, in create_weights
    verify_marlin_supports_shape(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 110, in verify_marlin_supports_shape
    raise ValueError(f"Weight input_size_per_partition = "
ValueError: Weight input_size_per_partition = 14784 is not divisible by min_thread_k = 128. Consider reducing tensor_parallel_size or running with --quantization gptq.

[2025-02-14 07:01:26] Received sigquit from a child proces. It usually means the child failed.
Killed

@mickqian
Copy link
Copy Markdown
Collaborator Author

I tried to serve this unofficial awq model https://huggingface.co/PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ , and got the following error:我尝试使用这个非官方的 awq 模型 https://huggingface.co/PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ ,并遇到了以下错误:

 $ python3.10 -m sglang.launch_server --model-path /data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/ --tp 2 --dtype float16
 

INFO 02-14 07:01:15 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 07:01:19] server_args=ServerArgs(model_path='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', tokenizer_path='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='/data1/Qwen2.5-VL-72B-Instruct-Pointer-AWQ/', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30000, mem_fraction_static=0.87, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=2, stream_interval=1, stream_output=False, random_seed=738640062, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
[2025-02-14 07:01:19] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:19 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 02-14 07:01:21 __init__.py:190] Automatically detected platform cuda.
INFO 02-14 07:01:22 __init__.py:190] Automatically detected platform cuda.
INFO 02-14 07:01:22 __init__.py:190] Automatically detected platform cuda.
[2025-02-14 07:01:25 TP0] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP1] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP0] Overlap scheduler is disabled for multimodal models.
[2025-02-14 07:01:25 TP0] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:25 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:25 TP0] Automatically reduce --mem-fraction-static to 0.827 because this is a multimodal model.
[2025-02-14 07:01:25 TP0] Init torch distributed begin.
[2025-02-14 07:01:26 TP1] Overlap scheduler is disabled for multimodal models.
[2025-02-14 07:01:26 TP1] Casting torch.bfloat16 to torch.float16.
INFO 02-14 07:01:26 awq_marlin.py:111] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
[2025-02-14 07:01:26 TP1] Automatically reduce --mem-fraction-static to 0.827 because this is a multimodal model.
[2025-02-14 07:01:26 TP1] Init torch distributed begin.
[2025-02-14 07:01:26 TP1] sglang is using nccl==2.21.5
[2025-02-14 07:01:26 TP0] sglang is using nccl==2.21.5
[2025-02-14 07:01:26 TP1] Load weight begin. avail mem=43.65 GB
[2025-02-14 07:01:26 TP0] Load weight begin. avail mem=43.65 GB
[2025-02-14 07:01:26 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/root/projects/sglang/python/sglang/srt/managers/scheduler.py", line 1816, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/root/projects/sglang/python/sglang/srt/managers/scheduler.py", line 240, in __init__
    self.tp_worker = TpWorkerClass(
  File "/root/projects/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in __init__
    self.model_runner = ModelRunner(
  File "/root/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 194, in __init__
    self.load_model()
  File "/root/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 317, in load_model
    self.model = get_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/loader.py", line 357, in load_model
    model = _initialize_model(
  File "/root/projects/sglang/python/sglang/srt/model_loader/loader.py", line 138, in _initialize_model
    return model_class(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2_5_vl.py", line 513, in __init__
    self.model = Qwen2Model(config, quant_config)
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 241, in __init__
    self.layers = make_layers(
  File "/root/projects/sglang/python/sglang/srt/utils.py", line 313, in make_layers
    [
  File "/root/projects/sglang/python/sglang/srt/utils.py", line 314, in <listcomp>
    maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 243, in <lambda>
    lambda idx, prefix: Qwen2DecoderLayer(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 190, in __init__
    self.mlp = Qwen2MLP(
  File "/root/projects/sglang/python/sglang/srt/models/qwen2.py", line 69, in __init__
    self.down_proj = RowParallelLinear(
  File "/root/projects/sglang/python/sglang/srt/layers/linear.py", line 1159, in __init__
    self.quant_method.create_weights(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 188, in create_weights
    verify_marlin_supports_shape(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 110, in verify_marlin_supports_shape
    raise ValueError(f"Weight input_size_per_partition = "
ValueError: Weight input_size_per_partition = 14784 is not divisible by min_thread_k = 128. Consider reducing tensor_parallel_size or running with --quantization gptq.

[2025-02-14 07:01:26] Received sigquit from a child proces. It usually means the child failed.
Killed

noticed. Could you move to a new issue? Thx

@mickqian mickqian force-pushed the qwen-2.5-vl branch 2 times, most recently from 8529af2 to eea7629 Compare February 14, 2025 09:12
@mickqian mickqian force-pushed the qwen-2.5-vl branch 9 times, most recently from e5f7b87 to 3a938fe Compare February 16, 2025 08:40
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I kindly ask why the order needs to be specifically adjusted here? I wasn't able to find any reference materials on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants