Skip to content

Commit db5929e

Browse files
authored
[MooncakeAdaptor] Init sglang_adaptor to support SGLang using transfer engine (#181)
* [MooncakeAdaptor] Init SGLang Adaptor
1 parent b5f170b commit db5929e

3 files changed

Lines changed: 417 additions & 0 deletions

File tree

mooncake-integration/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,14 @@ target_link_libraries(mooncake_vllm_adaptor PUBLIC
2828
)
2929
message("${PYTHON_SYS_PATH}")
3030
install(TARGETS mooncake_vllm_adaptor DESTINATION ${PYTHON_SYS_PATH}/)
31+
32+
pybind11_add_module(mooncake_sglang_adaptor ${SOURCES} ${CACHE_ALLOCATOR_SOURCES}
33+
sglang/sglang_adaptor.cpp
34+
)
35+
target_link_libraries(mooncake_sglang_adaptor PUBLIC
36+
transfer_engine
37+
glog
38+
gflags
39+
)
40+
message("${PYTHON_SYS_PATH}")
41+
install(TARGETS mooncake_sglang_adaptor DESTINATION ${PYTHON_SYS_PATH}/)
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
// Copyright 2024 KVCache.AI
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "sglang_adaptor.h"
16+
17+
#include <cassert>
18+
19+
SGLangAdaptor::SGLangAdaptor() {}
20+
21+
SGLangAdaptor::~SGLangAdaptor() {
22+
for (auto &handle : handle_map_) engine_->closeSegment(handle.second);
23+
handle_map_.clear();
24+
engine_.reset();
25+
for (auto &buffer : buffer_list_) free(buffer);
26+
buffer_list_.clear();
27+
for (auto &buffer : large_buffer_list_) free(buffer);
28+
large_buffer_list_.clear();
29+
}
30+
31+
std::string formatDeviceNames(const std::string &device_names) {
32+
std::stringstream ss(device_names);
33+
std::string item;
34+
std::vector<std::string> tokens;
35+
while (getline(ss, item, ',')) {
36+
tokens.push_back(item);
37+
}
38+
39+
std::string formatted;
40+
for (size_t i = 0; i < tokens.size(); ++i) {
41+
formatted += "\"" + tokens[i] + "\"";
42+
if (i < tokens.size() - 1) {
43+
formatted += ",";
44+
}
45+
}
46+
return formatted;
47+
}
48+
49+
std::pair<std::string, std::string> parseConnectionString(
50+
const std::string &conn_string) {
51+
std::pair<std::string, std::string> result;
52+
std::string proto = "etcd";
53+
std::string domain;
54+
std::size_t pos = conn_string.find("://");
55+
56+
if (pos != std::string::npos) {
57+
proto = conn_string.substr(0, pos);
58+
domain = conn_string.substr(pos + 3);
59+
} else {
60+
domain = conn_string;
61+
}
62+
63+
result.first = proto;
64+
result.second = domain;
65+
return result;
66+
}
67+
68+
int SGLangAdaptor::initialize(const char *local_hostname,
69+
const char *metadata_server, const char *protocol,
70+
const char *device_name) {
71+
auto conn_string = parseConnectionString(metadata_server);
72+
return initializeExt(local_hostname, conn_string.second.c_str(), protocol,
73+
device_name, conn_string.first.c_str());
74+
}
75+
76+
int SGLangAdaptor::initializeExt(const char *local_hostname,
77+
const char *metadata_server,
78+
const char *protocol, const char *device_name,
79+
const char *metadata_type) {
80+
std::string conn_string = metadata_server;
81+
if (conn_string.find("://") == std::string::npos)
82+
conn_string =
83+
std::string(metadata_type) + "://" + std::string(metadata_server);
84+
85+
// TODO: remove `false` in the feature, it's for keep same API in SGLang.
86+
engine_ = std::make_unique<TransferEngine>(false);
87+
auto hostname_port = parseHostNameWithPort(local_hostname);
88+
int ret = engine_->init(conn_string, local_hostname,
89+
hostname_port.first.c_str(), hostname_port.second);
90+
if (ret) return -1;
91+
92+
xport_ = nullptr;
93+
if (strcmp(protocol, "rdma") == 0) {
94+
auto device_names = formatDeviceNames(device_name);
95+
std::string nic_priority_matrix =
96+
"{\"cpu:0\": [[" + device_names + "], []]}";
97+
void **args = (void **)malloc(2 * sizeof(void *));
98+
args[0] = (void *)nic_priority_matrix.c_str();
99+
args[1] = nullptr;
100+
xport_ = engine_->installTransport("rdma", args);
101+
} else if (strcmp(protocol, "tcp") == 0) {
102+
xport_ = engine_->installTransport("tcp", nullptr);
103+
} else {
104+
LOG(ERROR) << "Unsupported protocol";
105+
return -1;
106+
}
107+
108+
if (!xport_) return -1;
109+
free_list_.resize(kSlabSizeKBTabLen);
110+
doBuddyAllocate(kMaxClassId);
111+
return 0;
112+
}
113+
114+
char *SGLangAdaptor::allocateRawBuffer(size_t capacity) {
115+
auto buffer = malloc(capacity);
116+
if (!buffer) return nullptr;
117+
int ret = engine_->registerLocalMemory(buffer, capacity, "cpu:0");
118+
if (ret) {
119+
free(buffer);
120+
return nullptr;
121+
}
122+
return (char *)buffer;
123+
}
124+
125+
int SGLangAdaptor::findClassId(size_t size) {
126+
if (size > 1024ull * kSlabSizeKB[kMaxClassId]) return -1;
127+
for (int i = kMaxClassId - 2; i >= 0; --i)
128+
if (size > 1024ull * kSlabSizeKB[i]) return i + 1;
129+
return 0;
130+
}
131+
132+
int SGLangAdaptor::doBuddyAllocate(int class_id) {
133+
if (class_id == kMaxClassId) {
134+
auto buffer = allocateRawBuffer(kDefaultBufferCapacity);
135+
buffer_list_.push_back(buffer);
136+
for (size_t offset = 0; offset < kDefaultBufferCapacity;
137+
offset += 1024ull * kSlabSizeKB[kMaxClassId])
138+
free_list_[kMaxClassId].push(buffer + offset);
139+
return 0;
140+
}
141+
if (free_list_[class_id + 1].empty()) {
142+
int ret = doBuddyAllocate(class_id + 1);
143+
if (ret) return ret;
144+
}
145+
assert(!free_list_[class_id + 1].empty());
146+
char *buffer = free_list_[class_id + 1].top();
147+
free_list_[class_id + 1].pop();
148+
free_list_[class_id].push(buffer);
149+
free_list_[class_id].push(buffer + kSlabSizeKB[class_id] * 1024);
150+
return 0;
151+
}
152+
153+
uintptr_t SGLangAdaptor::allocateManagedBuffer(size_t length) {
154+
std::lock_guard<std::mutex> guard(mutex_);
155+
int class_id = findClassId(length);
156+
if (class_id < 0) {
157+
char *buffer = allocateRawBuffer(length);
158+
if (buffer) large_buffer_list_.insert(buffer);
159+
return (uintptr_t)buffer;
160+
}
161+
if (free_list_[class_id].empty())
162+
if (doBuddyAllocate(class_id)) return 0;
163+
assert(!free_list_[class_id].empty());
164+
char *buffer = free_list_[class_id].top();
165+
free_list_[class_id].pop();
166+
return (uintptr_t)buffer;
167+
}
168+
169+
int SGLangAdaptor::freeManagedBuffer(uintptr_t buffer_addr, size_t length) {
170+
std::lock_guard<std::mutex> guard(mutex_);
171+
auto buffer = (char *)buffer_addr;
172+
int class_id = findClassId(length);
173+
if (class_id < 0) {
174+
large_buffer_list_.erase(buffer);
175+
engine_->unregisterLocalMemory(buffer);
176+
free(buffer);
177+
return 0;
178+
}
179+
free_list_[class_id].push(buffer);
180+
return 0;
181+
}
182+
183+
int SGLangAdaptor::transferSync(const char *target_hostname, uintptr_t buffer,
184+
uintptr_t peer_buffer_address, size_t length) {
185+
Transport::SegmentHandle handle;
186+
if (handle_map_.count(target_hostname)) {
187+
handle = handle_map_[target_hostname];
188+
} else {
189+
handle = engine_->openSegment(target_hostname);
190+
if (handle == (Transport::SegmentHandle)-1) return -1;
191+
handle_map_[target_hostname] = handle;
192+
}
193+
194+
auto batch_id = engine_->allocateBatchID(1);
195+
TransferRequest entry;
196+
entry.opcode = TransferRequest::READ;
197+
entry.length = length;
198+
entry.source = (void *)buffer;
199+
entry.target_id = handle;
200+
entry.target_offset = peer_buffer_address;
201+
202+
Status s = engine_->submitTransfer(batch_id, {entry});
203+
if (!s.ok()) return -1;
204+
205+
TransferStatus status;
206+
while (true) {
207+
Status s = engine_->getTransferStatus(batch_id, 0, status);
208+
LOG_ASSERT(s.ok());
209+
if (status.s == TransferStatusEnum::COMPLETED) {
210+
engine_->freeBatchID(batch_id);
211+
return 0;
212+
} else if (status.s == TransferStatusEnum::FAILED) {
213+
engine_->freeBatchID(batch_id);
214+
return -1;
215+
}
216+
}
217+
}
218+
219+
int SGLangAdaptor::transferSyncExt(const char *target_hostname, uintptr_t buffer,
220+
uintptr_t peer_buffer_address, size_t length, TransferOpcode opcode) {
221+
Transport::SegmentHandle handle;
222+
if (handle_map_.count(target_hostname)) {
223+
handle = handle_map_[target_hostname];
224+
} else {
225+
handle = engine_->openSegment(target_hostname);
226+
if (handle == (Transport::SegmentHandle)-1) return -1;
227+
handle_map_[target_hostname] = handle;
228+
}
229+
230+
auto batch_id = engine_->allocateBatchID(1);
231+
TransferRequest entry;
232+
if (opcode == TransferOpcode::WRITE) {
233+
entry.opcode = TransferRequest::WRITE;
234+
} else {
235+
entry.opcode = TransferRequest::READ;
236+
}
237+
entry.length = length;
238+
entry.source = (void *)buffer;
239+
entry.target_id = handle;
240+
entry.target_offset = peer_buffer_address;
241+
242+
Status s = engine_->submitTransfer(batch_id, {entry});
243+
if (!s.ok()) return -1;
244+
245+
TransferStatus status;
246+
while (true) {
247+
Status s = engine_->getTransferStatus(batch_id, 0, status);
248+
LOG_ASSERT(s.ok());
249+
if (status.s == TransferStatusEnum::COMPLETED) {
250+
engine_->freeBatchID(batch_id);
251+
return 0;
252+
} else if (status.s == TransferStatusEnum::FAILED) {
253+
engine_->freeBatchID(batch_id);
254+
return -1;
255+
}
256+
}
257+
}
258+
259+
int SGLangAdaptor::expRegisterMemory(uintptr_t buffer_addr, size_t capacity) {
260+
char *buffer = reinterpret_cast<char *>(buffer_addr);
261+
return engine_->registerLocalMemory(buffer, capacity, "cpu:0");
262+
}
263+
264+
int SGLangAdaptor::expUnregisterMemory(uintptr_t buffer_addr) {
265+
char *buffer = reinterpret_cast<char *>(buffer_addr);
266+
return engine_->unregisterLocalMemory(buffer);
267+
}
268+
269+
uintptr_t SGLangAdaptor::getFirstBufferAddress(const std::string &segment_name) {
270+
Transport::SegmentHandle segment_id = engine_->openSegment(segment_name.c_str());
271+
auto segment_desc = engine_->getMetadata()->getSegmentDescByID(segment_id);
272+
return segment_desc->buffers[0].addr;
273+
}
274+
275+
namespace py = pybind11;
276+
277+
PYBIND11_MODULE(mooncake_sglang_adaptor, m) {
278+
py::enum_<SGLangAdaptor::TransferOpcode> transfer_opcode(
279+
m, "TransferOpcode", py::arithmetic());
280+
transfer_opcode
281+
.value("READ", SGLangAdaptor::TransferOpcode::READ)
282+
.value("WRITE", SGLangAdaptor::TransferOpcode::WRITE)
283+
.export_values();
284+
285+
auto adaptor_cls = py::class_<SGLangAdaptor>(m, "TransferEngine")
286+
.def(py::init<>())
287+
.def("initialize", &SGLangAdaptor::initialize)
288+
.def("initializeExt", &SGLangAdaptor::initializeExt)
289+
.def("allocateManagedBuffer", &SGLangAdaptor::allocateManagedBuffer)
290+
.def("freeManagedBuffer", &SGLangAdaptor::freeManagedBuffer)
291+
.def("transferSyncExt", &SGLangAdaptor::transferSyncExt)
292+
.def("transferSync", &SGLangAdaptor::transferSync)
293+
.def("writeBytesToBuffer", &SGLangAdaptor::writeBytesToBuffer)
294+
.def("readBytesFromBuffer", &SGLangAdaptor::readBytesFromBuffer)
295+
.def("expRegisterMemory", &SGLangAdaptor::expRegisterMemory)
296+
.def("expUnregisterMemory", &SGLangAdaptor::expUnregisterMemory)
297+
.def("getFirstBufferAddress", &SGLangAdaptor::getFirstBufferAddress);
298+
299+
adaptor_cls.attr("TransferOpcode") = transfer_opcode;
300+
}

0 commit comments

Comments
 (0)