[CPU] Optimize small oc GEMM for Qwen3-next on CPU#12446
[CPU] Optimize small oc GEMM for Qwen3-next on CPU#12446Fridge003 merged 14 commits intosgl-project:mainfrom
Conversation
mingfeima
left a comment
There was a problem hiding this comment.
make as little change in the python level as possible,
make the changes happen in C++ level.
that's to say, we still use weight_packed_linear, but we use a different implementation for it under certain conditions.
|
@jianan-gu put more detailed descriptions in the PR message. |
|
one more thing, using F32 brgemm might not be a good idea from performance point of view, but this kernel doesn't contribute too much in the end to end profiling results, we can keep it as it is right now. |
Noted and added comments in codes. |
|
Test commit: 0b97584
|
|
Hi, could you please fix conflicts? thx! |
Thanks and have fixed. |
|
@jianan-gu rebase |
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
* [CPU] Optimize small oc GEMM for Qwen3-next on CPU (sgl-project#12446) Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com> * port linear_gelu_linear kernel * apply linear_gelu_linear for TP=1 * fix numa memory bind * apply parallel partition patch --------- Co-authored-by: jianan-gu <jianan.gu@intel.com>
* [CPU] Optimize small oc GEMM for Qwen3-next on CPU (sgl-project#12446) Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com> * port linear_gelu_linear kernel * apply linear_gelu_linear for TP=1 * fix numa memory bind * apply parallel partition patch --------- Co-authored-by: jianan-gu <jianan.gu@intel.com>
* port layernorm 3d * apply layernorm * support for bias * fix * intf fix * add support for CPU * fix tp=3/6 padding issue in encoder vision * fix tp=3/6 padding issue in qwen3-omni * refactor code * add mrope * change attention_mask shape to use flash attn * add kernel apply_rotary_pos_emb_cpu * replace nn.Linear with ReplicatedLinear * enable torch.compile * construct mask using query.dtype instead of bool on CPU * add fast path for sparse attention * fix double free segfault by wrong setting of BLOCK_M * improve extend kernel performance for long context length * update test_extend.py * update comment * fix topk softmax performance issue * port optimization for image preprocessor in Qwen2VLImageProcessorFast * apply optimization for image preprocessor * update docker file * optimize conv3d used in patch embedding * resolve conflict * apply optimized conv3d * apply optimization for flash_attn_varlen_func (sgl-project#19) * port optimization for flash_attn_varlen_func * apply flash_attn_varlen_func * remove contiguous before rope (sgl-project#20) * Revert "resolve conflict" This reverts commit 7622f6d. * fix after rebase * Update pyproject_cpu.toml * Update xeon.Dockerfile * minor fix after rebase * rope: add support for bf16 sincos (sgl-project#102) * format * Update xeon.Dockerfile * odd tp for cpu * Apply linear_gelu_linear and fix numa memory bind (sgl-project#22) * [CPU] Optimize small oc GEMM for Qwen3-next on CPU (sgl-project#12446) Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com> * port linear_gelu_linear kernel * apply linear_gelu_linear for TP=1 * fix numa memory bind * apply parallel partition patch --------- Co-authored-by: jianan-gu <jianan.gu@intel.com> * Revert "Fix: test_vlm_offline_throughput output throughput (sgl-project#13279)" (sgl-project#101) This reverts commit 7ee3e36. * fix input dtype mismatch issue * apply optimized layernorm --------- Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com> Co-authored-by: ZailiWang <zaili.wang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com>
This PR adds optimization for BF16 GEMMs with small oc shapes ( oc < 16), including:
weight_packed_linear, which is common for all GEMMs with oc < 16fused_linear_sigmoid_mul, which is specific for cases found in Qwen3-next (shared_expert_gate, where N = 1 and having post op sigmoid+mul)