Skip to content

Fix Qwen3Moe GGUF loading#1

Merged
ctcanbol merged 1 commit intoctcanbol:mainfrom
jusjinuk:main
Aug 10, 2025
Merged

Fix Qwen3Moe GGUF loading#1
ctcanbol merged 1 commit intoctcanbol:mainfrom
jusjinuk:main

Conversation

@jusjinuk
Copy link

What does this PR do?

I encountered the following error after trying out the latest commit in this branch to load a bf16 GGUF file of the Qwen3-30B-A3B model:

  File "/data_fast/home/jusjinuk/codes/transformers/src/transformers/modeling_gguf_pytorch_utils.py", line 121, in _split_moe_expert_tensor
    name = tensor_key_mapping[name]
KeyError: 'blk.24.ffn_down_exps.weight’

Upon investigation, I found that the issue arises because the code is not reading the config in this code. Thus, it uses the default Qwen3MoeConfig to instantiate the Qwen3Moe model in Transformers, which sets num_hidden_layers to 24 by default. This does not match the actual number of layers in the GGUF files.

The reason for not loading the correct config is because the code reads GGUF_CONFIG_MAPPING with the architecture name updated here, but GGUF_CONFIG_MAPPING has its keys with the old architecture name (qwen3moe), thus is not correctly reading the necessary config.

To address this, I updated L105 to reflect the updated architecture name in the GGUF_CONFIG_MAPPING as shown in this commit.
I believe qwen2moe should be updated similarly, but I left it unchanged since I have not personally tested the Qwen2Moe architecture. I also noticed a missing parameter (head_dim) and added it at L112 in the same commit.


After these changes, I confirmed that the output is generated without any issues, tested with the following code:

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch

gguf_model_path = "gguf_files/Qwen_Qwen3-30B-A3B-Instruct-2507-bf16"
gguf_file_name = "qwen3-30b-a3b-instruct-2507-bf16.gguf"
tokenizer_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"

model = AutoModelForCausalLM.from_pretrained(
    gguf_model_path, 
    gguf_file=gguf_file_name, 
    torch_dtype=torch.bfloat16, 
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
streamer = TextStreamer(tokenizer)

prompt = "Explain me what Large Language Models are."
chat = [
    {"role": "user", "content": prompt}
]
tokenized_chat = tokenizer.apply_chat_template(
    chat, tokenize=True, add_generation_prompt=True, return_tensors="pt")
tokenized_chat = tokenized_chat.to("cuda")
outputs = model.generate(tokenized_chat, max_new_tokens=100, streamer=streamer, eos_token_id=tokenizer.eos_token_id)

Thank you.

Fixes huggingface#39638

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Owner

@ctcanbol ctcanbol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ctcanbol ctcanbol merged commit fec7841 into ctcanbol:main Aug 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants