Skip to content

Add KV cache quantization for continuous batching#1030

Merged
Blaizzy merged 73 commits into
mainfrom
pc/batch-kv-quant
Apr 19, 2026
Merged

Add KV cache quantization for continuous batching#1030
Blaizzy merged 73 commits into
mainfrom
pc/batch-kv-quant

Conversation

@Blaizzy

@Blaizzy Blaizzy commented Apr 17, 2026

Copy link
Copy Markdown
Owner

Summary

  • Adds BatchQuantizedKVCache (uniform mx.quantize) and BatchTurboQuantKVCache with full extend()/filter() support for continuous batching
  • Wires --kv-bits, --kv-group-size, --kv-quant-scheme through BatchGeneratorResponseGenerator → server
  • Skips quantizing the last layer (sensitive in deep models like gemma-4-31b)
  • Adds attention dispatch in base.py for batch TurboQuant (dequantize + standard sdpa)

Memory savings (B=4, H=32, S=512, D=128)

Config KV Memory Reduction
Unquantized 67.1 MB 1x
Uniform 8-bit 18.9 MB 3.6x
Uniform 4-bit 10.5 MB 6.4x

Short context (gemma-4-26b-a4b-it, ~80 prompt tokens)

Config Response Gen tok/s KV Cache Peak Memory
No quant The capital of Japan is Tokyo. 65.2 62.4 MB 51.69 GB
Uniform 8-bit The capital of Japan is Tokyo. 63.8 60.4 MB 51.69 GB
TurboQuant 3.5-bit The capital of Japan is Tokyo. 63.9 58.3 MB 51.69 GB

Long context (gemma-4-26b-a4b-it, 20K prompt tokens)

Config Gen tok/s KV Cache KV Reduction Peak Memory
No quant 50.3 0.624 GB 1x 54.58 GB
Uniform 8-bit 52.6 0.469 GB 1.33x 54.69 GB
TurboQuant 3.5-bit 25.6 0.365 GB 1.71x 54.69 GB

Note: gemma-4 has 25 sliding-window layers (fixed-size RotatingKVCache, unquantized) and only 5 full-attention layers that get quantized. Models with all full-attention layers (e.g. Qwen, LLaMA) will see larger KV reductions. TurboQuant gen tok/s is lower due to dequantize-at-attention (custom Metal kernels not yet batch-aware).

Test plan

  • 16 new unit tests for BatchQuantizedKVCache (update, filter, extend, state, make_cache, BatchGenerator integration)
  • All 429 existing tests pass (no regressions)
  • End-to-end with gemma-4-26b-a4b-it: uniform 8-bit and TurboQuant 3.5-bit produce coherent identical output
  • 20K token context: KV cache savings scale with sequence length (1.33x uniform, 1.71x TurboQuant)
  • Test TurboQuant extend/filter with multi-sequence continuous batching

🤖 Generated with Claude Code

Blaizzy and others added 5 commits April 17, 2026 03:09
Enable both uniform (mx.quantize) and TurboQuant KV cache quantization
in the continuous batching path (BatchGenerator / ResponseGenerator).

- Add BatchQuantizedKVCache in cache.py with extend()/filter() for
  uniform quantization (4-bit: 6.4x, 8-bit: 3.6x memory reduction)
- Add BatchTurboQuantKVCache in turboquant.py with extend()/filter()
  using generic _map_state/_map_state_pair helpers
- Wire kv_bits, kv_group_size, kv_quant_scheme through _make_cache →
  BatchGenerator → ResponseGenerator → server
- Add attention dispatch for BatchTurboQuantKVCache in base.py
  (dequantize + standard sdpa; custom Metal kernels not yet batch-aware)
- Skip quantizing last layer (sensitive in deep models like gemma-4-31b)
- 16 new tests for BatchQuantizedKVCache operations

Tested with gemma-4-26b-a4b-it: identical coherent output across
no-quant, uniform 8-bit, and TurboQuant 3.5-bit modes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds Python logging throughout the server matching mlx-lm's approach:
- DEBUG: request params, response text (truncated), timing, token counts
- INFO: model loading, KV quant config, startup status
- Use --log-level DEBUG to enable verbose output
- Add KV Cache Quantization subsection under Continuous Batching
  with server examples and gemma-4 benchmark table
