Skip to content

Refactor weight loading#41580

Merged
ArthurZucker merged 387 commits into
mainfrom
refactor-weight-loading
Nov 13, 2025
Merged

Refactor weight loading#41580
ArthurZucker merged 387 commits into
mainfrom
refactor-weight-loading

Conversation

@ArthurZucker

@ArthurZucker ArthurZucker commented Oct 14, 2025

Copy link
Copy Markdown
Collaborator

CORE REFACTORING, loading, converting, logging

More helpful debugging report when loading weights
image

If you just want to fuse qkv:
image

It can. You just need to make sure you change the model code and pouf!

            WeightConverter(
                ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
                "self_attn.qkv_proj",
                operations=[Concatenate(dim=0)],  # more like stack?
            ),

For deepseek we will embed the rope permute:

            WeightConverter(
                ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
                operations=[RopePermute()],  # more like stack?
            ),

WeightConverter API:

The API allows you to define a mapping using WeightConverter. You can define many to one source/target keys, quantization opérations and distributed opérations along with normal opérations. For now MergeModuleLIst and Concatenate, will add the RopePermute one soon.

_checkpoint_conversion_mapping = {
    "mixtral": [
        WeightConverter(
            source_keys=[
                "mlp.experts.*.w1.weight",
                "mlp.experts.*.w3.weight",
            ],
            target_keys="mlp.experts.gate_up_proj",
            operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
        ),
        WeightConverter(
            source_keys=["mlp.experts.*.w2.weight"],
            target_keys="mlp.experts.down_proj",
            operations=[MergeModulelist(dim=0)],
        ),
    ],
}

We use to have this:

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4545-L4568

But now its just explicit:

        "legacy": [
            WeightConverter(
                source_keys="LayerNorm.gamma",
                target_keys="LayerNorm.weight",
            ),
            WeightConverter(
                source_keys="LayerNorm.beta",
                target_keys="LayerNorm.bias",
            ),
        ],
    }
    if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
        mapping["legacy"] += [
            WeightConverter(
                source_keys="weight_g",
                target_keys="parametrizations.weight.original0",
            ),
            WeightConverter(
                source_keys="weight_v",
                target_keys="parametrizations.weight.original1",
            ),
        ]
    else:
        mapping["legacy"] += [
            WeightConverter(
                source_keys="parametrizations.weight.original0",
                target_keys="weight_g",
            ),
            WeightConverter(
                source_keys="parametrizations.weight.original1",
                target_keys="weight_v",
            ),
        ]

and its faster cuz we don't iterate over the whole checkpoint

The core logic is:
Iterate over all of the dict keys:

  1. collect the keys that match the glob patterns from all source keys (you pipe the ones that are from the same weight converter): (mlp.experts.*.gate_proj.weight|mlp.experts.*.up_proj.weight) into a dict with key target key

This produces:

{ 
"mlp.experts.gate_up_proj" : 
    {"mlp.experts.*.w1.weight":
        { "mlp.experts.0.w1.weight": [t0, t1, t2, etc], "mlp.experts.1.w1.weight": [t0, t1, t2, etc]},
     "mlp.experts.*.w3.weight":
        { "mlp.experts.0.w3.weight": [t0, t1, t2, etc], "mlp.experts.1.w3.weight": [t0, t1, t2, etc]},
    }
  ....
}

We need to keep track of which layers were collected, and from which source pattern.

1bis. Schedule tensor materialization, without blocking the GIL (as this takes the most amount of time). We distribute the tensor at this stage, before any operations. This IS the trickiest. We do this during collection to not waste time.

  1. We collect the results of materialization, and we apply the operations on all the collected values (at this point { "mlp.experts.0.w1.weight": [t0, t1, t2, etc], "mlp.experts.1.w1.weight": [t0, t1, t2, etc]}.values() gives a list of lists.
  2. We create a dict with the target_key and the output values. We pass this to the quantizer
  3. We quantize the input tensors, outputting the final dict.
  4. We set the param into the model.

Keys are handled a lot better!

Enable MoE quantization for FP8

This script does not work on main

import torch
from transformers import MixtralForCausalLM, AutoTokenizer, FineGrainedFP8Config
import time 
quantization_config = FineGrainedFP8Config(modules_to_not_convert=["model.layers.*.mlp.gate"])
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", quantization_config=quantization_config, tp_plan="auto")

Enable TP + MoE without OOM

This script does not work on main

model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", tp_plan="auto")

Enable device_map="auto" + MoE + FP8

This script does not work on main

quantization_config = FineGrainedFP8Config(modules_to_not_convert=["model.layers.*.mlp.gate"])
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", quantization_config=quantization_config, device_map="auto")

Refactor the way we load weights, faster, flexible and better overall

Uses staging buffers per conversion op

  • 4x speedup with device_map="auto"
  • Full MoE quantization with FP8

TODOS:

  • Test with TP / EP
  • Add TQDM!
  • Test with deepspeedd
  • Test with loras and peft
  • Test with vllm backend
  • Test with fsdp
  • Add saving

Script:

import torch
from torch import nn
from transformers import MixtralForCausalLM, AutoTokenizer

import time 
start = time.time()
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
end = time.time() 
print("loading took ", end-start)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
inputs = tokenizer("hey how are you?", return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=16)
print(tokenizer.batch_decode(out))
loading took  14.271092891693115
['<s> hey how are you?\n\nI am a 20 year old male and I have been having']

⬆️ is with: merge modulelist, concat gate_up
⬇️ is naive loading.

loading took  54.271092891693115
['<s> hey how are you?\n\nI am a 20 year old male and I have been having']

Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/conversion_mapping.py Outdated
Comment thread src/transformers/conversion_mapping.py Outdated
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/conversion_mapping.py Outdated
Comment thread src/transformers/conversion_mapping.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated

@LysandreJik LysandreJik left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive effort

ArthurZucker pushed a commit that referenced this pull request Nov 27, 2025
* FIX Minimal fix for loading PEFT weights

After the weight conversion PR #41580, some adjustments were still
required for loading PEFT weights. This PR presents a minimal fix to
make it work again.

Besides renaming keys, this PR does not address possible conversions
that might be required to be applied to the PEFT weights
themselves (most wouldn't work anyway, but e.g. chunking should be
possible to implement).

As for test, the existing test_peft_from_pretrained in
test_peft_integration.py actually fails on main right now, this PR fixes
it. As the tests are slow tests, normal CI won't pick this up though.

* Allow n:n matching

* Reviewer feedback
@IlyasMoutawwakil

IlyasMoutawwakil commented Dec 4, 2025

Copy link
Copy Markdown
Member

This is actually super cool and could allow for batched experts inference which is traceable/exportable and faster in some cases when memory is not an issue ! Great work @ArthurZucker ! is there a plan for enabling pure-pytorch batched experts ? I can imagine something like moe_implementation which can be sequential/batched ?

@ArthurZucker

Copy link
Copy Markdown
Collaborator Author

Yes! cc @3outeille if he has time, but if you want to tackle it you should! (torch._bmm they have a native op now)

@stevhliu stevhliu mentioned this pull request Dec 4, 2025
@3outeille

Copy link
Copy Markdown
Member

@IlyasMoutawwakil yes definitely, how time sensitive is it ?

@IlyasMoutawwakil

Copy link
Copy Markdown
Member

@3outeille great ! nothing time sensitive (since we are patching the MoEs in optimum-onnx/optimum-intel for now) but would make life much easier to control this behavior with a single argument. lmk if you would like to tackle this !

sarathc-cerebras pushed a commit to sarathc-cerebras/transformers that referenced this pull request Dec 7, 2025
* FIX Minimal fix for loading PEFT weights

After the weight conversion PR huggingface#41580, some adjustments were still
required for loading PEFT weights. This PR presents a minimal fix to
make it work again.

Besides renaming keys, this PR does not address possible conversions
that might be required to be applied to the PEFT weights
themselves (most wouldn't work anyway, but e.g. chunking should be
possible to implement).

As for test, the existing test_peft_from_pretrained in
test_peft_integration.py actually fails on main right now, this PR fixes
it. As the tests are slow tests, normal CI won't pick this up though.

* Allow n:n matching

* Reviewer feedback
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* ah actually we don't discard lm head if missing -> needs to be moved to correct device and etc

* fix some tests

* small fixes

* up

* up

* dik why we tie weights twice but,..,,.

* ups

* removeunused

* fix hunyuan

* small fix

* nits

* ish

* up

* rev

* fix more tie weights keys

* small fixes

* nit

* update

* fix and fix

* fix a test

* glubs

* current shitty changes

* ship validated ones

* more

* more update

* more

* more

* more

* mllama

* more up

* fix ernie

* fix xopies

* up more

* more fixes

* up

* up

* fix-copies

* fix more

* more updates

* AI UPDATE

* up

* hoey

* make it fast

* fix

* lol

* fix asjusting

* more fixes

* _dtype nit

* up

* nit

* update

* update

* remove semaphores

* fix import to avoid jit execution

* try to remove custom tiing logic when its stupid

* fix more individual models

* fix whisper as well

* fix?

* fox umt5

* improve tqdm bar

* cleanup a bit

* oupsi

* some updates

* improve

* remove all buffering -> much faster without it

* remove some tie_weights custome funcs when not needed

* more fixes related to strict matching regex

* remove ALL custom tie weights

* small update

* revert change to init scheme (no need for params)

* mixtral init

* try less strict source check

* tied weight first shot to the fiiiixxxxxx

* does this help?

* :)

