Skip to content

Commit 70758d4

Browse files
authored
support non-disturbing remote-instance-weight-loader (#13125)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
1 parent fd1ebbb commit 70758d4

11 files changed

Lines changed: 500 additions & 29 deletions

File tree

docs/advanced_features/rfork.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# R-Fork
2+
3+
R-Fork (Tensor Remote Fork) is a novel weight loading methodology that leverages efficient inter-node GPU-to-GPU data transfer path to load tensors from a running SGLang instance to a new instance with zero-copy. It can significantly optimize the SGLang instance boot-up time by reducing model weights loading from several minutes to mere seconds.
4+
5+
To learn more details about R-Fork, please check **<a href=https://lmsys.org/blog/2025-12-10-rfork/> R-Fork blog </a>**
6+
7+
## Usage
8+
9+
| Argument | Usage |
10+
|--------------|--------------------------------------------|
11+
| load-format | set to `remote_instance` to enable R-Fork. |
12+
| remote-instance-weight-loader-backend | `nccl` or `transfer_engine`, default value is `nccl` |
13+
| remote-instance-weight-loader-seed-instance-ip | IP address of the seed instance who will provide the model weight |
14+
| remote-instance-weight-loader-seed-instance-service-port | the port that the seed instance's HTTP server is listening on |
15+
| remote-instance-weight-loader-send-weights-group-ports | the list of available ports on the seed instance that will be used to build NCCL communication groups between seed and client instance. This argument is only needed by `nccl` backend. |
16+
17+
### NCCL as backend
18+
19+
```shell
20+
python -m sglang.launch_server [args] \
21+
--load-format remote_instance \
22+
--remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \
23+
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
24+
--remote-instance-weight-loader-send-weights-group-ports [send_weights_nccl_group_ports_list] \
25+
--remote-instance-weight-loader-backend nccl
26+
```
27+
28+
### TransferEngine as backend
29+
30+
```shell
31+
python -m sglang.launch_server [args] \
32+
--load-format remote_instance \
33+
--remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \
34+
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
35+
--remote-instance-weight-loader-backend transfer_engine
36+
```

python/sglang/srt/configs/load_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class LoadConfig:
7373
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
7474
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
7575
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
76+
remote_instance_weight_loader_backend: Optional[str] = None
77+
remote_instance_weight_loader_transfer_engine: Optional[any] = None
7678

7779
# ModelOpt-specific loading options
7880
modelopt_checkpoint_restore_path: Optional[str] = None

python/sglang/srt/entrypoints/engine.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,18 @@ def __init__(self, **kwargs):
127127
atexit.register(self.shutdown)
128128

129129
# Launch subprocesses
130-
tokenizer_manager, template_manager, scheduler_info, port_args = (
131-
_launch_subprocesses(server_args=server_args)
132-
)
130+
(
131+
tokenizer_manager,
132+
template_manager,
133+
scheduler_info,
134+
port_args,
135+
remote_instance_transfer_engine_info,
136+
) = _launch_subprocesses(server_args=server_args)
133137
self.tokenizer_manager = tokenizer_manager
134138
self.template_manager = template_manager
135139
self.scheduler_info = scheduler_info
136140
self.port_args = port_args
141+
self.remote_instance_transfer_engine_info = remote_instance_transfer_engine_info
137142

138143
# Initialize ZMQ sockets
139144
context = zmq.Context(2)
@@ -910,6 +915,7 @@ def _launch_subprocesses(
910915

911916
# Wait for the model to finish loading
912917
scheduler_infos = []
918+
remote_instance_transfer_engine_info = {}
913919
for i in range(len(scheduler_pipe_readers)):
914920
try:
915921
data = scheduler_pipe_readers[i].recv()
@@ -926,9 +932,24 @@ def _launch_subprocesses(
926932
"Initialization failed. Please see the error messages above."
927933
)
928934
scheduler_infos.append(data)
935+
if (
936+
"tp_rank" in data
937+
and "remote_instance_transfer_engine_session_id" in data
938+
and "remote_instance_transfer_engine_weights_info_dict" in data
939+
):
940+
remote_instance_transfer_engine_info[data["tp_rank"]] = (
941+
data["remote_instance_transfer_engine_session_id"],
942+
data["remote_instance_transfer_engine_weights_info_dict"],
943+
)
929944

930945
# Assume all schedulers have the same scheduler_info
931946
scheduler_info = scheduler_infos[0]
932947
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
933948

934-
return tokenizer_manager, template_manager, scheduler_info, port_args
949+
return (
950+
tokenizer_manager,
951+
template_manager,
952+
scheduler_info,
953+
port_args,
954+
remote_instance_transfer_engine_info,
955+
)

python/sglang/srt/entrypoints/http_server.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,15 @@ class _GlobalState:
144144
tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
145145
template_manager: TemplateManager
146146
scheduler_info: Dict
147+
# Dict{
148+
# rank: Tuple(
149+
# session_id,
150+
# Dict{
151+
# name: Tuple (d_ptr, numel, element_size)
152+
# }
153+
# )
154+
# }
155+
remote_instance_transfer_engine_info: Optional[Dict] = None
147156

148157

149158
_global_state: Optional[_GlobalState] = None
@@ -813,6 +822,24 @@ async def send_weights_to_remote_instance(
813822
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
814823

815824

825+
@app.get("/get_remote_instance_transfer_engine_info")
826+
async def get_remote_instance_transfer_engine_info(rank: int = None):
827+
if rank is None or rank < 0:
828+
return Response(status_code=HTTPStatus.BAD_REQUEST)
829+
830+
try:
831+
result = {
832+
"rank": rank,
833+
"remote_instance_transfer_engine_info": _global_state.remote_instance_transfer_engine_info[
834+
rank
835+
],
836+
}
837+
return result
838+
except Exception as e:
839+
logger.error(f"Exception: {e}")
840+
return Response(status_code=HTTPStatus.BAD_REQUEST)
841+
842+
816843
@app.post("/init_weights_update_group")
817844
async def init_weights_update_group(
818845
obj: InitWeightsUpdateGroupReqInput, request: Request
@@ -1386,15 +1413,20 @@ def launch_server(
13861413
1. The HTTP server, Engine, and TokenizerManager all run in the main process.
13871414
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
13881415
"""
1389-
tokenizer_manager, template_manager, scheduler_info, port_args = (
1390-
_launch_subprocesses(server_args=server_args)
1391-
)
1416+
(
1417+
tokenizer_manager,
1418+
template_manager,
1419+
scheduler_info,
1420+
port_args,
1421+
remote_instance_transfer_engine_info,
1422+
) = _launch_subprocesses(server_args=server_args)
13921423

13931424
set_global_state(
13941425
_GlobalState(
13951426
tokenizer_manager=tokenizer_manager,
13961427
template_manager=template_manager,
13971428
scheduler_info=scheduler_info,
1429+
remote_instance_transfer_engine_info=remote_instance_transfer_engine_info,
13981430
)
13991431
)
14001432

python/sglang/srt/managers/scheduler.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,6 +2573,9 @@ def handle_freeze_gc(self, recv_req: FreezeGCReq):
25732573
self.send_to_detokenizer.send_output(recv_req, recv_req)
25742574
return None
25752575

2576+
def get_remote_instance_transfer_engine_info(self):
2577+
return self.tp_worker.get_remote_instance_transfer_engine_info()
2578+
25762579

25772580
class IdleSleeper:
25782581
"""
@@ -2686,13 +2689,29 @@ def run_scheduler_process(
26862689
pp_rank,
26872690
dp_rank,
26882691
)
2689-
pipe_writer.send(
2690-
{
2691-
"status": "ready",
2692-
"max_total_num_tokens": scheduler.max_total_num_tokens,
2693-
"max_req_input_len": scheduler.max_req_input_len,
2694-
}
2695-
)
2692+
if server_args.remote_instance_weight_loader_support_transfer_engine:
2693+
(
2694+
remote_instance_transfer_engine_session_id,
2695+
remote_instance_transfer_engine_weights_info_dict,
2696+
) = scheduler.get_remote_instance_transfer_engine_info()
2697+
pipe_writer.send(
2698+
{
2699+
"status": "ready",
2700+
"max_total_num_tokens": scheduler.max_total_num_tokens,
2701+
"max_req_input_len": scheduler.max_req_input_len,
2702+
"tp_rank": tp_rank,
2703+
"remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id,
2704+
"remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict,
2705+
}
2706+
)
2707+
else:
2708+
pipe_writer.send(
2709+
{
2710+
"status": "ready",
2711+
"max_total_num_tokens": scheduler.max_total_num_tokens,
2712+
"max_req_input_len": scheduler.max_req_input_len,
2713+
}
2714+
)
26962715

26972716
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
26982717
if disaggregation_mode == DisaggregationMode.NULL:

python/sglang/srt/managers/tp_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ def _forward_batch_generation_dllm(
366366
can_run_cuda_graph=can_run_cuda_graph,
367367
)
368368

369+
def get_remote_instance_transfer_engine_info(self):
370+
return (
371+
self.model_runner.remote_instance_transfer_engine_session_id,
372+
self.model_runner.remote_instance_transfer_engine_weight_info,
373+
)
374+
369375
def forward_batch_generation(
370376
self,
371377
model_worker_batch: ModelWorkerBatch,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
)
6666
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
6767
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
68+
from sglang.srt.environ import envs
6869
from sglang.srt.eplb.eplb_manager import EPLBManager
6970
from sglang.srt.eplb.expert_distribution import (
7071
ExpertDistributionRecorder,
@@ -135,9 +136,10 @@
135136
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
136137
PiecewiseCudaGraphRunner,
137138
)
138-
from sglang.srt.model_loader import get_model
139139
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
140140
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
141+
RemoteInstanceWeightLoaderBackend,
142+
register_memory_region_v2,
141143
trigger_init_weights_send_group_for_remote_instance_request,
142144
)
143145
from sglang.srt.model_loader.utils import set_default_torch_dtype
@@ -157,6 +159,7 @@
157159
get_available_gpu_memory,
158160
get_bool_env_var,
159161
get_cpu_ids_by_node,
162+
get_local_ip_auto,
160163
init_custom_process_group,
161164
is_cuda,
162165
is_float4_e2m1fn_x2,
@@ -319,6 +322,10 @@ def __init__(
319322
self.forward_pass_id = 0
320323
self.init_new_workspace = False
321324

325+
self.remote_instance_transfer_engine = None
326+
self.remote_instance_transfer_engine_session_id = ""
327+
self.remote_instance_transfer_engine_weight_info = None
328+
322329
# Apply the rank zero filter to logger
323330
if server_args.show_time_cost:
324331
enable_show_time_cost()
@@ -393,6 +400,9 @@ def initialize(self, min_per_gpu_memory: float):
393400
enable=self.server_args.enable_memory_saver
394401
)
395402

403+
if self.server_args.remote_instance_weight_loader_support_transfer_engine:
404+
self.remote_instance_init_transfer_engine()
405+
396406
if not self.is_draft_worker:
397407
set_global_expert_location_metadata(
398408
compute_initial_expert_location_metadata(
@@ -433,6 +443,16 @@ def initialize(self, min_per_gpu_memory: float):
433443
self.sampler = Sampler()
434444
self.load_model()
435445

446+
if (
447+
self.server_args.remote_instance_weight_loader_support_transfer_engine
448+
and self.remote_instance_transfer_engine_weight_info is None
449+
):
450+
self.remote_instance_transfer_engine_weight_info = (
451+
register_memory_region_v2(
452+
self.model, self.remote_instance_transfer_engine
453+
)
454+
)
455+
436456
# Check if the model is using hybrid SWA
437457
if (
438458
not self.server_args.disable_hybrid_swa_memory
@@ -547,6 +567,23 @@ def initialize(self, min_per_gpu_memory: float):
547567
# Initialize piecewise CUDA graph
548568
self.init_piecewise_cuda_graphs()
549569

570+
def remote_instance_init_transfer_engine(self):
571+
try:
572+
from mooncake.engine import TransferEngine
573+
except ImportError as e:
574+
logger.warning(
575+
"Please install mooncake for using remote instance transfer engine: pip install mooncake"
576+
)
577+
return
578+
self.remote_instance_transfer_engine = TransferEngine()
579+
local_ip = get_local_ip_auto()
580+
self.remote_instance_transfer_engine.initialize(
581+
local_ip, "P2PHANDSHAKE", "rdma", envs.MOONCAKE_DEVICE.value
582+
)
583+
self.remote_instance_transfer_engine_session_id = (
584+
f"{local_ip}:{self.remote_instance_transfer_engine.get_rpc_port()}"
585+
)
586+
550587
def model_specific_adjustment(self):
551588
server_args = self.server_args
552589

@@ -764,6 +801,8 @@ def load_model(self):
764801
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
765802
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
766803
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
804+
remote_instance_weight_loader_backend=self.server_args.remote_instance_weight_loader_backend,
805+
remote_instance_weight_loader_transfer_engine=self.remote_instance_transfer_engine,
767806
modelopt_config=modelopt_config,
768807
rl_quant_profile=self.server_args.rl_quant_profile,
769808
)
@@ -772,7 +811,11 @@ def load_model(self):
772811
self.model_config, self.load_config, self.tp_size
773812
)
774813

775-
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
814+
if (
815+
self.server_args.load_format == LoadFormat.REMOTE_INSTANCE
816+
and self.server_args.remote_instance_weight_loader_backend
817+
== RemoteInstanceWeightLoaderBackend.NCCL
818+
):
776819
if self.tp_rank == 0:
777820
instance_ip = socket.gethostbyname(socket.gethostname())
778821
t = threading.Thread(
@@ -797,11 +840,18 @@ def load_model(self):
797840
GPU_MEMORY_TYPE_WEIGHTS,
798841
enable_cpu_backup=enable_cpu_backup,
799842
):
800-
self.model = get_model(
801-
model_config=self.model_config,
843+
self.loader = get_model_loader(
802844
load_config=self.load_config,
845+
model_config=self.model_config,
846+
)
847+
self.model = self.loader.load_model(
848+
model_config=self.model_config,
803849
device_config=DeviceConfig(self.device, self.gpu_id),
804850
)
851+
if hasattr(self.loader, "remote_instance_transfer_engine_weight_info"):
852+
self.remote_instance_transfer_engine_weight_info = (
853+
self.loader.remote_instance_transfer_engine_weight_info
854+
)
805855
monkey_patch_vllm_parallel_state(reverse=True)
806856

807857
get_offloader().post_init()

0 commit comments

Comments
 (0)