Skip to content

[Bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine#2516

Merged
sammshen merged 9 commits intoLMCache:devfrom
liubj77:fix/crash_when_retrieve_fail
Mar 9, 2026
Merged

[Bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine#2516
sammshen merged 9 commits intoLMCache:devfrom
liubj77:fix/crash_when_retrieve_fail

Conversation

@liubj77
Copy link
Copy Markdown
Contributor

@liubj77 liubj77 commented Jan 30, 2026

Fixes #2294 #2327 #1732 #2204

@DongDongJu @sammshen Please take a look

What this PR does / why we need it:

In LMCacheConnectorV1Impl, if retrieval from the lmcache_engine fails, vLLM will call _handle_invalid_blocks to process the computed tokens of the request. However, this scenario is not currently handled by LMCacheConnectorV1Impl.

The crash stack show below:

(EngineCore_DP0 pid=85772)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=85772)   File "/usr/lib/python3.12/contextlib.py", line 144, in __exit__
(EngineCore_DP0 pid=85772)     next(self.gen)
(EngineCore_DP0 pid=85772)   File "/home/baojun.lbj/pycharm/vllm/vllm/vllm/v1/worker/kv_connector_model_runner_mixin.py", line 131, in _get_kv_connector_output
(EngineCore_DP0 pid=85772)     kv_connector.wait_for_save()
(EngineCore_DP0 pid=85772)   File "/home/baojun.lbj/pycharm/vllm/vllm/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py", line 187, in wait_for_save
(EngineCore_DP0 pid=85772)     self._lmcache_engine.wait_for_save()
(EngineCore_DP0 pid=85772)   File "/usr/local/lib/python3.12/dist-packages/lmcache/integration/vllm/vllm_v1_adapter.py", line 1105, in wait_for_save
(EngineCore_DP0 pid=85772)     assert len(slot_mapping) == len(token_ids)
(EngineCore_DP0 pid=85772)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=85772) AssertionError

It can be easily reproduce:

  • the token length of the request is greater than 1 block,

  • enforce schedule the request in multiple steps

  • get_num_new_matched_tokens return is greater than 0,

    • we can simply hack the code in get_num_new_matched_tokens
    # num_external_hit_tokens = self.lookup_client.lookup(
    #     token_ids,
    #     lookup_id=req_id,
    #     request_configs=request_configs,
    # )
    num_external_hit_tokens = 2
    
  • self.lmcache_engine.retrieve return fail in start_load_kv

    • just comment _process_tokens_internal in LMCacheEngine.retrieve to simulate retrieve fail

The vLLM start command like below:

LMCACHE_CHUNK_SIZE=1 vllm serve facebook/opt-125m --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' --chat-template=/home/baojun.lbj/work/chat-tmpl.jinja --max-num-batched-tokens=4 --no-async-scheduling
#  set LMCACHE_CHUNK_SIZE=1 
#  set max-num-batched-tokens=4, so the request can be easily scheduled in multiple steps

The prompt like below:

curl -X POST "http://localhost:8000/v1/completions" -H "Content-Type: application/json" -d '{
    "model": "facebook/opt-125m",
    "prompt": "Where is the capital of China? And tell me what the whether of Hangzhou, please as detail as possible",
    "temperature": 0.0,
    "best_of": 1,
    "max_tokens": 100
}'

The same crash will appear when the above request is received.

Special notes for your reviewers:

After this change, I compared the request.num_computed_tokens in self._unfinished_requests and len(self._request_trackers[req_id].token_ids), which are exactly the same, and there is no way to guarantee it before.

I have been running this modification online for more than 4 day until now. Before that, the same pressure tests would crash in up to 3 to 4 hours.

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @liubj77, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the LMCache integration with vLLM, specifically addressing a crash that occurs when the cache retrieval mechanism fails and vLLM rolls back a request's computed tokens. The changes ensure that the internal state of the LMCache connector remains consistent with vLLM's token tracking, thereby preventing system instability and improving the robustness of the caching layer.

Highlights

  • Crash Fix: Fixes a crash in LMCacheConnectorV1Impl.wait_for_save that occurred when lmcache_engine.retrieve failed, leading to an assertion error due to mismatched token counts.
  • State Synchronization: Synchronizes request_tracker.token_ids and request_tracker.num_saved_tokens with the actual num_current_tokens after vLLM rolls back a request, ensuring consistency between vLLM's internal state and the LMCache connector.
  • Token Slot Safeguard: Adds safeguards to truncate token_ids if they exceed the allocated block slots, preventing further inconsistencies and potential memory issues.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical crash that occurs when a retrieval from lmcache_engine fails, causing a desynchronization between vLLM's request state and LMCache's internal tracker. The fix correctly handles this token rollback scenario by truncating the tracker state to match vLLM's state. The implementation is robust, including a check for available token slots. I've included one suggestion to refactor the new logic for improved clarity and efficiency.

Comment thread lmcache/integration/vllm/vllm_v1_adapter.py Outdated
@liubj77 liubj77 force-pushed the fix/crash_when_retrieve_fail branch 7 times, most recently from 051cc4c to c831d95 Compare February 2, 2026 03:21
…ngine

Signed-off-by: liubj77 <liubj77@gmail.com>
@liubj77 liubj77 force-pushed the fix/crash_when_retrieve_fail branch from 105291c to ed36ab6 Compare February 2, 2026 08:34
@DongDongJu
Copy link
Copy Markdown
Collaborator

DongDongJu commented Feb 2, 2026

Hello @liubj77,
Thank you for your hard work!
Do you know when retrieve failed in vllm?

@hlin99
Copy link
Copy Markdown
Contributor

hlin99 commented Feb 3, 2026

Hello @hlin99, Thank you for your hard work! Do you know when retrieve failed in vllm?

hi @DongDongJu , you probably input comments in wrong context? this PR is not from me :)

ps, my pending PRs are #2467, #2478, #2509, please also help review. thanks! 👍

@hlin99
Copy link
Copy Markdown
Contributor

hlin99 commented Feb 3, 2026

btw, for retrieve fail... I'm really not aware such handing in vllm. maybe we're not running the exact version or cases.... so i can't tell more about this patch.

@DongDongJu
Copy link
Copy Markdown
Collaborator

Hello @hlin99, Thank you for your hard work! Do you know when retrieve failed in vllm?

hi @DongDongJu , you probably input comments in wrong context? this PR is not from me :)

ps, my pending PRs are #2467, #2478, #2509, please also help review. thanks! 👍

My bad. Sorry! I will take a look too.

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Feb 3, 2026

Hello @liubj77, Thank you for your hard work! Do you know when retrieve failed in vllm?

@DongDongJu I discovered this issue while using eic_connector to proxy requests to our backend. The hack code above is to reproduce this scenario.

  • In lmcache, if retrieve in start_load_kv fails, _invalid_block_ids will be updated.
  • vllm calls output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() in function KVConnectorModelRunnerMixin._get_kv_connector_output to obtain these invalid_block_ids for processing.
  • Then the processing function is _handle_invalid_blocks called by update_from_output in scheduler, which modifies request.num_computed_tokens, but lmcache does not handle this scenario.

This logic has been present in vllm from version 0.12 to 0.14; I haven't checked earlier versions.

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me.

When KV retrieval fails, vLLM can roll a request back by adjusting request.num_computed_tokens to the longest valid prefix (invalid-block handling).

Before this change, request_tracker.token_ids could remain ahead of num_computed_tokens, which matches the reported wait_for_save crash len(slot_mapping) != len(token_ids).

With this patch rollback num_current_tokens < len(token_ids) and truncate tracker state so token_ids stays aligned with vLLM’s rolled-back progress, preventing the assertion.

Nit: Should num_saved_tokens be clamped to tokens_to_keep to keep it <= len(token_ids) after truncation?

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Feb 3, 2026

This makes sense to me.

When KV retrieval fails, vLLM can roll a request back by adjusting request.num_computed_tokens to the longest valid prefix (invalid-block handling).

Before this change, request_tracker.token_ids could remain ahead of num_computed_tokens, which matches the reported wait_for_save crash len(slot_mapping) != len(token_ids).

With this patch rollback num_current_tokens < len(token_ids) and truncate tracker state so token_ids stays aligned with vLLM’s rolled-back progress, preventing the assertion.

Nit: Should num_saved_tokens be clamped to tokens_to_keep to keep it <= len(token_ids) after truncation?

