[Bugfix] Fix Qwen3.5 Marlin TP failure for GDN in_proj_ba#36199
[Bugfix] Fix Qwen3.5 Marlin TP failure for GDN in_proj_ba#36199AjAnubolu wants to merge 2 commits into
Conversation
The in_proj_ba linear layer has output dim = 2 * num_kv_heads which can be < GPTQ_MARLIN_MIN_THREAD_N (64) when sharded. Use disable_tp and quant_config=None for this layer, then manually slice b/a for the local TP rank. Fixes vllm-project#35924 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request addresses a Tensor Parallelism failure for Qwen3.5 with Marlin quantization. The fix involves disabling tensor parallelism for the in_proj_ba layer in the Gated Delta Network, which is too small to be sharded correctly with Marlin's constraints. Instead, the layer is replicated, and its output is manually sliced per TP rank. The changes in qwen3_5.py and qwen3_next.py are consistent with this approach. However, I've found a critical issue in qwen3_next.py where a reshape operation is incorrect for sequence lengths greater than 1, which will likely cause a runtime error during prefill.
| b = b.reshape(b.size(0), self.num_v_heads) | ||
| a = a.reshape(a.size(0), self.num_v_heads) |
There was a problem hiding this comment.
The reshape operation for b and a appears to be incorrect for sequence lengths (sq) greater than 1. The tensor b has a shape of (bs, sq, num_k_heads, num_v_heads // num_k_heads), but it's being reshaped to (bs, self.num_v_heads). This will raise a runtime error during prefill when sq > 1 because the number of elements won't match.
The reshape should probably flatten the batch and sequence dimensions (bs and sq) to get a tensor with shape (num_tokens, num_v_heads).
| b = b.reshape(b.size(0), self.num_v_heads) | |
| a = a.reshape(a.size(0), self.num_v_heads) | |
| b = b.reshape(-1, self.num_v_heads) | |
| a = a.reshape(-1, self.num_v_heads) |
Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has merge conflicts that must be resolved before it can be |
Summary
Closes #35924
Split the GDN
in_proj_balinear into separatein_proj_bandin_proj_aso each column dimension meets Marlin's MIN_THREAD_N=64 constraint at TP>=4.