Skip to content

Fix(mlx): keep Qwen3-VL vision MLP fp32 when activation dtype is fp16#3

Open
BardiaKoopah wants to merge 1 commit into
mmathew23:explore/mlxfrom
BardiaKoopah:fix/mlx-qwen3-vl-vision-mlp-fp16-overflow
Open

Fix(mlx): keep Qwen3-VL vision MLP fp32 when activation dtype is fp16#3
BardiaKoopah wants to merge 1 commit into
mmathew23:explore/mlxfrom
BardiaKoopah:fix/mlx-qwen3-vl-vision-mlp-fp16-overflow

Conversation

@BardiaKoopah

Copy link
Copy Markdown

Summary

Training Qwen3-VL with finetune_vision_layers=True on M1/M2 (where MLX defaults to float16 since these chips lack native bf16 support) produces grad=NaN at step 1, corrupting the adapter.

Diagnosis

Bisected patched_qwen3_vision_block_call by selectively keeping fp32 at different points:

Variant Result
Whole block default NaN
norm1 fp32 (attn input fp32) NaN
norm2 fp32 + downcast linear_fc1 output to fp16 NaN
norm2 fp32 + whole MLP fp32, cast only at residual add finite

So the MLP path is the source — specifically, casting linear_fc1's output back to fp16. Confirmed CUDA T4 (fp16) trains the same config cleanly, so the issue is specific to MLX + fp16 + vision tower forward.

Fix

When residual_dtype == mx.float16, upcast the MLP input to fp32 so the full MLP (linear_fc1, GELU, linear_fc2) runs in fp32. Cast back to source dtype at the residual add. bf16/fp32 paths keep the original (cheaper) flow.

Verified

  • Failing case: Qwen3-VL-2B + LaTeX_OCR + vision LoRA r=16 α=16 lr=2e-4 bs=1 on M2 → now clean, loss 0.94 → 0.69 over 5 steps, grad finite throughout
  • Vision frozen on M2: unchanged
  • Qwen3-0.6B text LoRA on M2: unchanged
  • Memory: 5.99 GB peak (vs 5.18 broken baseline = +0.8 GB)

Test

Source-inspection guard added to tests/test_mlx_trainer_internals.py, matching the existing test_qwen3_vl_vision_rotary_uses_transformers_fp32_math style for this same module.

Training Qwen3-VL with finetune_vision_layers=True on M1/M2 (where MLX
defaults to float16 since these chips lack native bf16 support) produces
grad=NaN at step 1, corrupting the adapter.

Bisected to the MLP path in patched_qwen3_vision_block_call: casting the
output of linear_fc1 back to fp16 reintroduces the NaN; keeping the full
MLP in fp32 and casting only at the residual add produces finite gradients
and a healthy loss curve.

Fix: when residual_dtype == mx.float16, run the MLP in fp32. bf16/fp32
paths keep the original (cheaper) flow.

Test: source-inspection guard in tests/test_mlx_trainer_internals.py,
matching the existing fp32-rotary test for this same file.
danielhanchen added a commit that referenced this pull request Jun 11, 2026
…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant