Skip to content

fix(memory): treat MLX fused SDPA as O(n) for all head_dim#1764

Merged
jundot merged 1 commit into
jundot:mainfrom
fqx:fix/preflight-turboquant-kv-dtype
Jun 9, 2026
Merged

fix(memory): treat MLX fused SDPA as O(n) for all head_dim#1764
jundot merged 1 commit into
jundot:mainfrom
fqx:fix/preflight-turboquant-kv-dtype

Conversation

@fqx

@fqx fqx commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Problem

After #1448 fixed VLM nested-config walking, the stale hd > 128 threshold in estimate_prefill_peak_bytes and estimate_chunk_transient_bytes became the dominant source of false-positive 413 rejections for models with head_dim > 128.

For Qwen3.6-VL (head_dim = 256) at 327 K tokens the SDPA estimate was:

Path Formula Result
Old (hd > 128) unfused fallback n_q × chunk × full_kv_len × 4 ~40 GB
Correct fused kernel n_q × chunk × hd × 4 ~32 MB

The ~1200x over-estimate pushed the preflight peak well past the memory ceiling → HTTP 413 on every request.

Root cause

The hd > 128 check assumed MLX only provided a fused (O(n) tiled) SDPA kernel for head_dim ≤ 128, materialising the full float32 attention-score matrix for larger values. That was true for MLX < 0.22 but has been false since then: mx.fast.scaled_dot_product_attention uses the same O(n) Metal tiling for all head_dim values. omlx 0.4.x already requires MLX ≥ 0.31, so the fallback path is never reached.

Fix

Replace the hard-coded sentinel 128 with a module-level constant:

# MLX fused SDPA supports arbitrary head_dim as of MLX 0.22.
# omlx 0.4.x requires MLX >= 0.31 — unfused fallback never executes.
_SDPA_FALLBACK_HEAD_DIM: float = float("inf")

Both estimate_prefill_peak_bytes and estimate_chunk_transient_bytes now use the fused O(n) formula unconditionally. The dead fallback branch is kept in place so it can be revived if a future MLX version re-introduces the restriction.

Note

The turboquant KV dtype fix originally bundled in this PR was merged in #1763. This PR now contains only the SDPA threshold fix, rebased on top of that merge.

Tests

Added TestSdpaThreshold (3 tests):

  • Constant is +inf
  • head_dim=256 at 327 K tokens uses the fused formula in estimate_prefill_peak_bytes
  • head_dim=256 uses the fused formula in estimate_chunk_transient_bytes

All 17 tests in test_memory_monitor_vlm_config.py pass.

The `hd > 128` threshold in `estimate_prefill_peak_bytes` and
`estimate_chunk_transient_bytes` assumed MLX only provided a fused
(O(n) tiled) SDPA kernel for head_dim <= 128, materialising the full
float32 attention-score matrix for larger values.  That assumption was
true for MLX < 0.22 but has been false since then: the MLX Metal SDPA
kernel handles arbitrary head_dim with the same O(n) tiling.  omlx
0.4.x already requires MLX >= 0.31, so the fallback path is never
reached.

The stale threshold caused a ~1200x SDPA over-estimate for
Qwen3.6-VL (head_dim = 256) at 327 K tokens: ~40 GB instead of the
actual ~32 MB, which in turn pushed the preflight peak well past the
memory ceiling and rejected every such request with HTTP 413.

Fix: replace the hard-coded `128` sentinel with a module-level
constant `_SDPA_FALLBACK_HEAD_DIM = float("inf")`.  This leaves the
dead fallback branch in place (easy to revive if a future MLX version
re-introduces the restriction) while ensuring the fused formula is
always selected.

Tests: add TestSdpaThreshold covering the constant value, the fused
formula path for head_dim = 256 at 327 K tokens, and the chunk
transient estimator.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@fqx fqx force-pushed the fix/preflight-turboquant-kv-dtype branch from 98f18a3 to 52054ae Compare June 9, 2026 09:50
@fqx fqx changed the title fix(memory): account for turboquant KV dtype in preflight peak estimate fix(memory): treat MLX fused SDPA as O(n) for all head_dim Jun 9, 2026
@jundot

jundot commented Jun 9, 2026

Copy link
Copy Markdown
Owner

Thanks for isolating this. The current estimator is wrong because it still assumes head_dim > 128 means MLX materializes the full attention score matrix, and that is causing valid long-context VLM requests to fail with false 413s.

This PR fixes the urgent over-rejection problem, so I am going to merge it. One part is still too optimistic, though: current MLX does not seem to be purely output-buffer-only for high head_dim; it still uses some tiled scratch memory. I will land a follow-up on main to model that tiled scratch term conservatively and update the memory-guard tests.

@jundot jundot merged commit a3080a1 into jundot:main Jun 9, 2026
0 of 4 checks passed
@jundot

jundot commented Jun 9, 2026

Copy link
Copy Markdown
Owner

Adding more context on the follow-up I mentioned above.

The stale part in the old estimator was the assumption that head_dim > 128 means current MLX materializes the full fp32 score matrix [q_heads, chunk, kv_len]. That is no longer a good model for MLX 0.31.x, and it caused false 413s for valid long-context VLM requests. This PR fixed that urgent overestimate.

The remaining issue is that the replacement estimate makes high-head-dim SDPA effectively output-buffer-only: q_heads * query_tokens * head_dim * 4. That removes kv_len entirely. In oMLX, that is too optimistic because prefill can reuse a large prefix-cache KV span while only computing a small suffix. For example, a request with cached_tokens=99k and new_tokens=1k only allocates KV for the 1k suffix, but those 1k query tokens still attend over the full ~100k reconstructed KV context.

This matters especially for VLM text configs such as Qwen/Gemma, where the LM dimensions are nested under text_config / language_config and can have q_heads != kv_heads with head_dim=256. KV growth should be estimated from KV heads and new tokens, but SDPA transient is query-head based and still scales with the full attention span.

Local MLX 0.31.x measurements do not match the old full-score estimate, but they also do not match output-only for high head_dim. They are closer to a bounded tiled scratch term, roughly q_heads * min(query_tokens, tile) * kv_len * dtype_size, plus the fp32 output buffer. I merged this PR because it fixes the user-visible false rejection problem; the follow-up keeps that fix while adding the tiled scratch term so admission and adaptive prefill throttling remain conservative enough to avoid later Metal allocation failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants