3232
3333#ifdef USE_CUDA
3434#include < bits/stdint-uintn.h>
35- #include < cuda_runtime.h>
3635#include < cuda.h>
36+ #include < cuda_runtime.h>
3737
3838#ifdef USE_NVMEOF
3939#include < cufile.h>
@@ -100,12 +100,12 @@ static void *allocateMemoryPool(size_t size, int socket_id,
100100 int gpu_id = FLAGS_gpu_id;
101101 void *d_buf;
102102 checkCudaError (cudaSetDevice (gpu_id), " Failed to set device" );
103- # ifdef USE_NVLINK
104- d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory (size);
105- # else
106- checkCudaError (cudaMalloc (&d_buf, size),
107- " Failed to allocate device memory" );
108- # endif
103+ if (FLAGS_protocol == " nvlink " ) {
104+ d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory (size);
105+ } else {
106+ checkCudaError (cudaMalloc (&d_buf, size),
107+ " Failed to allocate device memory" );
108+ }
109109 return d_buf;
110110 }
111111#endif
@@ -114,14 +114,14 @@ static void *allocateMemoryPool(size_t size, int socket_id,
114114
115115static void freeMemoryPool (void *addr, size_t size) {
116116#ifdef USE_CUDA
117- #ifdef USE_NVLINK
118- CUmemGenericAllocationHandle handle;
119- auto result = cuMemRetainAllocationHandle (&handle, addr);
120- if (result == CUDA_SUCCESS) {
121- mooncake::NvlinkTransport::freePinnedLocalMemory (addr);
122- return ;
117+ if (FLAGS_protocol == " nvlink" ) {
118+ CUmemGenericAllocationHandle handle;
119+ auto result = cuMemRetainAllocationHandle (&handle, addr);
120+ if (result == CUDA_SUCCESS) {
121+ mooncake::NvlinkTransport::freePinnedLocalMemory (addr);
122+ return ;
123+ }
123124 }
124- #endif
125125 // check pointer on GPU
126126 cudaPointerAttributes attributes;
127127 checkCudaError (cudaPointerGetAttributes (&attributes, addr),
@@ -406,12 +406,7 @@ int target() {
406406 buffer_num = FLAGS_use_vram ? 1 : NR_SOCKETS;
407407 if (FLAGS_use_vram) LOG (INFO) << " VRAM is used" ;
408408 for (int i = 0 ; i < buffer_num; ++i) {
409- #ifdef USE_NVLINK
410- addr[i] = mooncake::NvlinkTransport::allocatePinnedLocalMemory (
411- FLAGS_buffer_size);
412- #else
413409 addr[i] = allocateMemoryPool (FLAGS_buffer_size, i, FLAGS_use_vram);
414- #endif
415410 std::string name_prefix = FLAGS_use_vram ? " cuda:" : " cpu:" ;
416411 int rc = engine->registerLocalMemory (addr[i], FLAGS_buffer_size,
417412 name_prefix + std::to_string (i));
@@ -431,11 +426,7 @@ int target() {
431426 while (target_running) sleep (1 );
432427 for (int i = 0 ; i < buffer_num; ++i) {
433428 engine->unregisterLocalMemory (addr[i]);
434- #ifdef USE_NVLINK
435- mooncake::NvlinkTransport::freePinnedLocalMemory (addr[i]);
436- #else
437429 freeMemoryPool (addr[i], FLAGS_buffer_size);
438- #endif
439430 }
440431
441432 return 0 ;
0 commit comments