* fix some ppolry defined tied_weights_keys for now

* subclass nn.Parameters

* up

* lol

* Ouiiii

* fix led

* fix long cat flash

* fix qwen and long cat flash

* properly fix qwen init

* just push this for now

* propnet is dumb

* update

* push

* remove explict sharing of some tied keys.

* update decoder.bias

* moe case

* more changes to untangle old hardcoded ting

* fixup

* fix big faileurs

* fix prophnet

* fix resize token embeddings

* nits

* fix xcodex

* asyncio?

* fix smart apply

* fix data-2-vec

* [build-ci-image]

* checkout

* uupdate

* fix hunyuan

* update error message

* fix deformable detr

* fixes

* fix init weights for non param gate up projs

* shared todo?

* update some models

* big revert, don't break this behaviour

* ty @SunMarc this fixes the buffers

Co-authored-by: SunMarc <SunMarc@users.noreply.github.com>

* mt5 fuck

* fix lxmbert

* nuke slow test fetcher

* fix zamba and deepcopy for now

* fix zamba tied weight keys! ~

* fix-copies

* update fetch terst

* fix gradient for test modeling common!

* break "shared" for now I will fix tomorrow changes are properly isoalted now :)

* does this fix marian? probably not

* fix some vlms

* D fine seems to handle this well

* glob is fine actually

* fix dab detr

* small steps

* opusy

* fix some more models?

* yups

* better erro

* fix?

* fix double escape

* escape wehere it makes sense

* ??

* fix ibert

* fix tvp as well

* more fxes

* try always download ref PR

* ONONONO

* big fixup

* more fixup

* small step

* small nits

* nits

* brut force some stuff

* fix vilt

* make sure special models that always need tie always tie

* cleaning up

* small nits

* fix zamba and bridge tower!

* just fixup

* potential culprits

* revert bark and fix bridgetower

* remove now non existant tie_weights

* ?

* lol reformer actually had nothing tied!

* wow these two fucking models were really not well made

* fix sam family!

* fix bark revision

* fix speech2test ?

* push this for now....

* upsy

* the fuck

* fix rtdetr

* update

* proper

* wow that one 's annoying

* update

* try to find the culprit

* get some help on common

* nit about general init and cls.padding_idx

* revert num workers update

* remove old loading func

* fix glob

* add annotations

* fix re

* small improvements

* clean some stuff

* improvements

* someone did not understannnnnnd what I tried to dooo or does BNB not support that either?

* gluos

* fix case when `.` is just not there

* remove unused arg

* recover orignal parameter/buffer using _original

* fix glob issu

* this?

* deepspeed best-effort

* remove unused stuff

* Update tie weight keys as they were just wroong

Co-authored-by: Benjamin Bossan <benjaminbossan@users.noreply.github.com>"

* up

* augustuc clauss, a gloubs gloups gloubs

* fixup

* fixup

* there was fucking typo

* mrain

* nits

* fix marian 3 remaining tests

* one more

* fix some of the copies, not all :)

* small cleanup

* one propertest

* fix core model loadig tes

* attempt a new test

* fix some of the annoying tests by supporting reading .bin sometimes

* push

* push more small fixes

* remove 1 useless test

* up

* fix audio flamingo post rebase

* fixup

* some small updatess

* fix sam models

* nits

* up

* updates

* onem ore

* skip this stupid test

* some other fixes

* fixup

* update

* skip more offloaded stuff

* oups

* ups

* update mixtral

* skip this one

* LET"SGO

* fixup

* rope delta order

* fix csm

* small nit

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: SunMarc <SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* FIX Minimal fix for loading PEFT weights

After the weight conversion PR huggingface#41580, some adjustments were still
required for loading PEFT weights. This PR presents a minimal fix to
make it work again.

Besides renaming keys, this PR does not address possible conversions
that might be required to be applied to the PEFT weights
themselves (most wouldn't work anyway, but e.g. chunking should be
possible to implement).

As for test, the existing test_peft_from_pretrained in
test_peft_integration.py actually fails on main right now, this PR fixes
it. As the tests are slow tests, normal CI won't pick this up though.

* Allow n:n matching

* Reviewer feedback
@winglian winglian mentioned this pull request Mar 5, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Core: Modeling Internals of the library; Models. for_v5?

Projects

None yet

Development

Successfully merging this pull request may close these issues.