- Add --log-level, --kv-group-size, --max-kv-size to Server Options
- Note TurboQuant continuous batching support in Compatibility section
- Update table of contents

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The generation loop ran for 500ms before checking for new requests,
causing incoming requests to wait up to half a second before joining
the active batch. Two fixes:

- Reduce deadline from 500ms to 50ms
- Break out of the loop early if new requests are queued

New requests now join the batch within one decode step (~20ms).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Blaizzy Blaizzy force-pushed the pc/batch-kv-quant branch from f6cc5a8 to f01346a Compare April 17, 2026 02:04
Blaizzy and others added 16 commits April 17, 2026 04:07
When --log-level DEBUG is set, the fallback generate() calls now print
token-by-token output and timing stats matching CLI behavior.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When existing requests are decoding and a new request arrives, _next()
previously prefilled ALL pending prompts before doing any decode step,
stalling existing requests for the entire prefill duration.

Now when an active batch exists, _next() prefills only one batch of
prompts then immediately does a decode step. The next _next() call
handles remaining prompts. This interleaves prefill and decode so
existing requests keep generating tokens while new requests are
prefilled incrementally.

First-batch behavior (no active batch) is unchanged — prompts are
still batched together for efficient initial prefill.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Store per-prompt kwargs via insert(prompt_kwargs=) so vision
embeddings survive across deferred prefill steps. Split chunked
prefill into incremental (one chunk per _next() call when an active
batch exists) vs blocking (full prefill when no active batch).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Compile geglu activation for fused GPU execution
- Use mx.fast.rms_norm in Router, compute softmax only over top-k experts
- Simplify Experts forward pass (remove unnecessary reshapes)
- Refactor KV sharing to pass shared KV tensors through layers directly
  instead of using shared cache indices, avoiding redundant cache lookups
- Simplify attention mask creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e server throughput

Rewrite BatchGenerator internals to match mlx-lm's architecture:
- GenerationBatch: double-buffered _step() with async_eval/eval pipelining
- PromptProcessingBatch: separate prompt processing with decode-first ordering
- Batched logprob gathering (one GPU sync vs N per-token syncs)
- Conditional compute_logprobs flag (batch_generate skips logprobs entirely)

Server optimizations (46.3 -> 56.0 tok/s, +21% for 4-request staggered):
- Time-budget loop: run generation steps for up to 0.5s before polling queue
- Non-blocking queue polling when actively generating
- Simplified logprob extraction via pre-computed token_logprob field
- Lazy peak_memory query (only on finish)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of async_eval'ing the full (batch, vocab_size) logprobs tensor
and gathering per-token values later, gather right after sampling in
the same computation graph. This keeps the 262k-element logprobs tensor
as a GPU intermediate that MLX can free immediately, only materializing
the sampled token + its scalar logprob.

Results (Gemma 4 26B-A4B 5-bit):
- Single request: 71.7 -> 84.1 tok/s (matches mlx-lm's 84.7)
- 4-request concurrent: 54.5 -> 70.0 tok/s (exceeds mlx-lm's 69.2)
- Logprobs overhead: 15.5% -> ~0%

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The time-budget loop ran generation steps for up to 500ms before
checking the queue for new requests, causing incoming requests to wait
~250ms (avg) to join an active batch.

Measured impact (4-request staggered test, gemma-4-26b-a4b-it-5bit):
- Insertion latency: 254ms -> 121ms (-52%)
- Throughput: unchanged (~74 tok/s for R0 in both cases)

The +21% throughput attributed to this loop in cfdca51 was actually
from the bundled pre-gathered logprobs and simplified logprob path
optimizations; the budget loop contributes nothing measurable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
BatchGenerator was duplicating the set_wired_limit/restore logic that
already exists in the wired_limit context manager. Consolidate on the
context manager pattern via contextlib.ExitStack so there's one source
of truth for the wired-limit behavior.

Side benefits:
- Same "model size near working set" warning now also emitted for the
  BatchGenerator path (server and batch_generate).
- Proper synchronize-before-restore inherited from wired_limit.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1. First-token loss in GenerationBatch.__init__
   The constructor called _step() to "warm up" the double-buffer
   pipeline, but the return value (the first sampled token from
   prefill) was discarded. Effect: every generated sequence was
   missing its first token. Example: "Hello, world!" → ", world!",
   "10" → "0". Fix: don't call _step() in __init__; let the first
   next() call do it and emit T1.

2. IndexError when extending a batch with logprobs enabled
   PromptProcessingBatch.generate() didn't populate _next_lps on the
   fresh GenerationBatch, so extend() took the `if other._next_lps
   is not None` branch as False and left self._next_lps at its old
   size while _next_tokens grew. Next _step then returned a stale
   smaller lp_list that crashed with `list index out of range` when
   indexed for the larger self.uids. Symptom: parallel requests
   hanging until the 30s client timeout. Fix: gather and store
   per-token logprobs for the first tokens before returning the
   new GenerationBatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two related bugs caused parallel/staggered batched decoding to produce
wrong tokens (garbage or silent divergence) for any sequence after the
first:

1) qwen2_5_vl.py get_input_embeddings() reset the model's cached
   _position_ids and _rope_deltas to None on every text-only request.
   When a new text request arrived while other sequences were actively
   decoding, this destroyed state the ongoing decodes depended on —
   their next forward pass fell into the prefill recomputation path
   and used position_ids for the new (1-token) input, not for their
   real position in the sequence.

2) language.py LanguageModel.__call__ used cache[0].offset[0] as a
   single scalar cache_offset for the whole batch during decode, so
   every sequence got the same RoPE position even though BatchKVCache
   tracks per-sequence offsets (different left_padding per seq).

Fixes:
- Don't clear _position_ids/_rope_deltas on text-only embeddings; the
  prefill path already recomputes them via the cache_offset==0 branch.
- Read the full per-element offset array when the cache is batched
  (ported from qwen3_5/language.py) and build delta/position_ids
  per-sequence so each batch item gets its correct RoPE position.

Verified with 4 identical parallel prompts now produce 4 identical
correct responses (previously all diverged), and mixed image+text
parallel/staggered tests pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Same fix as the Qwen2.5-VL patch: three related bugs caused parallel /
staggered batched decoding to produce garbage or slightly-off tokens
for any sequence after the first.

1) qwen2_vl.py get_input_embeddings() cleared the model's cached
   _position_ids / _rope_deltas on every text-only call, destroying
   state that other sequences still decoding in the batch depended on.

2) language.py LanguageModel.__call__ used a single scalar cache_offset
   (cache[0].offset[0]) for the whole batch, so every sequence got the
   same RoPE position even though BatchKVCache tracks per-sequence
   offsets.

3) Qwen2Model.__call__ passed cache[0] to create_attention_mask, which
   emitted an explicit boolean mask via BatchKVCache.make_mask. That
   path interacted badly with the rest of the attention kernel for
   this model (produced repeated-token garbage). Passing the cache
   list falls through to the "causal" string path, matching how
   qwen2_5_vl already works.

Verified: 4 parallel different prompts → Paris/4/Blue/10 (all correct),
4 staggered → all correct, 4 identical parallel → all produce the same
counting output (previously all diverged into garbled tokens).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Same get_input_embeddings reset bug as qwen2_vl / qwen2_5_vl — every
text-only call cleared the LanguageModel's cached _position_ids /
_rope_deltas, destroying state that other sequences in the current
batch were relying on mid-decode.

qwen3_5_moe inherits get_input_embeddings from qwen3_5.Model, so this
fix covers both qwen3_5 and qwen3_5_moe (verified on
mlx-community/Qwen3.5-35B-A3B-4bit: 4 parallel different prompts now
produce 4 correct distinct answers; previously they all got the same
stale response).

The rest of the qwen3_5 code already handles batched per-sequence
cache offsets correctly — the LanguageModel.__call__ was where this
patch series borrowed its per-element offsets pattern from.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a client aborted a /v1/chat/completions stream mid-generation, the
server kept decoding tokens for that uid until natural EOS / max_tokens
and pushed them into an abandoned rqueue. That wasted compute, leaked
memory, and tied up batch slots that other requests could have used
(worst case: max_tokens=8192 aborted after 1 token → ~8000 tokens of
wasted decode plus a slot lost for the whole duration).

Implements cooperative cancellation end-to-end:

- BatchGenerator.remove(uid): removes a sequence by uid from whichever
  stage it's in (unprocessed queue, in-flight prefill when it's the
  only sequence, or generation batch). Falls through to the gen batch
  if the uid was in a multi-sequence prefill so a concurrent removal
  still catches it after transition.

- ResponseGenerator: new thread-safe _cancelled set + _cancel(uid).
  The per-request token_iterator tracks whether it ended naturally
  (None sentinel / exception) and calls _cancel(uid) from its finally
  block otherwise. GeneratorExit from fastapi closing the async
  response now reaches here via the standard generator protocol.

- _run drains cancelled uids every outer iteration and calls
  batch_gen.remove(uid) + drops the active entry (pushing a final
  None onto the rqueue so any lingering consumer unblocks cleanly).

Verified end-to-end: aborting a 2000-token stream after 0.5s and
immediately issuing a 500-token follow-up, the follow-up runs at
baseline single-request tok/s (83 vs 82) — i.e., the slot was
actually released. Concurrent case: when one of two parallel streams
aborts, the other immediately reverts to batch-size-1 throughput.

Adds unit tests for BatchGenerator.remove covering the unprocessed
queue path and the missing-uid case.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two fixes for the cancellation path from the previous commit:

1) token_iterator.finally was calling _cancel(uid) on any close, including
   the clean case where stream_generator breaks out of its loop after
   seeing finish_reason and then closes the iterator. The None sentinel
   the GPU thread pushes after the final token wasn't being consumed
   before close, so `ended` stayed False and every completed stream
   triggered a bogus cancellation request.

   Fix: mark `ended = True` as soon as a token with finish_reason is
   received (before yielding it) and break after that yield, so normal
   completion is distinguishable from a client abort.

2) GenerationBatch.filter now clears _current_tokens / _current_lps when
   the batch is emptied, matching the existing _next_tokens / _next_lps
   clearing. Without this, stale tensors from the removed sequences were
   kept alive until the next _step overwrote them.

Verified on Gemma-4-26B-A4B: cancellations now only fire on actual
aborts. First follow-up after an abort takes a ~10% one-time hit
(pending async work for the aborted sequence draining on the GPU
stream); subsequent runs return to baseline.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
GenerationBatch.next() wraps its work in `with mx.stream(generation_stream)`
so that the cache filter at the end of a natural completion runs on the
same stream as the preceding async evals. remove() was calling
GenerationBatch.filter() (and the unprocessed / prompt_batch cleanups)
from whichever stream the caller happened to be on, which meant the
cache slice ops for a cancelled sequence could reorder against pending
async work.

Wrap the whole body of BatchGenerator.remove() in the same stream
context so cancellation cleanup is ordered identically to natural
completion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Removes comments that just label what the next line does
("Batch state", "Stats counters", "Create batch cache", "Insert new
requests into batch", "Non-blocking poll when actively generating",
etc.) and trims a few multi-line blocks down to one or two lines
focused on the non-obvious reasoning. No behavior change.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Blaizzy and others added 7 commits April 17, 2026 23:09
Merges the chunked-prefill fix from #1032 (thanks @pcuenca) with the
parallel-batching fix for stateful _rope_deltas already on this branch.
The two interact through the model's cached _position_ids/_rope_deltas;
combining them cleanly needs a small adjustment on each side:

- PromptProcessingBatch no longer resets _position_ids / _rope_deltas
  on construction. get_input_embeddings is the only owner now.

- qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe LanguageModel.__call__
  drop the `cache_offset > 0` guard and always slice from the
  precomputed _position_ids when available. Needed so the first chunk
  of a chunked VLM prefill uses the same positional layout that
  get_input_embeddings laid out for the full input (fixes PR #1032's
  repro: mlx-community/Qwen3-VL-30B-A3B-Instruct-bf16 describing a
  large image).

