66
77namespace mooncake {
88
9+ constexpr const char * REGISTER_BUFFER_ERROR_MSG =
10+ " Failed to register local memory." ;
911constexpr const char * MULTI_DEVICE_ERROR_MSG =
1012 " Expecting one tensor only but got multiple." ;
1113constexpr const char * SYNC_OP_ERROR_MSG = " Expecting async op but got sync op." ;
@@ -15,9 +17,11 @@ constexpr const char* REDUCE_DTYPE_ERROR_MSG = "Unsupported reduce dtype: ";
1517
1618std::string MooncakeBackend::hostIp_ = " 127.0.0.1" ;
1719
20+ std::unique_ptr<BackendBuffer> MooncakeBackend::buffer_ = nullptr ;
21+
1822MooncakeBackend::MooncakeBackend (
1923 c10::intrusive_ptr<::c10d::Store> store, int rank, int size,
20- c10::intrusive_ptr<MooncakeBackendOptions> options, bool isCpu)
24+ c10::intrusive_ptr<MooncakeBackendOptions> options, bool isCpu, bool isTest )
2125 : Backend(rank, size),
2226 isCpu_ (isCpu),
2327 worker_(&engine_, rank, size,
@@ -26,7 +30,8 @@ MooncakeBackend::MooncakeBackend(
2630 {size},
2731 torch::dtype (torch::kInt32 ).device(torch::kCUDA ))) {
2832 // Get device data
29- cudaError err = cudaGetDevice (&device_id_);
33+ int deviceId_;
34+ cudaError err = cudaGetDevice (&deviceId_);
3035 TORCH_CHECK (!err, c10::str (" Failed to get device id" ));
3136
3237 // Initialize transfer engine
@@ -37,62 +42,70 @@ MooncakeBackend::MooncakeBackend(
3742 std::string localServerName = localRpcMeta.ip_or_host_name + " :" +
3843 std::to_string (localRpcMeta.rpc_port );
3944
40- // Register GPU buffers
41- constexpr size_t buffer_size = 1u << 29 ;
45+ // Register buffers
46+ BackendBuffer* buffer;
47+ if (isTest) {
48+ buffer = new BackendBuffer ();
49+ } else {
50+ if (!buffer_) {
51+ buffer_ = std::make_unique<BackendBuffer>();
52+ }
53+ buffer = buffer_.get ();
54+ }
55+
4256 if (isCpu) {
4357 for (size_t i = 0 ; i < 2 ; i++) {
44- send_buffer_[i] = malloc (buffer_size);
45- TORCH_CHECK (send_buffer_[i],
46- c10::str (" Failed to allocate CPU send buffer" ));
58+ int rc = engine_.registerLocalMemory (buffer->cpuSendBuffer_ [i],
59+ kBufferSize );
60+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
61+ }
4762
48- int rc = engine_.registerLocalMemory (send_buffer_[i], buffer_size);
49- TORCH_CHECK (!rc, c10::str (" Failed to register local memory" ));
63+ for (size_t i = 0 ; i < 2 ; i++) {
64+ int rc = engine_.registerLocalMemory (buffer->cpuRecvBuffer_ [i],
65+ kBufferSize );
66+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
5067 }
5168
5269 for (size_t i = 0 ; i < 2 ; i++) {
53- recv_buffer_[i] = malloc (buffer_size);
54- TORCH_CHECK (recv_buffer_[i],
55- c10::str (" Failed to allocate CPU recv buffer" ));
70+ int rc = engine_.registerLocalMemory (buffer->cpuSyncSendRegion_ [i],
71+ kMaxNumRanks * sizeof (int32_t ),
72+ kWildcardLocation );
73+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
74+ }
5675
57- int rc = engine_.registerLocalMemory (recv_buffer_[i], buffer_size);
58- TORCH_CHECK (!rc, c10::str (" Failed to register local memory" ));
76+ for (size_t i = 0 ; i < 2 ; i++) {
77+ int rc = engine_.registerLocalMemory (buffer->cpuSyncRecvRegion_ [i],
78+ kMaxNumRanks * sizeof (int32_t ),
79+ kWildcardLocation );
80+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
5981 }
6082 } else {
61- std::string location = " cuda:" + std::to_string (device_id_ );
83+ std::string location = " cuda:" + std::to_string (deviceId_ );
6284 for (size_t i = 0 ; i < 2 ; i++) {
63- err = cudaMalloc (&send_buffer_[i], buffer_size);
64- TORCH_CHECK (!err, c10::str (" Failed to allocate CUDA send buffer" ));
65-
66- int rc = engine_.registerLocalMemory (send_buffer_[i], buffer_size,
67- location);
68- TORCH_CHECK (!rc, c10::str (" Failed to register local memory" ));
85+ int rc = engine_.registerLocalMemory (buffer->cudaSendBuffer_ [i],
86+ kBufferSize , location);
87+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
6988 }
7089
7190 for (size_t i = 0 ; i < 2 ; i++) {
72- err = cudaMalloc (&recv_buffer_[i], buffer_size);
73- TORCH_CHECK (!err, c10::str (" Failed to allocate CUDA recv buffer" ));
74-
75- int rc = engine_.registerLocalMemory (recv_buffer_[i], buffer_size,
76- location);
77- TORCH_CHECK (!rc, c10::str (" Failed to register local memory" ));
91+ int rc = engine_.registerLocalMemory (buffer->cudaRecvBuffer_ [i],
92+ kBufferSize , location);
93+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
7894 }
79- }
8095
81- // Register CPU sync regions
82- for (size_t i = 0 ; i < 2 ; i++) {
83- cpu_sync_send_region_[i] = new int32_t [size];
84- int rc = engine_.registerLocalMemory (cpu_sync_send_region_[i],
85- size * sizeof (int32_t ),
86- kWildcardLocation );
87- TORCH_CHECK (!rc, c10::str (" Failed to register local memory" ));
88- }
96+ for (size_t i = 0 ; i < 2 ; i++) {
97+ int rc = engine_.registerLocalMemory (buffer->cudaSyncSendRegion_ [i],
98+ kMaxNumRanks * sizeof (int32_t ),
99+ kWildcardLocation );
100+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG);
101+ }
89102
90- for (size_t i = 0 ; i < 2 ; i++) {
91- cpu_sync_recv_region_[i] = new int32_t [size];
92- int rc = engine_. registerLocalMemory (cpu_sync_recv_region_[i] ,
93- size * sizeof ( int32_t ),
94- kWildcardLocation );
95- TORCH_CHECK (!rc, c10::str ( " Failed to register local memory " ));
103+ for (size_t i = 0 ; i < 2 ; i++) {
104+ int rc = engine_. registerLocalMemory (buffer-> cudaSyncRecvRegion_ [i],
105+ kMaxNumRanks * sizeof ( int32_t ) ,
106+ kWildcardLocation );
107+ TORCH_CHECK (!rc, REGISTER_BUFFER_ERROR_MSG );
108+ }
96109 }
97110
98111 // Sync metadata
@@ -103,23 +116,24 @@ MooncakeBackend::MooncakeBackend(
103116 server_names.push_back (
104117 store->get_to_str ({" server_name_" + std::to_string (i)}));
105118 }
119+ worker_.setBackendBuffer (buffer);
106120 worker_.initWorker (server_names);
107121}
108122
109123MooncakeBackend::~MooncakeBackend () {
110- for ( size_t i = 0 ; i < 2 ; i++ ) {
111- engine_. unregisterLocalMemory (cpu_sync_send_region_[i]);
112- delete[] cpu_sync_send_region_[i];
113- engine_.unregisterLocalMemory (cpu_sync_recv_region_ [i]);
114- delete[] cpu_sync_recv_region_ [i];
115- engine_.unregisterLocalMemory (send_buffer_ [i]);
116- engine_.unregisterLocalMemory (recv_buffer_ [i]);
117- if (isCpu_) {
118- free (send_buffer_ [i]);
119- free (recv_buffer_ [i]);
120- } else {
121- cudaFree (send_buffer_ [i]);
122- cudaFree (recv_buffer_[i]);
124+ if (buffer_ ) {
125+ for ( size_t i = 0 ; i < 2 ; i++) {
126+ if (isCpu_) {
127+ engine_.unregisterLocalMemory (buffer_-> cpuSendBuffer_ [i]);
128+ engine_. unregisterLocalMemory (buffer_-> cpuRecvBuffer_ [i]) ;
129+ engine_.unregisterLocalMemory (buffer_-> cpuSyncSendRegion_ [i]);
130+ engine_.unregisterLocalMemory (buffer_-> cpuSyncRecvRegion_ [i]);
131+ } else {
132+ engine_. unregisterLocalMemory (buffer_-> cudaSendBuffer_ [i]);
133+ engine_. unregisterLocalMemory (buffer_-> cudaRecvBuffer_ [i]);
134+ engine_. unregisterLocalMemory (buffer_-> cudaSyncSendRegion_ [i]);
135+ engine_. unregisterLocalMemory (buffer_-> cudaSyncRecvRegion_ [i]);
136+ }
123137 }
124138 }
125139}
0 commit comments