@DongDongJu I don't think these two are the same concept. num_saved_tokens simply indicates the number of tokens already saved to lmcache. And both new requests and preempted requests will retrieve this information from the last lookup result, and this length will not exceed the length of num_computed_tokens.

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I indicated the line I mentioned earlier.

request.all_token_ids[:tokens_to_keep]
)
request_tracker.num_saved_tokens = min(
request_tracker.num_saved_tokens, num_current_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is here.
IIUC, vllm using num_saved_tokens as skip_leading_tokens.
please correct me if im wrong.
Can we change num_current_tokens to tokens_to_keep?

Copy link
Copy Markdown
Contributor Author

@liubj77 liubj77 Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reviewed the logic again, as you said, using tokens_to_keep instead of num_current_tokens would be more robust, and I've updated the pr.

When retrieval fails, the computed number of tokens must be less than both the previous num_saved_tokens and last_allocated_block_ids * block_size. Therefore, the condition num_token_slots < num_current_tokens will never be triggered in this case, here is just a double check.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. Thanks!

@DongDongJu
Copy link
Copy Markdown
Collaborator

please check the DCO

Signed-off-by: liubj77 <liubj77@gmail.com>
@liubj77 liubj77 force-pushed the fix/crash_when_retrieve_fail branch from 53ead90 to b709b6d Compare February 5, 2026 03:33
@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Feb 5, 2026

please check the DCO

@DongDongJu done

@ziruiliu
Copy link
Copy Markdown
Contributor

I applied this change to my test in #2294, this change does not fix the problem. The assertion disappeared but the next line failed:

(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/vllm_exp/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py", line 187, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     self._lmcache_engine.wait_for_save()
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/LMCache-0.3.12/lmcache/integration/vllm/vllm_v1_adapter.py", line 1123, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     slot_mapping = slot_mapping.to(self.device)
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982] torch.AcceleratorError: CUDA error: device-side assert triggered

I believe other lines should be also modified to completely fix this issue, like this

    def update(...)
    ...
            self.allocated_block_ids.extend(new_block_ids)
            self.token_ids.extend(new_token_ids)

I added logs here and found it still mistakenly extend self.token_ids

Hope this helps

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Feb 11, 2026

I applied this change to my test in #2294, this change does not fix the problem. The assertion disappeared but the next line failed:

(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/vllm_exp/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py", line 187, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     self._lmcache_engine.wait_for_save()
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/LMCache-0.3.12/lmcache/integration/vllm/vllm_v1_adapter.py", line 1123, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     slot_mapping = slot_mapping.to(self.device)
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982] torch.AcceleratorError: CUDA error: device-side assert triggered

I believe other lines should be also modified to completely fix this issue, like this

    def update(...)
    ...
            self.allocated_block_ids.extend(new_block_ids)
            self.token_ids.extend(new_token_ids)

I added logs here and found it still mistakenly extend self.token_ids

Hope this helps

Can you confirm that these are the same issue? This fixes the assert len(slot_mapping) == len(token_ids) error. allocated_block_ids cannot be guaranteed to align with token_ids because blocks allocated in the last vLLM allocation were not reclaimed, and slot_mapping is pruned based on len(token_ids).

@ziruiliu
Copy link
Copy Markdown
Contributor

I applied this change to my test in #2294, this change does not fix the problem. The assertion disappeared but the next line failed:

(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/vllm_exp/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py", line 187, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     self._lmcache_engine.wait_for_save()
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/LMCache-0.3.12/lmcache/integration/vllm/vllm_v1_adapter.py", line 1123, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     slot_mapping = slot_mapping.to(self.device)
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982] torch.AcceleratorError: CUDA error: device-side assert triggered

I believe other lines should be also modified to completely fix this issue, like this

    def update(...)
    ...
            self.allocated_block_ids.extend(new_block_ids)
            self.token_ids.extend(new_token_ids)

I added logs here and found it still mistakenly extend self.token_ids
Hope this helps

Can you confirm that these are the same issue? This fixes the assert len(slot_mapping) == len(token_ids) error. allocated_block_ids cannot be guaranteed to align with token_ids because blocks allocated in the last vLLM allocation were not reclaimed, and slot_mapping is pruned based on len(token_ids).

I guess so. Your fixs does get rid of assert len(slot_mapping) == len(token_ids) , but there might be something else remaining so the next line slot_mapping = slot_mapping.to(self.device) throws an exception in my test.
As you included in the begining, issue #2294 is similar to your issue that vllm crashes when lmcache found but failed to retrieve the cache. Both hit the same place: assert len(slot_mapping) == len(token_ids) because request is rescheduled but slot_mapping is not aligned with token_ids.
However, these 2 issues may have difference root causes. I did not have much time to dive deep into it.

@DongDongJu
Copy link
Copy Markdown
Collaborator

I applied this change to my test in #2294, this change does not fix the problem. The assertion disappeared but the next line failed:

(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/vllm_exp/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py", line 187, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     self._lmcache_engine.wait_for_save()
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]   File "/work/ziliu/LMCache-0.3.12/lmcache/integration/vllm/vllm_v1_adapter.py", line 1123, in wait_for_save
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]     slot_mapping = slot_mapping.to(self.device)
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2978907) ERROR 02-10 09:33:27 [core.py:982] torch.AcceleratorError: CUDA error: device-side assert triggered

I believe other lines should be also modified to completely fix this issue, like this

    def update(...)
    ...
            self.allocated_block_ids.extend(new_block_ids)
            self.token_ids.extend(new_token_ids)

I added logs here and found it still mistakenly extend self.token_ids
Hope this helps

Can you confirm that these are the same issue? This fixes the assert len(slot_mapping) == len(token_ids) error. allocated_block_ids cannot be guaranteed to align with token_ids because blocks allocated in the last vLLM allocation were not reclaimed, and slot_mapping is pruned based on len(token_ids).

I guess so. Your fixs does get rid of assert len(slot_mapping) == len(token_ids) , but there might be something else remaining so the next line slot_mapping = slot_mapping.to(self.device) throws an exception in my test. As you included in the begining, issue #2294 is similar to your issue that vllm crashes when lmcache found but failed to retrieve the cache. Both hit the same place: assert len(slot_mapping) == len(token_ids) because request is rescheduled but slot_mapping is not aligned with token_ids. However, these 2 issues may have difference root causes. I did not have much time to dive deep into it.

Hello @ziruiliu, Thanks for the call out. let me take a look more detail today.

@deng451e
Copy link
Copy Markdown
Collaborator

Thanks for the fix, @liubj77! 🙏
Just wanted to add a bit more context on the code flow behind this bug.

When retrieval fails for some blocks, lmcache reports invalid blocks and the vllm scheduler rolls back request.num_computed_tokens in Scheduler._update_requests_with_invalid_blocks, but the allocated block_ids aren’t evicted for request. However, lmcache does not roll back RequestTracker.token_ids accordingly.

Later, in lmache ReqMeta.from_request_tracker, slot_mapping is derived as:
slot_mapping = ( block_offsets.reshape((1, block_size)) + block_ids.reshape((num_blocks, 1)) * block_size ).flatten()[: len(token_ids)]
When:
Rounded up(len(token_ids) / block_size) > Rounded up(request.num_computed_tokens / block_size)
no new blocks are assigned from scheduler because the num_blocks are sufficient for request.num_computed_tokens, causing:
len(slot_mapping) < len(token_ids) and triggering the assertion failure.
With your fix, I think simply rolling back RequestsTracker.token_ids to match request.num_computed_tokens should already be sufficient as .
request_tracker.token_ids = list(request.all_token_ids[:num_current_tokens])
num_saved_tokens can probably be ignored for now since it mainly affects KV saving (lmcache already has the tokens but retrieval failed), and those cases could be handled separately — e.g., removing the failed keys or verifying their existence (please correct me if I’m mistaken).

@sammshen
Copy link
Copy Markdown
Contributor

@deng451e would you like to create a PR for the fix discussed offline?

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Feb 14, 2026

Thanks for the fix, @liubj77! 🙏 Just wanted to add a bit more context on the code flow behind this bug.

When retrieval fails for some blocks, lmcache reports invalid blocks and the vllm scheduler rolls back request.num_computed_tokens in Scheduler._update_requests_with_invalid_blocks, but the allocated block_ids aren’t evicted for request. However, lmcache does not roll back RequestTracker.token_ids accordingly.

Later, in lmache ReqMeta.from_request_tracker, slot_mapping is derived as: slot_mapping = ( block_offsets.reshape((1, block_size)) + block_ids.reshape((num_blocks, 1)) * block_size ).flatten()[: len(token_ids)] When: Rounded up(len(token_ids) / block_size) > Rounded up(request.num_computed_tokens / block_size) no new blocks are assigned from scheduler because the num_blocks are sufficient for request.num_computed_tokens, causing: len(slot_mapping) < len(token_ids) and triggering the assertion failure. With your fix, I think simply rolling back RequestsTracker.token_ids to match request.num_computed_tokens should already be sufficient as . request_tracker.token_ids = list(request.all_token_ids[:num_current_tokens]) num_saved_tokens can probably be ignored for now since it mainly affects KV saving (lmcache already has the tokens but retrieval failed), and those cases could be handled separately — e.g., removing the failed keys or verifying their existence (please correct me if I’m mistaken).

@deng451e Modifying only request_tracker.token_ids is fine, but I prefer to modify num_saved_tokens synchronously, as this is used to handle skip_leading_tokens.
If a previous retrieve failed, it's highly likely because the token no longer exists in the backend storage, resaving it is better. On the other hand, some people might wonder why num_saved_tokens is longer than token_ids when tracking logs, modifying num_saved_tokens synchronously is preferable.

new_token_ids,
new_block_ids,
preempted=preempted,
lmcache_cached_tokens=lmcache_cached_tokens,
Copy link
Copy Markdown
Collaborator

@deng451e deng451e Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also update load_spec.lmcache_cached_tokens here? Otherwise this change could overwrite num_saved_tokens when updating the request tracker.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary. The num_saved_tokens in RequestTracker is only updated in preempted scenarios, which is not the case here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see — LGTM

Comment thread lmcache/integration/vllm/vllm_v1_adapter.py
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! This is a great fix

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for revising. Good to go.

@DongDongJu
Copy link
Copy Markdown
Collaborator

@liubj77 lots of CI failed. PTAL

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Mar 6, 2026

@liubj77 lots of CI failed. PTAL

@DongDongJu The code formatting has already been handled. Is there anything else I need to do? The workflow requires approval to continue running, and links like buildkite/k3-comprehensive-test are showing "page not found".

@DongDongJu DongDongJu enabled auto-merge (squash) March 6, 2026 17:46
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 6, 2026
@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Mar 6, 2026

please do one more pre-commit @liubj77

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Mar 6, 2026

updating the branch to address the not found but the k3 tests are non-blocking for now

auto-merge was automatically disabled March 9, 2026 02:47

Head branch was pushed to by a user without write access

@github-actions github-actions Bot removed the full Run comprehensive tests on this PR label Mar 9, 2026
@yanok
Copy link
Copy Markdown
Contributor

yanok commented Mar 9, 2026

Thanks! That's solves the issues with unreliable backends for us indeed.

@sammshen sammshen enabled auto-merge (squash) March 9, 2026 17:10
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 9, 2026
@sammshen sammshen merged commit 7f2df61 into LMCache:dev Mar 9, 2026
26 of 29 checks passed
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
…ngine (LMCache#2516)

* [bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine

Signed-off-by: liubj77 <liubj77@gmail.com>

* update

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

---------

Signed-off-by: liubj77 <liubj77@gmail.com>
Co-authored-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
…ngine (LMCache#2516)

* [bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine

Signed-off-by: liubj77 <liubj77@gmail.com>

* update

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

---------

Signed-off-by: liubj77 <liubj77@gmail.com>
Co-authored-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…ngine (LMCache#2516)

* [bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine

Signed-off-by: liubj77 <liubj77@gmail.com>

* update

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

---------

Signed-off-by: liubj77 <liubj77@gmail.com>
Co-authored-by: Samuel Shen <slshen@tensormesh.ai>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…ngine (LMCache#2516)

* [bugfix] fix crash in wait_for_save when retrieve fail from lmcache_engine

Signed-off-by: liubj77 <liubj77@gmail.com>

* update

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

* format code

Signed-off-by: liubj77 <liubj77@gmail.com>

---------

Signed-off-by: liubj77 <liubj77@gmail.com>
Co-authored-by: Samuel Shen <slshen@tensormesh.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] vllm crashes if retrieve() fails

7 participants