@@ -27,6 +27,7 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
2727 std::atomic<WorkerTaskStatus> task_status[kNumTasks_ ];
2828 using clock = std::chrono::high_resolution_clock;
2929 clock::time_point activeTime[kNumTasks_ ];
30+ bool signals[kNumTasks_ ][size_]{};
3031 while (running_) {
3132 _mm_pause ();
3233 for (size_t i = 0 ; i < kNumTasks_ ; ++i) {
@@ -116,43 +117,45 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
116117 if (!batch_done) {
117118 continue ;
118119 }
119- auto source_ptr =
120- (int32_t *)segment_descs_[rank_]->buffers [4 + i].addr ;
121- std::vector<TransferRequest> entries;
122120 for (int j = 0 ; j < size_; ++j) {
123121 if (brokenRanks_[j]) {
124122 continue ;
125123 }
126- *source_ptr = 1 ;
127- entries.push_back (TransferRequest{
128- .opcode = TransferRequest::WRITE,
129- .source = (void *)source_ptr,
130- .target_id = segment_ids_[j],
131- .target_offset =
132- segment_descs_[j]->buffers [6 + i].addr +
133- rank_ * sizeof (int32_t ),
134- .length = sizeof (int32_t ),
135- });
124+ static_assert (kNumTasks_ <= 10 );
125+ std::string notify_msg =
126+ std::to_string (i) + std::to_string (rank_);
127+ int ret = engine_->sendNotifyByName (
128+ segment_descs_[j]->name , {" signal" , notify_msg});
129+ if (ret) {
130+ LOG (ERROR) << " Rank " << rank_ << " marking peer "
131+ << j << " as broken during notifying op "
132+ << (int )task.opType ;
133+ brokenRanks_[j] = true ;
134+ }
136135 }
137- task.batchID = engine_->allocateBatchID (entries.size ());
138- engine_->submitTransfer (task.batchID , entries);
139136 activeTime[i] = clock::now ();
140137 task_status[i].store (SIGNALED_1, std::memory_order_release);
141138 } else if (task_status[i].load (std::memory_order_acquire) ==
142139 SIGNALED_1) {
143140 bool all_received = true ;
144- auto signal_ptr =
145- (int32_t *)segment_descs_[rank_]->buffers [6 + i].addr ;
141+ std::vector<TransferMetadata::NotifyDesc> notifies;
142+ engine_->getNotifies (notifies);
143+ for (const auto ¬ify : notifies) {
144+ if (notify.name == " signal" ) {
145+ int taskId = notify.notify_msg [0 ] - ' 0' ;
146+ int src = std::atoi (notify.notify_msg .c_str () + 1 );
147+ signals[taskId][src] = true ;
148+ }
149+ }
146150 auto now = clock::now ();
147151 auto diff =
148152 std::chrono::duration_cast<std::chrono::seconds>(
149153 now - activeTime[i]);
150154 for (int j = 0 ; j < size_; ++j) {
151- if (signal_ptr[j] != 1 && !brokenRanks_[j]) {
152- TransferMetadata::NotifyDesc msg{" ping" , " ping" };
153- if (diff.count () > 1 &&
154- engine_->sendNotifyByName (
155- segment_descs_[j]->name , msg)) {
155+ if (!signals[i][j] && !brokenRanks_[j]) {
156+ if (diff.count () > 1 && engine_->sendNotifyByName (
157+ segment_descs_[j]->name ,
158+ {" ping" , " ping" })) {
156159 LOG (ERROR)
157160 << " Rank " << rank_ << " marking peer " << j
158161 << " as broken during syncing op "
@@ -170,7 +173,7 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
170173 }
171174 if (all_received) {
172175 for (int j = 0 ; j < size_; ++j) {
173- signal_ptr[ j] = 0 ;
176+ signals[i][ j] = false ;
174177 }
175178 task_status[i].store (DONE, std::memory_order_release);
176179 task.active = false ;
0 commit comments