fix(memory): treat MLX fused SDPA as O(n) for all head_dim#1764
Conversation
The `hd > 128` threshold in `estimate_prefill_peak_bytes` and
`estimate_chunk_transient_bytes` assumed MLX only provided a fused
(O(n) tiled) SDPA kernel for head_dim <= 128, materialising the full
float32 attention-score matrix for larger values. That assumption was
true for MLX < 0.22 but has been false since then: the MLX Metal SDPA
kernel handles arbitrary head_dim with the same O(n) tiling. omlx
0.4.x already requires MLX >= 0.31, so the fallback path is never
reached.
The stale threshold caused a ~1200x SDPA over-estimate for
Qwen3.6-VL (head_dim = 256) at 327 K tokens: ~40 GB instead of the
actual ~32 MB, which in turn pushed the preflight peak well past the
memory ceiling and rejected every such request with HTTP 413.
Fix: replace the hard-coded `128` sentinel with a module-level
constant `_SDPA_FALLBACK_HEAD_DIM = float("inf")`. This leaves the
dead fallback branch in place (easy to revive if a future MLX version
re-introduces the restriction) while ensuring the fused formula is
always selected.
Tests: add TestSdpaThreshold covering the constant value, the fused
formula path for head_dim = 256 at 327 K tokens, and the chunk
transient estimator.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
98f18a3 to
52054ae
Compare
|
Thanks for isolating this. The current estimator is wrong because it still assumes head_dim > 128 means MLX materializes the full attention score matrix, and that is causing valid long-context VLM requests to fail with false 413s. This PR fixes the urgent over-rejection problem, so I am going to merge it. One part is still too optimistic, though: current MLX does not seem to be purely output-buffer-only for high head_dim; it still uses some tiled scratch memory. I will land a follow-up on main to model that tiled scratch term conservatively and update the memory-guard tests. |
|
Adding more context on the follow-up I mentioned above. The stale part in the old estimator was the assumption that The remaining issue is that the replacement estimate makes high-head-dim SDPA effectively output-buffer-only: This matters especially for VLM text configs such as Qwen/Gemma, where the LM dimensions are nested under Local MLX 0.31.x measurements do not match the old full-score estimate, but they also do not match output-only for high head_dim. They are closer to a bounded tiled scratch term, roughly |
Problem
After #1448 fixed VLM nested-config walking, the stale
hd > 128threshold inestimate_prefill_peak_bytesandestimate_chunk_transient_bytesbecame the dominant source of false-positive 413 rejections for models withhead_dim > 128.For Qwen3.6-VL (
head_dim = 256) at 327 K tokens the SDPA estimate was:hd > 128) unfused fallbackn_q × chunk × full_kv_len × 4n_q × chunk × hd × 4The ~1200x over-estimate pushed the preflight peak well past the memory ceiling → HTTP 413 on every request.
Root cause
The
hd > 128check assumed MLX only provided a fused (O(n) tiled) SDPA kernel forhead_dim ≤ 128, materialising the full float32 attention-score matrix for larger values. That was true for MLX < 0.22 but has been false since then:mx.fast.scaled_dot_product_attentionuses the same O(n) Metal tiling for all head_dim values. omlx 0.4.x already requires MLX ≥ 0.31, so the fallback path is never reached.Fix
Replace the hard-coded sentinel
128with a module-level constant:Both
estimate_prefill_peak_bytesandestimate_chunk_transient_bytesnow use the fused O(n) formula unconditionally. The dead fallback branch is kept in place so it can be revived if a future MLX version re-introduces the restriction.Note
The turboquant KV dtype fix originally bundled in this PR was merged in #1763. This PR now contains only the SDPA threshold fix, rebased on top of that merge.
Tests
Added
TestSdpaThreshold(3 tests):+infhead_dim=256at 327 K tokens uses the fused formula inestimate_prefill_peak_byteshead_dim=256uses the fused formula inestimate_chunk_transient_bytesAll 17 tests in
test_memory_monitor_vlm_config.pypass.