Skip to content

Commit e76aa46

Browse files
DongDongJuidevasena
authored andcommitted
fix(raw-block): harden metadata checkpoint recovery
Signed-off-by: DongDongJu <commisori28@gmail.com> Signed-off-by: Dongjoo Seo <dongjoo.seo1@samsung.com>
1 parent 5a1d331 commit e76aa46

2 files changed

Lines changed: 160 additions & 18 deletions

File tree

lmcache/v1/storage_backend/plugins/rust_raw_block_backend.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
self.meta_verify_on_load: bool = bool(
149149
extra.get("rust_raw_block.meta_verify_on_load", True)
150150
)
151+
self._meta_copy_count: int = 2
151152

152153
get_full_chunk_size_bytes = getattr(
153154
self.local_cpu_backend, "get_full_chunk_size_bytes", None
@@ -192,6 +193,14 @@ def __init__(
192193
"rust_raw_block.meta_total_bytes must be > block_align "
193194
"(room for metadata header + payload)"
194195
)
196+
self._meta_container_bytes: int = (
197+
(self.meta_total_bytes // self._meta_copy_count) // self.block_align
198+
) * self.block_align
199+
if self._meta_container_bytes <= self.block_align:
200+
raise ValueError(
201+
"rust_raw_block.meta_total_bytes must provide room for at least "
202+
"two metadata copies (header + payload)"
203+
)
195204

196205
self._lock = threading.Lock()
197206
self._index: dict[CacheEngineKey, _Entry] = {}
@@ -789,7 +798,12 @@ def _checkpoint_loop(self) -> None:
789798
logger.warning(f"Periodic metadata checkpoint failed: {e}")
790799

791800
def _meta_payload_capacity(self) -> int:
792-
return self.meta_total_bytes - self.block_align
801+
return self._meta_container_bytes - self.block_align
802+
803+
def _meta_container_offsets(self) -> list[int]:
804+
return [
805+
idx * self._meta_container_bytes for idx in range(self._meta_copy_count)
806+
]
793807

794808
def _read_meta_header(self, container_offset: int) -> Optional[dict[str, int]]:
795809
raw = self._rawdev()
@@ -832,6 +846,23 @@ def _load_meta_payload(self, header: dict[str, int]) -> Optional[bytes]:
832846
return None
833847
return payload
834848

849+
def _select_latest_checkpoint(
850+
self,
851+
) -> tuple[Optional[dict[str, int]], Optional[bytes]]:
852+
best_header: Optional[dict[str, int]] = None
853+
best_payload: Optional[bytes] = None
854+
for offset in self._meta_container_offsets():
855+
header = self._read_meta_header(offset)
856+
if header is None:
857+
continue
858+
payload = self._load_meta_payload(header)
859+
if payload is None:
860+
continue
861+
if best_header is None or int(header["seq"]) > int(best_header["seq"]):
862+
best_header = header
863+
best_payload = payload
864+
return best_header, best_payload
865+
835866
def _snapshot_state(self) -> tuple[dict[str, Any], int]:
836867
with self._lock:
837868
dirty_total = self._meta_dirty_total
@@ -886,10 +917,9 @@ def _write_checkpoint(self, payload: bytes, dirty_total_snapshot: int) -> bool:
886917
)
887918
return False
888919

889-
target = 0
890-
current = self._read_meta_header(target)
891-
current_seq = int(current["seq"]) if current is not None else 0
892-
next_seq = max(current_seq, self._meta_seq) + 1
920+
next_seq = self._meta_seq + 1
921+
target_idx = int((next_seq - 1) % self._meta_copy_count)
922+
target = self._meta_container_offsets()[target_idx]
893923

894924
payload_len = len(payload)
895925
payload_total_len = _round_up(payload_len, self.block_align)
@@ -942,6 +972,17 @@ def _checkpoint_once(self, force: bool) -> bool:
942972
)
943973
return ok
944974

