fix(dflash): report prefix-cache hits as cached_tokens (#1441)#1768
Merged
Conversation
DFlashEngine plumbs the prefix snapshot into stream_dflash_generate (so prefill IS skipped on a hit), but never set cached_tokens on its GenerationOutput — so the API reported cached_tokens: 0 on every turn with DFlash enabled, and the count returned the moment DFlash was disabled (BatchedEngine sets it). The underlying cache works; only the reporting was missing. Surface PrefixCacheFlow.hit_tokens (the matched prompt-token count) as cached_tokens, mirroring BatchedEngine: - _cached_tokens_from_flow() maps a prefix flow to its hit-token count. - non-streaming generate(): thread prefix_flow out of the executor and set it on the output. - streaming: carry it on the final (usage) chunk's metrics so the server's per-chunk sum isn't inflated. Tests: pure-mapping unit tests (hit/miss/None/missing/negative) run locally; a CI-gated end-to-end test asserts generate() sets cached_tokens from a hit.
Owner
|
Thanks for the focused fix. I verified that this keeps the change scoped to DFlash usage reporting: I also ran the focused DFlash tests and a local DFlash smoke check; repeated prompts now report cached tokens on the second request. This looks good to me, and I'm going to merge it. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #1441 (DFlash breaks KV prefix cache —
cached_tokens: 0when DFlash is enabled, restored when disabled).After verifying the issue, it has split in two since it was filed (v0.3.12):
stream_dflash_generate(PrefixCacheFlow+ a persistent L1/L2runtime_context), so prefill is skipped on a hit. I asserted this on the dflash-mlx side: a cold lookup misses, and a repeat of the same 4273-token prompt returnsl1_exactwithmatched_tokens == 4273(prefill skipped) — the existingtest_prefix_cache_hit_kind.pysuite covers it (87 passed).PrefixCacheFlow.hit_tokens) was computed but never mapped onto the output'scached_tokens, so the API always reported 0 with DFlash on.BatchedEngine/VLM already setcached_tokens=output.cached_tokens; DFlashEngine had zerocached_tokensreferences.Fix
Surface
PrefixCacheFlow.hit_tokens(matched prompt tokens) ascached_tokens, mirroring BatchedEngine:_cached_tokens_from_flow(prefix_flow)— pure mapping (hit → count, miss/None/missing/negative → 0).generate()— threadprefix_flowout of the executor_runand setcached_tokenson the output.stream_generate()— carry the count on the final (usage) chunk'smetricsonly, so the server's per-chunktotal_cached_tokens += output.cached_tokenssum isn't inflated; token deltas report 0.Test plan
pytest tests/test_dflash_engine.py -k "TestDFlashCachedTokens and not Wiring"— 5 pure-mapping unit tests (run locally)test_generate_sets_cached_tokens_from_hit— assertsgenerate()setscached_tokensfrom a prefix hit (skips where dflash-mlx is unavailable, runs in CI)tests/test_dflash_engine.py— 55 passed (4 pre-existing_build_runtime_contextenv failures are unrelated, pass in CI);test_output_collector.py/test_server_metrics.py— 63 passedRebased onto current
main.