Skip to content

support non-disturbing remote-instance-weight-loader#13125

Merged
Kangyan-Zhou merged 5 commits intosgl-project:mainfrom
amysaq2023:amy/non-disturbing-remote-instance-weight-loader
Dec 12, 2025
Merged

support non-disturbing remote-instance-weight-loader#13125
Kangyan-Zhou merged 5 commits intosgl-project:mainfrom
amysaq2023:amy/non-disturbing-remote-instance-weight-loader

Conversation

@amysaq2023
Copy link
Copy Markdown
Contributor

@amysaq2023 amysaq2023 commented Nov 12, 2025

Motivation

In #8215, SGLang has already supported a new load format: remote_instance, that allows new instance to load weights from another running instance. This approach can greatly improve weight loading time during instance initialization. However, since it use torch.distributed with NCCL as backend, it will disturb on-going inference requests: torch.distributed will always launch CUDA kernels for transferring weight tensors.

We come up with another backend option: TransferEngine, which will not disturbing any GPU workload and still, using RDMA to transfer weights.

Modifications

We initialize one TransferEngine for each ModelRunner and will register its weights to RDMA channel during initialization.

When initializing a new instance who wants to use remote_instance load_format with TransferEngine backend:

  1. It will send an HTTP request to retrieve the source instance's TransferEngine metadata, including RDMA keys mapped to the corresponding GPU memory addresses.
  2. Using these RDMA keys, the new instance directly loads weights from the source's GPU memory.

How to use:
python -m sglang.launch_server [args] \
--load-format remote_instance \
--remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \
--remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \
--remote-instance-weight-loader-backend "transfer_engine"

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@amysaq2023
Copy link
Copy Markdown
Contributor Author

cc @tianyuzhou95 @zhaochenyang20

Comment thread python/sglang/srt/model_executor/model_runner.py
@XiaotaoChen
Copy link
Copy Markdown

@amysaq2023 Hi, Can you share some perf data about remote_instance with Transfer_engine. I tried remote_instance with nccl with deepseek-v3(645GB), it synced within 3s on h200*8 with 8 NIC. Meantime, The on-going inference tasks would be blocked.
And there are another question. How should we allocate the NICs for GPU collective operations, KVcache transfer of PD disaggregation and remote_instance with TE? In Our practices (9 NICs), we use 8 NICs(assume ib0-7) pinned with GPUs for GPU collective operations, And 1 NIC(assume ib8) for KVcache transfer of PD disaggregation. And there is no NIC for remote_instance.
if we allocate one of ib0-7, it would disturbe inference tasks; if allocate ib8 for remote_instance, it also disturbe KVCache Transfer, which also disturbe inference tasks in PD disaggregation. So how should we allocate NICs for remote_instance to minimize disturbing the on-going inference tasks. default allocate all NICs, let's MoonCake or System process it? And Can we limit the bandwidth of remote_instance , such as only use 1/4, 1/2 bandwidth?

Comment thread python/sglang/srt/model_loader/loader.py Outdated
@XiaotaoChen
Copy link
Copy Markdown

XiaotaoChen commented Nov 19, 2025

@amysaq2023 @stmatengss Hi, in our tests(2 nodes of h200*8 with deepseek-v3(645GB)), the time costed in register_memory stage is similar to the time taken by batch_transfer_sync_read. such as register_memory takes ~45s, and batch_transfer_sync_read takes ~40s. how can we optimize it? I try split the whole state_dict into 8 parts, and launch threads to run register and batch_sync. but it's useless.

@amysaq2023
Copy link
Copy Markdown
Contributor Author

@amysaq2023 Hi, Can you share some perf data about remote_instance with Transfer_engine. I tried remote_instance with nccl with deepseek-v3(645GB), it synced within 3s on h200*8 with 8 NIC. Meantime, The on-going inference tasks would be blocked. And there are another question. How should we allocate the NICs for GPU collective operations, KVcache transfer of PD disaggregation and remote_instance with TE? In Our practices (9 NICs), we use 8 NICs(assume ib0-7) pinned with GPUs for GPU collective operations, And 1 NIC(assume ib8) for KVcache transfer of PD disaggregation. And there is no NIC for remote_instance. if we allocate one of ib0-7, it would disturbe inference tasks; if allocate ib8 for remote_instance, it also disturbe KVCache Transfer, which also disturbe inference tasks in PD disaggregation. So how should we allocate NICs for remote_instance to minimize disturbing the on-going inference tasks. default allocate all NICs, let's MoonCake or System process it? And Can we limit the bandwidth of remote_instance , such as only use 1/4, 1/2 bandwidth?

Performance tested with 8*H20 on loading Deepseek-V3:

  • register memory time: ~22s (~7s after WIP optimizing)
  • transfer weights time: ~3s

@amysaq2023 amysaq2023 force-pushed the amy/non-disturbing-remote-instance-weight-loader branch from c003071 to 6155a38 Compare November 27, 2025 15:36
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

Nice PR anqi! Sorry for waiting so long.

Copy link
Copy Markdown
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are my hero! As solid as always. Sorry for waiting so long for the review. I shall also connect Mooncake team and MicroSoft AI for review.

Comment thread python/sglang/srt/entrypoints/http_server.py Outdated
Comment thread python/sglang/srt/model_loader/loader.py Outdated
Comment thread python/sglang/srt/model_loader/loader.py Outdated
@XiaotaoChen
Copy link
Copy Markdown

XiaotaoChen commented Nov 28, 2025

@amysaq2023 Hi, Can you share some perf data about remote_instance with Transfer_engine. I tried remote_instance with nccl with deepseek-v3(645GB), it synced within 3s on h200*8 with 8 NIC. Meantime, The on-going inference tasks would be blocked. And there are another question. How should we allocate the NICs for GPU collective operations, KVcache transfer of PD disaggregation and remote_instance with TE? In Our practices (9 NICs), we use 8 NICs(assume ib0-7) pinned with GPUs for GPU collective operations, And 1 NIC(assume ib8) for KVcache transfer of PD disaggregation. And there is no NIC for remote_instance. if we allocate one of ib0-7, it would disturbe inference tasks; if allocate ib8 for remote_instance, it also disturbe KVCache Transfer, which also disturbe inference tasks in PD disaggregation. So how should we allocate NICs for remote_instance to minimize disturbing the on-going inference tasks. default allocate all NICs, let's MoonCake or System process it? And Can we limit the bandwidth of remote_instance , such as only use 1/4, 1/2 bandwidth?

Performance tested with 8*H20 on loading Deepseek-V3:

  • register memory time: ~22s (~7s after WIP optimizing)
  • transfer weights time: ~3s

@amysaq2023 thanks for your infos. Can you tell how to config MOONCAKE_DEVICE to reproduce the perf ? I tried this pr in our envs.

hardwares: 2 node of h200*8, with 8 * 400Gb/s NICs(mlx5_1~mlx5_8), 1 * 100Gb/s NIC(mlx5_0)
softwars: mooncake-transfer-engine                 0.3.7.post2
base envs:  export MOONCAKE_PROTOCOL="rdma"

and tested it with 3 settings, And its' perf as belows. Looking forward to your reply, thanks.

1. don't set MOONCAKE_DEVICE, register_memory cost: 70.6599s, batch_transfer_sync_read: 6.3640s
2. export MOONCAKE_DEVICE='mlx5_0', register_memory cost: 8.7188s, batch_transfer_sync_read: 48.9649s
3. export MOONCAKE_DEVICE='mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8', register_memory cost: 60.5853s, batch_transfer_sync_read: 6.1010s

nvidia-smi topo -m as belows

        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    SYS     PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    SYS     PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    SYS     PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    SYS     PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     PIX     PHB     PHB     PHB     90-179  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     PHB     PIX     PHB     PHB     90-179  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     PHB     PHB     PIX     PHB     90-179  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB     PIX     90-179  1               N/A
NIC0    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC1    PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     SYS      X      PHB     PHB     PHB     SYS     SYS     SYS     SYS
NIC2    PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     SYS     PHB      X      PHB     PHB     SYS     SYS     SYS     SYS
NIC3    PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     SYS     PHB     PHB      X      PHB     SYS     SYS     SYS     SYS
NIC4    PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB      X      SYS     SYS     SYS     SYS
NIC5    SYS     SYS     SYS     SYS     PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     SYS      X      PHB     PHB     PHB
NIC6    SYS     SYS     SYS     SYS     PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     SYS     PHB      X      PHB     PHB
NIC7    SYS     SYS     SYS     SYS     PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     SYS     PHB     PHB      X      PHB
NIC8    SYS     SYS     SYS     SYS     PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8

@JD-ETH
Copy link
Copy Markdown
Contributor

JD-ETH commented Dec 2, 2025

super cool! Just for my understanding, the batch_transfer_sync_read call is a single-sided write operation, and does not require additional TransferEngine instance at the target entity, correct?

@amysaq2023
Copy link
Copy Markdown
Contributor Author

@amysaq2023 Hi, Can you share some perf data about remote_instance with Transfer_engine. I tried remote_instance with nccl with deepseek-v3(645GB), it synced within 3s on h200*8 with 8 NIC. Meantime, The on-going inference tasks would be blocked. And there are another question. How should we allocate the NICs for GPU collective operations, KVcache transfer of PD disaggregation and remote_instance with TE? In Our practices (9 NICs), we use 8 NICs(assume ib0-7) pinned with GPUs for GPU collective operations, And 1 NIC(assume ib8) for KVcache transfer of PD disaggregation. And there is no NIC for remote_instance. if we allocate one of ib0-7, it would disturbe inference tasks; if allocate ib8 for remote_instance, it also disturbe KVCache Transfer, which also disturbe inference tasks in PD disaggregation. So how should we allocate NICs for remote_instance to minimize disturbing the on-going inference tasks. default allocate all NICs, let's MoonCake or System process it? And Can we limit the bandwidth of remote_instance , such as only use 1/4, 1/2 bandwidth?

Performance tested with 8*H20 on loading Deepseek-V3:

  • register memory time: ~22s (~7s after WIP optimizing)
  • transfer weights time: ~3s

@amysaq2023 thanks for your infos. Can you tell how to config MOONCAKE_DEVICE to reproduce the perf ? I tried this pr in our envs.

hardwares: 2 node of h200*8, with 8 * 400Gb/s NICs(mlx5_1~mlx5_8), 1 * 100Gb/s NIC(mlx5_0)
softwars: mooncake-transfer-engine                 0.3.7.post2
base envs:  export MOONCAKE_PROTOCOL="rdma"

and tested it with 3 settings, And its' perf as belows. Looking forward to your reply, thanks.

1. don't set MOONCAKE_DEVICE, register_memory cost: 70.6599s, batch_transfer_sync_read: 6.3640s
2. export MOONCAKE_DEVICE='mlx5_0', register_memory cost: 8.7188s, batch_transfer_sync_read: 48.9649s
3. export MOONCAKE_DEVICE='mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8', register_memory cost: 60.5853s, batch_transfer_sync_read: 6.1010s

nvidia-smi topo -m as belows

        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    SYS     PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    SYS     PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    SYS     PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    SYS     PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     0-89    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     PIX     PHB     PHB     PHB     90-179  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     PHB     PIX     PHB     PHB     90-179  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     PHB     PHB     PIX     PHB     90-179  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB     PIX     90-179  1               N/A
NIC0    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC1    PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     SYS      X      PHB     PHB     PHB     SYS     SYS     SYS     SYS
NIC2    PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     SYS     PHB      X      PHB     PHB     SYS     SYS     SYS     SYS
NIC3    PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     SYS     PHB     PHB      X      PHB     SYS     SYS     SYS     SYS
NIC4    PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB      X      SYS     SYS     SYS     SYS
NIC5    SYS     SYS     SYS     SYS     PIX     PHB     PHB     PHB     SYS     SYS     SYS     SYS     SYS      X      PHB     PHB     PHB
NIC6    SYS     SYS     SYS     SYS     PHB     PIX     PHB     PHB     SYS     SYS     SYS     SYS     SYS     PHB      X      PHB     PHB
NIC7    SYS     SYS     SYS     SYS     PHB     PHB     PIX     PHB     SYS     SYS     SYS     SYS     SYS     PHB     PHB      X      PHB
NIC8    SYS     SYS     SYS     SYS     PHB     PHB     PHB     PIX     SYS     SYS     SYS     SYS     SYS     PHB     PHB     PHB      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8

Could you please provide more information about how you set-up the test such as whether running SGLang instance in container and which SGLang image is used? Thanks.

Comment thread docs/advanced_features/rfork.md Outdated
@@ -0,0 +1,55 @@
# R-Fork

R-Fork (Tensor Remote Fork) provides a novel weight loading methodology that leverages efficient inter-node GPU-to-GPU data transfer path to load tensors from a running SGLang instance to a new instance with zero-copy. It can significantly optimize the SGLang instance boot-up time by reducing model weights loading from several minutes to mere seconds.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂 I find that the docs are merely the same as the blog post. So we can just leave a basic introduction, link to our blog, say what R-Fork can do with detailed parameter explanations, and give instructions on usage here.

@amysaq2023 amysaq2023 force-pushed the amy/non-disturbing-remote-instance-weight-loader branch 3 times, most recently from f2afd8c to 2fa8561 Compare December 10, 2025 14:30
@stmatengss
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Comment thread python/sglang/srt/model_loader/loader.py
Comment thread python/sglang/srt/model_loader/loader.py
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

…ader

This commit reduces the time cost on registering memory regions for
using TransferEngine as backend to loading weights from remote instance,
by merging continuous memory blocks into one memory region.

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
@amysaq2023 amysaq2023 force-pushed the amy/non-disturbing-remote-instance-weight-loader branch from 2fa8561 to 0b450ad Compare December 11, 2025 08:59
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@Kangyan-Zhou Kangyan-Zhou merged commit 70758d4 into sgl-project:main Dec 12, 2025
124 of 156 checks passed
Comment on lines +2697 to +2714
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
"tp_rank": tp_rank,
"remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id,
"remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict,
}
)
else:
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not duplicate the code here.
These lines are duplicated.


                    "status": "ready",
                    "max_total_num_tokens": scheduler.max_total_num_tokens,
                    "max_req_input_len": scheduler.max_req_input_len,

Do something like this.

if ...
    result_dict.update({
                    "remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id,
                    "remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict,
    })

template_manager,
scheduler_info,
port_args,
remote_instance_transfer_engine_info,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant. All data in remote_instance_transfer_engine_info is alrady in scheduler_info.
Can you revert the changes here?

i.e., do not change the interface of _launch_subprocesses

@merrymercy
Copy link
Copy Markdown
Contributor

merrymercy commented Dec 12, 2025

@amysaq2023 It conflicts with some refactors here #14869. Would be great if you can keep the interface of _launch_subprocesses unchanged.

@amysaq2023
Copy link
Copy Markdown
Contributor Author

@amysaq2023 It conflicts with some refactors here #14869. Would be great if you can keep the interface of _launch_subprocesses unchanged.

Sure. Working on refactoring this.

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Dec 12, 2025

This PR broke the B200 CI. We need to be more careful during review and before merging PRs — the main branch has been breaking too frequently lately.
@amysaq2023 @zhaochenyang20 @Kangyan-Zhou

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Dec 12, 2025

ref #14958

amysaq2023 added a commit to amysaq2023/sglang that referenced this pull request Dec 12, 2025
This commit address comments in
sgl-project#13125 (comment),
sgl-project#13125 (comment)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
amysaq2023 added a commit to amysaq2023/sglang that referenced this pull request Dec 12, 2025
This commit address comments in
sgl-project#13125 (comment),
sgl-project#13125 (comment)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
@amysaq2023
Copy link
Copy Markdown
Contributor Author

@amysaq2023 It conflicts with some refactors here #14869. Would be great if you can keep the interface of _launch_subprocesses unchanged.

Sure. Working on refactoring this.

PR addressed above comments: #14971
@merrymercy @zhaochenyang20

@weireweire
Copy link
Copy Markdown
Contributor

cause:

  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2720, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 320, in __init__
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 248, in __init__
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 363, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 448, in initialize
    register_memory_region_v2(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py", line 176, in register_memory_region_v2
    ret = transfer_engine.register_memory(address, size)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'register_memory'

Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 17, 2025
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants