Skip to content

[4/N] (Elastic EP) Back up Expert Weights in DRAM#17374

Merged
ShangmingCai merged 32 commits intosgl-project:mainfrom
HanHan009527:mapc-dev
Feb 27, 2026
Merged

[4/N] (Elastic EP) Back up Expert Weights in DRAM#17374
ShangmingCai merged 32 commits intosgl-project:mainfrom
HanHan009527:mapc-dev

Conversation

@ympcMark
Copy link
Copy Markdown
Contributor

@ympcMark ympcMark commented Jan 20, 2026

Motivation

Currently, in Elastic EP, when rank failures or expert weight loss occur, the system must reload missing weights from disk. This process significantly increases service interruption time.

We observed that servers often have unused spare DRAM capacity. Leveraging this available DRAM, we propose maintaining a hot backup of expert weights in memory. In the event of a failure, weights can be restored directly from DRAM over the high-speed RDMA network, updating the model weights in GPU memory much faster than disk-based recovery.

Experimental results show that this approach outperforms the original disk-based method—even when the original already employs optimizations such as multi-threading and OS page cache.

Furthermore, in multi-node deployment scenarios, each node only needs to store a subset of expert weights, significantly reducing per-node DRAM overhead.

Our changes are incremental and controlled by a server argument that is disabled by default, ensuring no impact on existing deployments.

Modifications

  • expert_backup_manager.py: Introduced a new ExpertBackupManager class and process responsible for loading a subset of expert weights from disk into a continuous CPU buffer and managing their transfer via Mooncake.

  • expert_backup_client.py: Implemented ZMQ-based communication between the ExpertBackupClient and the ExpertBackupManager to coordinate weight backup requests and transfer expert weight pointer information and session IDs across distributed nodes. Then initialize a client-side TransferEngine and dynamically load expert weights from DRAM when needed.

  • io_struct.py: UpdateExpertBackupReq and BackupDramReq.

  • engine.py: Launch ExpertBackupManager.

  • test_expert_backup.py: Add unit tests.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ympcMark, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the handling of expert weights in a distributed system to optimize DRAM usage. It introduces a dedicated ExpertBackupManager process that independently manages and backs up expert weights. The existing ExpertLocationUpdater and ModelRunner components are updated to communicate with this new manager, enabling efficient weight transfer and management, particularly beneficial for Mixture of Experts (MoE) models by avoiding redundant weight storage across scheduler processes.

Highlights

  • DRAM Resource Optimization: Decoupled the logic of backing up expert weights from the scheduler into an independent process, significantly reducing DRAM consumption by preventing each scheduler from storing a complete list of weights.
  • New Expert Backup Manager: Introduced a new ExpertBackupManager class and process responsible for loading a subset of expert weights from disk into a continuous CPU buffer and managing their transfer via a TransferEngine (Mooncake).
  • Inter-Process Communication: Implemented ZMQ-based communication between the ExpertLocationUpdater and the new ExpertBackupManager to coordinate weight backup requests and transfer metadata across distributed nodes.
  • ModelRunner Integration: Modified the ModelRunner to initialize a client-side TransferEngine, receive backup maps and session IDs from the ExpertBackupManager via the ExpertLocationUpdater, and dynamically load expert weights from DRAM when needed.
  • New Test Case: Added test_expert_backup.py to validate the new distributed expert weight backup and transfer mechanism, ensuring functional correctness and performance in a multi-process environment.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to back up expert weights in DRAM to reduce memory pressure on scheduler processes. It decouples the backup logic into a new, independent ExpertBackupManager process that communicates with the ExpertLocationUpdater via ZMQ and transfers weights using RDMA. While the overall approach is sound, I've identified some critical performance and correctness issues in the implementation, particularly concerning inefficient busy-wait loops and redundant operations in the new manager process. I've also pointed out areas for code cleanup and improved clarity.

