Skip to content

Commit 4675e9d

Browse files
authored
[TransferEngine] Fix minor bugs in NVLink transport and benchmark (kvcache-ai#468)
* [TransferEngine] Fix compilation bug in NVLink xport * [TransferEngine] Fix minor bugs in nvlink benchmark
1 parent 36c6a63 commit 4675e9d

2 files changed

Lines changed: 20 additions & 23 deletions

File tree

mooncake-transfer-engine/example/transfer_engine_bench.cpp

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
#ifdef USE_CUDA
3434
#include <bits/stdint-uintn.h>
35-
#include <cuda_runtime.h>
3635
#include <cuda.h>
36+
#include <cuda_runtime.h>
3737

3838
#ifdef USE_NVMEOF
3939
#include <cufile.h>
@@ -100,12 +100,12 @@ static void *allocateMemoryPool(size_t size, int socket_id,
100100
int gpu_id = FLAGS_gpu_id;
101101
void *d_buf;
102102
checkCudaError(cudaSetDevice(gpu_id), "Failed to set device");
103-
#ifdef USE_NVLINK
104-
d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory(size);
105-
#else
106-
checkCudaError(cudaMalloc(&d_buf, size),
107-
"Failed to allocate device memory");
108-
#endif
103+
if (FLAGS_protocol == "nvlink") {
104+
d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory(size);
105+
} else {
106+
checkCudaError(cudaMalloc(&d_buf, size),
107+
"Failed to allocate device memory");
108+
}
109109
return d_buf;
110110
}
111111
#endif
@@ -114,14 +114,14 @@ static void *allocateMemoryPool(size_t size, int socket_id,
114114

115115
static void freeMemoryPool(void *addr, size_t size) {
116116
#ifdef USE_CUDA
117-
#ifdef USE_NVLINK
118-
CUmemGenericAllocationHandle handle;
119-
auto result = cuMemRetainAllocationHandle(&handle, addr);
120-
if (result == CUDA_SUCCESS) {
121-
mooncake::NvlinkTransport::freePinnedLocalMemory(addr);
122-
return;
117+
if (FLAGS_protocol == "nvlink") {
118+
CUmemGenericAllocationHandle handle;
119+
auto result = cuMemRetainAllocationHandle(&handle, addr);
120+
if (result == CUDA_SUCCESS) {
121+
mooncake::NvlinkTransport::freePinnedLocalMemory(addr);
122+
return;
123+
}
123124
}
124-
#endif
125125
// check pointer on GPU
126126
cudaPointerAttributes attributes;
127127
checkCudaError(cudaPointerGetAttributes(&attributes, addr),
@@ -406,12 +406,7 @@ int target() {
406406
buffer_num = FLAGS_use_vram ? 1 : NR_SOCKETS;
407407
if (FLAGS_use_vram) LOG(INFO) << "VRAM is used";
408408
for (int i = 0; i < buffer_num; ++i) {
409-
#ifdef USE_NVLINK
410-
addr[i] = mooncake::NvlinkTransport::allocatePinnedLocalMemory(
411-
FLAGS_buffer_size);
412-
#else
413409
addr[i] = allocateMemoryPool(FLAGS_buffer_size, i, FLAGS_use_vram);
414-
#endif
415410
std::string name_prefix = FLAGS_use_vram ? "cuda:" : "cpu:";
416411
int rc = engine->registerLocalMemory(addr[i], FLAGS_buffer_size,
417412
name_prefix + std::to_string(i));
@@ -431,11 +426,7 @@ int target() {
431426
while (target_running) sleep(1);
432427
for (int i = 0; i < buffer_num; ++i) {
433428
engine->unregisterLocalMemory(addr[i]);
434-
#ifdef USE_NVLINK
435-
mooncake::NvlinkTransport::freePinnedLocalMemory(addr[i]);
436-
#else
437429
freeMemoryPool(addr[i], FLAGS_buffer_size);
438-
#endif
439430
}
440431

441432
return 0;

mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ int NvlinkTransport::registerLocalMemory(void *addr, size_t length,
281281
return 0;
282282
}
283283

284+
cudaError_t err = cudaSetDevice(0);
285+
if (err != cudaSuccess) {
286+
LOG(ERROR) << "NvlinkTransport: cudaSetDevice failed";
287+
return -1;
288+
}
289+
284290
// Find whole physical page for memory registration
285291
void *real_addr;
286292
size_t real_size;

0 commit comments

Comments
 (0)