Skip to content

Close the batch_generate / server decode gap + VLM fixes#1055

Merged
Blaizzy merged 13 commits into
mainfrom
pc/fix-batch-gap
Apr 24, 2026
Merged

Close the batch_generate / server decode gap + VLM fixes#1055
Blaizzy merged 13 commits into
mainfrom
pc/fix-batch-gap

Conversation

@Blaizzy

@Blaizzy Blaizzy commented Apr 23, 2026

Copy link
Copy Markdown
Owner

Summary

Brings BatchGenerator decode throughput at bs=1 to rough parity with stream_generate (was ~20% slower), and fixes several server + multi-image correctness issues surfaced along the way.

Key changes

  • Thread-local generation stream (mlx-lm#1090 port): switch module-level generation_stream to mx.new_thread_local_stream; BatchGenerator accepts an explicit stream= kwarg; server passes its generator-thread stream.
  • Async model init in ResponseGenerator: model resources now load on the generator thread so their stream associations belong there — fixes cross-thread no Stream(gpu, N) crashes on mlx 0.31.2+.
  • 13 MRoPE VLMs rewritten to use cache._idx (Python int token counter maintained by BatchKVCache) instead of cache.offset.item(). Removes a per-decode-step GPU sync; matches upstream mlx-lm convention (rope positions count the full padded sequence; left_padding is handled by the attention mask via create_causal_mask).
    • qwen2_5_vl, qwen3_5, qwen2_vl, paddleocr_vl
    • qwen3_vl, qwen3_vl_moe, qwen3_omni_moe
    • glm4v, glm4v_moe, glm_ocr, ernie4_5_moe_vl
    • falcon_ocr, gemma3n, mllama, hunyuan_vl, moondream3
  • qwen3_vl preprocessor config hardening: ignore explicit null for min_pixels / max_pixels emitted by newer transformers (5.7+); previously overwrote the defaults derived from size.longest_edge, crashing _smart_resize_image.
  • qwen3_vl / qwen3_vl_moe _deepstack_process: slice visual_embeds per-sample by a running offset; fixes the Shapes (N,D) and (M,D) cannot be broadcast crash on multi-image batches, plus a latent bug in qwen3_vl_moe that silently dropped the deepstack update.

Perf

Measured on Qwen2.5-VL-3B-Instruct-4bit, bs=1, 193 steady-state decode tokens, 6 interleaved runs:

Before After
BatchGenerator / stream_generate 0.80× (19.6% overhead) 0.94× (5% overhead; stdev 0.03)

The residual 5% is Python-level bookkeeping in BatchGenerator.nextGenerationBatch._step (Response construction, stop_criteria, BatchKVCache's mx.array offset update per layer). Amortizes away at bs > 1.

End-to-end test

Generate CLI + server single request + server concurrent (3 requests) across 12 VLMs:

Model CLI Single Concur Repo
qwen2_5_vl mlx-community/Qwen2.5-VL-3B-Instruct-4bit
qwen2_vl mlx-community/Qwen2-VL-2B-Instruct-4bit
qwen3_vl Qwen/Qwen3-VL-4B-Instruct
qwen3_vl_moe mlx-community/Qwen3-VL-30B-A3B-Instruct-4bit
qwen3_5 Qwen/Qwen3.5-4B
paddleocr_vl mlx-community/PaddleOCR-VL-8bit
gemma3n mlx-community/gemma-3n-E2B-it-8bit
falcon_ocr mlx-community/Falcon-OCR-bf16
hunyuan_vl tencent/HunyuanOCR
mllama mlx-community/Llama-3.2-11B-Vision-Instruct-8bit
glm4v mlx-community/GLM-4.1V-9B-Thinking-4bit
glm_ocr mlx-community/GLM-OCR-8bit

12/12 pass all three paths. Multi-image regression test (Qwen3-VL-2B-4bit, cats.jpg + graph.png) also passes — exercises the _deepstack_process scatter-add fix with two different per-image visual-token counts.

Test plan

  • python -m mlx_vlm generate --model <any VLM above> --image <path> --prompt "Describe." --max-tokens 24 --temperature 0
  • python -m mlx_vlm.server --model <any VLM above> --port 8080 then a single OpenAI-style chat request
  • 3 concurrent chat requests to the same server to exercise BatchGenerator
  • Multi-image: pass two different-size images in one VLM request to the qwen3_vl family

🤖 Generated with Claude Code

Blaizzy and others added 13 commits April 23, 2026 01:55
Switch generation_stream to mx.new_thread_local_stream and let
BatchGenerator accept a stream= kwarg, so the server can pass the
generator thread's default stream explicitly. Keeps generation and
synchronization on the same stream.

Requires mlx>=0.31.2 (for mx.new_thread_local_stream).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Updated ResponseGenerator to load model resources in a dedicated thread, improving responsiveness.
- Introduced a wait_until_ready method to ensure the model is fully loaded before generating responses.
- Added error handling for model loading failures, allowing for graceful degradation.
- Removed direct model loading from get_cached_model, streamlining the initialization process.

This change enhances the overall architecture by decoupling model loading from response generation, ensuring better performance and reliability.
Every VLM's decode __call__ already has an `isinstance(offset, int)` fast
path (used by plain KVCache) alongside an `isinstance(offset, mx.array)`
branch that ends in `.item()` — a blocking GPU sync on every decode step.
For uniform left-padding (the common case, including bs=1) the offset is
effectively scalar, so keeping it as a Python int lets the fast path win
for batch caches too.

Measured on Qwen2.5-VL-3B-Instruct-4bit, bs=1, 193 steady-state tokens:
  - Before: BatchGenerator / stream_generate ratio 0.80x (19.6% overhead)
  - After:  ratio ~1.0x (at parity)

- `_ScalarOffsetMixin` on upstream BatchKVCache / BatchRotatingKVCache via
  name-shadowing subclasses in mlx_vlm.models.cache — no monkey-patching.
- Same scalar-init + inflate-in-filter/extend applied inline to the local
  BatchQuantizedKVCache and BatchTurboQuantKVCache.
- `left_padding` stays as mx.array (feeds into create_causal_mask etc.).
- Non-uniform left-padding still uses mx.array offset; divergent-offset
  paths in __call__ are unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Upstream mlx-lm keeps a Python-int token counter `self._idx` on
`BatchKVCache`; it never calls `.item()` on `cache.offset` in the hot
path. Mirror that convention in mlx-vlm's VLMs so the decode step avoids
the per-step GPU sync without relying on the `_ScalarOffsetMixin` scalar
fast-path.

Affects the four VLMs that had the same 3-line `.item()` pattern:
  - qwen2_5_vl
  - qwen3_5
  - qwen2_vl
  - paddleocr_vl

Semantic note: `_idx` is the raw token count (includes padded positions).
`get_rope_index` already encodes the left-padding offset into
`mrope_position_deltas` (via `max_pos + 1 - seq_len`), and attention
masking for padded positions is handled by `cache.make_mask` →
`create_causal_mask(..., left_padding=...)`. So `_idx + rope_deltas`
yields the correct absolute rope position for both uniform and
left-padded batches — and fixes a latent bug where the previous
`max(offset, 0) + rope_deltas` would double-subtract the padding.

The divergent-offset path (continuous-batching, per-sequence offsets)
still uses `mx.maximum(cache.offset, 0)` — unchanged.

Measured on Qwen2.5-VL-3B-Instruct-4bit, bs=1, 193 steady-state tokens:
BatchGenerator / stream_generate ratio 1.01x (at parity).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extend the previous commit to the remaining VLMs that had the same GPU-
sync-per-decode pattern. Rope positions count across the full (possibly
padded) sequence — attention masking (via cache.make_mask →
create_causal_mask with left_padding) handles per-sequence padding.

Models rewritten:
  - qwen3_vl, qwen3_vl_moe, qwen3_omni_moe
  - glm4v, glm4v_moe, glm_ocr
  - ernie4_5_moe_vl
  - falcon_ocr
  - gemma3n, mllama, hunyuan_vl, moondream3

For the last four (which had simpler `.item()` call-sites feeding
`self.rope(offset=...)` or position arithmetic), the `.item()` paths are
kept as fallbacks for caches that don't expose `_idx` — the hot path
uses `int(cache._idx)`.

Bench parity (Qwen2.5-VL-3B-Instruct-4bit, bs=1, 193 tokens):
BatchGenerator / stream_generate = 1.07x (BatchGenerator slightly ahead).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Every MRoPE VLM was rewritten in the previous two commits to use
``cache._idx`` (Python int) directly instead of ``cache.offset.item()``.
That made the scalar-int fast-path on ``cache.offset`` unreachable, so
the mixin — which collapsed uniform batch offsets to a Python int — is
dead weight.

Fix: drop the mixin + name-shadowing subclasses and restore upstream
``BatchKVCache`` / ``BatchRotatingKVCache``. Also revert the inline
scalar-offset init and int→mx.array inflation guards from the local
``BatchQuantizedKVCache`` and ``BatchTurboQuantKVCache`` for the same
reason.

Side benefit: fixes a crash on ``BatchRotatingKVCache(max_size,
left_padding)`` — the mixin's positional-argument assumption was wrong
for this class (its signature has ``max_size`` first), which surfaced
as ``TypeError: 'int' object is not iterable`` on gemma3n (and any
other model that uses a rotating KV cache) when starting the server.

Verified: gemma3n server now serves requests; Qwen2.5-VL bench stays
at rough parity with ``stream_generate`` (0.93x steady-state,
within run-to-run variance of the previous 1.01-1.07x).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Updated variable names for clarity: changed `deepstack_image_embeds` to `deepstack_visual_embeds` in both `qwen3_vl.py` and `qwen3_vl_moe.py`. This change improves code readability and consistency across the models. No functional changes were made.
Newer ``transformers`` releases (5.7+) emit ``preprocessor_config.json``
entries like ``"min_pixels": null`` / ``"max_pixels": null`` alongside
``"size": {"shortest_edge": ..., "longest_edge": ...}``. The Qwen3-VL
config loader was unconditionally overwriting ``out[...]`` with those
``None`` values after deriving sensible defaults from ``size``, which
then crashed ``_smart_resize_image`` with:

  TypeError: '>' not supported between instances of 'int' and 'NoneType'

Reproduced with ``mlx-community/Qwen3-VL-30B-A3B-Instruct-4bit`` and
``mlx-community/Qwen3-VL-2B-Instruct-4bit`` on transformers 5.7.0.dev0.
Works on transformers 5.5.4 because that release doesn't emit the
explicit ``None`` keys.

Fix: only overwrite when the incoming value is not ``None``. Same
treatment for the video preprocessor kwargs. Also harden the ``size``
lookup against a ``None`` size dict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
``visual_embeds`` is the concatenation of visual tokens across the whole
batch (e.g. (3334 + 1734, 2048) for two images), but the per-sample
loop was passing the full tensor to ``batch_result.at[indices].add(...)``
regardless of how many visual positions that sample has. For anything
beyond a single image (bs > 1 or a single sample with >1 image) this
crashed with:

  Shapes (N, D) and (M, D) cannot be broadcast.

Fix: carry a running offset and slice ``visual_embeds[offset:offset+n]``
for each sample, where ``n`` is that sample's visual-position count.

Also fixes a latent bug in qwen3_vl_moe: the multi-image branch was
appending ``batch_hidden`` (pre-deepstack) instead of ``batch_result``,
so the deepstack add was silently dropped whenever the loop body ran.

Verified end-to-end with Qwen3-VL-2B-Instruct-4bit on two images of
different sizes (cats.jpg + graph.png).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Added assertions to ensure that the `_deepstack_process` method correctly handles multi-image batches without broadcasting errors. Also included checks to confirm that the RoPE offset is derived from `cache._idx` instead of `offset.item()`, addressing potential GPU synchronization issues. These changes improve the robustness of the model tests and guard against regressions related to recent refactors.
@Blaizzy Blaizzy merged commit 410dcff into main Apr 24, 2026
1 check passed
afanty2021 added a commit to afanty2021/mlx-vlm that referenced this pull request Apr 24, 2026
Merge changes from upstream:
- Blaizzy#1056: hunyuan_vl/gemma3n cache-offset optimization
- Blaizzy#1053: Fix DFlash speculative decoding (GPU hang, performance)
- Blaizzy#1050: Thread-local generation stream (port mlx-lm#1090)
- Blaizzy#1055: Close batch_generate/server decode gap + VLM fixes

Conflict resolution:
- requirements.txt: Mixed approach - mlx>=0.31.2 with transformers<5.4.0
  to maintain omlx compatibility while accepting mlx update

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
mdkirin pushed a commit to mdkirin/mlx-seori that referenced this pull request Apr 26, 2026
…MoE pos fix

upstream/main 흡수 (4-19 ~ 4-25 batch). Fork의 핵심 자산은 모두 보존:
MTP (mlx-lm 포팅, Qwen3.5 dense+MoE), PrefixCache hybrid, server hardening
(MLX_MEMORY_LIMIT_GB env, /v1/status, /v1/models 로드 모델 포함, model
pinning, busy tracking, GC threshold, last_request, OOM-위험 startup
warmup 제거), 서버사이드 thinking strip + 스트리밍 incremental, null
tool_calls 가드.

Upstream 흡수: continuous batching server (Blaizzy#1027), DFlash speculative
decoding (Blaizzy#1029, Blaizzy#1053 fix), thread-local generation stream (Blaizzy#1050,
mlx<0.32 hasattr 가드), batch_generate/server VLM fixes (Blaizzy#1055), Qwen3.5/3.6
MoE stale position IDs + gdn_sink 호환 (Blaizzy#1040), tool-call markup strip
(Blaizzy#1037), KV cache quantization (Blaizzy#1030), Qwen2-3.5 VL torch-free 비디오
processors (Blaizzy#1048), Gemma4 LoRA NaN/freeze fix (Blaizzy#1052), Gemma4 video,
Youtu-VL, distributed inference 등.

충돌 해결 원칙: fork의 MTP n_confirmed와 upstream의 gdn_sink는 같은
함수에서 공존하도록 시그니처 확장. fork는 Blaizzy#1029(DFlash) 도입 전 시점에서
분기되어 gdn_sink 본체 로직은 우리 모델에서 비활성(None 전달); 단
시그니처는 받아두어 호환성 유지. position_ids 캐시 재사용 시 fork의
">= cache_offset + seq_length" 체크가 Blaizzy#1040 fix를 더 정교하게 커버.
LanguageModelOutput.hidden_states/gdn_states 필드는 upstream 추가분 호환.

검증: 4개 파일 syntax + import OK. M3 96GB에서 mlx 0.31.0 호환 확인.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant