Fix GLM 4.7 Lite MoE gating func#18980
Conversation
|
Thank you!! 🫡 |
|
Please bro we need this! |
|
@pwilkin Does this anyone running GLM 4.7 Flash right now via llama.cpp is running a faulty implementation? |
|
Yes. |
|
The ideal solution would be fixing the gguf on conversion, in addition to this fix for existing gguf files. On the original config.json, the line This fix can be applied on current llama.cpp builds with |
|
@pwilkin Nice find - re DeepSeek V3 / R1 uses def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()I'm not fully sure, but because DeepSeek V3.1 defined if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
I checked MiniMax (softmax), GPT-OSS (softmax), Qwen3 MoE (softmax), Qwen3 VL MoE (softmax), Qwen 3 Next (softmax), Ernie 4.5 MoE (softmax) For Llama4 as well it's sigmoid see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L145, but I think (not 100% sure) in llama.cpp it's softmax: https://github.com/ggml-org/llama.cpp/blob/master/src/models/llama.cpp#L134C21-L134C58 TLDR:
Can folks verify if this is correct? Update @theo77186 mentioned below GLM 4.5 to GLM 4.7 is hardcoded to use sigmoid in convert_hf_to_gguf.py |
|
GLM 4.5 (and its Lite variant), 4.6 and 4.7 (not Flash) uses glm4moe instead of deepseek2, llama.cpp/convert_hf_to_gguf.py Line 8375 in 5bd341c |
|
Getting started on recomputing new imatrix and re cooking any imatrix quants too (GLM-4.7-Flash) |
|
@theo77186 I was just about to ask! Thanks ye I can see "glm4moe.expert_gating_func" is defined to 2 which is correct. I'm unsure on Llama4 though since it's hardcoded directly to |
|
|
Good job! logprobs test now report much better accuracy, just minor differences:
|
|
Confirmed it's working much much better - added For eg as a vibe check for a long convo for UD-Q4_K_XL |
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
any clue what these minor differences could be? 👀 |
|
Not sure about GLM 4.7 in general, but it seems GLM 4.7 Flash is still doing all PP on CPU with latest llama.cpp and gguf (using CUDA backend, I tried vulkan backend, and its all done on GPU, but it is 10x slower than CUDA backend. vulkan is about 2-3x faster than using CPU backend) To get PP off of CPU, I had to: fa = off, cache-type-k/cache-type-v = bf16, context = 32768 (from 98304) |
* Fix GLM 4.7 MoE gating func * Update src/models/deepseek2.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
* Fix GLM 4.7 MoE gating func * Update src/models/deepseek2.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>


GLM 4.7 Lite uses SIGMOID, not SOFTMAX like Deepseek.