fix(mlx): pass a flat image list to mlx-vlm processor in VLM collator#749
fix(mlx): pass a flat image list to mlx-vlm processor in VLM collator#749BardiaKoopah wants to merge 1 commit into
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ed7b746c17
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # the text placeholders pair positionally with each image. Using | ||
| # extend instead of append avoids the nested-list shape that crashes | ||
| # the per-image processor (image.shape unpacking). | ||
| all_images.extend(images) |
There was a problem hiding this comment.
Preserve ndarray images through the flat-image path
When a dataset supplies images as numpy arrays (which _extract_vlm_images preserves for the processor), flattening makes all_images a list of arrays; the subsequent _format_vlm_images_for_processor still does any(all_images), which evaluates bool(ndarray) and raises the usual ambiguous truth-value ValueError before the processor is called. This means array-backed VLM datasets still crash in the collator after this change; the emptiness check needs to avoid truth-testing the image objects themselves.
Useful? React with 👍 / 👎.
| # the text placeholders pair positionally with each image. Using | ||
| # extend instead of append avoids the nested-list shape that crashes | ||
| # the per-image processor (image.shape unpacking). | ||
| all_images.extend(images) |
There was a problem hiding this comment.
Keep per-sample groups for Gemma-style processors
With mlx-vlm processors that call make_nested_list_of_images (for example Gemma3Processor), this unconditional flattening turns a batch like [[img1], [img2]] into [img1, img2]; that helper treats a flat image list as one sample, so len(batched_images) becomes 1 while len(text) is 2 and the processor raises its inconsistent batch-size ValueError for batch_size > 1. Qwen needs a flat list, but applying it to every model regresses these processors; preserve/group images per processor family instead.
Useful? React with 👍 / 👎.
mlx-vlm 0.6.x iterates its 'images' kwarg as a flat list and calls _process_one on each element, which requires each element to be a single image (3-D ndarray or PIL.Image) so 'C, H, W = image.shape' succeeds. Two VLM batch collators were building all_images as a list-of-lists (one inner list per sample), so each top-level element handed to the processor was a list, not an image. _to_numpy_image wrapped it into a 4-D ndarray and _process_one then crashed with 'ValueError: too many values to unpack (expected 3)' before step 1. Fix both call sites in unsloth_zoo/mlx/utils.py: - _collate_vlm_batch (messages-based path) - _collate_vlm_prompt_completion_batch (prompt+completion path) Use list.extend instead of list.append per sample so all_images is a flat list across the batch. Order is preserved, and text placeholders pair positionally with images as before. Reproduced on unsloth/Qwen3-VL-2B-Instruct + unsloth/LaTeX_OCR with mlx-vlm 0.6.2 (M2, fp16); with the fix the collator no longer raises and training proceeds past the dataloader.
ed7b746 to
0d33c4d
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0d33c4d29e
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # the text placeholders pair positionally with each image. Same fix | ||
| # as the message-collator path -- this path handles prompt+completion | ||
| # datasets where messages are split across two fields. | ||
| all_images.extend(images) |
There was a problem hiding this comment.
Preserve grouped images in prompt/completion batches
The separate prompt/completion path now flattens per-sample image groups before calling _processor_vlm_inputs for combined_texts and prompt_texts. For mlx-vlm processors that batch via make_nested_list_of_images (for example Gemma-style processors), a batch like two prompt/completion rows with one image each becomes a single image batch [img1, img2], while text still has length 2, so the processor raises an inconsistent batch-size error before training starts. The other thread covers the message-collator occurrence; this is the same regression in the prompt/completion branch.
Useful? React with 👍 / 👎.
Training Qwen3-VL on MLX crashes in value_and_grad with: ValueError: [Primitive::vjp] Not implemented for CustomKernel. The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a fused Metal kernel whenever Metal is available (mlx-vlm 0.6.x), and that kernel has no gradient implementation. The same situation exists in qwen3_5 and is solved by PR #738 via _disable_fused_mrope, which flips fused_apply off on each rotary module so apply_rotary takes its differentiable cos/sin fallback. Wiring: add a 'qwen3_vl in model_type' block in trainer.py that calls _disable_fused_mrope(model). The function is the same one introduced by PR #738 (also added here so this PR is self-contained for testing; on rebase after #738 lands the function definition will dedupe). Verified on M2 16GB with unsloth/Qwen3-VL-2B-Instruct + unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps: - Studio logs: 'Disabled fused MRoPE kernel on 28 modules for training' - Step 1 loss 1.61, grad 4.34, finite throughout - Avg loss 1.73 over 5 steps, adapter saved Note: testing also required #749 (mlx-vlm flat image list); both fixes are needed end-to-end for Qwen3-VL training, but they fix independent errors in different layers. Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Summary
Training any VLM (vision LoRA or vision-frozen) on
unslothai/unsloth-zoo:maincrashes before step 1 with:raised from
mlx_vlm/models/qwen3_vl/processing_qwen3_vl.py:193(C, H, W = image.shape).Root cause
Two VLM batch collators in
unsloth_zoo/mlx/utils.pybuildall_imagesas a list of lists, one inner list per sample:_collate_vlm_batch(line 2422, messages-based path)_collate_vlm_prompt_completion_batch(line 2335, prompt+completion path)Both do:
But mlx-vlm 0.6.x's processor iterates the outer list and calls
_process_oneon each element, expecting each element to be a single image (3-D ndarray or PIL.Image):When the outer element is a list,
_to_numpy_imageproduces a 4-D(1, C, H, W)ndarray and_process_onecrashes unpacking it into 3 dims.Fix
Use
extendinstead ofappendper sample in both collators soall_imagesis a flat list across the batch.Image order is preserved (texts and images are built sample-by-sample in lockstep), and the rendered text's
<image>placeholders pair positionally with the flat image list as mlx-vlm expects.Batch-shape verification
all_imagesafter extend[img][img, img][img, img][img, img, img]All shapes feed
mlx-vlmcorrectly: each top-level element is a single PIL image,_to_numpy_imagereturns 3-D, and_process_onesucceeds.Verified
unsloth/Qwen3-VL-2B-Instruct+unsloth/LaTeX_OCR, bs=1, seq_len=1024, mlx-vlm 0.6.2, MLX on M2 16GB → crashes at the collator before step 1