Skip to content

[MP][Core] Block id based kernel for MP mode#2838

Merged
ApostaC merged 15 commits intoLMCache:devfrom
ApostaC:local-dev/mp-kernel-real
Mar 26, 2026
Merged

[MP][Core] Block id based kernel for MP mode#2838
ApostaC merged 15 commits intoLMCache:devfrom
ApostaC:local-dev/mp-kernel-real

Conversation

@ApostaC
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC commented Mar 21, 2026

What this PR does / why we need it:

Adds a new block-level CUDA kernel (multi_layer_block_kv_transfer) for MP mode that operates on block IDs instead of flat token-level slot mappings. This replaces the old multi_layer_kv_transfer in the MP server's store/retrieve paths.

Also fixes the correctness issue.

Key changes:

  • New CUDA kernel in csrc/mp_mem_kernels.cu supporting all 6 GPU KV formats (Normal, Cross-layer, Flash Infer, MLA, SGLang MHA, SGLang MLA) with bf16 and fp8
  • PTX ld.global.cs / st.global.cs streaming copy to bypass L2 cache
  • Block IDs passed as a GPU tensor (pre-staged via a 1M-element pre-allocated buffer on GPUCacheContext)
  • GPUCacheContext extended with PageBufferShapeDesc, num_heads, head_size, stage_block_ids(), and shape_desc property
  • MP server store/retrieve paths updated to use the new kernel with block-level skip prefix
  • 30 unit tests covering all 6 formats × 2 dtypes × 2 memory devices + skip-prefix

Special notes for your reviewers:

  • The kernel takes block IDs (not slot mappings), eliminating the slot_mapping_tensor computation
  • The tmp_buffer + lmcache_memcpy_async pattern is preserved for optimal GPU↔CPU DMA transfer
  • Token-level skip_first_n_tokens is converted to block-level with an error log if not block-aligned
  • mem_kernels.cuh gained #pragma once (was missing, caused redefinition errors)

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

Performance snapshot:

Before: gaps between h2d copies, ~90ms for a retrieve

image

After: gaps are mitigated, ~50ms for a retrieve, 1.8x improvements

image

Note

Medium Risk
Introduces a new CUDA/pybind transfer path and rewires MP store/retrieve to use it, which can affect correctness/perf across multiple KV layouts and dtypes. Risk is mitigated by added validation and broad CUDA unit test coverage, but kernel/offset logic and stream interactions are still sensitive.

Overview
Adds a new block-ID–driven CUDA kernel (multi_layer_block_kv_transfer) to move KV data between vLLM paged buffers and LMCache contiguous chunks, supporting all existing GPUKVFormats and using vectorized uint4 streaming loads/stores.

Updates MP mode to use this block-level transfer instead of token-level slot mappings: GPUCacheContext now stages block IDs to a preallocated GPU buffer, precomputes a PageBufferShapeDesc, and supports batched temp buffers; the MP server’s store/retrieve paths are rewired accordingly (including converting token skip to block skip and batching retrieve copies).

Build/ABI plumbing is extended (new sources, pybind exports, #pragma once fix), and new CUDA tests validate roundtrip + skip-prefix behavior across formats, bf16/fp8, and CPU/GPU memory objects.

Written by Cursor Bugbot for commit 45f8b41. This will update automatically on new commits. Configure here.

ApostaC added 5 commits March 20, 2026 21:50
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant architectural improvement for multi-process (MP) mode by replacing the token-level KV transfer mechanism with a more efficient block-level CUDA kernel. This change streamlines the process of moving KV data between vLLM paged buffers and LMCache objects, reducing overhead by operating directly on block IDs and optimizing memory access patterns. The update enhances the system's ability to manage KV caches across different GPU memory layouts and data types, leading to more robust and potentially faster operations in distributed environments.

Highlights

  • New Block-Level CUDA Kernel: Introduced multi_layer_block_kv_transfer, a new CUDA kernel in csrc/mp_mem_kernels.cu, designed for block-level KV data transfer in MP mode. This kernel supports all six GPU KV formats (Normal, Cross-layer, Flash Infer, MLA, SGLang MHA, SGLang MLA) with bf16 and fp8 data types.
  • Performance Optimization: The new kernel utilizes PTX ld.global.cs / st.global.cs streaming copies to bypass the L2 cache, aiming for improved memory transfer efficiency.
  • Simplified Block ID Handling: The kernel now directly takes block IDs as a GPU tensor, eliminating the need for slot_mapping_tensor computation. A 1M-element pre-allocated buffer on GPUCacheContext is used to stage these block IDs.
  • Updated GPUCacheContext: The GPUCacheContext has been extended with PageBufferShapeDesc, num_heads, head_size, and a stage_block_ids() method to support the new block-level operations.
  • MP Server Integration: The MP server's store/retrieve paths have been updated to leverage the new block-level kernel, including logic for block-level skip prefix handling and error logging for unaligned token skips.
  • Comprehensive Testing: Over 30 unit tests have been added, covering all six KV formats, bf16 and fp8 data types, GPU and CPU memory devices, and skip-prefix functionality, ensuring robustness and correctness.
  • Header Fix: The mem_kernels.cuh header file now includes #pragma once to prevent redefinition errors during compilation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel, multi_layer_block_kv_transfer, and supporting C++/Python infrastructure to enable efficient block-level KV cache transfers between vLLM paged buffers and LMCache memory objects. The changes include defining a PageBufferShapeDesc struct to describe KV cache geometry, implementing the CUDA kernel to handle various KV formats (Normal, Cross-layer, Flash Infer, MLA, SGLang MHA/MLA) for both H2D and D2H transfers, and exposing this functionality to Python via pybind11. Python-side updates involve initializing PageBufferShapeDesc in GPUCacheContext, pre-allocating a GPU buffer for block IDs, and updating store and retrieve methods in server.py to utilize the new block-level transfer mechanism. Review comments highlight a performance issue in stage_block_ids where torch.frombuffer creates a non-pinned tensor, negating asynchronous copy benefits, and suggest using pin_memory=True. Additionally, code duplication in the CUDA kernel's dispatch logic is noted as an opportunity for refactoring with a macro, and a style guide violation regarding import grouping in gpu_context.py is pointed out. Finally, it's suggested that the _MAX_BLOCK_IDS constant in GPUCacheContext should be moved to a class or module level for better maintainability.

Comment on lines +238 to +253
def stage_block_ids(self, block_ids: list[int]) -> torch.Tensor:
"""Copy block_ids into the pre-allocated GPU buffer and return a
view of the occupied region. Uses non-blocking copy via a pinned
CPU tensor created from the list's underlying buffer.

Args:
block_ids: Block indices as a Python list of ints.

Returns:
A GPU int64 tensor view into the pre-allocated buffer.
"""
n = len(block_ids)
cpu_tensor = torch.frombuffer(array.array("l", block_ids), dtype=torch.long)
buf = self.block_ids_buffer_[:n]
buf.copy_(cpu_tensor, non_blocking=True)
return buf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The stage_block_ids method aims for a non-blocking copy to the GPU by using non_blocking=True. However, for this to work as intended, the source CPU tensor must be page-locked (pinned). The current implementation with torch.frombuffer creates a non-pinned tensor, which will cause the copy to be synchronous, negating the performance benefit. Additionally, the docstring is misleading as it claims to use a pinned tensor.

To fix this, create a pinned tensor using torch.tensor(..., pin_memory=True). This ensures the copy is truly asynchronous.

    def stage_block_ids(self, block_ids: list[int]) -> torch.Tensor:
        """Copy block_ids into the pre-allocated GPU buffer and return a
        view of the occupied region. Uses non-blocking copy via a pinned
        CPU tensor.

        Args:
            block_ids: Block indices as a Python list of ints.

        Returns:
            A GPU int64 tensor view into the pre-allocated buffer.
        """
        n = len(block_ids)
        cpu_tensor = torch.tensor(block_ids, dtype=torch.long, pin_memory=True)
        buf = self.block_ids_buffer_[:n]
        buf.copy_(cpu_tensor, non_blocking=True)
        return buf

Comment thread csrc/mp_mem_kernels.cu
Comment on lines +260 to +314
if (direction == TransferDirection::H2D) {
switch (gpu_kv_format) {
case GPUKVFormat::NB_NL_TWO_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true, GPUKVFormat::NB_NL_TWO_BS_NH_HS);
break;
case GPUKVFormat::NL_X_TWO_NB_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true,
GPUKVFormat::NL_X_TWO_NB_BS_NH_HS);
break;
case GPUKVFormat::NL_X_NB_TWO_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true,
GPUKVFormat::NL_X_NB_TWO_BS_NH_HS);
break;
case GPUKVFormat::NL_X_NB_BS_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true, GPUKVFormat::NL_X_NB_BS_HS);
break;
case GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true,
GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS);
break;
case GPUKVFormat::NL_X_NBBS_ONE_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(true, GPUKVFormat::NL_X_NBBS_ONE_HS);
break;
default:
TORCH_CHECK(false, "Unsupported GPUKVFormat: ",
static_cast<int>(gpu_kv_format));
}
} else {
switch (gpu_kv_format) {
case GPUKVFormat::NB_NL_TWO_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false, GPUKVFormat::NB_NL_TWO_BS_NH_HS);
break;
case GPUKVFormat::NL_X_TWO_NB_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false,
GPUKVFormat::NL_X_TWO_NB_BS_NH_HS);
break;
case GPUKVFormat::NL_X_NB_TWO_BS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false,
GPUKVFormat::NL_X_NB_TWO_BS_NH_HS);
break;
case GPUKVFormat::NL_X_NB_BS_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false, GPUKVFormat::NL_X_NB_BS_HS);
break;
case GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false,
GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS);
break;
case GPUKVFormat::NL_X_NBBS_ONE_HS:
LAUNCH_BLOCK_KERNEL_WITH_FORMAT(false, GPUKVFormat::NL_X_NBBS_ONE_HS);
break;
default:
TORCH_CHECK(false, "Unsupported GPUKVFormat: ",
static_cast<int>(gpu_kv_format));
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The switch statement for dispatching based on gpu_kv_format is duplicated for H2D and D2H transfer directions. This code duplication can make maintenance harder. You can refactor this using a macro to eliminate the repeated code block, improving maintainability.

#define LAUNCH_SWITCH(direction) \
  switch (gpu_kv_format) { \
    case GPUKVFormat::NB_NL_TWO_BS_NH_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::NB_NL_TWO_BS_NH_HS); \
      break; \
    case GPUKVFormat::NL_X_TWO_NB_BS_NH_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::NL_X_TWO_NB_BS_NH_HS); \
      break; \
    case GPUKVFormat::NL_X_NB_TWO_BS_NH_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::NL_X_NB_TWO_BS_NH_HS); \
      break; \
    case GPUKVFormat::NL_X_NB_BS_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::NL_X_NB_BS_HS); \
      break; \
    case GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::TWO_X_NL_X_NBBS_NH_HS); \
      break; \
    case GPUKVFormat::NL_X_NBBS_ONE_HS: \
      LAUNCH_BLOCK_KERNEL_WITH_FORMAT(direction, GPUKVFormat::NL_X_NBBS_ONE_HS); \
      break; \
    default: \
      TORCH_CHECK(false, "Unsupported GPUKVFormat: ", static_cast<int>(gpu_kv_format)); \
  }

  // --- Dispatch on direction x format ---
  if (direction == TransferDirection::H2D) {
    LAUNCH_SWITCH(true);
  } else {
    LAUNCH_SWITCH(false);
  }

Comment on lines 31 to 39
)

if torch.cuda.is_available():
import lmcache.c_ops as lmc_ops

# First Party
from lmcache.v1.multiprocess.custom_types import (
KVCache,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import structure here violates the project's style guide (line 28), which requires grouping imports by type (Standard, Third Party, First Party). The conditional import of lmcache.c_ops is a first-party import and should be grouped with the others under a single # First Party comment to improve readability and consistency.

References
  1. Import order should be: Standard / Third Party / First Party / Local (with section heading comments). This change introduces a disjointed 'First Party' import section. (link)

# Pre-allocated GPU buffer for block IDs (up to 1M elements).
# The caller copies block_ids into this buffer before launching the
# block-level kernel. Single-thread assumption: no lock needed.
_MAX_BLOCK_IDS = 1_000_000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constant _MAX_BLOCK_IDS is defined as a local variable within __init__. For better readability and maintainability, it's preferable to define such constants as class attributes or module-level constants. This makes it easier to find and modify if needed.

class GPUCacheContext:
    _MAX_BLOCK_IDS = 1_000_000

    def __init__(self, kv_caches: KVCache, lmcache_chunk_size: int = 256):
        # ... existing code ...
        self.block_ids_buffer_ = torch.empty(
            self._MAX_BLOCK_IDS, dtype=torch.long, device=self.device_
        )

@ApostaC ApostaC added the mp Buildkite trigger for multi-processing mode test label Mar 21, 2026
Comment thread csrc/mp_mem_kernels.cu Outdated

/**
* Key logic in the kernel implementation:
* 1 Each thread block is for (BS, NH, HS) part (i.e., a single block in the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format newlines

ApostaC added 8 commits March 23, 2026 19:54
Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
@ApostaC ApostaC added the full Run comprehensive tests on this PR label Mar 24, 2026
Comment thread csrc/mp_mem_kernels.cu Outdated
__device__ inline size_t calculate_lmcache_global_offset(
const int k_or_v,
const int
token_offset_in_lmcache_block, // 0~255 if LMCache block size is 256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this just be chunk_offset? maybe keep the vocab of block for serving engine and chunk for LMCache?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread csrc/mp_mem_kernels.cu
}

template <typename ScalarType, GPUKVFormat format>
__device__ inline size_t calculate_lmcache_local_offset(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for nitpicky naming but how about just:
calculate_lmcache_offset and calculate_engine_offset for above?

just a suggestion, feel free to ignore!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the difference between global and local.
global means the offset in the buffer for the whole thread block
local means the offset for the current token

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comments

Comment thread csrc/mp_mem_kernels.cuh
#include <c10/cuda/CUDAGuard.h>
#include <vector>

struct PageBufferShapeDesc {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great!!

Comment thread csrc/mp_mem_kernels.cu
shape_desc);
ScalarType* paged_buffer_layer_ptr;
if constexpr (format == GPUKVFormat::NB_NL_TWO_BS_NH_HS) {
paged_buffer_layer_ptr = (ScalarType*)paged_buffer_ptrs[0];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit (feel free to ingore): maybe use reinterpret_cast<ScalarType*> here instead?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm I just read that C style cast is best practice.

Comment thread csrc/mp_mem_kernels.cu
return;
}
const int obj_idx = flat_block_idx / num_blocks_per_object;
const int block_idx_in_object = flat_block_idx % num_blocks_per_object;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit comments as above. maybe sticking to chunk naming for all lmcache related things will be much easier to read

e.g.

chunk_idx, block_id_in_chunk

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread csrc/mp_mem_kernels.cu
}

template <typename ScalarType>
__device__ inline void warp_copy(ScalarType* __restrict__ dst,
Copy link
Copy Markdown
Contributor

@sammshen sammshen Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe leave a comemnt that the caller should disdpatch with blockDim.x <= 32?

with (
torch.cuda.device(gpu_context.device),
torch.cuda.stream(gpu_context.high_priority_stream),
torch.cuda.stream(gpu_context.stream),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

@sammshen
Copy link
Copy Markdown
Contributor

so to summarize this PR does three things:

  1. remove slot_mapping and use block_mapping instead
  2. batch kernel operations into 4 memory objects
  3. fix the double stream race condition

did I miss anything @ApostaC

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!!

@sammshen
Copy link
Copy Markdown
Contributor

@cursor review

@sammshen
Copy link
Copy Markdown
Contributor

cursor review

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Mar 26, 2026

bugbot run

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

@property
def shape_desc(self) -> "lmc_ops.PageBufferShapeDesc":
return self.shape_desc_

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inserted property stole @property from gpu_kv_format_name

High Severity

The new shape_desc property was inserted between the existing @property decorator (line 227) and gpu_kv_format_name (line 231), stealing the decorator. Now gpu_kv_format_name is a plain method, but server.py accesses it without parentheses as ctx.gpu_kv_format_name, which returns a bound method object instead of the format name string. This breaks report_status() output and could cause serialization failures.

Additional Locations (1)
Fix in Cursor Fix in Web


@property
def head_size(self) -> int:
return self.head_size_
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New public properties missing required docstrings

Low Severity

The new public properties num_heads, head_size, and shape_desc lack docstrings. Per AGENTS.md, every public function/method must have a docstring covering what it does, arguments, and return values. All neighboring properties (block_size, num_layers, num_blocks, hidden_dim_size, is_mla) include docstrings. This violates the project convention.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: LMCache Code Review Style Guide

Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HND

@ApostaC ApostaC enabled auto-merge (squash) March 26, 2026 20:23
@ApostaC ApostaC merged commit 65b834a into LMCache:dev Mar 26, 2026
35 checks passed
@DongDongJu
Copy link
Copy Markdown
Collaborator

WoW terrific work.

royyhuang pushed a commit to royyhuang/LMCache that referenced this pull request Mar 26, 2026
* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
royyhuang added a commit that referenced this pull request Mar 27, 2026
…fy config (#2806)

* [MP][Observability] Migrate MP server telemetry to EventBus, unify config, remove old telemetry system

- Migrate 6 telemetry call sites in MPCacheEngine (store/retrieve/lookup)
  from log_telemetry(make_start_event/make_end_event) to EventBus.publish(Event)
- Add MPServerLoggingSubscriber for debug logging of all MP server events
- Add MPServerTracingSubscriber for OTel spans from START/END event pairs
- Replace PrometheusConfig + TelemetryConfig with unified ObservabilityConfig
  (--disable-observability, --disable-metrics, --disable-logging,
   --enable-tracing, --otlp-endpoint, --prometheus-port)
- Move OTLP endpoint from env var to config option (--otlp-endpoint)
- Wire init_otel_tracing() when tracing is enabled
- Conditionally register subscribers based on config toggles
- Add __all__ exports to all subscriber __init__.py files
- Delete entire telemetry/ subdirectory and its tests
- Rename REFACT_DESIGN.md to DESIGN.md (architecture doc only)
- Split event metadata contracts into EVENTS.md
- Update README.md with full config reference
- Update METRICS.md with MP server event types

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Fix mypy: update test files to use ObservabilityConfig

Replace DEFAULT_PROMETHEUS_CONFIG / prometheus_config with
DEFAULT_OBSERVABILITY_CONFIG / obs_config in test_cache_server,
test_blend_server, and test_blend_server_v2.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add back the lost logging subscribers

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add event bus queue size configurable from cli args

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Address PR review comments

- Move OTel LoggingHandler into init_logger so all loggers get OTel
  forwarding automatically; remove duplicated setup from 3 subscriber
  files. Log level now respects LMCACHE_LOG_LEVEL instead of being
  hardcoded to DEBUG.
- Add global observability flag (is_observability_enabled) to skip
  launch_host_func calls in CUDA streams when observability is disabled.
- Validate that --enable-tracing requires --otlp-endpoint at startup.
- Create AGENTS.override.md for mp_observability module.
- Rewrite observability docs to match new EventBus CLI args and
  document all three modes (metrics, logging, tracing).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Add publish_on_stream util to EventBus

Extract the repeated pattern of checking observability + calling
launch_host_func into EventBus.publish_on_stream(stream, event).
Callers no longer need to manually check the flag before scheduling
host functions on CUDA streams.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [CLI] Update server command for new observability config API

Replace removed add_prometheus_args/add_telemetry_args with
add_observability_args, and parse_args_to_prometheus_config/
parse_args_to_telemetry_config with parse_args_to_observability_config.
Update test assertions to match new kwarg name (obs_config).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Core] Block id based kernel for MP mode (#2838)

* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>

---------

Signed-off-by: royyhuang <roy.y.huang@gmail.com>
Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…fy config (LMCache#2806)

* [MP][Observability] Migrate MP server telemetry to EventBus, unify config, remove old telemetry system

- Migrate 6 telemetry call sites in MPCacheEngine (store/retrieve/lookup)
  from log_telemetry(make_start_event/make_end_event) to EventBus.publish(Event)
- Add MPServerLoggingSubscriber for debug logging of all MP server events
- Add MPServerTracingSubscriber for OTel spans from START/END event pairs
- Replace PrometheusConfig + TelemetryConfig with unified ObservabilityConfig
  (--disable-observability, --disable-metrics, --disable-logging,
   --enable-tracing, --otlp-endpoint, --prometheus-port)
- Move OTLP endpoint from env var to config option (--otlp-endpoint)
- Wire init_otel_tracing() when tracing is enabled
- Conditionally register subscribers based on config toggles
- Add __all__ exports to all subscriber __init__.py files
- Delete entire telemetry/ subdirectory and its tests
- Rename REFACT_DESIGN.md to DESIGN.md (architecture doc only)
- Split event metadata contracts into EVENTS.md
- Update README.md with full config reference
- Update METRICS.md with MP server event types

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Fix mypy: update test files to use ObservabilityConfig

Replace DEFAULT_PROMETHEUS_CONFIG / prometheus_config with
DEFAULT_OBSERVABILITY_CONFIG / obs_config in test_cache_server,
test_blend_server, and test_blend_server_v2.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add back the lost logging subscribers

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add event bus queue size configurable from cli args

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Address PR review comments

- Move OTel LoggingHandler into init_logger so all loggers get OTel
  forwarding automatically; remove duplicated setup from 3 subscriber
  files. Log level now respects LMCACHE_LOG_LEVEL instead of being
  hardcoded to DEBUG.
- Add global observability flag (is_observability_enabled) to skip
  launch_host_func calls in CUDA streams when observability is disabled.
- Validate that --enable-tracing requires --otlp-endpoint at startup.
- Create AGENTS.override.md for mp_observability module.
- Rewrite observability docs to match new EventBus CLI args and
  document all three modes (metrics, logging, tracing).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Add publish_on_stream util to EventBus

Extract the repeated pattern of checking observability + calling
launch_host_func into EventBus.publish_on_stream(stream, event).
Callers no longer need to manually check the flag before scheduling
host functions on CUDA streams.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [CLI] Update server command for new observability config API

Replace removed add_prometheus_args/add_telemetry_args with
add_observability_args, and parse_args_to_prometheus_config/
parse_args_to_telemetry_config with parse_args_to_observability_config.
Update test assertions to match new kwarg name (obs_config).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Core] Block id based kernel for MP mode (LMCache#2838)

* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>

---------

Signed-off-by: royyhuang <roy.y.huang@gmail.com>
Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…fy config (LMCache#2806)

* [MP][Observability] Migrate MP server telemetry to EventBus, unify config, remove old telemetry system

- Migrate 6 telemetry call sites in MPCacheEngine (store/retrieve/lookup)
  from log_telemetry(make_start_event/make_end_event) to EventBus.publish(Event)
- Add MPServerLoggingSubscriber for debug logging of all MP server events
- Add MPServerTracingSubscriber for OTel spans from START/END event pairs
- Replace PrometheusConfig + TelemetryConfig with unified ObservabilityConfig
  (--disable-observability, --disable-metrics, --disable-logging,
   --enable-tracing, --otlp-endpoint, --prometheus-port)
- Move OTLP endpoint from env var to config option (--otlp-endpoint)
- Wire init_otel_tracing() when tracing is enabled
- Conditionally register subscribers based on config toggles
- Add __all__ exports to all subscriber __init__.py files
- Delete entire telemetry/ subdirectory and its tests
- Rename REFACT_DESIGN.md to DESIGN.md (architecture doc only)
- Split event metadata contracts into EVENTS.md
- Update README.md with full config reference
- Update METRICS.md with MP server event types

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Fix mypy: update test files to use ObservabilityConfig

Replace DEFAULT_PROMETHEUS_CONFIG / prometheus_config with
DEFAULT_OBSERVABILITY_CONFIG / obs_config in test_cache_server,
test_blend_server, and test_blend_server_v2.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add back the lost logging subscribers

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* add event bus queue size configurable from cli args

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Address PR review comments

- Move OTel LoggingHandler into init_logger so all loggers get OTel
  forwarding automatically; remove duplicated setup from 3 subscriber
  files. Log level now respects LMCACHE_LOG_LEVEL instead of being
  hardcoded to DEBUG.
- Add global observability flag (is_observability_enabled) to skip
  launch_host_func calls in CUDA streams when observability is disabled.
- Validate that --enable-tracing requires --otlp-endpoint at startup.
- Create AGENTS.override.md for mp_observability module.
- Rewrite observability docs to match new EventBus CLI args and
  document all three modes (metrics, logging, tracing).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Observability] Add publish_on_stream util to EventBus

Extract the repeated pattern of checking observability + calling
launch_host_func into EventBus.publish_on_stream(stream, event).
Callers no longer need to manually check the flag before scheduling
host functions on CUDA streams.

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [CLI] Update server command for new observability config API

Replace removed add_prometheus_args/add_telemetry_args with
add_observability_args, and parse_args_to_prometheus_config/
parse_args_to_telemetry_config with parse_args_to_observability_config.
Update test assertions to match new kwarg name (obs_config).

Signed-off-by: royyhuang <roy.y.huang@gmail.com>

* [MP][Core] Block id based kernel for MP mode (LMCache#2838)

* add new kernels and unit tests for mp mode

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* change the block ids to be on gpu tensor

Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>

---------

Signed-off-by: royyhuang <roy.y.huang@gmail.com>
Signed-off-by: Yihua Cheng <yihua98@uchicago.edu>
Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR mp Buildkite trigger for multi-processing mode test

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants