Skip to content

Commit f20ffb2

Browse files
committed
Use transfer engine's notifications to implement collective signals
1 parent 3a0d872 commit f20ffb2

1 file changed

Lines changed: 26 additions & 23 deletions

File tree

mooncake-ep/src/mooncake_worker_thread.cpp

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
2727
std::atomic<WorkerTaskStatus> task_status[kNumTasks_];
2828
using clock = std::chrono::high_resolution_clock;
2929
clock::time_point activeTime[kNumTasks_];
30+
bool signals[kNumTasks_][size_]{};
3031
while (running_) {
3132
_mm_pause();
3233
for (size_t i = 0; i < kNumTasks_; ++i) {
@@ -116,43 +117,45 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
116117
if (!batch_done) {
117118
continue;
118119
}
119-
auto source_ptr =
120-
(int32_t *)segment_descs_[rank_]->buffers[4 + i].addr;
121-
std::vector<TransferRequest> entries;
122120
for (int j = 0; j < size_; ++j) {
123121
if (brokenRanks_[j]) {
124122
continue;
125123
}
126-
*source_ptr = 1;
127-
entries.push_back(TransferRequest{
128-
.opcode = TransferRequest::WRITE,
129-
.source = (void *)source_ptr,
130-
.target_id = segment_ids_[j],
131-
.target_offset =
132-
segment_descs_[j]->buffers[6 + i].addr +
133-
rank_ * sizeof(int32_t),
134-
.length = sizeof(int32_t),
135-
});
124+
static_assert(kNumTasks_ <= 10);
125+
std::string notify_msg =
126+
std::to_string(i) + std::to_string(rank_);
127+
int ret = engine_->sendNotifyByName(
128+
segment_descs_[j]->name, {"signal", notify_msg});
129+
if (ret) {
130+
LOG(ERROR) << "Rank " << rank_ << " marking peer "
131+
<< j << " as broken during notifying op "
132+
<< (int)task.opType;
133+
brokenRanks_[j] = true;
134+
}
136135
}
137-
task.batchID = engine_->allocateBatchID(entries.size());
138-
engine_->submitTransfer(task.batchID, entries);
139136
activeTime[i] = clock::now();
140137
task_status[i].store(SIGNALED_1, std::memory_order_release);
141138
} else if (task_status[i].load(std::memory_order_acquire) ==
142139
SIGNALED_1) {
143140
bool all_received = true;
144-
auto signal_ptr =
145-
(int32_t *)segment_descs_[rank_]->buffers[6 + i].addr;
141+
std::vector<TransferMetadata::NotifyDesc> notifies;
142+
engine_->getNotifies(notifies);
143+
for (const auto &notify : notifies) {
144+
if (notify.name == "signal") {
145+
int taskId = notify.notify_msg[0] - '0';
146+
int src = std::atoi(notify.notify_msg.c_str() + 1);
147+
signals[taskId][src] = true;
148+
}
149+
}
146150
auto now = clock::now();
147151
auto diff =
148152
std::chrono::duration_cast<std::chrono::seconds>(
149153
now - activeTime[i]);
150154
for (int j = 0; j < size_; ++j) {
151-
if (signal_ptr[j] != 1 && !brokenRanks_[j]) {
152-
TransferMetadata::NotifyDesc msg{"ping", "ping"};
153-
if (diff.count() > 1 &&
154-
engine_->sendNotifyByName(
155-
segment_descs_[j]->name, msg)) {
155+
if (!signals[i][j] && !brokenRanks_[j]) {
156+
if (diff.count() > 1 && engine_->sendNotifyByName(
157+
segment_descs_[j]->name,
158+
{"ping", "ping"})) {
156159
LOG(ERROR)
157160
<< "Rank " << rank_ << " marking peer " << j
158161
<< " as broken during syncing op "
@@ -170,7 +173,7 @@ void MooncakeWorker::initWorker(const std::vector<std::string> &server_names) {
170173
}
171174
if (all_received) {
172175
for (int j = 0; j < size_; ++j) {
173-
signal_ptr[j] = 0;
176+
signals[i][j] = false;
174177
}
175178
task_status[i].store(DONE, std::memory_order_release);
176179
task.active = false;

0 commit comments

Comments
 (0)