Skip to content

feat(gds): add multipath KV-cache offloading support#2817

Merged
DongDongJu merged 1 commit intoLMCache:devfrom
glimchb:gds_multipath_support
Mar 27, 2026
Merged

feat(gds): add multipath KV-cache offloading support#2817
DongDongJu merged 1 commit intoLMCache:devfrom
glimchb:gds_multipath_support

Conversation

@glimchb
Copy link
Copy Markdown
Contributor

@glimchb glimchb commented Mar 18, 2026

Parse comma-separated gds_path and select path per GPU worker using by_gpu sharding (device_id % num_paths), matching the approach in NIXL PR #2418. This distributes KV cache I/O across multiple NVMe drives to increase aggregate bandwidth.

Single-path configs work unchanged.

Includes tests (TestGdsMultiPath) and documentation updates.

What this PR does / why we need it:

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces multi-path support for the GDS backend, allowing users to distribute KV cache I/O across multiple NVMe drives. This is achieved by parsing comma-separated paths and selecting a path per GPU worker based on its device ID. The changes include code modifications to handle multiple paths, documentation updates, and new tests to ensure the functionality works as expected.

Highlights

  • Multi-Path GDS Support: Introduces the ability to distribute GDS I/O across multiple NVMe drives by specifying a comma-separated list of paths in the gds_path configuration. Each GPU worker selects a path based on its device index, improving aggregate bandwidth.
  • GPU Affinity: Implements GPU-to-NVMe affinity using device_id % num_paths to ensure traffic is spread evenly across available drives without manual pinning.
  • Testing and Documentation: Includes new tests (TestGdsMultiPath) to verify the multi-path GDS backend support, along with updates to the documentation to explain the new functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully implements multi-path KV-cache offloading support for GDS, enabling distribution of I/O across multiple NVMe drives. The logic for parsing comma-separated paths and selecting a path per GPU worker using device_id % num_paths is correctly implemented. Documentation has been updated to reflect this new feature, and a comprehensive set of unit tests ensures the correctness and backward compatibility of the multi-path functionality, including path parsing, GPU affinity, and directory creation. The changes are robust and well-tested.

@DongDongJu DongDongJu self-requested a review March 19, 2026 13:03
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sammshen sammshen requested a review from deng451e March 20, 2026 01:56
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @glimchb, Thanks for the contributions.

In distributed workloads especially when CUDA_VISIBLE_DEVICES is set, this index based can be makes the unbalanced distribution in global manner.
Maybe UUID or config-based affinity is better idea for future?

And could you modify the read path can support this one too?
I will take a look one by one your patch series.
Lets move on.

Comment thread lmcache/v1/storage_backend/gds_backend.py Outdated
Comment thread lmcache/v1/storage_backend/gds_backend.py Outdated
@glimchb glimchb force-pushed the gds_multipath_support branch from 0a0761e to ec785ff Compare March 20, 2026 23:25
@glimchb glimchb requested a review from DongDongJu March 20, 2026 23:27
@glimchb
Copy link
Copy Markdown
Contributor Author

glimchb commented Mar 20, 2026

Hello @glimchb, Thanks for the contributions.

In distributed workloads especially when CUDA_VISIBLE_DEVICES is set, this index based can be makes the unbalanced distribution in global manner. Maybe UUID or config-based affinity is better idea for future?

And could you modify the read path can support this one too? I will take a look one by one your patch series. Lets move on.

thanks @DongDongJu. I added comments, read and write works and tested. I added a bit more clarity.

Regarding other sharding/affinity strategies, that is exactly our next patch after this one is going in.
add config-based affinity or UUID or other techniques... This is the basis for it without complications. Next will be small diff.

@glimchb glimchb force-pushed the gds_multipath_support branch 3 times, most recently from 46c53e6 to 88c19de Compare March 21, 2026 01:00
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM.

Comment thread docs/source/kv_cache/storage_backends/gds.rst
@DongDongJu DongDongJu enabled auto-merge (squash) March 21, 2026 14:03
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 21, 2026
auto-merge was automatically disabled March 25, 2026 11:57

Head branch was pushed to by a user without write access

@glimchb glimchb force-pushed the gds_multipath_support branch from 54c5d91 to 9e73ce7 Compare March 25, 2026 11:57
@DongDongJu DongDongJu enabled auto-merge (squash) March 25, 2026 23:29
auto-merge was automatically disabled March 26, 2026 00:39

Head branch was pushed to by a user without write access

@glimchb glimchb force-pushed the gds_multipath_support branch from 35c6441 to 5435f7c Compare March 26, 2026 00:39
@glimchb
Copy link
Copy Markdown
Contributor Author

glimchb commented Mar 26, 2026

i ran uvx pre-commit run --all-files now

@deng451e deng451e enabled auto-merge (squash) March 26, 2026 00:42
@github-actions github-actions Bot added full Run comprehensive tests on this PR and removed full Run comprehensive tests on this PR labels Mar 26, 2026
auto-merge was automatically disabled March 26, 2026 14:40

Head branch was pushed to by a user without write access

@glimchb glimchb force-pushed the gds_multipath_support branch from 5435f7c to 6b940a8 Compare March 26, 2026 14:40
@github-actions github-actions Bot removed the full Run comprehensive tests on this PR label Mar 26, 2026
@DongDongJu DongDongJu enabled auto-merge (squash) March 26, 2026 16:52
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 26, 2026
auto-merge was automatically disabled March 26, 2026 17:00

Head branch was pushed to by a user without write access

@glimchb glimchb force-pushed the gds_multipath_support branch from 6b940a8 to 524ea01 Compare March 26, 2026 17:00
@github-actions github-actions Bot removed the full Run comprehensive tests on this PR label Mar 26, 2026
Parse comma-separated gds_path and select path per GPU worker
using by_gpu sharding (device_id % num_paths), matching the
approach in NIXL PR LMCache#2418. This distributes KV cache I/O across
multiple NVMe drives to increase aggregate bandwidth.

Single-path configs work unchanged.

Includes tests (TestGdsMultiPath) and documentation updates.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
@glimchb glimchb force-pushed the gds_multipath_support branch from 524ea01 to 543a684 Compare March 26, 2026 20:11
@DongDongJu DongDongJu enabled auto-merge (squash) March 27, 2026 05:31
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 27, 2026
@DongDongJu DongDongJu merged commit 4ba991e into LMCache:dev Mar 27, 2026
34 checks passed
@glimchb glimchb deleted the gds_multipath_support branch March 27, 2026 10:44
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
Parse comma-separated gds_path and select path per GPU worker
using by_gpu sharding (device_id % num_paths), matching the
approach in NIXL PR LMCache#2418. This distributes KV cache I/O across
multiple NVMe drives to increase aggregate bandwidth.

Single-path configs work unchanged.

Includes tests (TestGdsMultiPath) and documentation updates.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
Parse comma-separated gds_path and select path per GPU worker
using by_gpu sharding (device_id % num_paths), matching the
approach in NIXL PR LMCache#2418. This distributes KV cache I/O across
multiple NVMe drives to increase aggregate bandwidth.

Single-path configs work unchanged.

Includes tests (TestGdsMultiPath) and documentation updates.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
glimchb added a commit to glimchb/LMCache that referenced this pull request Apr 2, 2026
…finity

Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via device_id % num_paths,
matching the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- file:// prefix stripped per-part in the backend (no config parser change)
- All directories created at startup
- Added tests and updated docs

Before this change the only way to increase perfomance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
glimchb added a commit to glimchb/LMCache that referenced this pull request Apr 2, 2026
Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
glimchb added a commit to glimchb/LMCache that referenced this pull request Apr 6, 2026
Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
glimchb added a commit to glimchb/LMCache that referenced this pull request Apr 7, 2026
Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
glimchb added a commit to glimchb/LMCache that referenced this pull request Apr 7, 2026
Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
DongDongJu pushed a commit that referenced this pull request Apr 8, 2026
#2801)

feat(disk): support multi-path local disk backend with path sharding

Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach #2817 and NIXL approach #2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
Oasis-Git pushed a commit to Oasis-Git/LMCache that referenced this pull request Apr 13, 2026
LMCache#2801)

feat(disk): support multi-path local disk backend with path sharding

Allow `local_disk` to accept comma-separated paths (e.g.
"/mnt/nvme0/,/mnt/nvme1/") to use multiple NVMe devices.  Each GPU
worker selects one path at init time via the `local_disk_path_sharding`
strategy (currently only "by_gpu": device_id % num_paths), matching
the GDS backend approach LMCache#2817 and NIXL approach LMCache#2418.

- Path selected once in __init__; _key_to_path, write_file, read_file
  unchanged from upstream
- _parse_local_disk now uses startswith("file://") instead of regex,
  fixing file:// URIs without a trailing slash
- All directories created at startup
- Added local_disk_path_sharding config field (default: "by_gpu")
- Added tests and updated docs

Before this change the only way to increase performance was to use
any of the linux multi-pathing technologies to aggregate IOs.

Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants