Skip to content

[AWQ] Evaluate fused vs unfused GEMM on actual shape#30783

Closed
mgehre-amd wants to merge 3 commits intovllm-project:mainfrom
mgehre-amd:matthias.awq_specialize
Closed

[AWQ] Evaluate fused vs unfused GEMM on actual shape#30783
mgehre-amd wants to merge 3 commits intovllm-project:mainfrom
mgehre-amd:matthias.awq_specialize

Conversation

@mgehre-amd
Copy link
Copy Markdown
Contributor

@mgehre-amd mgehre-amd commented Dec 16, 2025

to get a speed up of 2.5x for non-batched decode (single requests).

Before this PR, the condition

FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
    out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    return torch.matmul(input, out)
else:
    return awq_gemm(...)

was evaluated during torch.compile based on max-num-batched-tokens. By default, max-num-batched-tokens is over the threshold, which meant that awq_gemm was never taken, even when later doing a single request decode (input.shape[0] == 1).

To evaluate the condition for each specific shape during during CUDA graph capture, this PR move the decision into awq_gemm and wraps that into a torch custom op, which shields it from being traced through.

Test results on Strix Halo (gfx1151):

$ vllm bench throughput --model Qwen/Qwen3-4B-AWQ --input-len 1024 \
     --gpu-memory-utilization 0.4 --max-model-len 4096 --num-prompts 30

# on main
Throughput: 0.75 requests/s, 860.06 total tokens/s, 95.56 output tokens/s
Total num prompt tokens:  30720
Total num output tokens:  3840

# on this PR
Throughput: 0.85 requests/s, 978.43 total tokens/s, 108.71 output tokens/s
Total num prompt tokens:  30720
Total num output tokens:  3840

Test results with max-num-seq=1 on Strix Halo (gfx1151):

$ vllm bench throughput --model Qwen/Qwen3-4B-AWQ --input-len 1024 \
   --gpu-memory-utilization 0.4 --max-model-len 4096 --num-prompts 30 \
   --max-num-seqs=1

# on main
Throughput: 0.08 requests/s, 96.89 total tokens/s, 10.77 output tokens/s
Total num prompt tokens:  30720
Total num output tokens:  3840

# on this PR
Throughput: 0.20 requests/s, 231.17 total tokens/s, 25.69 output tokens/s
Total num prompt tokens:  30720
Total num output tokens:  3840

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for AWQ GEMM operations by deferring the choice between fused and unfused kernels to runtime. By wrapping awq_gemm in a torch.custom_op, the heuristic condition is evaluated on the actual input shape during CUDA graph capture, rather than at compile time. This change correctly moves the dispatch logic into _custom_ops.py and is expected to significantly improve performance for non-batched decoding, as demonstrated by the provided benchmarks. The implementation appears correct and the changes are well-justified.

@mgehre-amd mgehre-amd changed the title AWQ: Evaluate fused vs unfused GEMM on actual shape [AWQ] Evaluate fused vs unfused GEMM on actual shape Dec 17, 2025
@yuttian1
Copy link
Copy Markdown
Contributor

yuttian1 commented Dec 18, 2025

Hi, thanks for the PR. I might be missing some design context here, but I’m seeing the following error:
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
The model I’m using is Qwen3-235B-VL-AWQ.

@mgehre-amd
Copy link
Copy Markdown
Contributor Author

assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8

@yuttian1, thanks for your feedback.
I'm not sure how this could be caused by this PR. Can you please share

  1. the full commands to reproduce
  2. the full output/traceback
  3. and whether the same things works on vLLMs main branch?
    Thanks!

@yuttian1
Copy link
Copy Markdown
Contributor

assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8

@yuttian1, thanks for your feedback. I'm not sure how this could be caused by this PR. Can you please share

  1. the full commands to reproduce
  2. the full output/traceback
  3. and whether the same things works on vLLMs main branch?
    Thanks!
    Sorry for the late reply!On my side, I’m running vLLM under ROCm.
    Both awq_dequant + gemm and awq_gemm work fine when run independently, but this error appears after applying your changes.
    the full output is
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "<eval_with_key>.190", line 1046, in forward
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] submod_0 = self.submod_0(l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_inputs_embeds_, s59, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qweight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_scales_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qzeros_, l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_, s7); l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qweight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_scales_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qzeros_ = l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_ = None
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/compilation/cuda_graph.py", line 125, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.runnable(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/compilation/piecewise_backend.py", line 92, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.compiled_graph_for_general_shape(*args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return fn(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1207, in forward
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return compiled_fn(full_args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 331, in runtime_wrapper
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] all_outs = call_func_at_runtime_with_args(
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] out = normalize_as_list(f(args))
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 692, in inner_fn
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] outs = compiled_fn(args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 498, in wrapper
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return compiled_fn(runtime_args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 561, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.current_callable(inputs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 2444, in run
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return model(new_inputs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/root/.cache/vllm/torch_compile_cache/ccf95ef384/rank_0_0/inductor_cache/2i/c2iqegvoziigca2hnpiv4jhaa4ywkmdsbcdatjahluc3c5gkscpz.py", line 321, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] buf2 = torch.ops.vllm.awq_gemm.default(buf1, arg3_1, arg4_1, arg5_1, 8)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 776, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self._op(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/_custom_ops.py", line 461, in _awq_gemm
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/_custom_ops.py", line 442, in awq_dequantize
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return awq_dequantize_triton(qweight, scales, zeros)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/model_executor/layers/quantization/awq_triton.py", line 246, in awq_dequantize_triton
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

@yuttian1
Copy link
Copy Markdown
Contributor

yuttian1 commented Dec 19, 2025

assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8

@yuttian1, thanks for your feedback. I'm not sure how this could be caused by this PR. Can you please share

  1. the full commands to reproduce
  2. the full output/traceback
  3. and whether the same things works on vLLMs main branch?
    Thanks!
    Sorry for the late reply!On my side, I’m running vLLM under ROCm.
    Both awq_dequant + gemm and awq_gemm work fine when run independently, but this error appears after applying your changes.
    the full output is
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "<eval_with_key>.190", line 1046, in forward
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] submod_0 = self.submod_0(l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_inputs_embeds_, s59, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qweight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_scales_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qzeros_, l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_, s7); l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qweight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_scales_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_qzeros_ = l_self_modules_layers_modules_0_modules_self_attn_modules_q_norm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_k_norm_parameters_weight_ = None
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/compilation/cuda_graph.py", line 125, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.runnable(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/compilation/piecewise_backend.py", line 92, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.compiled_graph_for_general_shape(*args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return fn(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1207, in forward
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return compiled_fn(full_args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 331, in runtime_wrapper
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] all_outs = call_func_at_runtime_with_args(
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] out = normalize_as_list(f(args))
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 692, in inner_fn
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] outs = compiled_fn(args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 498, in wrapper
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return compiled_fn(runtime_args)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 561, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self.current_callable(inputs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 2444, in run
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return model(new_inputs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/root/.cache/vllm/torch_compile_cache/ccf95ef384/rank_0_0/inductor_cache/2i/c2iqegvoziigca2hnpiv4jhaa4ywkmdsbcdatjahluc3c5gkscpz.py", line 321, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] buf2 = torch.ops.vllm.awq_gemm.default(buf1, arg3_1, arg4_1, arg5_1, 8)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 776, in call
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return self._op(*args, **kwargs)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/_custom_ops.py", line 461, in _awq_gemm
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/_custom_ops.py", line 442, in awq_dequantize
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] return awq_dequantize_triton(qweight, scales, zeros)
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] File "/opt/vllm/vllm/model_executor/layers/quantization/awq_triton.py", line 246, in awq_dequantize_triton
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
    (EngineCore_DP0 pid=12538) ERROR 12-19 06:40:14 [core.py:790] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

I’ve looked into this further and found the issue was on my side. My local vLLM version was outdated, and the argument order of scales and qzeros didn’t match the current implementation. After fixing that, the issue no longer reproduces.

With this PR applied, I do see a performance improvement. Under my test setup (concurrency = 1, ~88k input tokens).

@mgehre-amd
Copy link
Copy Markdown
Contributor Author

Friendly ping, is anyone available as reviewer?

Before this PR, the condition
```
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
    out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    return torch.matmul(input, out)
else:
    return awq_gemm(...)
```
was evaluated during `torch.compile` based on `max-num-batched-tokens`.
By default, `max-num-batched-tokens` is over the threshold,
which meant that `awq_gemm` was never taken, even when doing a single
request decode.

To evaluate the condition for each specific shape during during
cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op,
which shields it from being traced through.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.awq_specialize branch from ba590e4 to 5e6c5a2 Compare January 15, 2026 13:10
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

Use input.dtype instead of scales.dtype to match the underlying
_C::awq_gemm fake implementation for torch.compile type inference.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
…lize

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd
Copy link
Copy Markdown
Contributor Author

@mgoin, are you available for review?
I have follow up PRs in the pipeline that will further improve Qwen/Qwen3-4B-AWQ from 25 tok/s (this PR) to 57 tok/s.

zhandaz pushed a commit to CentML/vllm that referenced this pull request Feb 4, 2026
@mgehre-amd mgehre-amd closed this Feb 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants