@@ -19,6 +19,9 @@ namespace {
1919 * [2, L, 256, NH * HS], where 256 means that 256 tokens
2020 */
2121
22+ /* *
23+ * Calculate the offset for the current block in the paged buffer
24+ */
2225template <typename ScalarType, GPUKVFormat format>
2326__device__ inline size_t calculate_engine_global_offset (
2427 const int k_or_v, const int engine_block_idx, const int layer_idx,
@@ -50,6 +53,10 @@ __device__ inline size_t calculate_engine_global_offset(
5053 }
5154}
5255
56+ /* *
57+ * Calculate the offset for the current token against the start
58+ * of the block in the paged buffer.
59+ */
5360template <typename ScalarType, GPUKVFormat format>
5461__device__ inline size_t calculate_engine_local_offset (
5562 const int token_offset, const int head_idx,
@@ -60,21 +67,29 @@ __device__ inline size_t calculate_engine_local_offset(
6067 return head_idx * scalars_per_head + token_offset * scalars_per_token;
6168}
6269
70+ /* *
71+ * Calculate the global offset for the current `block` in the LMCache object.
72+ * The `block` here is the memory region corresponding to a thread-block.
73+ */
6374template <typename ScalarType, GPUKVFormat format>
6475__device__ inline size_t calculate_lmcache_global_offset (
6576 const int k_or_v,
6677 const int
67- token_offset_in_lmcache_block , // 0~255 if LMCache block size is 256
78+ token_offset_in_lmcache_object , // 0~255 if LMCache chunk size is 256
6879 const int layer_idx,
6980 const int lmcache_chunk_size, // e.g., 256
7081 const PageBufferShapeDesc shape_desc) {
7182 size_t scalars_per_token = shape_desc.scalars_per_token <ScalarType>();
7283 // LMCache is using 2LTD all the times
73- return token_offset_in_lmcache_block * scalars_per_token +
84+ return token_offset_in_lmcache_object * scalars_per_token +
7485 layer_idx * lmcache_chunk_size * scalars_per_token +
7586 k_or_v * shape_desc.nl * lmcache_chunk_size * scalars_per_token;
7687}
7788
89+ /* *
90+ * Calculate the local offset for the current token against the start of the
91+ * block in the LMCache object.
92+ */
7893template <typename ScalarType, GPUKVFormat format>
7994__device__ inline size_t calculate_lmcache_local_offset (
8095 const int token_offset, const int head_idx,
0 commit comments