Skip to content

Commit 45f8b41

Browse files
committed
fix nit comments
Signed-off-by: ApostaC <yihua98@uchicago.edu>
1 parent dea9e2f commit 45f8b41

1 file changed

Lines changed: 17 additions & 2 deletions

File tree

csrc/mp_mem_kernels.cu

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ namespace {
1919
* [2, L, 256, NH * HS], where 256 means that 256 tokens
2020
*/
2121

22+
/**
23+
* Calculate the offset for the current block in the paged buffer
24+
*/
2225
template <typename ScalarType, GPUKVFormat format>
2326
__device__ inline size_t calculate_engine_global_offset(
2427
const int k_or_v, const int engine_block_idx, const int layer_idx,
@@ -50,6 +53,10 @@ __device__ inline size_t calculate_engine_global_offset(
5053
}
5154
}
5255

56+
/**
57+
* Calculate the offset for the current token against the start
58+
* of the block in the paged buffer.
59+
*/
5360
template <typename ScalarType, GPUKVFormat format>
5461
__device__ inline size_t calculate_engine_local_offset(
5562
const int token_offset, const int head_idx,
@@ -60,21 +67,29 @@ __device__ inline size_t calculate_engine_local_offset(
6067
return head_idx * scalars_per_head + token_offset * scalars_per_token;
6168
}
6269

70+
/**
71+
* Calculate the global offset for the current `block` in the LMCache object.
72+
* The `block` here is the memory region corresponding to a thread-block.
73+
*/
6374
template <typename ScalarType, GPUKVFormat format>
6475
__device__ inline size_t calculate_lmcache_global_offset(
6576
const int k_or_v,
6677
const int
67-
token_offset_in_lmcache_block, // 0~255 if LMCache block size is 256
78+
token_offset_in_lmcache_object, // 0~255 if LMCache chunk size is 256
6879
const int layer_idx,
6980
const int lmcache_chunk_size, // e.g., 256
7081
const PageBufferShapeDesc shape_desc) {
7182
size_t scalars_per_token = shape_desc.scalars_per_token<ScalarType>();
7283
// LMCache is using 2LTD all the times
73-
return token_offset_in_lmcache_block * scalars_per_token +
84+
return token_offset_in_lmcache_object * scalars_per_token +
7485
layer_idx * lmcache_chunk_size * scalars_per_token +
7586
k_or_v * shape_desc.nl * lmcache_chunk_size * scalars_per_token;
7687
}
7788

89+
/**
90+
* Calculate the local offset for the current token against the start of the
91+
* block in the LMCache object.
92+
*/
7893
template <typename ScalarType, GPUKVFormat format>
7994
__device__ inline size_t calculate_lmcache_local_offset(
8095
const int token_offset, const int head_idx,

0 commit comments

Comments
 (0)