[MoE] Quantisation support and patches: Bnb and FP8#659
Conversation
Upstream Bnb4bitQuantize.convert unwraps input_dict[key] twice (first .values() then [0]). The patched version did only the first unwrap, which produced a list for most weight-converter dispatch paths and raised TypeError inside Params4bit constructor — silently masked by the broad except Exception and falling back to original_convert (leaving experts unquantized). Empirically: tiny DeepSeek V3 happens to dispatch experts as bare Tensors so the bug was invisible there. Qwen3.5-35B-A3B dispatches as list[Tensor] and triggers the TypeError on every expert param. Reviewed-by: R1 (correctness), R3 (regression)
…-MoE warning (M7+M8) Package-wide convention (mxfp4.py, qwen3_moe.py, qwen3_5_moe.py, qwen3_vl_moe.py, qwen3_next_moe.py, glm4_moe.py, deepseek_v3_moe.py) gates every logger.info on UNSLOTH_ENABLE_LOGGING. The three info logs in moe_bnb_transformers.py and the unconditional warning when no experts found weren't gated and spammed every 4-bit load including dense models (Phi3, GLM4 dense, Llama, Mistral, etc.). Reviewed-by: R3 (regression), R4 (conventions)
…lause (m4+m6+M1) - Add __all__ to stop wildcard re-export of bnb/torch/nn/Optional/etc. into unsloth_zoo.temporary_patches namespace. - Tighten param_needs_quantization: only return True when Params4bit AND bnb_quantized=False, protecting against re-invocation after first quantize. - Narrow patched_param_needs_quantization except from bare Exception to (KeyError, AttributeError) — the expected failures from get_module_from_name. Reviewed-by: R1 (correctness), R3 (regression), R4 (conventions)
…h in detector (M3+B2)
M3: _get_base_weight calls bnb.functional.dequantize_4bit(param.data, None) when
param.quant_state has not been populated yet (meta-device placeholder, or before
.to(cuda)). Add explicit guard to fall through to other branches in that case.
B2: _is_moe_experts_module was tightened from (nn.Parameter|Tensor, ndim in (2,3))
to (Params4bit, ndim==2) OR (nn.Parameter, ndim==3). This contradicts the function's
own doc-comment ('After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj
returns torch.Tensor (not nn.Parameter), so we must accept both.'). Restore the
Tensor branch so the post-parametrize path is still detected.
Reviewed-by: R1 (correctness), R3 (regression)
…issing _original_shape (M5+M6+m5) R2 found two related issues in patch_peft_param_wrapper_4bit_expert_shape: M5/M6: _patched_get_param hard-codes the peft 0.18 dim-ordering convention (num_experts, in_features, out_features = shape) and reassigns self.in_features and self.out_features on every call. PEFT 0.19's ParamWrapper.update_layer swaps in_features<->out_features for 3D params (gated on _did_swap_in_out_features), then calls _move_adapter_to_device_of_base_layer which calls get_param() again. The reassignment UN-does PEFT's swap, breaking second-adapter add_adapter and any external reader of layer.in_features. Drop the in_features/out_features assignment; keep num_experts (which both peft versions derive identically). m5: the 'else: # TODO: Can we raise an error here? pass' branch silently returned the packed Params4bit. PEFT's _get_in_out_features would then read the (K, 1) packed shape and create LoRA factors with wrong dims. Raise ValueError explaining the failure mode instead. Reviewed-by: R2 (peft)
Reviewed-by: R4 (conventions)
Three changes from REV findings: REV-narrowed-except-loses-runtime-context: patched_convert still had broad except Exception that masked B1 for months. Narrow to (KeyError, AttributeError) for expected get_module_from_name failures, and use logger.exception() for unexpected failures so the traceback is preserved. REV-m3-defer-not-fix: when _get_base_weight sees Params4bit with quant_state=None, fall-through to subsequent branches just returns the raw packed uint8 data which crashes grouped_mm later with a different error. Raise an actionable error instead. REV-typing-tuple-unused: drop unused Tuple import. Reviewed-by: REV (independent post-fix review)
_get_moe_weight_and_quant_info imported _get_base_weight_and_quant_state and _try_attach_block_size from .moe_utils, but those symbols don't exist there. ImportError raised inside forward_moe_backend_fp8 was caught by the broad 'except ImportError' in forward_moe_backend, silently falling through to forward_native_grouped_mm — which crashes on FP8 weights. Two fixes: 1. Inline _try_attach_block_size and rewrite _get_moe_weight_and_quant_info to use _unwrap_param_attr (FP8-preserving) directly. Also reads quant_state from .quant_state when present and falls back to module-level scale attrs. 2. Narrow 'except ImportError' in forward_moe_backend to ONLY wrap the import statement — runtime errors inside the bnb4bit/fp8 path now propagate.
_try_attach_block_size is now defined locally in moe_utils_fp8.py (commit 87ba813), but three function bodies still tried to re-import it from .moe_utils. With the narrowed except in forward_moe_backend, those ImportErrors no longer silently fall through — they propagate and abort the FP8 forward. Use the module-local helper directly.
…mm path
_forward_scaled_grouped_mm_fp8 was importing _get_grouped_lora,
_apply_grouped_lora, and _expand_grouped_bias from .moe_utils — none of
those symbols exist anywhere in the codebase, making this whole path
unrunnable on Hopper. Didn't surface on B200 because the function is gated
behind _check_torch_scaled_grouped_mm_supported() (Hopper SM 9.x only).
Mirror the inline pattern from moe_utils.forward_native_grouped_mm (the
validated reference): add small local helpers _moe_separated_lora_delta
and _expand_grouped_bias, and read the (first, second, scaling) tuple
directly from _unsloth_lora_{gate_up,down}_proj injected by
_patched_param_wrapper_forward.
When the fused gate_up_proj has 2*I == H (common in models like DeepseekV3 and Qwen3.5MoE where intermediate_dim is half hidden_dim), both 'standard' and 'swapped' shape checks in _detect_moe_lora_layout pass, and the first-match-wins order picks 'swapped' incorrectly. For DeepseekV3, the only LoRA path is the fused MoE (target_modules doesn't attach to shared_experts), so the wrong slicing corrupts every expert and reload accuracy drops to zero. Qwen3.5MoE was masked because its shared_expert ALSO got Linear LoRA via standard PEFT matching, and the shared_expert deltas dominated the codename memorization task. Disambiguate using _did_swap_in_out_features (set by unsloth on the PEFT wrapper at training time), defaulting to 'standard' when absent.
For FP8 base models with stored weight in float8_e4m3fn plus companion
weight_scale_inv tensors, _merge_moe_experts_file used to do
W.to(float32) which is a bit-level cast that drops the per-block
scale entirely. The LoRA delta was then added in un-scaled space
and the result re-written as fp8 with the original (now wrong)
scale, corrupting the merged checkpoint.
This patch adds two helpers in saving_utils.py:
- _fp8_dequant_blockwise(W_fp8, scale_inv): block-wise inverse
quant matching transformers.integrations.finegrained_fp8
Fp8Dequantize.convert (W_real = decode_fp8(W) * scale_inv).
- _fp8_requant_blockwise(W_bf16, block_shape, scale_dtype): per
block max-abs quant with new_scale_inv = max_abs / 448 and
W_fp8 = clamp(W / scale_inv, +/- 448), matching Fp8Quantize.
A new _fp8_load_for_merge helper detects fp8 + scale_inv companion
in the safetensors header and returns the dequantized bf16 weight
along with scale metadata. The per-expert merge loop in
_merge_moe_experts_file (gate_proj, up_proj, down_proj) calls this
on read, runs the existing _merge_moe_* math in bf16 / float32, and
then re-quantizes + overwrites both the weight tensor and its
weight_scale_inv slot on write.
Non-FP8 paths are unchanged: the helper short-circuits on
W.dtype != float8_e4m3fn and falls through to the existing
in-place fp8 / bf16 write logic with original dtype.
Validation:
- Qwen3-Coder-30B-A3B-Instruct-FP8: LoRA-adapter reload token
overlap improved from 1/3 to 2/3 prompts vs the live model
(saved adapter is now correct in bf16-equivalent space).
- Merged-reload for the same cell still diverges from live on
this overfit synthetic test (loss 3.4 -> 0.0002). Root cause
is independent: _write_tensor_direct_torch writes in place
into fixed-size fp8 slots, so the bf16 merged result has to be
re-quantized to fp8 with ~3.5% per-block precision loss
(e4m3 mantissa). Proper fix is to emit bf16 shards for fp8
sources under save_method="merged_16bit", which is a larger
refactor and out of scope for this commit.
No regression on non-FP8 cells (bnb4bit, bf16) since the helper
is a no-op for those dtypes.
There was a problem hiding this comment.
Code Review
This pull request implements comprehensive support for quantized Mixture of Experts (MoE) models, focusing on 4-bit (bitsandbytes) and FP8 formats. It introduces block-wise dequantization and re-quantization utilities, patches PEFT's ParamWrapper to handle 4-bit MoE shapes and weight merging, and adds a forward dispatcher for optimized backends like _scaled_grouped_mm and Triton kernels. The PR also improves MLX submodule resolution on non-Apple hosts and extends support to standard GLM4 MoE models. Review feedback correctly pointed out that ALL_FP8_EXPERTS_FUNCTIONS is a dictionary, requiring standard dictionary operations instead of getattr and setattr for managing the sentinel flag.
| except Exception: | ||
| pass | ||
|
|
||
| setattr(ALL_FP8_EXPERTS_FUNCTIONS, sentinel, True) |
| return | ||
|
|
||
| sentinel = "_unsloth_fp8_dispatcher" | ||
| if getattr(ALL_FP8_EXPERTS_FUNCTIONS, sentinel, False): |
There was a problem hiding this comment.
The ALL_FP8_EXPERTS_FUNCTIONS object in transformers.integrations.finegrained_fp8 is a standard Python dictionary. Using getattr to check for a key will not work as intended and will likely return the default value or raise an error if the sentinel is not an attribute of the dict object itself. Use the in operator instead.
| if getattr(ALL_FP8_EXPERTS_FUNCTIONS, sentinel, False): | |
| if sentinel in ALL_FP8_EXPERTS_FUNCTIONS: |
When PEFT 0.19's ParamWrapper.get_delta_weight reshapes lora_B as
`reshape(out_features, rank, num_experts)`, per-expert lora_B is the
stride-E slice `lora_B[:, e::E]`. Following that convention here would
match a pure PEFT-trained model.
Unsloth-trained models do not store the per-expert split that way.
Unsloth ships `patch_param_wrapper_for_moe()` (in
`unsloth_zoo/temporary_patches/moe_utils.py`) which overrides
PEFT's `ParamWrapper.forward` with `_patched_param_wrapper_forward` and
routes per-expert LoRA extraction through
`_canonical_lora_weights_for_grouped_mm`. That helper does
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
i.e. contiguous-r per expert — the opposite of PEFT's stride-E. Since
the patch is in effect during both training AND `PeftModel.from_pretrained`
inference, the stored lora_A/lora_B encode the per-expert split in
contiguous-r terms. The merge here must match that convention so the
saved checkpoint reproduces the unsloth training-time forward; using
PEFT's stride-E slicing here garbles the merged output.
This commit adds inline pointers to the patch site in all four call
sites that slice per-expert lora_B (the bnb4bit + FP8 per-expert
helpers and the GPT-OSS / Gemma4 3D fused helpers). No behavior change.
….18 + 0.19) `_detect_moe_lora_layout` falls back to a default when both the standard and swapped shape checks pass (which happens whenever the fused gate_up_proj has `2*I == H`, i.e. the intermediate is exactly half the hidden dim — the common case for symmetric MoE tinies like Qwen3-MoE, Qwen3.5-MoE, GLM-4-MoE, DeepSeek-V3 used in our test suite). Pre-fix that default was hard-coded to "standard", which is correct for PEFT 0.19+ (it swaps in/out for 3D non-transposed MoE in `ParamWrapper.update_layer` so the stored `lora_A.shape = (r*E, in_features)` matches the "standard" branch) but wrong for PEFT 0.18 (no swap → `lora_A.shape = (r*E, out_features)` matches the "swapped" branch). Under PEFT 0.18 the merge math therefore picked the wrong slicing convention and produced garbled merged checkpoints on every symmetric 4bit MoE cell tested. Fix: at module import time, inspect `peft.tuners.lora.layer.ParamWrapper.update_layer` for the `_did_swap_in_out_features` substring (introduced in PEFT 0.19). If present, default the ambiguous case to "standard"; otherwise default to "swapped". A caller can still override per-call by passing a `lora_module` whose `_did_swap_in_out_features` attribute is set explicitly. Empirical verification: Cell PEFT 0.19 PEFT 0.18 pre-fix PEFT 0.18 post-fix qwen3_moe / 4bit @ 200 [1, 2, 3] (all garbled) [0, 0, 2] glm4_moe / 4bit @ 200 [0, 0, 3] [0, 0, 0] [1, 1, 1] qwen3_5_moe / 4bit @ 200 [3, 2, 2] [0, 2, 0] [3, 2, 2] (numbers are token-set overlap between LIVE first-5 tokens and MERGED-reload first-5 tokens, per prompt; PEFT 0.18 post-fix now matches PEFT 0.19 baseline). For non-symmetric shapes (`2*I != H`, e.g. all real-world MoE checkpoints like Qwen3-Coder-30B-FP8 with I=768, H=2048) the shape check is unambiguous in either branch and the default is never used, so this fix is a no-op for those cases.
… 4.58 The patch unconditionally dereferenced `_trainer_utils_mod.validate_quantization_for_training` to read its sentinel, but the attribute only exists in HF 4.58+. On HF 4.57.6 (one of the CI matrix cells) this raised AttributeError at unsloth import time, breaking the Core (HF=4.57.6 + TRL<1) job. Use `getattr(..., None)` and return early if the attribute isn't present — older HF versions don't have the FP8 guard to relax in the first place.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f7afcb2274
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if os.path.isdir(model_name): | ||
| single_path = os.path.join(model_name, "model.safetensors") | ||
| if os.path.exists(single_path): | ||
| return {k: single_path for k in needed_keys} |
There was a problem hiding this comment.
Check keys before accepting single-file scale shards
When an FP8 checkpoint is stored in a single model.safetensors but uses the per-expert *.weight_scale_inv layout, this helper reports that the first probe's stacked gate_up_proj_scale_inv/down_proj_scale_inv keys exist merely because the file exists. _maybe_attach_dropped_moe_fp8_scales then takes the stacked path and _attach_stacked_scales calls get_tensor on keys that are not in the file, so loading crashes instead of falling back to the per-expert layout this code is meant to support. The single-file Hub branch below has the same assumption; inspect the safetensors keys before returning success for a requested key set.
Useful? React with 👍 / 👎.
Addresses the silent-failure modes a code review flagged across the MoE merge save path and the FP8 forward dispatcher. None of these change the happy-path behavior — they convert silent corruption / silent LoRA-dropping into explicit `_record_moe_merge_fallback` calls or hard RuntimeError so the user sees the problem instead of getting a wrong checkpoint. - `_fp8_load_for_merge` now returns a `_FP8_MERGE_UNSAFE` sentinel when an FP8 weight has no usable scale companion. The previous behavior (return the raw FP8 tensor + None meta) caused the caller to call `W.to(torch.float32)`, which is a bit-level fp8 decode that drops the per-block scale — producing a merged checkpoint that is `1/scale_inv` off the real weight. The three call sites (gate/up/down) now record a fallback and skip the LoRA merge for that expert instead. Also accepts `weight_scale` (compressed-tensors naming) in addition to the existing `weight_scale_inv` (DeepSeek naming). - `_merge_moe_gate_or_up_expert` / `_merge_moe_down_proj_expert` no longer fall back to `r = total_rank // 1 = total_rank → r=1` when `lora_stats.rank` is missing. The degenerate `r=1` slicing happens to pass the layout shape check and silently writes a wrong delta. Now records a fallback and returns the unmerged weight. - `_merge_moe_fused_gate_up_expert` / `_merge_moe_fused_down_proj_expert` now route every early-return through `_record_moe_merge_fallback` and also touch the `_MOE_MERGE_STATE` counters (attempted/applied), so the fused GPT-OSS / Gemma4 path matches the non-fused path's loud-fail discipline introduced by #5410. - `_forward_native_fp8_expert_loop` (last-resort FP8 forward) now raises if `_unsloth_lora_gate_up_proj` or `_unsloth_lora_down_proj` is attached on the expert module. Previously, hitting this fallback while LoRA was active silently dropped the adapter contribution — easy to mistake for "LoRA is converging slowly". Refusing is the only safe option until this fallback also applies the LoRA delta. - `_peft_paramwrapper_swaps_in_out`: prefer `peft.__version__` over the brittle `inspect.getsource` source-string probe (which would silently flip the ambiguous-layout default if PEFT refactored the swap into a helper while still performing it). Source-inspect kept as last-resort fallback. - `_fp8_requant_blockwise` docstring corrected: code clamps zero-blocks to `scale_inv = 1e-12`, not 1.0 as the comment claimed.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 978b0798f5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if quant_kind == "weight_scale_inv": | ||
| scale = scale.reciprocal() |
There was a problem hiding this comment.
Preserve scale_inv direction for scaled_grouped_mm
When the fast _scaled_grouped_mm path is used with row-wise *_weight_scale_inv/*_scale_inv tensors, this reciprocal flips the dequant scale into the quantization scale. The rest of this file treats weight_scale_inv as the multiplier for dequantization (for example _dequantize_full_expert_weights_vectorized returns w * s_expanded), and _manual_fp8_rowwise_quantize passes decode scales directly to _scaled_grouped_mm, so this branch silently produces outputs scaled by the inverse factor on affected FP8 checkpoints instead of falling back to the correct dequant path.
Useful? React with 👍 / 👎.
| method = getattr(cfg, "quant_method", None) | ||
| return method is not None and "fp8" in str(method).lower() |
There was a problem hiding this comment.
Recognize compressed-tensors FP8 in Trainer guard
For FP8 MoE checkpoints loaded through compressed-tensors, quant_method is "compressed-tensors" (the same module explicitly checks that value in _maybe_patch_glm4_stacked_moe_fp8_scales), so this string test returns false and leaves HF Trainer's quantization guard active. In that environment, LoRA training for the compressed-tensors FP8 models this patch supports will still be rejected even though the new guard is intended to allow FP8+LoRA.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d0a069fea4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| except Exception: | ||
| pass | ||
|
|
||
| setattr(ALL_FP8_EXPERTS_FUNCTIONS, sentinel, True) |
There was a problem hiding this comment.
Use mapping state instead of setattr for FP8 dispatcher sentinel
ALL_FP8_EXPERTS_FUNCTIONS is a registry mapping, so calling setattr on it raises AttributeError when this patch runs on Transformers FP8 models. That aborts patch_fp8_experts_interface after writing some keys, leaving the patch non-idempotent and potentially preventing the FP8 expert-forward reroute from being applied reliably in startup patch loops. Store the sentinel in the mapping (or a module-level flag) instead of using setattr on the registry object.
Useful? React with 👍 / 👎.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
…revert) Commit 0bd6a8d reverted the merge created by a git pull, which undid the diffs from two already-merged PRs even though both still show as merged on GitHub: - unslothai#588 [Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference - unslothai#659 [MoE] Quantisation support and patches: Bnb and FP8 This reapplies their code (+5025 lines across 20 files, including moe_utils_fp8.py, moe_utils_bnb4bit.py and test_vllm_to_hf_conversion.py). This reverts commit 0bd6a8d.
After the recent changes in Transformers V5 to move MOE weights to an
nn.Parameter, PEFT was no longer able to quantize them, so memory usage was blowing up for many MoE models — and a lot of MoE models (Qwen 3.5, Qwen 3.6, GLM-4.7, etc.) need Transformers V5 to function. This PR patches around it so we can do 4-bit training with on-the-fly dequant to bf16, and group/FP8 if available.A lot of bnb4bit work is done by @sensai99 in #527 and I'm extending it to FP8 (as done by me in #548) into a single PR.
Training-loss sanity (same model, same hyperparams, only base-weight quant differs)
Qwen3-Coder-30B bf16 vs FP8, 30 steps, rank=16, batch=2, ga=2, seq=1024, seed=3407
tiny_qwen3_moe bf16 vs 4bit, 200 steps, same hyperparams
Max absolute gap = 0.26 (at step 90); both curves converge to ~1e-3 by step 150.
→ bf16, 4bit, and FP8 are interchangeable for LoRA SFT training loss on the configs we tested. The LoRA adapters absorb whatever quant noise the base weights carry. The earlier
save_pretrained_mergedregressions were strictly save-time precision issues (per-block re-quant rounding the small delta away), not training-time issues.