[4/N] (Elastic EP) Back up Expert Weights in DRAM#17374
[4/N] (Elastic EP) Back up Expert Weights in DRAM#17374ShangmingCai merged 32 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @ympcMark, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the handling of expert weights in a distributed system to optimize DRAM usage. It introduces a dedicated Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to back up expert weights in DRAM to reduce memory pressure on scheduler processes. It decouples the backup logic into a new, independent ExpertBackupManager process that communicates with the ExpertLocationUpdater via ZMQ and transfers weights using RDMA. While the overall approach is sound, I've identified some critical performance and correctness issues in the implementation, particularly concerning inefficient busy-wait loops and redundant operations in the new manager process. I've also pointed out areas for code cleanup and improved clarity.
| def event_loop(self): | ||
| while True: | ||
| while True: | ||
| try: | ||
| recv_req = self.recv_from_expert_location_updater.recv_pyobj(zmq.NOBLOCK) | ||
| except zmq.ZMQError: | ||
| break | ||
| logger.info("recv_req: %s", recv_req) | ||
| # TODO (stage 1): handle recv_req | ||
| print("RECEIVED REQ: ", self.engine_rank) | ||
| self.backup_weights_from_disk() | ||
| self.start_transfer_server() | ||
| back_req = BackupDramReq( | ||
| _rank=self.engine_rank, | ||
| _map=self.weight_pointer_map, | ||
| session_id=self.session_id, | ||
| buffer_size=self.continuous_buffer.numel() * self.continuous_buffer.element_size() | ||
| ) | ||
| self.send_to_expert_location_updater.send_pyobj(back_req) |
There was a problem hiding this comment.
The event_loop implementation has several issues affecting performance and correctness:
- Busy-Wait Loop: The nested
while Truewith a non-blockingrecv_pyobjcreates a busy-wait loop, consuming unnecessary CPU cycles. It's more efficient to use a single blockingrecv_pyobj()to wait for messages. - Redundant Operations:
backup_weights_from_disk()andstart_transfer_server()are called for every received request. Since the expert weights on disk are static, these expensive operations should be performed only once when the manager starts. - Potential
NoneAccess: Ifbackup_weights_from_diskresults inself.continuous_bufferbeingNone, the subsequent accessself.continuous_buffer.numel()will raise anAttributeError.
I suggest refactoring this method to perform one-time setup and use a blocking receive loop.
def event_loop(self):
self.backup_weights_from_disk()
if self.continuous_buffer is not None:
self.start_transfer_server()
while True:
try:
recv_req = self.recv_from_expert_location_updater.recv_pyobj()
except zmq.ZMQError:
# This can happen if the context is terminated.
break
logger.info("recv_req: %s", recv_req)
# TODO (stage 1): handle recv_req
print("RECEIVED REQ: ", self.engine_rank)
buffer_size = 0
if self.continuous_buffer is not None:
buffer_size = self.continuous_buffer.numel() * self.continuous_buffer.element_size()
back_req = BackupDramReq(
_rank=self.engine_rank,
_map=self.weight_pointer_map,
session_id=self.session_id,
buffer_size=buffer_size,
)
self.send_to_expert_location_updater.send_pyobj(back_req)| def _receive_loop(self): | ||
| cnt = 0 | ||
| while True: | ||
| for i in range(self.engine_num): | ||
| try: | ||
| response = self.recv_list[i].recv_pyobj() | ||
| except zmq.ZMQError: | ||
| continue | ||
| self.model_runner.dram_map_list[response._rank] = response._map | ||
| self.model_runner.session_id_list[response._rank] = response.session_id | ||
| self.model_runner.buffer_size = max(self.model_runner.buffer_size, response.buffer_size) | ||
| cnt += 1 | ||
| print("RECEIVED: ", response._rank) | ||
| if cnt == self.engine_num: | ||
| self.model_runner.if_backup = True | ||
| self.model_runner.start_transfer_client() |
There was a problem hiding this comment.
The current implementation of _receive_loop will sequentially block on receiving from each socket in self.recv_list. This means it will wait for a message from self.recv_list[0], then self.recv_list[1], and so on. This is likely not the intended behavior, as it should process messages from any backup manager as they arrive.
A better approach is to use zmq.Poller to wait for messages on all sockets simultaneously. This avoids both sequential blocking and busy-waiting.
Here's an example of how you could refactor it:
def _receive_loop(self):
cnt = 0
poller = zmq.Poller()
for sock in self.recv_list:
poller.register(sock, zmq.POLLIN)
while cnt < self.engine_num:
try:
socks = dict(poller.poll())
except zmq.ZMQError:
# Context terminated
break
for sock in socks:
if socks[sock] == zmq.POLLIN:
try:
response = sock.recv_pyobj(zmq.NOBLOCK)
self.model_runner.dram_map_list[response._rank] = response._map
self.model_runner.session_id_list[response._rank] = response.session_id
self.model_runner.buffer_size = max(self.model_runner.buffer_size, response.buffer_size)
cnt += 1
print("RECEIVED: ", response._rank)
except zmq.ZMQError:
continue
if cnt == self.engine_num:
self.model_runner.if_backup = True
self.model_runner.start_transfer_client()| self.model_path = server_args.model_path | ||
| self.load_format = server_args.load_format | ||
| self.model_config = ModelConfig.from_server_args(server_args) | ||
| print("MANAGER START: ", server_args.node_rank) |
There was a problem hiding this comment.
Replace print with logger.info for consistent and controllable logging. This allows log levels to be managed centrally and provides more context (like timestamp, module name) in logs.
| print("MANAGER START: ", server_args.node_rank) | |
| logger.info("MANAGER START: %s", server_args.node_rank) |
|
|
||
| if get_world_rank() % (get_world_size() // server_args.nnodes) == 0: | ||
| self.send_to_backup_manager.send_pyobj(UpdateExpertBackupReq()) | ||
| print("SEND TO MANAGER!", server_args.node_rank) |
| ) | ||
|
|
||
| def start_transfer_client(self): | ||
| print("START CLIENT") |
| self.params_dict = dict(self.model.named_parameters()) | ||
| for name, param in self.params_dict.items(): | ||
| param_data = param.data | ||
| ret_value = self.transfer_engine.register_memory( | ||
| param_data.data_ptr(), param_data.numel() * param_data.element_size() | ||
| ) | ||
| if ret_value != 0: | ||
| print(f"GPU buffer memory registration failed for param {name}") | ||
| raise RuntimeError("GPU buffer memory registration failed.") |
There was a problem hiding this comment.
This loop registers memory for all model parameters with the transfer engine. However, the backup and transfer mechanism seems to be focused on expert weights. If only expert weights are transferred, registering all parameters might be unnecessary and could increase initialization time and memory overhead for the transfer engine. It would be more efficient to only register the memory for the parameters that will actually be updated from the DRAM backup.
| """ | ||
| def load_weight_from_dram(self, name): | ||
| weight_info = self.dram_map[name] | ||
| server_ptr = weight_info['weight_ptr'] | ||
| weight_size = weight_info['numel'] * weight_info['element_size'] | ||
|
|
||
| gpu_buffer_size = self.gpu_buffer.numel() * self.gpu_buffer.element_size() | ||
| if weight_size > gpu_buffer_size: | ||
| raise RuntimeError(f"Weight size {weight_size} exceeds GPU buffer size {gpu_buffer_size}") | ||
|
|
||
| ret = self.transfer_engine.transfer_sync_read( | ||
| self.server_session_id, | ||
| self.gpu_buffer.data_ptr(), | ||
| server_ptr, | ||
| weight_size | ||
| ) | ||
|
|
||
| if ret != 0: | ||
| raise RuntimeError(f"Failed to read weight {name} from backup, error code: {ret}") | ||
|
|
||
| byte_data = self.gpu_buffer[:weight_size] | ||
| weight_tensor = byte_data.view(weight_info['dtype']).reshape(weight_info['shape']) | ||
|
|
||
| return weight_tensor | ||
| """ |
| def get_weight_iter(config): | ||
| iter = loader._get_weights_iterator( | ||
| DefaultModelLoader.Source.init_new(config, self.model) | ||
| ) | ||
| if weight_name_filter is not None: | ||
| iter = ( | ||
| (name, weight) for name, weight in iter if weight_name_filter(name) | ||
| if self.if_backup: | ||
| print("USE BACKUP") | ||
| global_expert_location_metadata = get_global_expert_location_metadata() | ||
| num_experts = self.model_config.hf_config.n_routed_experts + self.server_args.ep_num_redundant_experts | ||
| num_local_experts = num_experts // self.moe_ep_size | ||
| expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||
| ckpt_gate_proj_name="gate_proj", | ||
| ckpt_down_proj_name="down_proj", | ||
| ckpt_up_proj_name="up_proj", | ||
| num_experts=num_experts, | ||
| ) | ||
| for i in range(self.engine_num): | ||
| server_ptr_list = [] | ||
| local_ptr_list = [] | ||
| weight_size_list = [] | ||
|
|
||
| for name, weight_info in self.dram_map_list[i].items(): | ||
| layer_id, expert_id = extract_layer_and_expert_id(name) | ||
| if layer_id >= self.model_config.hf_config.num_hidden_layers: | ||
| continue | ||
| if weight_name_filter is None or weight_name_filter(name): | ||
| for mapping in expert_params_mapping: | ||
| param_name, weight_name, expert_id, shard_id = mapping | ||
| if weight_name not in name: | ||
| continue | ||
| physical_expert_ids = global_expert_location_metadata.logical_to_all_physical( | ||
| layer_id, expert_id | ||
| ) | ||
| for physical_expert_id in physical_expert_ids: | ||
| if physical_expert_id not in range( | ||
| num_local_experts * self.moe_ep_rank, | ||
| num_local_experts * (self.moe_ep_rank + 1) | ||
| ): | ||
| continue | ||
| name = name.replace(weight_name, param_name) | ||
| param = self.params_dict[name] | ||
| param = param[physical_expert_id % num_local_experts] | ||
| if shard_id == "w1": | ||
| param = param.narrow(0, 0, param.shape[0] // 2) | ||
| elif shard_id == "w3": | ||
| param = param.narrow(0, param.shape[0] // 2, param.shape[0] // 2) | ||
| weight_info['tensor'] = param | ||
| server_ptr_list.append(weight_info['weight_ptr']) | ||
| local_ptr_list.append(weight_info['tensor'].data_ptr()) | ||
| assert weight_info['tensor'].numel() * weight_info['tensor'].element_size() == weight_info['byte_size'] | ||
| weight_size_list.append(weight_info['byte_size']) | ||
|
|
||
| before_transfer = time.time() | ||
| ret = self.transfer_engine.batch_transfer_sync_read( | ||
| self.session_id_list[i], | ||
| local_ptr_list, | ||
| server_ptr_list, | ||
| weight_size_list | ||
| ) | ||
| after_transfer = time.time() | ||
| logger.info(f"transfer time = {after_transfer - before_transfer} s") | ||
|
|
||
| return iter | ||
| if ret != 0: | ||
| raise RuntimeError(f"Failed to read weights from backup, error code: {ret}") | ||
| else: | ||
| iter = loader._get_weights_iterator( | ||
| DefaultModelLoader.Source.init_new(config, self.model) | ||
| ) | ||
| for name, weight in iter: | ||
| if weight_name_filter is None or weight_name_filter(name): | ||
| yield name, weight |
There was a problem hiding this comment.
When self.if_backup is true, this function performs weight updates directly via RDMA transfer instead of yielding weights. This behavior is contrary to what its name get_weight_iter suggests, making the control flow confusing. The caller model_load_weights receives an empty iterator and does nothing, while the actual weight update happens inside this function.
Consider refactoring this logic to improve clarity. For example, you could create a separate function like _update_weights_from_dram_backup that is called when self.if_backup is true, and have get_weight_iter consistently behave as an iterator.
|
/tag-run-ci-label |
f361680 to
0102b24
Compare
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
| | `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | `None` | Type: str | | ||
| | `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | `None` | Type: int | | ||
| | `--elastic-ep-backend` | Specify the collective communication backend for elastic EP. Currently supports 'mooncake'. | `none` | `none`, `mooncake` | | ||
| | `--enable-elastic-expert-backup` | Enable elastic expert backup feature. | `False` | bool flag (set to enable) | |
There was a problem hiding this comment.
the explanation here carries zero information
There was a problem hiding this comment.
Yes, we will make it clear to understand
| param_data.data_ptr(), param_data.numel() * param_data.element_size() | ||
| ) | ||
| if ret_value != 0: | ||
| raise RuntimeError("GPU buffer memory registration failed.") |
There was a problem hiding this comment.
Can we avoid engine crash for this resilient feature?
There was a problem hiding this comment.
we will avoid using backup when failing
| self.server_args.language_only | ||
| and self.server_args.encoder_transfer_backend == "mooncake" | ||
| ) | ||
| or self.server_args.enable_elastic_expert_backup |
There was a problem hiding this comment.
I feel that we should check elastic_ep_backend here instead
There was a problem hiding this comment.
Does it only work for mooncake ep backend? need to make sure if it is general or we need to check the elastic ep backend == mooncake
There was a problem hiding this comment.
it is general, precisely we check elastic ep backend is not None
| for i in range(self.engine_num): | ||
| self.recv_list[i] = context.socket(zmq.SUB) | ||
| self.recv_list[i].connect( | ||
| f"tcp://{all_ips[i * get_world_size() // server_args.nnodes]}:{PORT_BASE + i * 2 + 1}" |
There was a problem hiding this comment.
Looks a little bit hacky. But you can optimize this in the future.
|
/tag-and-rerun-ci |
|
Hi @ShangmingCai! May I request a rerun of the failed tests? Thank you! |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
DeepEP tests are failing. Similar failures can be observed in other runs like https://github.com/sgl-project/sglang/actions/runs/22431204967. |
|
Hi all! May I request a rerun of the failed tests? Thanks! |
I have fixed the CI. Let me rerun the remaining tests. |
|
/rerun-stage stage-c-test-deepep-4-gpu |
|
✅ Triggered |
|
/rerun-failed-ci |
|
@ShangmingCai @ch-wan thank you for the valuable review comments! |
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>


Motivation
Currently, in Elastic EP, when rank failures or expert weight loss occur, the system must reload missing weights from disk. This process significantly increases service interruption time.
We observed that servers often have unused spare DRAM capacity. Leveraging this available DRAM, we propose maintaining a hot backup of expert weights in memory. In the event of a failure, weights can be restored directly from DRAM over the high-speed RDMA network, updating the model weights in GPU memory much faster than disk-based recovery.
Experimental results show that this approach outperforms the original disk-based method—even when the original already employs optimizations such as multi-threading and OS page cache.
Furthermore, in multi-node deployment scenarios, each node only needs to store a subset of expert weights, significantly reducing per-node DRAM overhead.
Our changes are incremental and controlled by a server argument that is disabled by default, ensuring no impact on existing deployments.
Modifications
expert_backup_manager.py: Introduced a new
ExpertBackupManagerclass and process responsible for loading a subset of expert weights from disk into a continuous CPU buffer and managing their transfer via Mooncake.expert_backup_client.py: Implemented ZMQ-based communication between the
ExpertBackupClientand theExpertBackupManagerto coordinate weight backup requests and transfer expert weight pointer information and session IDs across distributed nodes. Then initialize a client-side TransferEngine and dynamically load expert weights from DRAM when needed.io_struct.py:
UpdateExpertBackupReqandBackupDramReq.engine.py: Launch
ExpertBackupManager.test_expert_backup.py: Add unit tests.
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci