Skip to content

GptOss experts implementation#43227

Merged
ArthurZucker merged 35 commits intomainfrom
gpt-oss-experts-impl
Jan 22, 2026
Merged

GptOss experts implementation#43227
ArthurZucker merged 35 commits intomainfrom
gpt-oss-experts-impl

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Jan 12, 2026

What does this PR do?

[wip] Fixes #43193 @vasqu @mattteochen
generate

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as draft January 12, 2026 09:06
@IlyasMoutawwakil
Copy link
Member Author

run-slow: gpt_oss

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as ready for review January 13, 2026 17:50
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Let's agree on transposed vs not transposed (shape ccould help?)

Comment on lines +82 to +88
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, makes it more readable!

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so we no longer transpose during weight conversion but do it manually? Not against it but we will need to solve the loading/reverse issue either way, not necessarily here

Left a few comments but they are mostly aesthetic / nits

@IlyasMoutawwakil
Copy link
Member Author

Not against it but we will need to solve the loading/reverse issue either way, not necessarily here

yes and the transposition can be achieved and works fine (all tests passà with the temporary weight renaming trick, however the biggest problem here is that megablocks expect the transposed format of expert weights so it might make more sense to transpose the other MoEs to be compatible with megablocks (if it's worth the speedup)

@vasqu
Copy link
Contributor

vasqu commented Jan 14, 2026

the biggest problem here is that megablocks expect the transposed format of expert weights

Argh, yea this will be hard to solve otherwise (would need to rewrite that kernel). Maybe we transpose the other way around for the other models at some point - depends on how much improvement that kernel provides

On another note, is the transpose expensive on runtime? I would expect slight perf regression having to call transpose here and there (in grouped_mm implementation, I see a few).

@IlyasMoutawwakil
Copy link
Member Author

On another note, is the transpose expensive on runtime?

no, the grouped_mm work on strided tensors so the transpose is just a view (no op).

Comment on lines +346 to +348
# Add the ones from the quantizer as well if provided
if hf_quantizer is not None:
weight_conversions.extend(hf_quantizer.get_weight_conversions())
weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's seems that the order of operations matters when both quantization conversions and weight conversions are defined. e.g. with gpt_oss if I want to apply the Force16ByteAlignment conversion, it has to be placed after the dequantization otherwise the loader is confused (see output below). @vasqu @ArthurZucker

device = torch.device("cuda:1")
model_id = "openai/gpt-oss-20b"
model = GptOssForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    dtype=torch.bfloat16,
    quantization_config=Mxfp4Config(dequantize=True),
).eval()

results in:

Loading weights: 100%|█████████████████████████████████████████████| 363/363 [00:00<00:00, 1003.90it/s, Materializing param=model.norm.weight]
GptOssForCausalLM LOAD REPORT from: openai/gpt-oss-20b
Key                                                   | Status     | 
------------------------------------------------------+------------+-
model.layers.{0...23}.mlp.experts.down_proj_scales    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.down_proj_blocks    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_scales | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_blocks | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj        | MISSING    | 
model.layers.{0...23}.mlp.experts.down_proj           | MISSING    | 

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING       :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Since get_weight_conversions is only to DEQUANTIZE it makes sense that it's first. However, as discussed offline, for now it's not possible to match 1 param with 2 converter (i.e. 1 dequant converter, and 1 from the hardcoded mapping). So it means that any model with a mapping registered cannot be dequantized 😭 So in theory you're right it should come first, but as it does not work anyway currently it does not make a difference 😭 Let's still keep this change adding a comment on all that please!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @Cyrilvallez added a comment mostly rewording your explanation !

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I guess the outputs slightly change now with the interface's grouped_mm/batched_mm or even the eager rewrite?

@IlyasMoutawwakil
Copy link
Member Author

LGTM, I guess the outputs slightly change now with the interface's grouped_mm/batched_mm or even the eager rewrite?

the eager for training is the same, I kept everything as is when i realized that the slow tests are brokens (see #43246), I removed the batched implementation from there as it is redundant, so users will have to set experts impl to batched to get the same eval output as before.

@vasqu
Copy link
Contributor

vasqu commented Jan 16, 2026

Probably merging next week, currently don't have the permission to merge with red CI 😄

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! LGTM, lets make sure quantized tests pass but let's go!

Comment on lines +363 to +367
# NOTE: Since get_weight_conversions() only serves to dequantize, we need to put them first in the list.
# However, for now it's not possible to match 1 param with 2 converters (i.e. 1 dequantization converter
# and 1 model-specific converter). Which means that if a model that has model-specific conversions and is being
# dequantized, the model-specific conversion that has patterns matching the dequantization patterns will be ignored.
weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it makes a difference at all no? Because the operation are ordered by length of collected tensors I thing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah now that I'm digging deeper into the weight loader I think the reason I got the error above is because it's not possible to cascade converters (i.e., applying model-specific conversions on top of tensors created by the dequantization conversions). Not because you can't match one tensor with two converters (that's a valid limitation, but not the one happening here in gpt oss).

I added $ to the end of Force16BytesAlignement source pattern to fix my error without changing this order. Basically making sure that the sources of the mxfp dequant converter and Force16BytesAlignement are exclusive.
I will revert the line change and make this comment instead:

        # NOTE: Since get_weight_conversions() only serve to dequantize, we normally want to apply them first.
        # However, for now it's not possible to cascade converters (i.e., applying model-specific conversions on top
        # of tensors created by the dequantization conversions)
        # This means that if a model has model-specific conversions and is being dequantized, the model-specific conversion
        # that relies on tensors created by dequantization conversions will not be applied.
        # GptOss example: with Mxfp4Config(dequantize=True), Force16BytesAlignment converters are ignored because the tensors
        # "mlp.experts.gate_up_proj$" and "mlp.experts.down_proj$" are only created after dequantization conversions are applied.
        weight_conversions.extend(hf_quantizer.get_weight_conversions())

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, as we talked offline let's add more comments about how to handle the weight converter

Comment on lines +444 to +447
"""
Ensures that the given tensor is 16-bytes aligned in memory and clones it if not.
This garantees 16-bytes alignmenet for kernels / implementations that use TMA or SIMD instructions like torch._grouped_mm.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice 🫡

@IlyasMoutawwakil
Copy link
Member Author

Should be good to merge

@vasqu
Copy link
Contributor

vasqu commented Jan 19, 2026

#43353 leads me to believe that a simple weight converter won't be enough. Somehow under the concurrent workers we hit the same 16byte alignment issue on our CI with (qwen 3 omni moe).

I'll take a proper look tomorrow but it might make sense to add the fallback regardless sadly. Just wanted to put it out before we hastly merge

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 20, 2026

@vasqu I think what you are dealing with might be the cpu+safetensors memory misalignement issue i have seen with gptoss ; there are two kinds of misalignment issues i have encountered so far:

  • on cuda: this only happens if the strided tensor has shapes that are not divisible by 8 (for fp16/bf16) or 4 (for fp32), otherwise the cuda allocator will allocate an aligned tensor. That's the issue I was solving in the first experts PR, because I was testing on the dgx so tests were automatically running on cuda.
  • on cpu: Safetensors adds an 8 bytes-aligned metadata on top of the file (see this issue from which the 8 bytes were decided), so when reloaded on cpu the metadata can "misalign" the tensors in the memmaped file. This is what the weight converter defined here solves.

I would suggest you try the weight converter defined here because the ci runner uses cpu (I had to set TRANSFORMERS_TEST_DEVICE=cpu to reproduce the gpt_oss issue locally)

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 20, 2026

To sum up the conversation I had with @vasqu internally:

All moes that have a converter that creates new tensors, for example all the ones that merge their experts using cat/stack, will create 16-bytes aligned tensors on cpu. This is because they will use the allocator to create new merged tensors, and the allocator will return a 16-bytes aligned memory.
If any MoE fails because of 16 bytes misalignment, it's gonna be because it's not allocating its tensors ; for example using safetensors memmaped tensors (which are only 8 bytes aligned). Those are the ones that will need the Force16BytesAlignement converter.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_oss

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43227&sha=711a65

@ArthurZucker ArthurZucker merged commit 2d4d8fe into main Jan 22, 2026
32 of 34 checks passed
@ArthurZucker ArthurZucker deleted the gpt-oss-experts-impl branch January 22, 2026 09:34
vaibhav-research pushed a commit to vaibhav-research/transformers that referenced this pull request Jan 22, 2026
* experts impl gpt oss

* no need to transpose dequantized experts

* skip test_reverse_loading_mapping

* fix custom gating

* revert transposition and simply support transposed experts to avoid modifying eager

* style

* don't rely on weight shapes as they can be square matrices

* no need to relaod

* fallback to eager

* Update src/transformers/models/gpt_oss/modeling_gpt_oss.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* force 16 bytes alignmenet during weight loading

* simplify logic

* quantization conversions should be applied first

* avoid baddbmm as it is less performant / less optimizable by max-autotune

* no need for logger

* add comment explaining limitation

* standarize operations and only reshape when needed

* fixup conversion and test

* Update src/transformers/conversion_mapping.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* force alignment docstring

* move default apply gate

* offsets

* add docs and make kernel_config optional

* use reshapes as they are equivalent to views when memory is contiguous

* fix and better notes

* reshapes instead of views

* keep model saving and reloading in grouped_mm test to catch misalignment issues

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: vasqu <antonprogamer@gmail.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* experts impl gpt oss

* no need to transpose dequantized experts

* skip test_reverse_loading_mapping

* fix custom gating

* revert transposition and simply support transposed experts to avoid modifying eager

* style

* don't rely on weight shapes as they can be square matrices

* no need to relaod

* fallback to eager

* Update src/transformers/models/gpt_oss/modeling_gpt_oss.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* force 16 bytes alignmenet during weight loading

* simplify logic

* quantization conversions should be applied first

* avoid baddbmm as it is less performant / less optimizable by max-autotune

* no need for logger

* add comment explaining limitation

* standarize operations and only reshape when needed

* fixup conversion and test

* Update src/transformers/conversion_mapping.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* force alignment docstring

* move default apply gate

* offsets

* add docs and make kernel_config optional

* use reshapes as they are equivalent to views when memory is contiguous

* fix and better notes

* reshapes instead of views

* keep model saving and reloading in grouped_mm test to catch misalignment issues

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: vasqu <antonprogamer@gmail.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch._grouped_mm support for GptOssExperts

5 participants