- qwen2_vl / qwen2_5_vl / qwen3_5 Model.get_input_embeddings, when
  pixel_values is None, now resets _position_ids but leaves
  _rope_deltas alone. Clearing only _position_ids keeps the chunked
  prefill slice safe (no stale image-shaped tensor leaking into a
  text prefill) while preserving _rope_deltas so other sequences that
  are still decoding in the batch keep their correct per-sequence
  positional state — which is what the original fix on this branch
  needed for parallel identical-text to be deterministic.

Verified across Qwen2.5-VL-3B, Qwen3.5-35B-A3B and Gemma-4-26B:
- 4 parallel identical prompts → all identical
- 4 parallel different prompts → distinct correct answers
- Staggered image+text → all coherent
- Mixed image+text concurrent decoding → both sequences complete;
  text now matches single-run baseline on Qwen3.5, image now matches
  on Gemma-4 (Qwen2.5-VL still diverges slightly from single-run
  baseline due to the single-slot _rope_deltas during concurrent
  decode — a separate, pre-existing limitation)
- Abort + follow-up → unchanged

Gemma-4 throughput (single + parallel 2/3/4) unchanged within noise.

Co-Authored-By: Pedro Cuenca <pedro@huggingface.co>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extend the parallel-batching fix from earlier to the six VLMs that
still reset both _position_ids and _rope_deltas in their text-only
get_input_embeddings path:

- glm4v, glm4v_moe, glm_ocr
- paddleocr_vl
- qwen3_vl, qwen3_vl_moe

Each now only clears _position_ids (to keep the chunked-prefill slice
safe) and leaves _rope_deltas alone so other sequences still decoding
in the current batch keep their per-sequence positional state.

Verified qwen3_vl parallel identical text decodes deterministically
for three of four slots — the remaining divergence is floating-point
non-determinism at a close-call argmax (batched matmul vs single),
not a fix gap.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds per-sequence _rope_deltas tracking on GenerationBatch (concatenated
by extend, sliced by filter) so sequences with different deltas can
share a batch without aliasing self._rope_deltas on the model.

- PromptProcessingBatch.generate snapshots the model's _rope_deltas
  right after prefill and hands it to the new GenerationBatch, shape
  normalized to (B, 1). Models without _rope_deltas stay untouched.

- GenerationBatch._step passes rope_deltas=self._rope_deltas as a
  kwarg on every decode forward, mirroring how n_to_process is
  threaded through prefill.

- qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe / qwen3_5
  LanguageModel.__call__ now accept the kwarg and prefer it over
  self._rope_deltas in the decode branch.

Verified on concurrent image + text decoding (image stream started,
text stream joins mid-decode):
  Qwen2.5-VL-3B : image matches single-run baseline (was diverging).
  Qwen3.5-35B   : both image and text match baseline.
  Gemma-4       : unaffected (no _rope_deltas on its LM).

All 49 existing unit tests still pass. Parallel identical / different
text, staggered, and abort + follow-up all continue to work.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Apply the same kwarg plumbing as the Qwen family to the rest of the
VLMs that cache _rope_deltas on the LanguageModel:

  - glm4v, glm4v_moe, glm_ocr
  - paddleocr_vl
  - ernie4_5_moe_vl
  - qwen3_omni_moe

Each LanguageModel.__call__ now pops `rope_deltas` from kwargs and
prefers that over self._rope_deltas in the decode branch, so
sequences with different deltas can share a batch without aliasing
the single model-level slot.

qwen3_5_moe inherits __call__ from qwen3_5.LanguageModel and is
covered transitively.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both models guard _position_ids reuse behind `cache_offset > 0`, so
the first chunk of a long image prefill falls into the get_rope_index
recompute path and sees only a slice of input_ids — missing image
token context and producing wrong positions. Dropping the guard lets
the precomputed _position_ids (set up by get_input_embeddings with
the full input) be sliced for every chunk, matching the fix from
#1032 already applied to qwen2_vl / qwen2_5_vl / qwen3_vl /
qwen3_vl_moe.

(The two falcon VLMs that also had `cache_offset > 0` guards use a
different _rope_delta path that's decode-only, not the chunked-
prefill bug.)

