Skip to content

Commit 2a90190

Browse files
glimchbDevin
andcommitted
feat(gds): add gds_path_sharding config for multi-path strategy
Add a top-level `gds_path_sharding` config field (default: "by_gpu") that controls how GPUs are assigned to storage paths when multiple comma-separated paths are provided in `gds_path`. This replaces the previously hardcoded by_gpu logic with an explicit, extensible setting. Currently only "by_gpu" is supported (selects path via `device_id % num_paths`); unsupported values raise AssertionError. Generated with [Devin](https://cli.devin.ai/docs) Co-Authored-By: Devin <noreply@cognition.ai> Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
1 parent 996da47 commit 2a90190

5 files changed

Lines changed: 62 additions & 9 deletions

File tree

docs/source/api_reference/configurations.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,10 @@ Settings for different storage backends and paths.
341341
- Description
342342
* - gds_path
343343
- LMCACHE_GDS_PATH
344-
- Path for GDS backend. Supports comma-separated paths for multi-device I/O (e.g. ``/mnt/nvme0/cache,/mnt/nvme1/cache``). Each GPU selects a path via ``device_id % num_paths``.
344+
- Path for GDS backend. Supports comma-separated paths for multi-device I/O (e.g. ``/mnt/nvme0/cache,/mnt/nvme1/cache``). See ``gds_path_sharding`` for how paths are assigned to GPUs.
345+
* - gds_path_sharding
346+
- LMCACHE_GDS_PATH_SHARDING
347+
- Strategy for selecting a path when multiple paths are provided. Currently only ``"by_gpu"`` is supported, which selects paths based on GPU device ID (default: "by_gpu").
345348
* - cufile_buffer_size
346349
- LMCACHE_CUFILE_BUFFER_SIZE
347350
- Buffer size for cuFile/hipFile operations

docs/source/kv_cache/storage_backends/gds.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ Multi-Path (Multi-Device) Support
5252
---------------------------------
5353

5454
When a system has multiple NVMe drives, you can distribute GDS I/O across them
55-
by specifying a comma-separated list of paths in ``gds_path``. Each GPU worker
56-
automatically selects one path based on its device index (``device_id % num_paths``),
57-
so traffic is spread evenly across the drives without any manual pinning.
55+
by specifying a comma-separated list of paths in ``gds_path``. The
56+
``gds_path_sharding`` option controls how each GPU worker selects its path.
57+
Currently only ``"by_gpu"`` is supported (the default), which selects a path
58+
based on the device index (``device_id % num_paths``), so traffic is spread
59+
evenly across the drives without any manual pinning.
5860

5961
**Why this helps:** a single PCIe Gen 4 x4 NVMe tops out at ~7 GB/s. With four
6062
drives the aggregate bandwidth can reach ~28 GB/s, matching what multi-GPU
@@ -65,12 +67,14 @@ systems need for KV cache eviction and prefetch.
6567
.. code-block:: bash
6668
6769
export LMCACHE_GDS_PATH="/mnt/nvme0/cache,/mnt/nvme1/cache,/mnt/nvme2/cache,/mnt/nvme3/cache"
70+
export LMCACHE_GDS_PATH_SHARDING="by_gpu"
6871
6972
**YAML config:**
7073

7174
.. code-block:: yaml
7275
7376
gds_path: "/mnt/nvme0/cache,/mnt/nvme1/cache,/mnt/nvme2/cache,/mnt/nvme3/cache"
77+
gds_path_sharding: "by_gpu"
7478
7579
With the above configuration on a 4-GPU node:
7680

lmcache/v1/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@
231231
},
232232
# Storage paths
233233
"gds_path": {"type": Optional[str], "default": None, "env_converter": str},
234+
"gds_path_sharding": {
235+
"type": str,
236+
"default": "by_gpu",
237+
"env_converter": str,
238+
},
234239
"cufile_buffer_size": {
235240
"type": Optional[int],
236241
"default": None,

lmcache/v1/storage_backend/gds_backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,15 @@ def __init__(
202202
assert config.gds_path is not None, "Need to specify gds_path for GdsBackend"
203203

204204
# Multi-path support: parse comma-separated paths and select one
205-
# based on GPU device ID (by_gpu sharding, like NIXL PR #2418).
205+
# based on the configured sharding strategy.
206206
self.gds_paths = [p.strip() for p in config.gds_path.split(",") if p.strip()]
207207
assert len(self.gds_paths) > 0, "gds_path cannot be empty"
208208

209-
# TODO: next patch we can add additional sharding strategies
209+
self.gds_path_sharding = config.gds_path_sharding
210+
assert self.gds_path_sharding == "by_gpu", (
211+
f"Unsupported gds_path_sharding '{self.gds_path_sharding}'. "
212+
"Only 'by_gpu' is supported currently."
213+
)
210214
self.gds_path = self.gds_paths[device_id % len(self.gds_paths)]
211215
self.fstype = get_fstype(self.gds_path)
212216

tests/v1/storage_backend/test_gds_backend.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
from tests.v1.utils import create_test_memory_obj, has_cufile, has_hipfile
3131

3232

33-
def create_test_config(gds_path: str):
33+
def create_test_config(gds_path: str, gds_path_sharding: str = "by_gpu"):
3434
config = LMCacheEngineConfig.from_defaults(
3535
chunk_size=256,
3636
gds_path=gds_path,
37+
gds_path_sharding=gds_path_sharding,
3738
lmcache_instance_id="test_instance",
3839
cufile_buffer_size=256,
3940
extra_config={"use_direct_io": True},
@@ -492,7 +493,12 @@ class TestGdsMultiPath:
492493
"""
493494

494495
@staticmethod
495-
def _make_backend(gds_path: str, dst_device: str, async_loop):
496+
def _make_backend(
497+
gds_path: str,
498+
dst_device: str,
499+
async_loop,
500+
gds_path_sharding: str = "by_gpu",
501+
):
496502
"""Create a GdsBackend with mocked allocator and fstype.
497503
498504
Mocks are used so the tests run without cuFile / real NVMe.
@@ -505,7 +511,7 @@ def __init__(self):
505511
def close(self):
506512
pass
507513

508-
config = create_test_config(gds_path)
514+
config = create_test_config(gds_path, gds_path_sharding=gds_path_sharding)
509515
metadata = create_test_metadata()
510516
with (
511517
mock.patch(
@@ -755,3 +761,34 @@ def test_try_to_read_metadata_finds_across_all_paths(self, async_loop):
755761
finally:
756762
for p in paths:
757763
shutil.rmtree(p, ignore_errors=True)
764+
765+
def test_gds_path_sharding_default(self, temp_gds_path, async_loop):
766+
"""Default gds_path_sharding is 'by_gpu'."""
767+
backend = self._make_backend(temp_gds_path, "cuda:0", async_loop)
768+
try:
769+
assert backend.gds_path_sharding == "by_gpu"
770+
finally:
771+
backend.close()
772+
773+
def test_gds_path_sharding_explicit_by_gpu(self, temp_gds_path, async_loop):
774+
"""Explicitly setting gds_path_sharding='by_gpu' works."""
775+
backend = self._make_backend(
776+
temp_gds_path,
777+
"cuda:0",
778+
async_loop,
779+
gds_path_sharding="by_gpu",
780+
)
781+
try:
782+
assert backend.gds_path_sharding == "by_gpu"
783+
finally:
784+
backend.close()
785+
786+
def test_gds_path_sharding_unsupported_raises(self, temp_gds_path, async_loop):
787+
"""Unsupported gds_path_sharding value raises AssertionError."""
788+
with pytest.raises(AssertionError, match="Unsupported gds_path_sharding"):
789+
self._make_backend(
790+
temp_gds_path,
791+
"cuda:0",
792+
async_loop,
793+
gds_path_sharding="round_robin",
794+
)

0 commit comments

Comments
 (0)