Skip to content

[Bugfix] fix some memory leak in cache_engine and eic connector#2544

Merged
sammshen merged 11 commits intoLMCache:devfrom
liubj77:fix/fix_some_memoy_leak
Apr 16, 2026
Merged

[Bugfix] fix some memory leak in cache_engine and eic connector#2544
sammshen merged 11 commits intoLMCache:devfrom
liubj77:fix/fix_some_memoy_leak

Conversation

@liubj77
Copy link
Copy Markdown
Contributor

@liubj77 liubj77 commented Feb 5, 2026

Fixes #2017

@DongDongJu @sammshen @ApostaC Please take a look.

This might be related to this bug, because although I didn't use RedisConnector, I looked at the code and it seems that RedisConnector would also have the same memory leak.

After running for a long period of time, this log will appear.

No eviction candidates found in local cpu backend

What this PR does / why we need it:

When async_loading is not used, memory leaks can occur due to differences in the implementation of some RemoteConnector classes, such as whether they support batched_get, and differences in the implementation of batched_get itself.

For example:
RedisConnector does not support batched_get support_batched_get=False, it will call the below logic in batched_get_blocking, which queries multiple keys simultaneously. This can lead to a problem: the previous key may fail, but the next key will succeed.

futures = [
      asyncio.run_coroutine_threadsafe(self.connection.get(key), self.loop)
      for key in keys
  ]

Similarly, although EicConnector supports batched_get support_batched_get=True, but its batched_get implementation also queries each key independently in parallel, leading to the same problem as above.

However, function _process_tokens_internal in LMCacheEngine does not handle this scenario, it only performs break and truncate, as shown in the code below, this will lead to memory leak of memory_obj.

for (key, start, end), memory_obj in zip(blocks, memory_objs, strict=False):
      if memory_obj is None:
          logger.warning(
              "The cache block is in the storage, but it can't be retrieved"
          )
          if (
              last_failed_block_start is None
              or last_failed_block_start < start
          ):
              last_failed_block_start = start
          break

Special notes for your reviewers:

  1. verify

As shown in the PR's new code, I added logging for unreleased memory. After running it in the test environment for a long time, I confirmed that some memory_obj instances are not being released.

image image
  1. other leak

This PR also fixes the memory_obj leak in EICConnector when get_data fails.

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

Note

Medium Risk
Touches retrieval control-flow and reference counting, so mistakes could cause premature frees or missing cache hits; changes are bounded and covered by new unit tests for failure/truncation cases.

Overview
Fixes a KV-cache memory leak during retrieval when batched_get returns non-contiguous results (e.g., a later key succeeds after an earlier key fails).

LMCacheEngine._process_tokens_internal now (1) tracks the earliest failed block across locations, (2) immediately ref_count_down()s any successfully-fetched but unused blocks after a per-location failure, and (3) truncates returned chunks using end <= last_failed_block_start while freeing any already-added chunks that extend past the gap.

EICConnector.get_data now guarantees allocated memory_objs are released via ref_count_down() on both error status codes and unexpected exceptions during mget.

Adds focused unit tests covering boundary truncation, multi-location failure ordering, freeing of unused successful blocks, and avoiding double-free scenarios.

Reviewed by Cursor Bugbot for commit 05c4043. Bugbot is set up for automated code reviews on this repo. Configure here.

Signed-off-by: liubj77 <liubj77@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @liubj77, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical memory leaks identified in the cache engine and EIC connector, particularly when async_loading is not utilized. The leaks stem from RemoteConnector implementations, where batched_get operations might query keys independently, leading to MemoryObject instances not being correctly released upon partial failures or truncation. The changes introduce explicit reference count decrements for these unreleased memory objects, thereby preventing memory accumulation and improving system stability.

Highlights

  • LMCacheEngine Memory Leak Fix: Corrected memory object reference counting in LMCacheEngine's _process_tokens_internal method to prevent leaks. This addresses scenarios where batched_get operations partially fail or blocks are truncated, ensuring memory_obj instances are properly released.
  • EicConnector Memory Leak Fix: Ensured MemoryObject instances are properly released in EicConnector when get_data operations fail by adding a ref_count_down() call in the error handling path.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/v1/cache_engine.py
    • Introduced a used_keys set to track successfully processed cache blocks within _process_tokens_internal.
    • Added logic to decrement ref_count for memory_obj instances that were retrieved but not ultimately used because a preceding key failed.
    • Implemented additional logic to decrement ref_count for memory_obj instances that are truncated due to last_failed_block_start.
  • lmcache/v1/storage_backend/connector/eic_connector.py
    • Added memory_obj.ref_count_down() call within the get_data method's error handling path to ensure memory release when data retrieval fails.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses two memory leaks. The first occurs in the cache_engine when a batched retrieval has partial failures, leading to unused memory objects not being released. The second is in the eic_connector, where a memory object was not released upon a data fetching error. The fixes are correct and crucial for system stability. I have one suggestion for a minor refactoring in cache_engine.py to improve efficiency.

@sammshen sammshen requested a review from deng451e February 6, 2026 07:29
Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liubj77 Left a comment.

f"eic mget data {key_str} failed, status_code {status_code}"
" err_code {err_code}"
)
memory_obj.ref_count_down()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the exception path? In another word, memory_obj should ref_count_down although exception occured.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the exception path? In another word, memory_obj should ref_count_down although exception occured.

According to the code design, EIC returns values ​​based on the code and will not encounter any exceptions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's ok. There is another trick, we can surround the code after self.memory_allocator.allocate and catch all exception, then, we can ensure whatever happened, we can call ref_count_down. The memory object will never leak.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, possible exceptions have been handled in try-catch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maobaolong Please review this again and see if it's acceptable to merge.

Signed-off-by: liubj77 <liubj77@gmail.com>
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Mar 18, 2026

@maobaolong @sammshen @deng451e Could you please review this and see if it's acceptable to merge.

@deng451e
Copy link
Copy Markdown
Collaborator

LGTM

Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to take another look sorry

ret_mask[start:end] = True
used_keys.add(key)

for (key, _, _), memory_obj in zip(blocks, memory_objs, strict=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question:

why do we need to do this check twice?

@deng451e has a great question which is: how do you know there is no overlap bewteen these two sets (complement of used_keys and end >= last_failed_block_start)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question:

why do we need to do this check twice?

@deng451e has a great question which is: how do you know there is no overlap bewteen these two sets (complement of used_keys and end >= last_failed_block_start)

@sammshen I've modified some logic and added unit tests. please take a look again.

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I am not that familiar with eic connector.
So, just left one my comment from code reading.

Comment thread lmcache/v1/cache_engine.py
Signed-off-by: liubj77 <liubj77@gmail.com>
@liubj77 liubj77 force-pushed the fix/fix_some_memoy_leak branch from bc85af0 to 57437f7 Compare March 23, 2026 02:48
@liubj77
Copy link
Copy Markdown
Contributor Author

liubj77 commented Apr 2, 2026

LGTM, but I am not that familiar with eic connector. So, just left one my comment from code reading.

@DongDongJu @sammshen @maobaolong @deng451e Can this PR be merged? And the EIC modify didn't change the original logic, it just releases memory on failure.

I've tested it for about two weeks, and with this patch, lmcache memory leaks have decreased significantly. Of course, there are some other memory leaks that haven't been discovered yet.

image

Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sammshen sammshen enabled auto-merge (squash) April 15, 2026 01:26
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Apr 15, 2026
@liubj77 liubj77 requested a review from YaoJiayi as a code owner April 16, 2026 02:16
@sammshen sammshen merged commit ab3a30c into LMCache:dev Apr 16, 2026
37 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] Local CPU memory leaking when using remote connector

5 participants