Fix(mlx): keep Qwen3-VL vision MLP fp32 when activation dtype is fp16#3
Open
BardiaKoopah wants to merge 1 commit into
Open
Conversation
Training Qwen3-VL with finetune_vision_layers=True on M1/M2 (where MLX defaults to float16 since these chips lack native bf16 support) produces grad=NaN at step 1, corrupting the adapter. Bisected to the MLP path in patched_qwen3_vision_block_call: casting the output of linear_fc1 back to fp16 reintroduces the NaN; keeping the full MLP in fp32 and casting only at the residual add produces finite gradients and a healthy loss curve. Fix: when residual_dtype == mx.float16, run the MLP in fp32. bf16/fp32 paths keep the original (cheaper) flow. Test: source-inspection guard in tests/test_mlx_trainer_internals.py, matching the existing fp32-rotary test for this same file.
danielhanchen
added a commit
that referenced
this pull request
Jun 11, 2026
…n, finalize_huggingface_model
- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
_call_create_lora_manager's signature inspection still sees vllm_config; pass model
positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
rotary_pos_emb on vision_config availability; mirror language_model detection from
set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
(defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
so non-tensor state-dict values pass through.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Training Qwen3-VL with
finetune_vision_layers=Trueon M1/M2 (where MLX defaults to float16 since these chips lack native bf16 support) producesgrad=NaNat step 1, corrupting the adapter.Diagnosis
Bisected
patched_qwen3_vision_block_callby selectively keeping fp32 at different points:linear_fc1output to fp16So the MLP path is the source — specifically, casting
linear_fc1's output back to fp16. Confirmed CUDA T4 (fp16) trains the same config cleanly, so the issue is specific to MLX + fp16 + vision tower forward.Fix
When
residual_dtype == mx.float16, upcast the MLP input to fp32 so the full MLP (linear_fc1, GELU,linear_fc2) runs in fp32. Cast back to source dtype at the residual add. bf16/fp32 paths keep the original (cheaper) flow.Verified
Test
Source-inspection guard added to
tests/test_mlx_trainer_internals.py, matching the existingtest_qwen3_vl_vision_rotary_uses_transformers_fp32_mathstyle for this same module.