Comment on lines +149 to +167
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_expert_location_updater.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
logger.info("recv_req: %s", recv_req)
# TODO (stage 1): handle recv_req
print("RECEIVED REQ: ", self.engine_rank)
self.backup_weights_from_disk()
self.start_transfer_server()
back_req = BackupDramReq(
_rank=self.engine_rank,
_map=self.weight_pointer_map,
session_id=self.session_id,
buffer_size=self.continuous_buffer.numel() * self.continuous_buffer.element_size()
)
self.send_to_expert_location_updater.send_pyobj(back_req)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The event_loop implementation has several issues affecting performance and correctness:

  1. Busy-Wait Loop: The nested while True with a non-blocking recv_pyobj creates a busy-wait loop, consuming unnecessary CPU cycles. It's more efficient to use a single blocking recv_pyobj() to wait for messages.
  2. Redundant Operations: backup_weights_from_disk() and start_transfer_server() are called for every received request. Since the expert weights on disk are static, these expensive operations should be performed only once when the manager starts.
  3. Potential None Access: If backup_weights_from_disk results in self.continuous_buffer being None, the subsequent access self.continuous_buffer.numel() will raise an AttributeError.

I suggest refactoring this method to perform one-time setup and use a blocking receive loop.

    def event_loop(self):
        self.backup_weights_from_disk()
        if self.continuous_buffer is not None:
            self.start_transfer_server()

        while True:
            try:
                recv_req = self.recv_from_expert_location_updater.recv_pyobj()
            except zmq.ZMQError:
                # This can happen if the context is terminated.
                break

            logger.info("recv_req: %s", recv_req)
            # TODO (stage 1): handle recv_req
            print("RECEIVED REQ: ", self.engine_rank)

            buffer_size = 0
            if self.continuous_buffer is not None:
                buffer_size = self.continuous_buffer.numel() * self.continuous_buffer.element_size()

            back_req = BackupDramReq(
                _rank=self.engine_rank,
                _map=self.weight_pointer_map,
                session_id=self.session_id,
                buffer_size=buffer_size,
            )
            self.send_to_expert_location_updater.send_pyobj(back_req)

Comment on lines +82 to +97
def _receive_loop(self):
cnt = 0
while True:
for i in range(self.engine_num):
try:
response = self.recv_list[i].recv_pyobj()
except zmq.ZMQError:
continue
self.model_runner.dram_map_list[response._rank] = response._map
self.model_runner.session_id_list[response._rank] = response.session_id
self.model_runner.buffer_size = max(self.model_runner.buffer_size, response.buffer_size)
cnt += 1
print("RECEIVED: ", response._rank)
if cnt == self.engine_num:
self.model_runner.if_backup = True
self.model_runner.start_transfer_client()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _receive_loop will sequentially block on receiving from each socket in self.recv_list. This means it will wait for a message from self.recv_list[0], then self.recv_list[1], and so on. This is likely not the intended behavior, as it should process messages from any backup manager as they arrive.

A better approach is to use zmq.Poller to wait for messages on all sockets simultaneously. This avoids both sequential blocking and busy-waiting.

Here's an example of how you could refactor it:

def _receive_loop(self):
    cnt = 0
    poller = zmq.Poller()
    for sock in self.recv_list:
        poller.register(sock, zmq.POLLIN)

    while cnt < self.engine_num:
        try:
            socks = dict(poller.poll())
        except zmq.ZMQError:
            # Context terminated
            break

        for sock in socks:
            if socks[sock] == zmq.POLLIN:
                try:
                    response = sock.recv_pyobj(zmq.NOBLOCK)
                    self.model_runner.dram_map_list[response._rank] = response._map
                    self.model_runner.session_id_list[response._rank] = response.session_id
                    self.model_runner.buffer_size = max(self.model_runner.buffer_size, response.buffer_size)
                    cnt += 1
                    print("RECEIVED: ", response._rank)
                except zmq.ZMQError:
                    continue
    
    if cnt == self.engine_num:
        self.model_runner.if_backup = True
        self.model_runner.start_transfer_client()

self.model_path = server_args.model_path
self.load_format = server_args.load_format
self.model_config = ModelConfig.from_server_args(server_args)
print("MANAGER START: ", server_args.node_rank)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replace print with logger.info for consistent and controllable logging. This allows log levels to be managed centrally and provides more context (like timestamp, module name) in logs.

Suggested change
print("MANAGER START: ", server_args.node_rank)
logger.info("MANAGER START: %s", server_args.node_rank)


if get_world_rank() % (get_world_size() // server_args.nnodes) == 0:
self.send_to_backup_manager.send_pyobj(UpdateExpertBackupReq())
print("SEND TO MANAGER!", server_args.node_rank)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please use the logging module instead of print for better log management and consistency.

Suggested change
print("SEND TO MANAGER!", server_args.node_rank)
logger.info("SEND TO MANAGER! node_rank=%s", server_args.node_rank)

)