975+
def _is_valid_checkpoint_entry(self, offset: int, size: int) -> bool:
976+
if offset < self._data_base_offset:
977+
return False
978+
rel = offset - self._data_base_offset
979+
if rel % self.slot_bytes != 0:
980+
return False
981+
slot = rel // self.slot_bytes
982+
if slot < 0 or slot >= self._max_slots:
983+
return False
984+
return 0 < size <= (self.slot_bytes - self.header_bytes)
985+
945986
def _apply_loaded_state(self, data: dict[str, Any]) -> bool:
946987
if not isinstance(data, dict):
947988
return False
@@ -1036,6 +1077,16 @@ def _apply_loaded_state(self, data: dict[str, Any]) -> bool:
10361077
fmt_name = entry.get("fmt")
10371078
cached_positions_list = entry.get("cached_positions")
10381079

1080+
if not self._is_valid_checkpoint_entry(offset, size):
1081+
logger.warning(
1082+
"Skipping invalid checkpoint entry for key '%s': "
1083+
"offset=%d size=%d",
1084+
k_str,
1085+
offset,
1086+
size,
1087+
)
1088+
continue
1089+
10391090
shape = (
10401091
torch.Size(list(shape_list)) if shape_list is not None else None
10411092
)
@@ -1132,18 +1183,13 @@ def _validate_loaded_entries(self) -> None:
11321183
)
11331184

11341185
def _load_checkpoint_from_device(self) -> None:
1135-
header = self._read_meta_header(0)
1186+
header, payload = self._select_latest_checkpoint()
11361187
if header is None:
11371188
logger.info(
11381189
"RustRawBlockBackend: no valid on-device metadata checkpoint found"
11391190
)
11401191
return
1141-
payload = self._load_meta_payload(header)
1142-
if payload is None:
1143-
logger.warning(
1144-
"RustRawBlockBackend: metadata header exists but payload is invalid"
1145-
)
1146-
return
1192+
assert payload is not None
11471193
try:
11481194
data = json.loads(payload.decode("utf-8"))
11491195
except Exception:

tests/v1/storage_backend/test_rust_raw_block_backend.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_rust_raw_block_backend_ignores_torn_newer_checkpoint(
370370
memory_allocator, loop_in_thread
371371
):
372372
"""
373-
If checkpoint payload CRC is invalid, loader ignores metadata.
373+
If a newer checkpoint copy is torn, loader falls back to the older valid copy.
374374
"""
375375
with tempfile.TemporaryDirectory() as td:
376376
dev_path = os.path.join(td, "dev.bin")
@@ -423,22 +423,24 @@ def test_rust_raw_block_backend_ignores_torn_newer_checkpoint(
423423
)
424424
assert obj is not None and obj.tensor is not None
425425
obj.tensor.fill_(11)
426+
expected = bytes(obj.byte_array)
426427
try:
427428
fut = backend1.batched_submit_put_task([key], [obj])[0]
428429
fut.result(timeout=10)
429430
finally:
431+
torn_offset = backend1._meta_container_offsets()[1]
430432
backend1.close()
431433

432-
# Corrupt single checkpoint header/payload with invalid CRC.
434+
# Corrupt the newer checkpoint copy with invalid CRC.
433435
# Header format: <8sIQQI (magic, version, seq, payload_len, crc).
434436
header = struct.pack(
435437
"<8sIQQI", _DEFAULT_META_MAGIC, _DEFAULT_META_VERSION, 9999, 2, 0
436438
)
437439
padded_header = header + bytes(align - len(header))
438440
with open(dev_path, "r+b") as f:
439-
f.seek(align)
441+
f.seek(torn_offset + align)
440442
f.write(b"{}")
441-
f.seek(0)
443+
f.seek(torn_offset)
442444
f.write(padded_header)
443445

444446
backend2 = RustRawBlockBackend(
@@ -449,8 +451,102 @@ def test_rust_raw_block_backend_ignores_torn_newer_checkpoint(
449451
dst_device="cpu",
450452
)
451453
try:
452-
assert not backend2.contains(key)
454+
assert backend2.contains(key)
453455
out = backend2.get_blocking(key)
454-
assert out is None
456+
assert out is not None
457+
assert bytes(out.byte_array) == expected
455458
finally:
456459
backend2.close()
460+
461+
462+
@pytest.mark.skipif(
463+
not _has_ext(), reason="lmcache_rust_raw_block_io extension not installed"
464+
)
465+
def test_rust_raw_block_backend_skips_invalid_checkpoint_entries(
466+
memory_allocator, loop_in_thread
467+
):
468+
"""Checkpoint restore should reject invalid offset/size metadata entries."""
469+
with tempfile.TemporaryDirectory() as td:
470+
dev_path = os.path.join(td, "dev.bin")
471+
with open(dev_path, "wb") as f:
472+
f.truncate(64 * 1024 * 1024)
473+
474+
base_cfg = LMCacheEngineConfig.from_defaults(
475+
chunk_size=256,
476+
local_cpu=True,
477+
max_local_cpu_size=0.1,
478+
lmcache_instance_id="test_rust_raw_block_backend_invalid_checkpoint",
479+
)
480+
base_cfg.extra_config = {
481+
"rust_raw_block.device_path": dev_path,
482+
"rust_raw_block.block_align": 4096,
483+
"rust_raw_block.header_bytes": 4096,
484+
"rust_raw_block.meta_total_bytes": 4 * 1024 * 1024,
485+
"rust_raw_block.meta_enable_periodic": False,
486+
"rust_raw_block.meta_verify_on_load": False,
487+
}
488+
metadata = LMCacheMetadata(
489+
model_name="test_model",
490+
world_size=1,
491+
local_world_size=1,
492+
worker_id=0,
493+
local_worker_id=0,
494+
kv_dtype=torch.bfloat16,
495+
kv_shape=(4, 2, 256, 8, 128),
496+
)
497+
498+
local_cpu = LocalCPUBackend(
499+
config=base_cfg,
500+
metadata=metadata,
501+
dst_device="cpu",
502+
memory_allocator=memory_allocator,
503+
)
504+
backend = RustRawBlockBackend(
505+
config=base_cfg,
506+
metadata=metadata,
507+
local_cpu_backend=local_cpu,
508+
loop=loop_in_thread,
509+
dst_device="cpu",
510+
)
511+
try:
512+
entries = {}
513+
for chunk_hash, (offset, size) in {
514+
1: (backend._data_base_offset - backend.slot_bytes, 1024),
515+
2: (backend._data_base_offset + 1, 1024),
516+
3: (
517+
backend._data_base_offset,
518+
backend.slot_bytes - backend.header_bytes + 1,
519+
),
520+
}.items():
521+
key = CacheEngineKey("test_model", 1, 0, chunk_hash, torch.bfloat16)
522+
entries[key.to_string()] = {
523+
"offset": offset,
524+
"size": size,
525+
"shape": [2, 16, 8, 128],
526+
"dtype": "bfloat16",
527+
"fmt": MemoryFormat.KV_T2D.name,
528+
"cached_positions": None,
529+
}
530+
531+
applied = backend._apply_loaded_state(
532+
{
533+
"version": 1,
534+
"device_path": dev_path,
535+
"capacity_bytes": backend.capacity_bytes,
536+
"block_align": backend.block_align,
537+
"header_bytes": backend.header_bytes,
538+
"slot_bytes": backend.slot_bytes,
539+
"meta_total_bytes": backend.meta_total_bytes,
540+
"meta_magic": backend.meta_magic_text,
541+
"meta_version": backend.meta_version,
542+
"data_base_offset": backend._data_base_offset,
543+
"next_slot": 0,
544+
"free_slots": [],
545+
"lru_keys": [],
546+
"entries": entries,
547+
}
548+
)
549+
assert applied is True
550+
assert backend._index == {}
551+
finally:
552+
backend.close()

0 commit comments

Comments
 (0)