Add KV cache quantization for continuous batching#1030
Merged
Conversation
Enable both uniform (mx.quantize) and TurboQuant KV cache quantization in the continuous batching path (BatchGenerator / ResponseGenerator). - Add BatchQuantizedKVCache in cache.py with extend()/filter() for uniform quantization (4-bit: 6.4x, 8-bit: 3.6x memory reduction) - Add BatchTurboQuantKVCache in turboquant.py with extend()/filter() using generic _map_state/_map_state_pair helpers - Wire kv_bits, kv_group_size, kv_quant_scheme through _make_cache → BatchGenerator → ResponseGenerator → server - Add attention dispatch for BatchTurboQuantKVCache in base.py (dequantize + standard sdpa; custom Metal kernels not yet batch-aware) - Skip quantizing last layer (sensitive in deep models like gemma-4-31b) - 16 new tests for BatchQuantizedKVCache operations Tested with gemma-4-26b-a4b-it: identical coherent output across no-quant, uniform 8-bit, and TurboQuant 3.5-bit modes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds Python logging throughout the server matching mlx-lm's approach: - DEBUG: request params, response text (truncated), timing, token counts - INFO: model loading, KV quant config, startup status - Use --log-level DEBUG to enable verbose output
- Add KV Cache Quantization subsection under Continuous Batching with server examples and gemma-4 benchmark table - Add --log-level, --kv-group-size, --max-kv-size to Server Options - Note TurboQuant continuous batching support in Compatibility section - Update table of contents Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The generation loop ran for 500ms before checking for new requests, causing incoming requests to wait up to half a second before joining the active batch. Two fixes: - Reduce deadline from 500ms to 50ms - Break out of the loop early if new requests are queued New requests now join the batch within one decode step (~20ms). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f6cc5a8 to
f01346a
Compare
When --log-level DEBUG is set, the fallback generate() calls now print token-by-token output and timing stats matching CLI behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When existing requests are decoding and a new request arrives, _next() previously prefilled ALL pending prompts before doing any decode step, stalling existing requests for the entire prefill duration. Now when an active batch exists, _next() prefills only one batch of prompts then immediately does a decode step. The next _next() call handles remaining prompts. This interleaves prefill and decode so existing requests keep generating tokens while new requests are prefilled incrementally. First-batch behavior (no active batch) is unchanged — prompts are still batched together for efficient initial prefill. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Store per-prompt kwargs via insert(prompt_kwargs=) so vision embeddings survive across deferred prefill steps. Split chunked prefill into incremental (one chunk per _next() call when an active batch exists) vs blocking (full prefill when no active batch). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Compile geglu activation for fused GPU execution - Use mx.fast.rms_norm in Router, compute softmax only over top-k experts - Simplify Experts forward pass (remove unnecessary reshapes) - Refactor KV sharing to pass shared KV tensors through layers directly instead of using shared cache indices, avoiding redundant cache lookups - Simplify attention mask creation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e server throughput Rewrite BatchGenerator internals to match mlx-lm's architecture: - GenerationBatch: double-buffered _step() with async_eval/eval pipelining - PromptProcessingBatch: separate prompt processing with decode-first ordering - Batched logprob gathering (one GPU sync vs N per-token syncs) - Conditional compute_logprobs flag (batch_generate skips logprobs entirely) Server optimizations (46.3 -> 56.0 tok/s, +21% for 4-request staggered): - Time-budget loop: run generation steps for up to 0.5s before polling queue - Non-blocking queue polling when actively generating - Simplified logprob extraction via pre-computed token_logprob field - Lazy peak_memory query (only on finish) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of async_eval'ing the full (batch, vocab_size) logprobs tensor and gathering per-token values later, gather right after sampling in the same computation graph. This keeps the 262k-element logprobs tensor as a GPU intermediate that MLX can free immediately, only materializing the sampled token + its scalar logprob. Results (Gemma 4 26B-A4B 5-bit): - Single request: 71.7 -> 84.1 tok/s (matches mlx-lm's 84.7) - 4-request concurrent: 54.5 -> 70.0 tok/s (exceeds mlx-lm's 69.2) - Logprobs overhead: 15.5% -> ~0% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The time-budget loop ran generation steps for up to 500ms before checking the queue for new requests, causing incoming requests to wait ~250ms (avg) to join an active batch. Measured impact (4-request staggered test, gemma-4-26b-a4b-it-5bit): - Insertion latency: 254ms -> 121ms (-52%) - Throughput: unchanged (~74 tok/s for R0 in both cases) The +21% throughput attributed to this loop in cfdca51 was actually from the bundled pre-gathered logprobs and simplified logprob path optimizations; the budget loop contributes nothing measurable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
BatchGenerator was duplicating the set_wired_limit/restore logic that already exists in the wired_limit context manager. Consolidate on the context manager pattern via contextlib.ExitStack so there's one source of truth for the wired-limit behavior. Side benefits: - Same "model size near working set" warning now also emitted for the BatchGenerator path (server and batch_generate). - Proper synchronize-before-restore inherited from wired_limit. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1. First-token loss in GenerationBatch.__init__ The constructor called _step() to "warm up" the double-buffer pipeline, but the return value (the first sampled token from prefill) was discarded. Effect: every generated sequence was missing its first token. Example: "Hello, world!" → ", world!", "10" → "0". Fix: don't call _step() in __init__; let the first next() call do it and emit T1. 2. IndexError when extending a batch with logprobs enabled PromptProcessingBatch.generate() didn't populate _next_lps on the fresh GenerationBatch, so extend() took the `if other._next_lps is not None` branch as False and left self._next_lps at its old size while _next_tokens grew. Next _step then returned a stale smaller lp_list that crashed with `list index out of range` when indexed for the larger self.uids. Symptom: parallel requests hanging until the 30s client timeout. Fix: gather and store per-token logprobs for the first tokens before returning the new GenerationBatch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two related bugs caused parallel/staggered batched decoding to produce wrong tokens (garbage or silent divergence) for any sequence after the first: 1) qwen2_5_vl.py get_input_embeddings() reset the model's cached _position_ids and _rope_deltas to None on every text-only request. When a new text request arrived while other sequences were actively decoding, this destroyed state the ongoing decodes depended on — their next forward pass fell into the prefill recomputation path and used position_ids for the new (1-token) input, not for their real position in the sequence. 2) language.py LanguageModel.__call__ used cache[0].offset[0] as a single scalar cache_offset for the whole batch during decode, so every sequence got the same RoPE position even though BatchKVCache tracks per-sequence offsets (different left_padding per seq). Fixes: - Don't clear _position_ids/_rope_deltas on text-only embeddings; the prefill path already recomputes them via the cache_offset==0 branch. - Read the full per-element offset array when the cache is batched (ported from qwen3_5/language.py) and build delta/position_ids per-sequence so each batch item gets its correct RoPE position. Verified with 4 identical parallel prompts now produce 4 identical correct responses (previously all diverged), and mixed image+text parallel/staggered tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Same fix as the Qwen2.5-VL patch: three related bugs caused parallel / staggered batched decoding to produce garbage or slightly-off tokens for any sequence after the first. 1) qwen2_vl.py get_input_embeddings() cleared the model's cached _position_ids / _rope_deltas on every text-only call, destroying state that other sequences still decoding in the batch depended on. 2) language.py LanguageModel.__call__ used a single scalar cache_offset (cache[0].offset[0]) for the whole batch, so every sequence got the same RoPE position even though BatchKVCache tracks per-sequence offsets. 3) Qwen2Model.__call__ passed cache[0] to create_attention_mask, which emitted an explicit boolean mask via BatchKVCache.make_mask. That path interacted badly with the rest of the attention kernel for this model (produced repeated-token garbage). Passing the cache list falls through to the "causal" string path, matching how qwen2_5_vl already works. Verified: 4 parallel different prompts → Paris/4/Blue/10 (all correct), 4 staggered → all correct, 4 identical parallel → all produce the same counting output (previously all diverged into garbled tokens). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Same get_input_embeddings reset bug as qwen2_vl / qwen2_5_vl — every text-only call cleared the LanguageModel's cached _position_ids / _rope_deltas, destroying state that other sequences in the current batch were relying on mid-decode. qwen3_5_moe inherits get_input_embeddings from qwen3_5.Model, so this fix covers both qwen3_5 and qwen3_5_moe (verified on mlx-community/Qwen3.5-35B-A3B-4bit: 4 parallel different prompts now produce 4 correct distinct answers; previously they all got the same stale response). The rest of the qwen3_5 code already handles batched per-sequence cache offsets correctly — the LanguageModel.__call__ was where this patch series borrowed its per-element offsets pattern from. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a client aborted a /v1/chat/completions stream mid-generation, the server kept decoding tokens for that uid until natural EOS / max_tokens and pushed them into an abandoned rqueue. That wasted compute, leaked memory, and tied up batch slots that other requests could have used (worst case: max_tokens=8192 aborted after 1 token → ~8000 tokens of wasted decode plus a slot lost for the whole duration). Implements cooperative cancellation end-to-end: - BatchGenerator.remove(uid): removes a sequence by uid from whichever stage it's in (unprocessed queue, in-flight prefill when it's the only sequence, or generation batch). Falls through to the gen batch if the uid was in a multi-sequence prefill so a concurrent removal still catches it after transition. - ResponseGenerator: new thread-safe _cancelled set + _cancel(uid). The per-request token_iterator tracks whether it ended naturally (None sentinel / exception) and calls _cancel(uid) from its finally block otherwise. GeneratorExit from fastapi closing the async response now reaches here via the standard generator protocol. - _run drains cancelled uids every outer iteration and calls batch_gen.remove(uid) + drops the active entry (pushing a final None onto the rqueue so any lingering consumer unblocks cleanly). Verified end-to-end: aborting a 2000-token stream after 0.5s and immediately issuing a 500-token follow-up, the follow-up runs at baseline single-request tok/s (83 vs 82) — i.e., the slot was actually released. Concurrent case: when one of two parallel streams aborts, the other immediately reverts to batch-size-1 throughput. Adds unit tests for BatchGenerator.remove covering the unprocessed queue path and the missing-uid case. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two fixes for the cancellation path from the previous commit: 1) token_iterator.finally was calling _cancel(uid) on any close, including the clean case where stream_generator breaks out of its loop after seeing finish_reason and then closes the iterator. The None sentinel the GPU thread pushes after the final token wasn't being consumed before close, so `ended` stayed False and every completed stream triggered a bogus cancellation request. Fix: mark `ended = True` as soon as a token with finish_reason is received (before yielding it) and break after that yield, so normal completion is distinguishable from a client abort. 2) GenerationBatch.filter now clears _current_tokens / _current_lps when the batch is emptied, matching the existing _next_tokens / _next_lps clearing. Without this, stale tensors from the removed sequences were kept alive until the next _step overwrote them. Verified on Gemma-4-26B-A4B: cancellations now only fire on actual aborts. First follow-up after an abort takes a ~10% one-time hit (pending async work for the aborted sequence draining on the GPU stream); subsequent runs return to baseline. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
GenerationBatch.next() wraps its work in `with mx.stream(generation_stream)` so that the cache filter at the end of a natural completion runs on the same stream as the preceding async evals. remove() was calling GenerationBatch.filter() (and the unprocessed / prompt_batch cleanups) from whichever stream the caller happened to be on, which meant the cache slice ops for a cancelled sequence could reorder against pending async work. Wrap the whole body of BatchGenerator.remove() in the same stream context so cancellation cleanup is ordered identically to natural completion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Removes comments that just label what the next line does
("Batch state", "Stats counters", "Create batch cache", "Insert new
requests into batch", "Non-blocking poll when actively generating",
etc.) and trims a few multi-line blocks down to one or two lines
focused on the non-obvious reasoning. No behavior change.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Merges the chunked-prefill fix from #1032 (thanks @pcuenca) with the parallel-batching fix for stateful _rope_deltas already on this branch. The two interact through the model's cached _position_ids/_rope_deltas; combining them cleanly needs a small adjustment on each side: - PromptProcessingBatch no longer resets _position_ids / _rope_deltas on construction. get_input_embeddings is the only owner now. - qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe LanguageModel.__call__ drop the `cache_offset > 0` guard and always slice from the precomputed _position_ids when available. Needed so the first chunk of a chunked VLM prefill uses the same positional layout that get_input_embeddings laid out for the full input (fixes PR #1032's repro: mlx-community/Qwen3-VL-30B-A3B-Instruct-bf16 describing a large image). - qwen2_vl / qwen2_5_vl / qwen3_5 Model.get_input_embeddings, when pixel_values is None, now resets _position_ids but leaves _rope_deltas alone. Clearing only _position_ids keeps the chunked prefill slice safe (no stale image-shaped tensor leaking into a text prefill) while preserving _rope_deltas so other sequences that are still decoding in the batch keep their correct per-sequence positional state — which is what the original fix on this branch needed for parallel identical-text to be deterministic. Verified across Qwen2.5-VL-3B, Qwen3.5-35B-A3B and Gemma-4-26B: - 4 parallel identical prompts → all identical - 4 parallel different prompts → distinct correct answers - Staggered image+text → all coherent - Mixed image+text concurrent decoding → both sequences complete; text now matches single-run baseline on Qwen3.5, image now matches on Gemma-4 (Qwen2.5-VL still diverges slightly from single-run baseline due to the single-slot _rope_deltas during concurrent decode — a separate, pre-existing limitation) - Abort + follow-up → unchanged Gemma-4 throughput (single + parallel 2/3/4) unchanged within noise. Co-Authored-By: Pedro Cuenca <pedro@huggingface.co> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extend the parallel-batching fix from earlier to the six VLMs that still reset both _position_ids and _rope_deltas in their text-only get_input_embeddings path: - glm4v, glm4v_moe, glm_ocr - paddleocr_vl - qwen3_vl, qwen3_vl_moe Each now only clears _position_ids (to keep the chunked-prefill slice safe) and leaves _rope_deltas alone so other sequences still decoding in the current batch keep their per-sequence positional state. Verified qwen3_vl parallel identical text decodes deterministically for three of four slots — the remaining divergence is floating-point non-determinism at a close-call argmax (batched matmul vs single), not a fix gap. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds per-sequence _rope_deltas tracking on GenerationBatch (concatenated by extend, sliced by filter) so sequences with different deltas can share a batch without aliasing self._rope_deltas on the model. - PromptProcessingBatch.generate snapshots the model's _rope_deltas right after prefill and hands it to the new GenerationBatch, shape normalized to (B, 1). Models without _rope_deltas stay untouched. - GenerationBatch._step passes rope_deltas=self._rope_deltas as a kwarg on every decode forward, mirroring how n_to_process is threaded through prefill. - qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe / qwen3_5 LanguageModel.__call__ now accept the kwarg and prefer it over self._rope_deltas in the decode branch. Verified on concurrent image + text decoding (image stream started, text stream joins mid-decode): Qwen2.5-VL-3B : image matches single-run baseline (was diverging). Qwen3.5-35B : both image and text match baseline. Gemma-4 : unaffected (no _rope_deltas on its LM). All 49 existing unit tests still pass. Parallel identical / different text, staggered, and abort + follow-up all continue to work. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Apply the same kwarg plumbing as the Qwen family to the rest of the VLMs that cache _rope_deltas on the LanguageModel: - glm4v, glm4v_moe, glm_ocr - paddleocr_vl - ernie4_5_moe_vl - qwen3_omni_moe Each LanguageModel.__call__ now pops `rope_deltas` from kwargs and prefers that over self._rope_deltas in the decode branch, so sequences with different deltas can share a batch without aliasing the single model-level slot. qwen3_5_moe inherits __call__ from qwen3_5.LanguageModel and is covered transitively. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both models guard _position_ids reuse behind `cache_offset > 0`, so the first chunk of a long image prefill falls into the get_rope_index recompute path and sees only a slice of input_ids — missing image token context and producing wrong positions. Dropping the guard lets the precomputed _position_ids (set up by get_input_embeddings with the full input) be sliced for every chunk, matching the fix from #1032 already applied to qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe. (The two falcon VLMs that also had `cache_offset > 0` guards use a different _rope_delta path that's decode-only, not the chunked- prefill bug.) Verified: mlx-community/PaddleOCR-VL-bfloat16 + webpage_1.png (3557 tok, chunked) produces the expected Chinese news summary. tencent/HunyuanOCR + webpage_1.png (4161 tok, chunked) produces the expected "俄中传媒资讯网" site description. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both models assumed cache.offset was always a scalar and called
offset.item(), which blows up on BatchKVCache's multi-element offset
arrays (ValueError: Only length-1 arrays can be converted to Python
scalars). Using the first element's offset would have avoided the
crash but left every sequence in a batch sharing the same RoPE
encoding — correct only when all sequences happen to be at the same
position.
Now when cache.offset has more than one element, build per-sequence
position_ids of shape (B, L) = base_offset + arange(L), and let
apply_rotary_emb_1d reshape cos/sin with the batch dim when the
indexed tables come out 3D. Single-sequence usage still takes the
original 2D reshape path.
Verified:
mlx-community/Falcon-OCR-bf16 single-request paper.png → unchanged
correct author-list extraction.
4 parallel identical image requests → all 4 sequences emit the same
correct paper header (previously server errored on offset.item()).
tiiuae/Falcon-Perception generate_perception() on cats.jpg → still
returns the two correct detections.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
microsoft/Phi-4-reasoning-vision-15B ships a custom modeling_phi4_visionr.py referenced from config.json via auto_map. AutoTokenizer.from_pretrained resolves its class through AutoConfig, which imports that remote file at class-definition time -- and the file uses `siglip2_ips.filter_out_non_signature_kwargs()` as a decorator, a helper that was moved in transformers 5.6+. The import therefore raises AttributeError before our install_auto_processor_patch gets to call our own Phi4SigLipProcessor implementation. Phi4SigLipProcessor is self-contained and doesn't need the remote modeling classes -- only the tokenizer artifacts (tokenizer.json / tokenizer_config.json) which are standard. Pop `trust_remote_code` from the forwarded kwargs and call AutoTokenizer.from_pretrained with trust_remote_code=False, bypassing the remote auto_map chain. Verified: Phi-4-reasoning-vision-15B now loads and serves; single and concurrent same-image requests produce coherent image descriptions with zero generation errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Moondream3's attention did `mx.arange(cache.offset, cache.offset + L)` and `self.rope(..., offset=cache.offset)`, both of which expect a Python int. Under continuous batching our BatchKVCache always stores `cache.offset` as an `mx.array`, so the server crashed with `arange(): incompatible function arguments` on every decode step. Collapse the offset array to an int (the single value for B=1, or max across sequences for B>1) before calling arange/RoPE. Single-request serving now works end-to-end; same-image concurrent decodes correctly for the first sequence but subsequent sequences still degrade, mirroring the pre-existing qwen2_vl-2B concurrent sensitivity -- proper per-row positions for heterogeneous-length batches need a larger refactor of Tau/RoPE to accept (B, L) positions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Builds on the previous arange-offset coercion so moondream3 handles
BatchKVCache properly, not just degenerate single-seq cases:
- Attention: keep the full per-seq offset array when the cache carries
one. Build positions as `offsets[:, None] + arange(L)` so each row
receives its own position values instead of the batch-wide max.
- Tau.__call__: accept 1-D shared positions (existing path) or 2-D
per-sequence positions (B, L). For the 2-D case, broadcast alpha
across the batch and keep tau_pos shape (B, n_heads, L).
- RoPE: mx.fast.rope only accepts scalar offsets, so for multi-seq
batches apply RoPE per row and concatenate. Single-seq continues to
hit the fast path.
Verified on moondream3-preview: 3x same-image concurrent now produces
coherent per-seq descriptions (previously seqs 1-2 emitted
<|md_reserved_4|> repetitions).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…he.make_mask Two interacting issues: 1. Attention.__call__ sliced `mask` to `keys.shape[-2]` *before* `cache.update_and_fetch`. During chunked prefill (keys just the new chunk, but the cache already holds previous chunks), this trimmed the mask to the current-chunk column count. After update_and_fetch the returned keys include both cached and new tokens, so the (sliced) mask and (extended) keys had mismatched lengths -- fine when mask was `None`/`"causal"` from the old unmasked decode path, but broken as soon as a real tensor mask reached this point. 2. TextModel.__call__ passed the cache *list* to `create_attention_mask`, which bypassed `BatchKVCache.make_mask` and produced an undersized mask for concurrent batches with per-sequence left_padding, contributing to the garbled tail on mixed-modality text responses. Swap the order so the mask is sliced against the post-update key length, and route mask construction through `cache[0]`. Verified on InternVL3-8B-bf16: single, concurrent same-image x3 (now identical per-seq outputs), and mixed 2-img + 2-txt all complete with zero generation errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Continuous-batching chunked prefill passes `n_to_process=<chunk_size>` (and any other prompt kwargs stashed on the VLM wrapper) when it calls `self.model(...)`. aya_vision's `LanguageModel.__call__` had a strictly-positional signature, so the server crashed every decode step with `LanguageModel.__call__() got an unexpected keyword argument 'n_to_process'`. Add **kwargs so the unused hints are accepted. Verified on aya-vision-8b-8bit: single and concurrent same-image x3 now produce coherent descriptions, zero generation errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
aya_vision interleaves global KVCache layers with RotatingKVCache sliding-window layers (sliding_window_pattern=4), but TextModel.__call__ built a single mask from `cache[j - 1 : j]` -- a Python list slice that doesn't expose `make_mask`, so upstream `create_attention_mask` fell through to "causal" / None and dropped the per-sequence `left_padding`. For mixed-modality batches where a short text sequence joins a long image sequence, the text row's 2000+ zero-padded K/V slots got attention weight, drowning out its real context and garbling tokens. Copy gemma3's approach: build two masks -- `global_mask` from the global BatchKVCache (cache[j-1]) and `sliding_mask` from a BatchRotatingKVCache with `window_size=self.config.sliding_window` supplied -- and dispatch per layer based on `(i + 1) % j == 0`. Both masks are now left-padding-aware via `BatchKVCache.make_mask`, and the sliding-window layers get a correctly-sized mask for their rotating cache. Verified on aya-vision-8b-8bit: - single, concurrent same-image x3, and mixed 2img + 2txt all complete with zero generation errors. - mixed-batch text no longer emits unicode salad; the text sequences produce clean English (the model sometimes asks for an image, which is a template / training quirk, not a batching bug). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Deepseek-VL v1 (model_type: multi_modality) crashed every generation step under continuous batching with `LanguageModel.__call__() got an unexpected keyword argument 'attention_mask_4d'`. The VLM wrapper ships attention_mask_4d / pixel_values / etc. through `InputEmbeddingsFeatures.to_dict()` as decode kwargs, and chunked prefill also passes `n_to_process`. Adding **kwargs lets the fixed positional signature absorb those hints. Verified on mlx-community/deepseek-vl-7b-chat-4bit: single and concurrent same-image x3 now produce coherent descriptions with zero generation errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match the defensive pattern added to aya_vision / multi_modality so llava's LanguageModel absorbs kwargs forwarded by continuous batching (attention_mask_4d from InputEmbeddingsFeatures.to_dict, n_to_process from chunked prefill, etc). Note: llava-1.5 still cannot be served end-to-end in transformers 5.6+ because LlavaProcessor.__call__ routes `padding` into ImagesKwargs and hits additional `int // None` math inside the image processor; those are upstream transformers API drifts separate from this plumbing fix and need their own investigation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mllama's cross-attention used `cache.offset > 0` (scalar compare) and self-attention used `self.rope(..., offset=cache.offset)` (int expected). Both crash under continuous batching where BatchKVCache stores offset as a per-sequence mx.array. Fix cross-attention guard with isinstance dispatch, and coerce self-attention RoPE offset to int (max across seqs for multi-seq batches). Verified on Llama-3.2-11B-Vision-Instruct-4bit: single, concurrent same-image x3, and mixed 2-img + 2-text all complete with coherent output and zero generation errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…p broadcast Three fixes for gemma3n BatchKVCache compatibility: 1. Copy cache.offset before update_and_fetch to prevent in-place mutation from corrupting the RoPE offset used for queries (the root cause of generating EOS immediately on the server). 2. Add make_cache() to Model so _make_cache creates the correct mix of BatchKVCache (full attention) and BatchRotatingKVCache (sliding). 3. Fix altup correct() broadcast: transpose(2,0,1) keeps batch dim consistent for batch size > 1. Also: default pixel_values to None, route inputs_embeds through language_model when provided, and use single cache entries for mask construction. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The conv_1d helper was broadcasting the full kernel array against each padded slice, which fails when the spatial dimension differs from the kernel size. Use per-element kernel[i] instead of the reshaped kernel. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove the elif branch that handled None cache entries — caches are always initialized before this code path runs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_index Follows the Qwen pattern: position_ids, pos_hw, rope_deltas, and the full attention mask are now computed in LanguageModel.get_rope_index() instead of Model._precompute_positions(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace verbose isinstance/ndim/size branching with a clean is_batch flag and single extraction path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…g through kwargs Remove the (B,3,L) → transpose → (3,B,L) round-trip. Position_ids are now set directly on language_model._position_ids during get_input_embeddings, matching the Qwen2-VL pattern. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…thod documentation. Removed verbose comments to enhance code clarity and maintainability.
…ts for clarity and maintainability. Removed redundant comments to enhance code readability.
Only copy cache.offset when it's an mx.array (batch mode) to prevent in-place mutation. Leave int offsets as-is so RoPE uses the fast scalar path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…nd maintainability.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mdkirin
pushed a commit
to mdkirin/mlx-seori
that referenced
this pull request
Apr 26, 2026
…MoE pos fix upstream/main 흡수 (4-19 ~ 4-25 batch). Fork의 핵심 자산은 모두 보존: MTP (mlx-lm 포팅, Qwen3.5 dense+MoE), PrefixCache hybrid, server hardening (MLX_MEMORY_LIMIT_GB env, /v1/status, /v1/models 로드 모델 포함, model pinning, busy tracking, GC threshold, last_request, OOM-위험 startup warmup 제거), 서버사이드 thinking strip + 스트리밍 incremental, null tool_calls 가드. Upstream 흡수: continuous batching server (Blaizzy#1027), DFlash speculative decoding (Blaizzy#1029, Blaizzy#1053 fix), thread-local generation stream (Blaizzy#1050, mlx<0.32 hasattr 가드), batch_generate/server VLM fixes (Blaizzy#1055), Qwen3.5/3.6 MoE stale position IDs + gdn_sink 호환 (Blaizzy#1040), tool-call markup strip (Blaizzy#1037), KV cache quantization (Blaizzy#1030), Qwen2-3.5 VL torch-free 비디오 processors (Blaizzy#1048), Gemma4 LoRA NaN/freeze fix (Blaizzy#1052), Gemma4 video, Youtu-VL, distributed inference 등. 충돌 해결 원칙: fork의 MTP n_confirmed와 upstream의 gdn_sink는 같은 함수에서 공존하도록 시그니처 확장. fork는 Blaizzy#1029(DFlash) 도입 전 시점에서 분기되어 gdn_sink 본체 로직은 우리 모델에서 비활성(None 전달); 단 시그니처는 받아두어 호환성 유지. position_ids 캐시 재사용 시 fork의 ">= cache_offset + seq_length" 체크가 Blaizzy#1040 fix를 더 정교하게 커버. LanguageModelOutput.hidden_states/gdn_states 필드는 upstream 추가분 호환. 검증: 4개 파일 syntax + import OK. M3 96GB에서 mlx 0.31.0 호환 확인. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
BatchQuantizedKVCache(uniformmx.quantize) andBatchTurboQuantKVCachewith fullextend()/filter()support for continuous batching--kv-bits,--kv-group-size,--kv-quant-schemethroughBatchGenerator→ResponseGenerator→ serverbase.pyfor batch TurboQuant (dequantize + standard sdpa)Memory savings (B=4, H=32, S=512, D=128)
Short context (gemma-4-26b-a4b-it, ~80 prompt tokens)
Long context (gemma-4-26b-a4b-it, 20K prompt tokens)
Test plan
BatchQuantizedKVCache(update, filter, extend, state, make_cache, BatchGenerator integration)🤖 Generated with Claude Code