def start_transfer_client(self):
print("START CLIENT")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please replace print with logger.info for standardized logging. This will make logs more informative and easier to filter.

Suggested change
print("START CLIENT")
logger.info("START CLIENT")

Comment on lines +892 to +900
self.params_dict = dict(self.model.named_parameters())
for name, param in self.params_dict.items():
param_data = param.data
ret_value = self.transfer_engine.register_memory(
param_data.data_ptr(), param_data.numel() * param_data.element_size()
)
if ret_value != 0:
print(f"GPU buffer memory registration failed for param {name}")
raise RuntimeError("GPU buffer memory registration failed.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop registers memory for all model parameters with the transfer engine. However, the backup and transfer mechanism seems to be focused on expert weights. If only expert weights are transferred, registering all parameters might be unnecessary and could increase initialization time and memory overhead for the transfer engine. It would be more efficient to only register the memory for the parameters that will actually be updated from the DRAM backup.

Comment on lines +901 to +925
"""
def load_weight_from_dram(self, name):
weight_info = self.dram_map[name]
server_ptr = weight_info['weight_ptr']
weight_size = weight_info['numel'] * weight_info['element_size']

gpu_buffer_size = self.gpu_buffer.numel() * self.gpu_buffer.element_size()
if weight_size > gpu_buffer_size:
raise RuntimeError(f"Weight size {weight_size} exceeds GPU buffer size {gpu_buffer_size}")

ret = self.transfer_engine.transfer_sync_read(
self.server_session_id,
self.gpu_buffer.data_ptr(),
server_ptr,
weight_size
)

if ret != 0:
raise RuntimeError(f"Failed to read weight {name} from backup, error code: {ret}")

byte_data = self.gpu_buffer[:weight_size]
weight_tensor = byte_data.view(weight_info['dtype']).reshape(weight_info['shape'])

return weight_tensor
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code for load_weight_from_dram should be removed to improve code clarity if it is no longer in use.

Comment on lines +950 to +1016
def get_weight_iter(config):
iter = loader._get_weights_iterator(
DefaultModelLoader.Source.init_new(config, self.model)
)
if weight_name_filter is not None:
iter = (
(name, weight) for name, weight in iter if weight_name_filter(name)
if self.if_backup:
print("USE BACKUP")
global_expert_location_metadata = get_global_expert_location_metadata()
num_experts = self.model_config.hf_config.n_routed_experts + self.server_args.ep_num_redundant_experts
num_local_experts = num_experts // self.moe_ep_size
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=num_experts,
)
for i in range(self.engine_num):
server_ptr_list = []
local_ptr_list = []
weight_size_list = []

for name, weight_info in self.dram_map_list[i].items():
layer_id, expert_id = extract_layer_and_expert_id(name)
if layer_id >= self.model_config.hf_config.num_hidden_layers:
continue
if weight_name_filter is None or weight_name_filter(name):
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
layer_id, expert_id
)
for physical_expert_id in physical_expert_ids:
if physical_expert_id not in range(
num_local_experts * self.moe_ep_rank,
num_local_experts * (self.moe_ep_rank + 1)
):
continue
name = name.replace(weight_name, param_name)
param = self.params_dict[name]
param = param[physical_expert_id % num_local_experts]
if shard_id == "w1":
param = param.narrow(0, 0, param.shape[0] // 2)
elif shard_id == "w3":
param = param.narrow(0, param.shape[0] // 2, param.shape[0] // 2)
weight_info['tensor'] = param
server_ptr_list.append(weight_info['weight_ptr'])
local_ptr_list.append(weight_info['tensor'].data_ptr())
assert weight_info['tensor'].numel() * weight_info['tensor'].element_size() == weight_info['byte_size']
weight_size_list.append(weight_info['byte_size'])

before_transfer = time.time()
ret = self.transfer_engine.batch_transfer_sync_read(
self.session_id_list[i],
local_ptr_list,
server_ptr_list,
weight_size_list
)
after_transfer = time.time()
logger.info(f"transfer time = {after_transfer - before_transfer} s")

return iter
if ret != 0:
raise RuntimeError(f"Failed to read weights from backup, error code: {ret}")
else:
iter = loader._get_weights_iterator(
DefaultModelLoader.Source.init_new(config, self.model)
)
for name, weight in iter:
if weight_name_filter is None or weight_name_filter(name):
yield name, weight
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When self.if_backup is true, this function performs weight updates directly via RDMA transfer instead of yielding weights. This behavior is contrary to what its name get_weight_iter suggests, making the control flow confusing. The caller model_load_weights receives an empty iterator and does nothing, while the actual weight update happens inside this function.

Consider refactoring this logic to improve clarity. For example, you could create a separate function like _update_weights_from_dram_backup that is called when self.if_backup is true, and have get_weight_iter consistently behave as an iterator.

@hzh0425 hzh0425 self-assigned this Jan 20, 2026
@ympcMark ympcMark changed the title Back up Expert Weights in DRAM [7/N] (Elastic EP) Back up Expert Weights in DRAM Jan 20, 2026
@stmatengss stmatengss self-assigned this Jan 20, 2026
@stmatengss
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@UNIDY2002 UNIDY2002 force-pushed the mapc-dev branch 3 times, most recently from f361680 to 0102b24 Compare January 22, 2026 08:25
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Jan 23, 2026
@ympcMark ympcMark marked this pull request as ready for review January 23, 2026 09:39
| `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | `None` | Type: str |
| `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | `None` | Type: int |
| `--elastic-ep-backend` | Specify the collective communication backend for elastic EP. Currently supports 'mooncake'. | `none` | `none`, `mooncake` |
| `--enable-elastic-expert-backup` | Enable elastic expert backup feature. | `False` | bool flag (set to enable) |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the explanation here carries zero information

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will make it clear to understand

param_data.data_ptr(), param_data.numel() * param_data.element_size()
)
if ret_value != 0:
raise RuntimeError("GPU buffer memory registration failed.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid engine crash for this resilient feature?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will avoid using backup when failing

Comment thread python/sglang/srt/elastic_ep/expert_backup_client.py
self.server_args.language_only
and self.server_args.encoder_transfer_backend == "mooncake"
)
or self.server_args.enable_elastic_expert_backup
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we should check elastic_ep_backend here instead

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it only work for mooncake ep backend? need to make sure if it is general or we need to check the elastic ep backend == mooncake

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is general, precisely we check elastic ep backend is not None

for i in range(self.engine_num):
self.recv_list[i] = context.socket(zmq.SUB)
self.recv_list[i].connect(
f"tcp://{all_ips[i * get_world_size() // server_args.nnodes]}:{PORT_BASE + i * 2 + 1}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks a little bit hacky. But you can optimize this in the future.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@UNIDY2002
Copy link
Copy Markdown
Contributor

Hi @ShangmingCai! May I request a rerun of the failed tests? Thank you!

@stmatengss
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@UNIDY2002
Copy link
Copy Markdown
Contributor

DeepEP tests are failing. Similar failures can be observed in other runs like https://github.com/sgl-project/sglang/actions/runs/22431204967.

@ympcMark ympcMark changed the title [7/N] (Elastic EP) Back up Expert Weights in DRAM [4/N] (Elastic EP) Back up Expert Weights in DRAM Feb 26, 2026
@UNIDY2002
Copy link
Copy Markdown
Contributor

Hi all! May I request a rerun of the failed tests? Thanks!

@ShangmingCai
Copy link
Copy Markdown
Collaborator

DeepEP tests are failing. Similar failures can be observed in other runs like https://github.com/sgl-project/sglang/actions/runs/22431204967.

I have fixed the CI. Let me rerun the remaining tests.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-deepep-4-gpu

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-deepep-4-gpu to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@UNIDY2002
Copy link
Copy Markdown
Contributor

Screenshot 2026-02-27 13 39 09 The DeepEP tests have passed. The only failing test is irrelevant. Do we need to rerun it?

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

Failed AMD tests (lora, tool call parser) are irrelevant.

@ShangmingCai ShangmingCai merged commit 43fade5 into sgl-project:main Feb 27, 2026
190 of 201 checks passed
@ympcMark
Copy link
Copy Markdown
Contributor Author

@ShangmingCai @ch-wan thank you for the valuable review comments!

@UNIDY2002 UNIDY2002 deleted the mapc-dev branch February 28, 2026 12:26
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants