Skip to content

[EP] In-place Member Update#1630

Merged
UNIDY2002 merged 23 commits intomainfrom
mapc-upd
Mar 9, 2026
Merged

[EP] In-place Member Update#1630
UNIDY2002 merged 23 commits intomainfrom
mapc-upd

Conversation

@ympcMark
Copy link
Copy Markdown
Collaborator

@ympcMark ympcMark commented Mar 8, 2026

Description

Part of #1225.

When a process fails and rejoins, we used to recreate all EP buffer.
This means we need to capture CUDA graph again, causing longer service interruption times.
Now for the process already exists, we just recreate its QP.
Then all processes exchange their own metadata and establish a connection with other processes.
To recycle old QPs, we also modify the implement of memheap.

Module

  • Transfer Engine (mooncake-transfer-engine)
  • Mooncake Store (mooncake-store)
  • Mooncake EP (mooncake-ep)
  • Integration (mooncake-integration)
  • P2P Store (mooncake-p2p-store)
  • Python Wheel (mooncake-wheel)
  • PyTorch Backend (mooncake-pg)
  • Mooncake RL (mooncake-rl)
  • CI/CD
  • Docs
  • Other

Type of Change

  • Bug fix
  • New feature
  • Refactor
  • Breaking change
  • Documentation update
  • Other

How Has This Been Tested?

Checklist

  • I have performed a self-review of my own code.
  • I have formatted my own code using ./scripts/code_format.sh before submitting.
  • I have updated the documentation.
  • I have added tests to prove my changes are effective.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the resilience and efficiency of the Mooncake EP system by optimizing the process rejoin mechanism. Instead of a full recreation of all EP buffers and subsequent CUDA graph recaptures, which led to notable service interruptions, the system now intelligently recreates only the necessary Queue Pairs (QPs) for existing processes. This change, coupled with a refined memory management strategy that allows for the recycling of old QPs, ensures faster recovery and reduced downtime when processes fail and rejoin the cluster.

Highlights

  • Optimized Process Rejoin: The system now recreates only the necessary Queue Pairs (QPs) for existing processes when a process fails and rejoins, significantly reducing service interruption compared to recreating all EP buffers and recapturing CUDA graphs.
  • Enhanced Memory Management: The memheap implementation has been modified to support the recycling of old QPs, improving resource efficiency.
  • Dynamic Connection Establishment: Processes are now able to exchange their metadata and re-establish connections after QP recreation, ensuring seamless operation post-rejoin.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • No specific activity (comments, reviews, or progress updates) was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the EP member update process to reduce service interruption time by only recreating QPs for existing processes. However, the implementation introduces several critical security and stability issues. The custom memheap allocator is fundamentally broken, leading to memory overlap and corruption of RDMA hardware structures. Additionally, the QP update logic contains out-of-bounds vector accesses, and cleanup routines are vulnerable to null pointer dereferences. Further areas for improvement include addressing code duplication, enhancing robustness in a destructor function, and improving clarity in a test case.

Comment on lines 138 to 154
static inline void memheap_free(struct memheap *heap, size_t offset) {
// currently no-op
if (!heap || offset == (size_t)-1) {
return;
}

mutex_lock(&heap->lock);

for (int i = 0; i < heap->alloc_count; i++) {
if (heap->allocs[i].used && heap->allocs[i].offset == offset) {
heap->allocs[i].used = false;
heap->allocated -= heap->allocs[i].size;
break;
}
}

mutex_unlock(&heap->lock);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

The memheap implementation incorrectly manages the allocated field, which is used as a watermark for new allocations. In memheap_free, heap->allocated is decremented by the size of the freed block regardless of its position. If the freed block is not the last one allocated, this moves the watermark backward into already occupied memory. Conversely, in memheap_aligned_alloc, reusing a free block increments heap->allocated, moving the watermark forward unnecessarily. When a new allocation is needed, it starts at the current heap->allocated offset. This logic will inevitably lead to overlapping memory allocations and corruption of critical hardware structures (CQs, WQs, DBRs), causing unpredictable behavior or system crashes.

Comment on lines +460 to +498
void MooncakeEpBuffer::update_local_qpns() {
for (int i = 0; i < MAX_QP_COUNT; ++i) {
if (qps[i]) {
mlx5gda_destroy_qp(ctrl_buf_heap, qps[i]);
qps[i] = nullptr;
}
}

for (int i = 0; i < MAX_QP_COUNT; ++i) {
mlx5gda_qp* qp =
mlx5gda_create_rc_qp(mpd, ctrl_buf, ctrl_buf_umem, ctrl_buf_heap,
pd, 16384, 1, comm_stream.stream());
if (!qp) {
perror("Failed to recreate QP");
ibgda_disabled_ = true;
return;
}
is_roce_ = qp->port_attr.link_layer == IBV_LINK_LAYER_ETHERNET;
if (mlx5gda_modify_rc_qp_rst2init(qp, 0)) {
perror("Failed to mlx5gda_modify_rc_qp_rst2init");
ibgda_disabled_ = true;
return;
}
// Ensure all async memset operations are complete before accessing QP
// structures
CUDA_CHECK(cudaStreamSynchronize(comm_stream.stream()));

mlx5gda_qp_devctx qp_devctx = {
.qpn = qp->qpn,
.wqeid_mask = qp->num_wqebb - 1,
.wq = (mlx5gda_wqebb*)(ctrl_buf + qp->wq_offset),
.cq = (mlx5_cqe64*)(ctrl_buf + qp->send_cq->cq_offset),
.dbr = (mlx5gda_wq_dbr*)(ctrl_buf + qp->dbr_offset),
.bf = (char*)qp->uar->reg_addr,
};
cudaMemcpy(qp_devctxs + i * sizeof(mlx5gda_qp_devctx), &qp_devctx,
sizeof(mlx5gda_qp_devctx), cudaMemcpyHostToDevice);
qps[i] = qp;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The update_local_qpns function iterates up to MAX_QP_COUNT and accesses the qps vector using the subscript operator. However, the qps vector is populated in init_ibgda using push_back, and its size may be less than MAX_QP_COUNT if initialization partially failed. Accessing qps[i] when i >= qps.size() is an out-of-bounds operation that leads to undefined behavior, including potential crashes or memory corruption during the assignment at line 497.

Comment on lines +104 to 123
if (ret == (size_t)-1) {
size_t offset = heap->allocated;
if (offset & (align - 1)) {
offset = (offset | (align - 1)) + 1;
}
if (offset + size <= heap->size) {
ret = offset;

if (heap->alloc_count < MEMHEAP_MAX_ALLOCATIONS) {
heap->allocs[heap->alloc_count].offset = offset;
heap->allocs[heap->alloc_count].size = size;
heap->allocs[heap->alloc_count].used = true;
heap->alloc_count++;
}

heap->allocated = offset + size;
} else {
errno = ENOMEM;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The memheap allocator has a permanent memory leak. If heap->alloc_count reaches MEMHEAP_MAX_ALLOCATIONS, allocations are not recorded in the allocs array, preventing them from being reclaimed by memheap_free. Additionally, the allocation logic has a critical issue where heap->allocated is used for conflicting purposes (tracking total size and high-water mark), which, after memheap_free decrements it, can lead to new allocations overlapping with existing ones and causing memory corruption. A separate member should track the high-water mark for the bump allocator logic.

Comment on lines +350 to +369
void mlx5gda_destroy_qp(struct memheap *ctrl_buf_heap, struct mlx5gda_qp *qp) {
if (qp->mqp) {
mlx5dv_devx_obj_destroy(qp->mqp);
}
if (qp->uar) {
destroy_uar(qp->uar);
}
if (qp->send_cq) {
mlx5gda_destroy_cq(ctrl_buf_heap, qp->send_cq);
}
if (qp->wq_offset != -1) {
memheap_free(ctrl_buf_heap, qp->wq_offset);
}
if (qp->dbr_offset != -1) {
memheap_free(ctrl_buf_heap, qp->dbr_offset);
}
if (qp) {
free(qp);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The function mlx5gda_destroy_qp dereferences the qp pointer (e.g., qp->mqp, qp->uar) before checking for NULL at line 366. This can lead to a null pointer dereference and application crash if qp is NULL. A null check should be added at the beginning of the function for robustness.

Comment on lines +163 to +235
def update_ep_member(self):
from mooncake import ep
if not self._use_fallback:
(raddr, rkey) = self.runtime.get_mr_info()

raddr = torch.tensor([raddr], dtype=torch.int64, device='cuda')
raddrs = [torch.empty(1, dtype=torch.int64, device='cuda') for _ in range(self.group_size)]
dist.all_gather(raddrs, raddr, self.group)
raddrs = torch.cat(raddrs).tolist()

rkey = torch.tensor([rkey], dtype=torch.int32, device='cuda')
rkeys = [torch.empty(1, dtype=torch.int32, device='cuda') for _ in range(self.group_size)]
dist.all_gather(rkeys, rkey, self.group)
rkeys = torch.cat(rkeys).tolist()

all_to_all_size = ep.MAX_QP_COUNT // self.group_size

self.runtime.update_local_qpns()

local_qpns = self.runtime.get_local_qpns()
local_qpns = list(torch.unbind(torch.tensor(local_qpns, dtype=torch.int32, device='cuda').view(-1, all_to_all_size)))
remote_qpns = [torch.empty(all_to_all_size, dtype=torch.int32, device='cuda') for _ in range(self.group_size)]
dist.all_to_all(remote_qpns, local_qpns, self.group)
remote_qpns = torch.cat(remote_qpns).tolist()

if self.runtime.is_roce():
(subnet_prefix, interface_id) = self.runtime.get_gid()

subnet_prefix = torch.tensor([subnet_prefix], dtype=torch.int64, device='cuda')
subnet_prefixes = [torch.empty(1, dtype=torch.int64, device='cuda') for _ in range(self.group_size)]
dist.all_gather(subnet_prefixes, subnet_prefix, self.group)
subnet_prefixes = torch.cat(subnet_prefixes).tolist()

interface_id = torch.tensor([interface_id], dtype=torch.int64, device='cuda')
interface_ids = [torch.empty(1, dtype=torch.int64, device='cuda') for _ in range(self.group_size)]
dist.all_gather(interface_ids, interface_id, self.group)
interface_ids = torch.cat(interface_ids).tolist()

self.runtime.sync_roce(raddrs, rkeys, remote_qpns, subnet_prefixes, interface_ids)
else:

local_lids = self.runtime.get_local_lids()
local_lids = list(torch.unbind(torch.tensor(local_lids, dtype=torch.int32, device='cuda').view(-1, all_to_all_size)))
remote_lids = [torch.empty(all_to_all_size, dtype=torch.int32, device='cuda') for _ in range(self.group_size)]
dist.all_to_all(remote_lids, local_lids, self.group)
remote_lids = torch.cat(remote_lids).tolist()

self.runtime.sync_ib(raddrs, rkeys, remote_qpns, remote_lids)

try:
local_handle_ints = self.runtime.get_ipc_handle()
# pybind11 converts std::vector<int32_t> to a list of integers
local_handle_tensor = torch.tensor(local_handle_ints, dtype=torch.int32, device='cuda')
handles = [torch.empty(len(local_handle_ints), dtype=torch.int32, device='cuda') for _ in range(self.group_size)]
dist.all_gather(handles, local_handle_tensor, self.group)
remote_handles = [h.tolist() for h in handles]
self.runtime.sync_nvlink_ipc_handles(remote_handles)
except Exception as e:
import warnings
warnings.warn(
f"[Rank {self.rank}] Failed to exchange IPC handles: {e}. Falling back.",
RuntimeWarning,
stacklevel=2,
)

use_fast_path = False
try:
use_fast_path = bool(self.runtime.use_fast_path())
except Exception:
ibgda_disabled = bool(self.runtime.ibgda_disabled())
use_fast_path = not ibgda_disabled

self._use_fallback = not use_fast_path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new method update_ep_member has a large amount of code that is duplicated from the __init__ method (lines 77-161). This includes logic for exchanging RDMA metadata, QP numbers, GID/LID information, and IPC handles.

To improve maintainability and reduce redundancy, this common logic should be extracted into a private helper method that can be called from both __init__ and update_ep_member.

Comment on lines +171 to +174
if local_rank == 0:
buffer.update_ep_member()
else:
buffer = Buffer(group, num_ep_buffer_bytes=num_ep_buffer_bytes)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test logic here is confusing and potentially brittle. It appears to be testing a scenario where rank 0 calls update_ep_member() while other ranks re-initialize the Buffer object. Since both update_ep_member and Buffer.__init__ contain collective communication calls, this relies on the sequence of collectives being identical in both paths.

While this might work for now, it's not a clear or robust way to test the functionality. A small change in either __init__ or update_ep_member could break this test in subtle ways.

A clearer approach would be to have all ranks call update_ep_member() to test its correctness collectively. If the goal is to test a mixed-state scenario, the test should be documented clearly to explain the scenario it's simulating.

A more direct test for update_ep_member would be:

buffer = Buffer(group, num_ep_buffer_bytes=num_ep_buffer_bytes)
# ... any operations ...
# All ranks call update_ep_member to resynchronize
group.barrier() # if needed
buffer.update_ep_member()
# ... continue testing ...

Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Fixed
@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@UNIDY2002 UNIDY2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Need to update some code.

# Use fast-path only if runtime says it's safe
self._use_fallback = not use_fast_path

def update_ep_member(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has duplicated code with init

Consider merging?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I remembered, the CI does not check format regarding Python code.

So, I think it would be better if you keep unrelated lines unchanged, or it might increase chances of conflicts with other collaborators.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used black to format it. And if leaving init untouched, I don't know how to merge?

gathered = torch.empty(
(num_valid, hidden), dtype=torch.bfloat16, device=x.device
)
src_meta = torch.empty(
Copy link
Copy Markdown
Collaborator

@UNIDY2002 UNIDY2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a comment here.

Comment thread mooncake-wheel/tests/test_mooncake_ep.py Outdated
Co-authored-by: Xun Sun <UNIDY2002@outlook.com>
@UNIDY2002 UNIDY2002 changed the title EP Member Update [EP] In-place Member Update Mar 9, 2026
@UNIDY2002 UNIDY2002 merged commit 59ae11f into main Mar 9, 2026
11 of 18 checks passed
@UNIDY2002 UNIDY2002 deleted the mapc-upd branch March 9, 2026 03:51
Socratesa pushed a commit to Socratesa/Mooncake that referenced this pull request Mar 20, 2026
whn09 pushed a commit to whn09/Mooncake that referenced this pull request Apr 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants