Skip to content

[RFC] Parallel get to improve TTFT a bit#863

Closed
da-x wants to merge 1 commit intoLMCache:devfrom
da-x:parallel-blocking-get-pr
Closed

[RFC] Parallel get to improve TTFT a bit#863
da-x wants to merge 1 commit intoLMCache:devfrom
da-x:parallel-blocking-get-pr

Conversation

@da-x
Copy link
Copy Markdown
Contributor

@da-x da-x commented Jun 18, 2025

We saw a 10% improvement in TTFT by allowing blocking_get to be served from threadpool.

This adds the ability to enable this threadpool and to specify the size of the threadpool in the amount of threads.

I'm not expecting this to be merged as-is... I am putting this here more to open a discussion.

Some points:

  • I am aware there's the LayerwiseLMCacheEngine which is using the async APIs for fetching KV cache, so adding a thread pool may seem somewhat redundant. However after some experimentation I see it is either not fully stable or does not improve latency over original LMCacheEngine. Even after LayerwiseLMCacheEngine would work well, I think this change still has value, because it would allow to do a more robust comparison of the performance difference of the the two engines.
  • We likely need to find a better name than the too long parallel_blocking_get_thread_count. Please suggest.

@guymguym
Copy link
Copy Markdown
Contributor

👍 I/O concurrency is extremely valuable to maximize the throughput of any storage because it hides roundtrip delays, and I agree this is notably missing from LMCache at the moment.

For filesystem I/O - aiofiles already uses a threadpool inside (used by LocalDiskBackend and FSConnector), so if we could make sure that the engine (both layerwise and blockwise) will use async this will be great.

For network I/O - it is typically considered most efficient to use asyncio - however this has a hidden throughput limit because a single thread running an async event loop will eventually be limited by ~3GB/s speed due to the overhead of actual data copying to/from the socket. In some cases, this can be sufficient, because there are multiple separate LMCache instances and each one will add another ~3GB to the total throughput, but when each GPU requires more KV throughput, a single thread will limit the ability to saturate the GPU.

For which backend did you see TTFT improvement?

@da-x
Copy link
Copy Markdown
Contributor Author

da-x commented Jun 19, 2025

@guymguym I tested GdsBackend in cuFile POSIX compat mode with use_direct_io=True.

@YaoJiayi
Copy link
Copy Markdown
Collaborator

@da-x Could you fix the code quality check first?

@xiaguan
Copy link
Copy Markdown
Contributor

xiaguan commented Jun 20, 2025

For remote connectors, we’d love to have something like batch_exist, batch_get, and batch_put so we can fully leverage the performance of the underlying C++ network and storage stack.

@da-x da-x force-pushed the parallel-blocking-get-pr branch from 9667428 to 704e558 Compare June 22, 2025 12:34
@YaoJiayi
Copy link
Copy Markdown
Collaborator

YaoJiayi commented Jun 22, 2025

@guymguym @xiaguan
I've been working on exposing the "batch" to the underlying backends so that the backend can operate on the batch more efficiently.
For examples:
https://github.com/LMCache/LMCache/blob/dev/lmcache/v1/cache_engine.py#L207
https://github.com/LMCache/LMCache/blob/dev/lmcache/v1/storage_backend/storage_manager.py#L186

@guymguym
Copy link
Copy Markdown
Contributor

@YaoJiayi @xiaguan Using also batched_get/exists would be great and allow the backends optimize with async/threads.

What will be the actual batch size on get - assuming no chunked prefill, will the batch be the full list of input tokens or smaller?

Comment thread lmcache/v1/cache_engine.py Outdated
Comment on lines +285 to +287
get_thread_pool(
self.config.parallel_blocking_get_thread_count
).submit(self.storage_manager.get, key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@da-x I wonder if it's needed to separate this threadpool from the default executor that asyncio/aiofiles already initialized lazy? see these docs -

awaitable loop.run_in_executor(executor, func, *args)

The executor argument should be an concurrent.futures.Executor instance. The default executor is used if executor is None. The default executor can be set by loop.set_default_executor(), otherwise, a concurrent.futures.ThreadPoolExecutor will be lazy-initialized and used by run_in_executor() if needed.

class concurrent.futures.ThreadPoolExecutor(max_workers=None, thread_name_prefix='', initializer=None, initargs=())

Changed in version 3.13: Default value of max_workers is changed to min(32, (os.process_cpu_count() or 1) + 4).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might still want to add a general config option to override the default executor number of workers...

Copy link
Copy Markdown
Contributor Author

@da-x da-x Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@da-x I wonder if it's needed to separate this threadpool from the default executor that asyncio/aiofiles already initialized lazy? see these docs -

@guymguym

  • The threads created for the threadpool are not connected to the current default loop. They have no loop, unless they decide to create it or bind it specifically. See example program.
  • If 'blocking get' would use an existing loop to which we have an handle, say, self.loop, then we would have an issue here, but I think self.loop is only being used for put and prefetch. You can check this too.
  • To avoid adding another config option, maybe we can allow "default" value for using the Python default instead of passing a number, in such that the type is string | number | None?
import asyncio
from concurrent.futures import ThreadPoolExecutor


async def threaded_f():
    loop = asyncio.get_running_loop()
    # This is a different loop than in `main`.
    print("loop in threaded_f", id(loop))


def has_no_async_loop():
    # NOTE:
    #     asyncio.get_running_loop()
    # Throws: Throws RuntimeError: no running event loop
    asyncio.run(threaded_f())


async def main():
    loop = asyncio.get_running_loop()
    print("loop in main", id(loop))

    futures = []
    tp = ThreadPoolExecutor(max_workers=1)
    for x in range(1):
        futures.append(tp.submit(has_no_async_loop))
    _results = [f.result() for f in futures]


if __name__ == '__main__':
    asyncio.run(main())

@YaoJiayi
Copy link
Copy Markdown
Collaborator

@da-x Could you resolve the conflict a bit?

@chunxiaozheng
Copy link
Copy Markdown
Collaborator

hi, @da-x, I test this PR with FSConnector, run about 220 prompts, without this PR, the ttft is 2.45s, and with this PR, the ttft is 1.12s. This is a very good improvement! But I'm not sure if there will be such a good optimization effect in high concurrency situations.

@da-x
Copy link
Copy Markdown
Contributor Author

da-x commented Jul 28, 2025

There is a new batched_get_blocking backend API that each backend can implement, and I am considering moving this to GdsBackend.

@chunxiaozheng what backend were you using?

@da-x da-x force-pushed the parallel-blocking-get-pr branch 2 times, most recently from c2a841e to 8b6b0f8 Compare August 25, 2025 22:10
We saw a 10% improvement in TTFT by allowing `blocking_get` to be served
from threadpool.

This adds the ability to enable this threadpool and to specify the size
of the threadpool in threads.

Signed-off-by: Dan Aloni <dan@kernelim.com>
@da-x da-x force-pushed the parallel-blocking-get-pr branch from 8b6b0f8 to a70cc96 Compare August 25, 2025 22:16
@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had activity within 60 days. It will be automatically closed if no further activity occurs within 30 days.

@github-actions github-actions Bot added the stale label Oct 25, 2025
@github-actions
Copy link
Copy Markdown

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it!

@github-actions github-actions Bot closed this Nov 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants