Skip to content

Fix GLM 4.7 Lite MoE gating func#18980

Merged
pwilkin merged 3 commits intoggml-org:masterfrom
pwilkin:glm47fixrouter
Jan 21, 2026
Merged

Fix GLM 4.7 Lite MoE gating func#18980
pwilkin merged 3 commits intoggml-org:masterfrom
pwilkin:glm47fixrouter

Conversation

@pwilkin
Copy link
Contributor

@pwilkin pwilkin commented Jan 21, 2026

GLM 4.7 Lite uses SIGMOID, not SOFTMAX like Deepseek.

@pwilkin pwilkin requested a review from CISC as a code owner January 21, 2026 01:12
@ddh0
Copy link
Contributor

ddh0 commented Jan 21, 2026

Thank you!! 🫡

@github-actions github-actions bot added the model Model specific label Jan 21, 2026
@itzpingcat
Copy link

Please bro we need this!

@qingy1337
Copy link

@pwilkin Does this anyone running GLM 4.7 Flash right now via llama.cpp is running a faulty implementation?

@ddh0
Copy link
Contributor

ddh0 commented Jan 21, 2026

Yes.

@theo77186
Copy link
Contributor

The ideal solution would be fixing the gguf on conversion, in addition to this fix for existing gguf files. On the original config.json, the line "scoring_func": "sigmoid" is missing because of the default value being different between GLM and DeepSeek.

This fix can be applied on current llama.cpp builds with --override-kv deepseek2.expert_gating_func=int:2.

@danielhanchen
Copy link
Contributor

danielhanchen commented Jan 21, 2026

@pwilkin Nice find - re DeepSeek V3 / R1 uses sigmoid as well (not softmax). See https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L214:

    def route_tokens_to_experts(self, router_logits):
        router_logits = router_logits.sigmoid()

I'm not fully sure, but because DeepSeek V3.1 defined "scoring_func": "sigmoid", in https://huggingface.co/deepseek-ai/DeepSeek-V3.1/blob/main/config.json#L59, llama.cpp auto gets the correct one via https://github.com/ggml-org/llama.cpp/blob/master/convert_hf_to_gguf.py#L920 ie: (haha @theo77186 wrote the comment when I made this comment)

if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
    if score_func == "sigmoid":
        self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
    elif score_func == "softmax":
        self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
    else:
        raise ValueError(f"Unsupported expert score gating function value: {score_func}")
    logger.info(f"gguf: expert score gating function = {score_func}")

### So I think all models for GLM 4.5 until GLM 4.7 all need updating to use sigmoid.

I checked MiniMax (softmax), GPT-OSS (softmax), Qwen3 MoE (softmax), Qwen3 VL MoE (softmax), Qwen 3 Next (softmax), Ernie 4.5 MoE (softmax)

For Llama4 as well it's sigmoid see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L145, but I think (not 100% sure) in llama.cpp it's softmax: https://github.com/ggml-org/llama.cpp/blob/master/src/models/llama.cpp#L134C21-L134C58

TLDR:

  1. GLM 4.5, GLM 4.6, GLM 4.7 etc use sigmoid (not just GLM Flash) -> all need updating. It's because GLM's config.json doesn't emit "scoring_func": "sigmoid". [Needs fixing]
  2. DeepSeek V3, R1 etc use sigmoid (not softmax), but this is fine since "scoring_func": "sigmoid" was seen in all config.json files.
  3. Llama4 also uses sigmoid, but it's hardcoded to LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX [Needs fixing I think]
  4. MiniMax, GPT-OSS, Qwen3 MoE, Qwen3 VL MoE, Qwen 3 Next, Ernie 4.5 MoE all use softmax (fine)

Can folks verify if this is correct?

Update @theo77186 mentioned below GLM 4.5 to GLM 4.7 is hardcoded to use sigmoid in convert_hf_to_gguf.py

@theo77186
Copy link
Contributor

theo77186 commented Jan 21, 2026

GLM 4.5 (and its Lite variant), 4.6 and 4.7 (not Flash) uses glm4moe instead of deepseek2, which is already hardcoded using sigmoid which has correct gating function from the gguf.
edit: the sigmoid gating function is applied on gguf conversion there

self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)

@ubergarm
Copy link
Contributor

ubergarm commented Jan 21, 2026

Getting started on recomputing new imatrix and re cooking any imatrix quants too (GLM-4.7-Flash)

@danielhanchen
Copy link
Contributor

danielhanchen commented Jan 21, 2026

@theo77186 I was just about to ask! Thanks ye I can see "glm4moe.expert_gating_func" is defined to 2 which is correct.

I'm unsure on Llama4 though since it's hardcoded directly to LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX in https://github.com/ggml-org/llama.cpp/blob/master/src/models/llama.cpp#L134C21-L134C58 Maybe @ngxson can help on Llama4 verification

@danielhanchen
Copy link
Contributor

danielhanchen commented Jan 21, 2026

  1. Kimi K2 also uses sigmoid and is deepseekv2 arch, but luckily https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/config.json#L128 defines "scoring_func": "sigmoid", so Kimi K2 is fine.
  2. Mimo V2 MoE is fine as well (sigmoid)
  3. Nemotron 3 uses sigmoid (hardcoded)
  4. Mistral 3 Large uses DeepSeekV3 MoE but uses expert_gating_func = 1 ie Softmax - see https://github.com/vllm-project/vllm/pull/29757/changes which specifies config["scoring_func"] = "softmax" (fine)

@ngxson
Copy link
Contributor

ngxson commented Jan 21, 2026

Good job! logprobs test now report much better accuracy, just minor differences:

idx logits_llama.log logprob_1 logits_other.log logprob_2 diff (abs)
1 '\' -2.5015 '\' -2.5100 0.0085
2 ' here' -1.7050 ' here' -1.6765 0.0286
3 ' AI' -0.8085 ' AI' -0.9069 0.0985
4 ' AI' -0.9470 ' AI' -0.9261 0.0209
5 ' assistant' -0.6059 ' assistant' -0.6661 0.0602
6 ' specialized' -1.2674 ' designed' -1.3800 0.1126
7 ' of' -0.0039 ' of' -0.0041 0.0003
8 '\' -2.7554 '\' -2.7089 0.0466
9 ' the' -2.1011 ' the' -1.9421 0.1589
10 ' to' -0.9308 '.' -1.0014 0.0706
1011 ' you' -0.0000 ' you' -0.0000 0.0000
1012 ' need' -0.0000 ' need' -0.0000 0.0000
1013 ' to' -0.0000 ' to' -0.0000 0.0000
1014 ' use' -0.0002 ' use' -0.0000 0.0001
1015 ' a' -0.0000 ' a' -0.0000 0.0000
1016 ' tool' -0.0000 ' tool' -0.0000 0.0000
1017 ' output' -0.0000 ' output' -0.0000 0.0000
1018 ' the' -0.0000 ' the' -0.0000 0.0000
1019 ' call' -0.0000 ' call' -0.0000 0.0000
1020 ' in' -0.0000 ' in' -0.0000 0.0000
5021 ' requires' -0.0000 ' requires' -0.0000 0.0000
5022 ' external' -0.0000 ' external' -0.0000 0.0000
5023 ' data' -0.0000 ' data' -0.0000 0.0000
5024 ' computation' -0.0000 ' computation' -0.0000 0.0000
5025 ' or' -0.0000 ' or' -0.0000 0.0000
5026 ' actions' -0.0000 ' actions' -0.0000 0.0000
5027 ' beyond' -0.0000 ' beyond' -0.0000 0.0000
5028 ' your' -0.0000 ' your' -0.0000 0.0000
5029 ' internal' -0.0000 ' internal' -0.0000 0.0000
5030 ' knowledge' -0.0000 ' knowledge' -0.0000 0.0000

@danielhanchen
Copy link
Contributor

danielhanchen commented Jan 21, 2026

Confirmed it's working much much better - added "scoring_func": "sigmoid" to config.json on main llama.cpp so without the PR - https://huggingface.co/unsloth/GLM-4.7-Flash/commit/3fd53b491e04f707f307aef2f70f8a7520511e6d

For eg as a vibe check for a long convo for UD-Q4_K_XL ./llama.cpp/llama-cli --model GLM-4.7-Flash-GGUF/GLM-4.7-Flash-UD-Q4_K_XL.gguf --temp 1.0 --top-p 0.95 --min-p 0.01 --jinja: "Hi", "What is 2+2", "Create a Python Flappy Bird game", "Create a totally different game in Rust", "Find bugs in both", "Make the 1st game I mentioned but in a standalone HTML file", "Find bugs and show the fixed game"
image

pwilkin and others added 2 commits January 21, 2026 12:34
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
@pwilkin pwilkin merged commit 12a4a47 into ggml-org:master Jan 21, 2026
7 checks passed
@ngxson
Copy link
Contributor

ngxson commented Jan 21, 2026

Can be quite a perfect meme to sum up the situation

image

@qingy1337
Copy link

Good job! logprobs test now report much better accuracy, just minor differences:

idx logits_llama.log logprob_1 logits_other.log logprob_2 diff (abs)
1 '' -2.5015 '' -2.5100 0.0085
2 ' here' -1.7050 ' here' -1.6765 0.0286
3 ' AI' -0.8085 ' AI' -0.9069 0.0985
4 ' AI' -0.9470 ' AI' -0.9261 0.0209
5 ' assistant' -0.6059 ' assistant' -0.6661 0.0602
6 ' specialized' -1.2674 ' designed' -1.3800 0.1126
7 ' of' -0.0039 ' of' -0.0041 0.0003
8 '' -2.7554 '' -2.7089 0.0466
9 ' the' -2.1011 ' the' -1.9421 0.1589
10 ' to' -0.9308 '.' -1.0014 0.0706
1011 ' you' -0.0000 ' you' -0.0000 0.0000
1012 ' need' -0.0000 ' need' -0.0000 0.0000
1013 ' to' -0.0000 ' to' -0.0000 0.0000
1014 ' use' -0.0002 ' use' -0.0000 0.0001
1015 ' a' -0.0000 ' a' -0.0000 0.0000
1016 ' tool' -0.0000 ' tool' -0.0000 0.0000
1017 ' output' -0.0000 ' output' -0.0000 0.0000
1018 ' the' -0.0000 ' the' -0.0000 0.0000
1019 ' call' -0.0000 ' call' -0.0000 0.0000
1020 ' in' -0.0000 ' in' -0.0000 0.0000
5021 ' requires' -0.0000 ' requires' -0.0000 0.0000
5022 ' external' -0.0000 ' external' -0.0000 0.0000
5023 ' data' -0.0000 ' data' -0.0000 0.0000
5024 ' computation' -0.0000 ' computation' -0.0000 0.0000
5025 ' or' -0.0000 ' or' -0.0000 0.0000
5026 ' actions' -0.0000 ' actions' -0.0000 0.0000
5027 ' beyond' -0.0000 ' beyond' -0.0000 0.0000
5028 ' your' -0.0000 ' your' -0.0000 0.0000
5029 ' internal' -0.0000 ' internal' -0.0000 0.0000
5030 ' knowledge' -0.0000 ' knowledge' -0.0000 0.0000

any clue what these minor differences could be? 👀

@pfn
Copy link

pfn commented Jan 22, 2026

Not sure about GLM 4.7 in general, but it seems GLM 4.7 Flash is still doing all PP on CPU with latest llama.cpp and gguf (using CUDA backend, I tried vulkan backend, and its all done on GPU, but it is 10x slower than CUDA backend. vulkan is about 2-3x faster than using CPU backend)

To get PP off of CPU, I had to: fa = off, cache-type-k/cache-type-v = bf16, context = 32768 (from 98304)

shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* Fix GLM 4.7 MoE gating func

* Update src/models/deepseek2.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
michaelw9999 pushed a commit to michaelw9999/llama.cpp that referenced this pull request Mar 4, 2026
* Fix GLM 4.7 MoE gating func

* Update src/models/deepseek2.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.