Verified:
  mlx-community/PaddleOCR-VL-bfloat16  +  webpage_1.png (3557 tok,
    chunked) produces the expected Chinese news summary.
  tencent/HunyuanOCR  +  webpage_1.png (4161 tok, chunked) produces
    the expected "俄中传媒资讯网" site description.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both models assumed cache.offset was always a scalar and called
offset.item(), which blows up on BatchKVCache's multi-element offset
arrays (ValueError: Only length-1 arrays can be converted to Python
scalars). Using the first element's offset would have avoided the
crash but left every sequence in a batch sharing the same RoPE
encoding — correct only when all sequences happen to be at the same
position.

Now when cache.offset has more than one element, build per-sequence
position_ids of shape (B, L) = base_offset + arange(L), and let
apply_rotary_emb_1d reshape cos/sin with the batch dim when the
indexed tables come out 3D. Single-sequence usage still takes the
original 2D reshape path.

Verified:
  mlx-community/Falcon-OCR-bf16 single-request paper.png → unchanged
    correct author-list extraction.
  4 parallel identical image requests → all 4 sequences emit the same
    correct paper header (previously server errored on offset.item()).
  tiiuae/Falcon-Perception generate_perception() on cats.jpg → still
    returns the two correct detections.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Blaizzy and others added 13 commits April 18, 2026 23:43
microsoft/Phi-4-reasoning-vision-15B ships a custom
modeling_phi4_visionr.py referenced from config.json via auto_map.
AutoTokenizer.from_pretrained resolves its class through AutoConfig,
which imports that remote file at class-definition time -- and the file
uses `siglip2_ips.filter_out_non_signature_kwargs()` as a decorator, a
helper that was moved in transformers 5.6+. The import therefore raises
AttributeError before our install_auto_processor_patch gets to call our
own Phi4SigLipProcessor implementation.

Phi4SigLipProcessor is self-contained and doesn't need the remote
modeling classes -- only the tokenizer artifacts (tokenizer.json /
tokenizer_config.json) which are standard. Pop `trust_remote_code` from
the forwarded kwargs and call AutoTokenizer.from_pretrained with
trust_remote_code=False, bypassing the remote auto_map chain.

Verified: Phi-4-reasoning-vision-15B now loads and serves; single and
concurrent same-image requests produce coherent image descriptions with
zero generation errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Moondream3's attention did `mx.arange(cache.offset, cache.offset + L)`
and `self.rope(..., offset=cache.offset)`, both of which expect a Python
int. Under continuous batching our BatchKVCache always stores
`cache.offset` as an `mx.array`, so the server crashed with
`arange(): incompatible function arguments` on every decode step.

Collapse the offset array to an int (the single value for B=1, or max
across sequences for B>1) before calling arange/RoPE. Single-request
serving now works end-to-end; same-image concurrent decodes correctly
for the first sequence but subsequent sequences still degrade, mirroring
the pre-existing qwen2_vl-2B concurrent sensitivity -- proper per-row
positions for heterogeneous-length batches need a larger refactor of
Tau/RoPE to accept (B, L) positions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Builds on the previous arange-offset coercion so moondream3 handles
BatchKVCache properly, not just degenerate single-seq cases:

  - Attention: keep the full per-seq offset array when the cache carries
    one. Build positions as `offsets[:, None] + arange(L)` so each row
    receives its own position values instead of the batch-wide max.
  - Tau.__call__: accept 1-D shared positions (existing path) or 2-D
    per-sequence positions (B, L). For the 2-D case, broadcast alpha
    across the batch and keep tau_pos shape (B, n_heads, L).
  - RoPE: mx.fast.rope only accepts scalar offsets, so for multi-seq
    batches apply RoPE per row and concatenate. Single-seq continues to
    hit the fast path.

Verified on moondream3-preview: 3x same-image concurrent now produces
coherent per-seq descriptions (previously seqs 1-2 emitted
<|md_reserved_4|> repetitions).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…he.make_mask

Two interacting issues:

1. Attention.__call__ sliced `mask` to `keys.shape[-2]` *before*
   `cache.update_and_fetch`. During chunked prefill (keys just the new
   chunk, but the cache already holds previous chunks), this trimmed
   the mask to the current-chunk column count. After update_and_fetch
   the returned keys include both cached and new tokens, so the
   (sliced) mask and (extended) keys had mismatched lengths -- fine
   when mask was `None`/`"causal"` from the old unmasked decode path,
   but broken as soon as a real tensor mask reached this point.

2. TextModel.__call__ passed the cache *list* to
   `create_attention_mask`, which bypassed `BatchKVCache.make_mask`
   and produced an undersized mask for concurrent batches with
   per-sequence left_padding, contributing to the garbled tail on
   mixed-modality text responses.

Swap the order so the mask is sliced against the post-update key
length, and route mask construction through `cache[0]`. Verified on
InternVL3-8B-bf16: single, concurrent same-image x3 (now identical
per-seq outputs), and mixed 2-img + 2-txt all complete with zero
generation errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Continuous-batching chunked prefill passes `n_to_process=<chunk_size>`
(and any other prompt kwargs stashed on the VLM wrapper) when it calls
`self.model(...)`. aya_vision's `LanguageModel.__call__` had a
strictly-positional signature, so the server crashed every decode step
with
`LanguageModel.__call__() got an unexpected keyword argument 'n_to_process'`.

Add **kwargs so the unused hints are accepted. Verified on
aya-vision-8b-8bit: single and concurrent same-image x3 now produce
coherent descriptions, zero generation errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
aya_vision interleaves global KVCache layers with RotatingKVCache
sliding-window layers (sliding_window_pattern=4), but TextModel.__call__
built a single mask from `cache[j - 1 : j]` -- a Python list slice that
doesn't expose `make_mask`, so upstream `create_attention_mask` fell
through to "causal" / None and dropped the per-sequence `left_padding`.
For mixed-modality batches where a short text sequence joins a long
image sequence, the text row's 2000+ zero-padded K/V slots got
attention weight, drowning out its real context and garbling tokens.

Copy gemma3's approach: build two masks -- `global_mask` from the
global BatchKVCache (cache[j-1]) and `sliding_mask` from a
BatchRotatingKVCache with `window_size=self.config.sliding_window`
supplied -- and dispatch per layer based on `(i + 1) % j == 0`. Both
masks are now left-padding-aware via `BatchKVCache.make_mask`, and the
sliding-window layers get a correctly-sized mask for their rotating
cache.

Verified on aya-vision-8b-8bit:
- single, concurrent same-image x3, and mixed 2img + 2txt all complete
  with zero generation errors.
- mixed-batch text no longer emits unicode salad; the text
  sequences produce clean English (the model sometimes asks for an
  image, which is a template / training quirk, not a batching bug).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Deepseek-VL v1 (model_type: multi_modality) crashed every generation
step under continuous batching with
`LanguageModel.__call__() got an unexpected keyword argument 'attention_mask_4d'`.

The VLM wrapper ships attention_mask_4d / pixel_values / etc. through
`InputEmbeddingsFeatures.to_dict()` as decode kwargs, and chunked
prefill also passes `n_to_process`. Adding **kwargs lets the fixed
positional signature absorb those hints.

Verified on mlx-community/deepseek-vl-7b-chat-4bit: single and
concurrent same-image x3 now produce coherent descriptions with zero
generation errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match the defensive pattern added to aya_vision / multi_modality so
llava's LanguageModel absorbs kwargs forwarded by continuous batching
(attention_mask_4d from InputEmbeddingsFeatures.to_dict, n_to_process
from chunked prefill, etc).

Note: llava-1.5 still cannot be served end-to-end in transformers 5.6+
because LlavaProcessor.__call__ routes `padding` into ImagesKwargs and
hits additional `int // None` math inside the image processor; those
are upstream transformers API drifts separate from this plumbing fix
and need their own investigation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mllama's cross-attention used `cache.offset > 0` (scalar compare) and
self-attention used `self.rope(..., offset=cache.offset)` (int expected).
Both crash under continuous batching where BatchKVCache stores offset as
a per-sequence mx.array.

Fix cross-attention guard with isinstance dispatch, and coerce
self-attention RoPE offset to int (max across seqs for multi-seq
batches). Verified on Llama-3.2-11B-Vision-Instruct-4bit: single,
concurrent same-image x3, and mixed 2-img + 2-text all complete with
coherent output and zero generation errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…p broadcast

Three fixes for gemma3n BatchKVCache compatibility:

1. Copy cache.offset before update_and_fetch to prevent in-place mutation
   from corrupting the RoPE offset used for queries (the root cause of
   generating EOS immediately on the server).

2. Add make_cache() to Model so _make_cache creates the correct mix of
   BatchKVCache (full attention) and BatchRotatingKVCache (sliding).

3. Fix altup correct() broadcast: transpose(2,0,1) keeps batch dim
   consistent for batch size > 1.

Also: default pixel_values to None, route inputs_embeds through
language_model when provided, and use single cache entries for mask
construction.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The conv_1d helper was broadcasting the full kernel array against each
padded slice, which fails when the spatial dimension differs from the
kernel size. Use per-element kernel[i] instead of the reshaped kernel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Blaizzy and others added 13 commits April 19, 2026 15:33
Remove the elif branch that handled None cache entries — caches are
always initialized before this code path runs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_index

Follows the Qwen pattern: position_ids, pos_hw, rope_deltas, and the
full attention mask are now computed in LanguageModel.get_rope_index()
instead of Model._precompute_positions().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace verbose isinstance/ndim/size branching with a clean is_batch
flag and single extraction path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…g through kwargs

Remove the (B,3,L) → transpose → (3,B,L) round-trip. Position_ids are
now set directly on language_model._position_ids during
get_input_embeddings, matching the Qwen2-VL pattern.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…thod documentation. Removed verbose comments to enhance code clarity and maintainability.
…ts for clarity and maintainability. Removed redundant comments to enhance code readability.
Only copy cache.offset when it's an mx.array (batch mode) to prevent
in-place mutation. Leave int offsets as-is so RoPE uses the fast scalar
path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Blaizzy Blaizzy merged commit b027538 into main Apr 19, 2026
1 check passed
mdkirin pushed a commit to mdkirin/mlx-seori that referenced this pull request Apr 26, 2026
…MoE pos fix

upstream/main 흡수 (4-19 ~ 4-25 batch). Fork의 핵심 자산은 모두 보존:
MTP (mlx-lm 포팅, Qwen3.5 dense+MoE), PrefixCache hybrid, server hardening
(MLX_MEMORY_LIMIT_GB env, /v1/status, /v1/models 로드 모델 포함, model
pinning, busy tracking, GC threshold, last_request, OOM-위험 startup
warmup 제거), 서버사이드 thinking strip + 스트리밍 incremental, null
tool_calls 가드.

Upstream 흡수: continuous batching server (Blaizzy#1027), DFlash speculative
decoding (Blaizzy#1029, Blaizzy#1053 fix), thread-local generation stream (Blaizzy#1050,
mlx<0.32 hasattr 가드), batch_generate/server VLM fixes (Blaizzy#1055), Qwen3.5/3.6
MoE stale position IDs + gdn_sink 호환 (Blaizzy#1040), tool-call markup strip
(Blaizzy#1037), KV cache quantization (Blaizzy#1030), Qwen2-3.5 VL torch-free 비디오
processors (Blaizzy#1048), Gemma4 LoRA NaN/freeze fix (Blaizzy#1052), Gemma4 video,
Youtu-VL, distributed inference 등.

충돌 해결 원칙: fork의 MTP n_confirmed와 upstream의 gdn_sink는 같은
함수에서 공존하도록 시그니처 확장. fork는 Blaizzy#1029(DFlash) 도입 전 시점에서
분기되어 gdn_sink 본체 로직은 우리 모델에서 비활성(None 전달); 단
시그니처는 받아두어 호환성 유지. position_ids 캐시 재사용 시 fork의
">= cache_offset + seq_length" 체크가 Blaizzy#1040 fix를 더 정교하게 커버.
LanguageModelOutput.hidden_states/gdn_states 필드는 upstream 추가분 호환.

검증: 4개 파일 syntax + import OK. M3 96GB에서 mlx 0.31.0 호환 확인.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant