[AWQ] Evaluate fused vs unfused GEMM on actual shape#30783
[AWQ] Evaluate fused vs unfused GEMM on actual shape#30783mgehre-amd wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for AWQ GEMM operations by deferring the choice between fused and unfused kernels to runtime. By wrapping awq_gemm in a torch.custom_op, the heuristic condition is evaluated on the actual input shape during CUDA graph capture, rather than at compile time. This change correctly moves the dispatch logic into _custom_ops.py and is expected to significantly improve performance for non-batched decoding, as demonstrated by the provided benchmarks. The implementation appears correct and the changes are well-justified.
|
Hi, thanks for the PR. I might be missing some design context here, but I’m seeing the following error: |
@yuttian1, thanks for your feedback.
|
|
I’ve looked into this further and found the issue was on my side. My local vLLM version was outdated, and the argument order of scales and qzeros didn’t match the current implementation. After fixing that, the issue no longer reproduces. With this PR applied, I do see a performance improvement. Under my test setup (concurrency = 1, ~88k input tokens). |
|
Friendly ping, is anyone available as reviewer? |
Before this PR, the condition
```
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
return torch.matmul(input, out)
else:
return awq_gemm(...)
```
was evaluated during `torch.compile` based on `max-num-batched-tokens`.
By default, `max-num-batched-tokens` is over the threshold,
which meant that `awq_gemm` was never taken, even when doing a single
request decode.
To evaluate the condition for each specific shape during during
cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op,
which shields it from being traced through.
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
ba590e4 to
5e6c5a2
Compare
Use input.dtype instead of scales.dtype to match the underlying _C::awq_gemm fake implementation for torch.compile type inference. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
…lize Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
|
@mgoin, are you available for review? |
…reset on `apt-get` (vllm-project#30783)" (#31) This reverts commit 2a60ac9.
to get a speed up of 2.5x for non-batched decode (single requests).
Before this PR, the condition
was evaluated during
torch.compilebased onmax-num-batched-tokens. By default,max-num-batched-tokensis over the threshold, which meant thatawq_gemmwas never taken, even when later doing a single request decode (input.shape[0] == 1).To evaluate the condition for each specific shape during during CUDA graph capture, this PR move the decision into
awq_gemmand wraps that into a torch custom op, which shields it from being traced through.Test results on Strix Halo (gfx1151):
Test results with
max-num-seq=1on Strix Halo (gfx1151):