3232
3333#ifdef USE_CUDA
3434#include < bits/stdint-uintn.h>
35- #include < cuda.h>
3635#include < cuda_runtime.h>
36+ #include < cuda.h>
3737
3838#ifdef USE_NVMEOF
3939#include < cufile.h>
@@ -100,12 +100,12 @@ static void *allocateMemoryPool(size_t size, int socket_id,
100100 int gpu_id = FLAGS_gpu_id;
101101 void *d_buf;
102102 checkCudaError (cudaSetDevice (gpu_id), " Failed to set device" );
103- if (FLAGS_protocol == " nvlink " ) {
104- d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory (size);
105- } else {
106- checkCudaError (cudaMalloc (&d_buf, size),
107- " Failed to allocate device memory" );
108- }
103+ # ifdef USE_NVLINK
104+ d_buf = mooncake::NvlinkTransport::allocatePinnedLocalMemory (size);
105+ # else
106+ checkCudaError (cudaMalloc (&d_buf, size),
107+ " Failed to allocate device memory" );
108+ # endif
109109 return d_buf;
110110 }
111111#endif
@@ -114,14 +114,14 @@ static void *allocateMemoryPool(size_t size, int socket_id,
114114
115115static void freeMemoryPool (void *addr, size_t size) {
116116#ifdef USE_CUDA
117- if (FLAGS_protocol == " nvlink" ) {
118- CUmemGenericAllocationHandle handle;
119- auto result = cuMemRetainAllocationHandle (&handle, addr);
120- if (result == CUDA_SUCCESS) {
121- mooncake::NvlinkTransport::freePinnedLocalMemory (addr);
122- return ;
123- }
117+ #ifdef USE_NVLINK
118+ CUmemGenericAllocationHandle handle;
119+ auto result = cuMemRetainAllocationHandle (&handle, addr);
120+ if (result == CUDA_SUCCESS) {
121+ mooncake::NvlinkTransport::freePinnedLocalMemory (addr);
122+ return ;
124123 }
124+ #endif
125125 // check pointer on GPU
126126 cudaPointerAttributes attributes;
127127 checkCudaError (cudaPointerGetAttributes (&attributes, addr),
@@ -406,7 +406,12 @@ int target() {
406406 buffer_num = FLAGS_use_vram ? 1 : NR_SOCKETS;
407407 if (FLAGS_use_vram) LOG (INFO) << " VRAM is used" ;
408408 for (int i = 0 ; i < buffer_num; ++i) {
409+ #ifdef USE_NVLINK
410+ addr[i] = mooncake::NvlinkTransport::allocatePinnedLocalMemory (
411+ FLAGS_buffer_size);
412+ #else
409413 addr[i] = allocateMemoryPool (FLAGS_buffer_size, i, FLAGS_use_vram);
414+ #endif
410415 std::string name_prefix = FLAGS_use_vram ? " cuda:" : " cpu:" ;
411416 int rc = engine->registerLocalMemory (addr[i], FLAGS_buffer_size,
412417 name_prefix + std::to_string (i));
@@ -426,7 +431,11 @@ int target() {
426431 while (target_running) sleep (1 );
427432 for (int i = 0 ; i < buffer_num; ++i) {
428433 engine->unregisterLocalMemory (addr[i]);
434+ #ifdef USE_NVLINK
435+ mooncake::NvlinkTransport::freePinnedLocalMemory (addr[i]);
436+ #else
429437 freeMemoryPool (addr[i], FLAGS_buffer_size);
438+ #endif
430439 }
431440
432441 return 0 ;
0 commit comments