Skip to content

refactor(mp): replace job_id with request_id in query_prefetch_status#2996

Merged
maobaolong merged 4 commits intoLMCache:devfrom
yoo-kumaneko:feature/replace-job-id-with-request-id
Apr 14, 2026
Merged

refactor(mp): replace job_id with request_id in query_prefetch_status#2996
maobaolong merged 4 commits intoLMCache:devfrom
yoo-kumaneko:feature/replace-job-id-with-request-id

Conversation

@yoo-kumaneko
Copy link
Copy Markdown
Contributor

@yoo-kumaneko yoo-kumaneko commented Apr 10, 2026

The server-internal integer job_id was unnecessarily leaked to the vLLM adapter via _lookup_job_ids / _finished_lookup_jobs, adding a redundant layer of indirection. The external request_id (str) is the natural key and is already stored in every _PrefetchJob.

No changes to storage_manager.py or prefetch_controller.py; the internal prefetch_request_id (int) inside PrefetchHandle is unaffected.

Signed-off-by: crclq2018@gmail.com

What this PR does / why we need it:

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

Note

Medium Risk
Changes the multiprocess wire protocol for LOOKUP/QUERY_PREFETCH_STATUS (and related client/server state tracking) to use request_id, which could break older clients/servers and affects lookup/prefetch correctness under retries.

Overview
Refactors the multiprocess lookup/prefetch flow to stop leaking server-internal job_ids and instead track/poll prefetch jobs purely by external request_id.

LOOKUP now returns None and the server stores _prefetch_jobs keyed by request_id; QUERY_PREFETCH_STATUS/QUERY_PREFETCH_LOOKUP_HITS payloads switch from int job IDs to str request IDs, with unknown requests returning 0 to avoid infinite polling. The vLLM scheduler adapter updates its lookup bookkeeping to cache results by request_id (supporting repeated checks after server pop) and adjusts cleanup accordingly, with tests updated to match the new protocol.

Reviewed by Cursor Bugbot for commit c3c13a3. Bugbot is set up for automated code reviews on this repo. Configure here.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the multiprocess lookup protocol to use client-provided request_ids instead of server-generated job_ids for tracking prefetch jobs. It removes the QUERY_PREFETCH_LOOKUP_HITS request type and updates the LOOKUP and QUERY_PREFETCH_STATUS protocols to use the request_id. A critical issue was identified in query_prefetch_status where returning None for an unknown request_id could lead to infinite polling; returning 0 is suggested instead to allow the client to terminate. Additionally, a private method should be moved to the end of the class to comply with the project's style guide.

Comment thread lmcache/v1/multiprocess/server.py Outdated
Comment thread lmcache/v1/multiprocess/server.py
Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py Outdated
The server-internal integer job_id was unnecessarily leaked to the vLLM
adapter via _lookup_job_ids / _finished_lookup_jobs, adding a redundant
layer of indirection. The external request_id (str) is the natural key
and is already stored in every _PrefetchJob.

Changes:
- _prefetch_jobs is now keyed by request_id (str) instead of job_id (int)
- _next_prefetch_job_id removed; _register_prefetch_job() stores by request_id
- lookup() now returns None; job tracking is entirely server-side
- query_prefetch_status(request_id: str) replaces the int-based version
- query_prefetch_lookup_hits removed: it was only called internally by
  sync_lookup (feature branch) and had no external callers on dev; the
  QUERY_PREFETCH_LOOKUP_HITS RPC endpoint is removed accordingly
- Adapter simplified: _lookup_job_ids replaced by _pending_lookups: set[str]
  and _finished_lookup_results: dict[str, int] (cache keyed by request_id);
  polls QUERY_PREFETCH_STATUS by request_id; cached result handles repeated
  calls to check_lookup_result after the server has popped the job
- Protocol updated: LOOKUP response_class=None, QUERY_PREFETCH_STATUS
  payload_classes=[str], QUERY_PREFETCH_LOOKUP_HITS removed
- tests/v1/multiprocess/test_query_lookup_hits.py deleted
- All other affected tests updated

No changes to storage_manager.py or prefetch_controller.py; the internal
prefetch_request_id (int) inside PrefetchHandle is unaffected.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from ee3e7da to 47554d1 Compare April 11, 2026 17:47
@yoo-kumaneko yoo-kumaneko changed the title [Draft] refactor(mp): replace job_id with request_id in query_prefetch_status refactor(mp): replace job_id with request_id in query_prefetch_status Apr 12, 2026
Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chunxiaozheng chunxiaozheng added the mp Buildkite trigger for multi-processing mode test label Apr 12, 2026
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from 375db3a to 4712304 Compare April 12, 2026 07:41
Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from 4712304 to 9ae816c Compare April 12, 2026 07:49
Signed-off-by: rigginschen <rigginschen@tencent.com>
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from 9ae816c to 1038a60 Compare April 12, 2026 15:43
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential problem in the query_prefetch_status's return value. Please see the details below.

Comment thread lmcache/v1/multiprocess/server.py Outdated
Comment on lines +665 to +668
The number of hits for the prefetched keys if the lookup phase is
done. None if the lookup phase is still in progress, or the prefetch
is already completed and consumed by query_prefetch_status, or the
request_id is invalid.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential problem here:
Previously, since we have the job id, the caller should be able to avoid calling this function with an unknown job id.
However, since we changed to using an external request id, there could be a case where the caller calls this function with an unknown request id. If we still return None here, it may cause the caller to spin on this request forever because it may interpret the returned None as "lookup is in progress".

Therefore, I suggest we return 0 if the request ID is not found. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. We should never return None in this case. Let's return 0 and emit a warning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch 2 times, most recently from 85a7052 to 6562bbc Compare April 13, 2026 03:57
Comment thread examples/disagg_prefill/disagg_proxy_server.py
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from 6562bbc to 85a7052 Compare April 13, 2026 05:03
When query_prefetch_status or query_prefetch_lookup_hits is called with
an unknown request_id, returning None would cause the caller to
interpret it as "still in progress" and spin forever. Return 0 instead
to signal no hits and allow the caller to terminate.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from 85a7052 to d176fd3 Compare April 13, 2026 06:02
@yoo-kumaneko yoo-kumaneko requested a review from ApostaC April 13, 2026 06:12
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit e0535d8. Configure here.

Comment thread tests/v1/multiprocess/test_query_lookup_hits.py
Signed-off-by: rigginschen <rigginschen@tencent.com>
@yoo-kumaneko yoo-kumaneko force-pushed the feature/replace-job-id-with-request-id branch from e0535d8 to c3c13a3 Compare April 13, 2026 07:24
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ApostaC ApostaC enabled auto-merge (squash) April 13, 2026 17:39
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Apr 13, 2026
@maobaolong maobaolong disabled auto-merge April 14, 2026 03:06
@maobaolong maobaolong merged commit 0784a72 into LMCache:dev Apr 14, 2026
40 checks passed
@github-actions github-actions Bot removed the full Run comprehensive tests on this PR label Apr 14, 2026
ekaynar pushed a commit to ekaynar/LMCache that referenced this pull request Apr 15, 2026
…LMCache#2996)

* refactor(mp): replace job_id with request_id in query_prefetch_status

The server-internal integer job_id was unnecessarily leaked to the vLLM
adapter via _lookup_job_ids / _finished_lookup_jobs, adding a redundant
layer of indirection. The external request_id (str) is the natural key
and is already stored in every _PrefetchJob.

Changes:
- _prefetch_jobs is now keyed by request_id (str) instead of job_id (int)
- _next_prefetch_job_id removed; _register_prefetch_job() stores by request_id
- lookup() now returns None; job tracking is entirely server-side
- query_prefetch_status(request_id: str) replaces the int-based version
- query_prefetch_lookup_hits removed: it was only called internally by
  sync_lookup (feature branch) and had no external callers on dev; the
  QUERY_PREFETCH_LOOKUP_HITS RPC endpoint is removed accordingly
- Adapter simplified: _lookup_job_ids replaced by _pending_lookups: set[str]
  and _finished_lookup_results: dict[str, int] (cache keyed by request_id);
  polls QUERY_PREFETCH_STATUS by request_id; cached result handles repeated
  calls to check_lookup_result after the server has popped the job
- Protocol updated: LOOKUP response_class=None, QUERY_PREFETCH_STATUS
  payload_classes=[str], QUERY_PREFETCH_LOOKUP_HITS removed
- tests/v1/multiprocess/test_query_lookup_hits.py deleted
- All other affected tests updated

No changes to storage_manager.py or prefetch_controller.py; the internal
prefetch_request_id (int) inside PrefetchHandle is unaffected.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>

* fix(mp): return 0 for unknown request_id to prevent infinite polling

When query_prefetch_status or query_prefetch_lookup_hits is called with
an unknown request_id, returning None would cause the caller to
interpret it as "still in progress" and spin forever. Return 0 instead
to signal no hits and allow the caller to terminate.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>

---------

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>
Co-authored-by: rigginschen <rigginschen@tencent.com>
ftian1 pushed a commit to ftian1/LMCache that referenced this pull request Apr 20, 2026
…LMCache#2996)

* refactor(mp): replace job_id with request_id in query_prefetch_status

The server-internal integer job_id was unnecessarily leaked to the vLLM
adapter via _lookup_job_ids / _finished_lookup_jobs, adding a redundant
layer of indirection. The external request_id (str) is the natural key
and is already stored in every _PrefetchJob.

Changes:
- _prefetch_jobs is now keyed by request_id (str) instead of job_id (int)
- _next_prefetch_job_id removed; _register_prefetch_job() stores by request_id
- lookup() now returns None; job tracking is entirely server-side
- query_prefetch_status(request_id: str) replaces the int-based version
- query_prefetch_lookup_hits removed: it was only called internally by
  sync_lookup (feature branch) and had no external callers on dev; the
  QUERY_PREFETCH_LOOKUP_HITS RPC endpoint is removed accordingly
- Adapter simplified: _lookup_job_ids replaced by _pending_lookups: set[str]
  and _finished_lookup_results: dict[str, int] (cache keyed by request_id);
  polls QUERY_PREFETCH_STATUS by request_id; cached result handles repeated
  calls to check_lookup_result after the server has popped the job
- Protocol updated: LOOKUP response_class=None, QUERY_PREFETCH_STATUS
  payload_classes=[str], QUERY_PREFETCH_LOOKUP_HITS removed
- tests/v1/multiprocess/test_query_lookup_hits.py deleted
- All other affected tests updated

No changes to storage_manager.py or prefetch_controller.py; the internal
prefetch_request_id (int) inside PrefetchHandle is unaffected.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>

* fix(mp): return 0 for unknown request_id to prevent infinite polling

When query_prefetch_status or query_prefetch_lookup_hits is called with
an unknown request_id, returning None would cause the caller to
interpret it as "still in progress" and spin forever. Return 0 instead
to signal no hits and allow the caller to terminate.

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>

---------

Signed-off-by: crclq2018@gmail.com
Signed-off-by: rigginschen <rigginschen@tencent.com>
Co-authored-by: rigginschen <rigginschen@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mp Buildkite trigger for multi-processing mode test

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants