Skip to content

[Misc] Mooncake EP & Mooncake Backend#805

Merged
alogfans merged 91 commits intomainfrom
sunxun/mooncake-backend-dev
Sep 26, 2025
Merged

[Misc] Mooncake EP & Mooncake Backend#805
alogfans merged 91 commits intomainfrom
sunxun/mooncake-backend-dev

Conversation

@UNIDY2002
Copy link
Copy Markdown
Collaborator

In this PR, we propose Mooncake EP and the Mooncake Backend.

Mooncake EP is an adaptation of DeepEP that supports fault tolerance for large-scale MoE inference. It remains API-compatible with DeepEP, with an extra broken_ranks tensor to track failed ranks.

Mooncake Backend is a PyTorch distributed backend, designed as a fault-tolerant replacement for NCCL and Gloo. it can continue to perform collective communication under rank failures and reports them to upper layers for graceful handling.

Read more at doc/en/ep-backend.md.


Tests

Since the C++ APIs are not intended for direct use, no C++ unit tests are provided. Instead, three Python unit tests are included under mooncake-wheel/tests/:

  • test_mooncake_ep.py: Adapted from DeepEP’s test_low_latency.py. Verifies the correctness of the EP APIs and includes a basic performance test.
  • test_mooncake_backend.py: Validates the correctness of the Mooncake Backend.
  • test_mooncake_backend_perf.py: Compares the performance of the Mooncake Backend against NCCL and Gloo.

Performance

Tested on a 8 * H100 node.

Mooncake EP (pure RDMA)

Impl Dispatch bandwidth Dispatch latency Combine bandwidth Combine latency
Mooncake 41 GB/s 184 us 38 GB/s 387 us
DeepEP 46 GB/s 163 us 46 GB/s 318 us

Mooncake Backend

Here is the preliminary performance result of the Mooncake Backend. Further optimizations will be done in the future.

All data are in microseconds.

Mooncake v.s. Gloo

Allgather

Data Size Mooncake Gloo
1K 94 681
4K 125 834
16K 288 1121
64K 928 6253
256K 3715 8163
1M 7929 37067
4M 31239 142334

Allreduce

Data Size Mooncake Gloo
1K 87 1334
4K 163 1358
16K 476 1482
64K 1623 1606
256K 6382 2202
1M 23194 5324
4M 92664 15734

Broadcast

Data Size Mooncake Gloo
1K 61 101
4K 87 129
16K 142 177
64K 389 449
256K 1389 1130
1M 1662 2759
4M 7876 11559

Mooncake v.s. NCCL

Allgather

Data Size Mooncake NCCL
1K 67 93
4K 69 88
16K 78 93
64K 122 84
256K 293 81
1M 1038 178
4M 4158 521

Allreduce

Data Size Mooncake NCCL
1K 57 34
4K 60 30
16K 77 31
64K 122 30
256K 300 31
1M 1112 53
4M 14421 119

Broadcast

Data Size Mooncake NCCL
1K 50 28
4K 38 26
16K 47 27
64K 100 28
256K 246 34
1M 834 28
4M 3196 68

Comment thread doc/en/ep-backend.md Outdated
@whybeyoung
Copy link
Copy Markdown
Collaborator

Amazing work!

Comment thread mooncake-wheel/setup.py
Comment on lines +112 to +148
if int(os.getenv("BUILD_WITH_EP", "0")):
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
abi_flag = int(torch._C._GLIBCXX_USE_CXX11_ABI)
current_dir = os.path.abspath(os.path.dirname(__file__))
ext_modules = [
CUDAExtension(
name="mooncake.ep",
include_dirs=[
os.path.join(current_dir, "../mooncake-ep/include"),
os.path.join(current_dir, "../mooncake-transfer-engine/include"),
],
sources=["../mooncake-integration/ep/ep_py.cpp"],
extra_compile_args={
"cxx": [f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}", "-std=c++20"],
"nvcc": [f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}", "-std=c++20"],
},
libraries=["ibverbs", "mlx5"],
extra_objects=[
os.path.join(current_dir, "../build/mooncake-ep/src/libmooncake_ep.a"),
os.path.join(current_dir, "mooncake/engine.so"),
],
)
]
setup(
distclass=BinaryDistribution,
cmdclass={
"bdist_wheel": CustomBdistWheel,
"build_ext": BuildExtension,
},
ext_modules=ext_modules,
)
else:
setup(
distclass=BinaryDistribution,
cmdclass={"bdist_wheel": CustomBdistWheel},
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=c++20 the minimum required version? cc: @xiaguan

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mooncake Store needs C++20, others could probably use a lower C++ standard like C++17.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that a C++20 feature is used here (starts_with)

if (server_name.starts_with("[")) {

Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Outdated
def dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, broken_ranks: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, timeout_us: int,
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed Tuple[torch.Tensor, torch.Tensor] to Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]

Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Outdated
Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Outdated
@ShangmingCai
Copy link
Copy Markdown
Collaborator

I have another urgent PR need to test and review today, will continue with this PR tomorrow.

@alogfans Please take some time to review this PR as well.

Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Outdated
Comment thread mooncake-wheel/mooncake/mooncake_ep_buffer.py Outdated
TORCH_CHECK(tensorSize * meta->size < kBufferSize, "Too large!");
auto future = c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
int taskId = cpuTaskCount % 2;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need a comment here for clarification?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment is added.

.attr("__version__")
.attr("split")("+")
.cast<std::vector<std::string>>()[0];
TORCH_CHECK(version == "2.8.0", "Mooncake Backend requires torch==2.8.0");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use >= in case SGLang/vLLM requires a newer version of PyTorch?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid a strict equal is required here, as the Mooncake lib should match the libtorch C++ ABI.

If SGLang/vLLM require a newer version of PyTorch, perhaps we have to recompile Mooncake with the corresponding PyTorch version. (Or, to be optimistic, we might figure out a better solution in the following versions.)

Comment thread mooncake-ep/include/mooncake_ep_buffer.h Outdated
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a huge PR. I have finished several rounds of basic reviews with some easy-to-fix problems. I think we can merge this first after addressing the above comments to see if we can get some user feedback. CC: @alogfans, better take a look before merging this PR.

@UNIDY2002
Copy link
Copy Markdown
Collaborator Author

@ShangmingCai Thanks for your review and valuable feedbacks! I'll fix the issues.

@alogfans
Copy link
Copy Markdown
Collaborator

I agree with @ShangmingCai, merge it first.

@alogfans alogfans merged commit c5829aa into main Sep 26, 2025
13 checks passed
@UNIDY2002 UNIDY2002 deleted the sunxun/mooncake-backend-dev branch September 29, 2025 12:36
wanyue-wy pushed a commit to wanyue-wy/Mooncake that referenced this pull request Dec 14, 2025
* Initialize a mooncake backend

* Add pybind

* Fix incorrect backend registration

* Fix wheel building of mooncake_ep

* Add a fake allreduce implementation

* Introduce transfer_engine to mooncake_backend

* Add a basic CPU proxy execution framework

* Implement a seemingly working allgather

* Remove mooncake_ep's dependency on etcd

* Implement `_allgather_base`

* Implement `allreduce`

* Implement `alltoall`

* Use an even-odd pattern for data transfer

* Add a `set_host_ip` method

* Switch to an extended-API implementation of the Mooncake backend

* Implement `broadcast`

* Implement `barrier`

* Extend Mooncake backend to CPU

* Support more operations for reduction

* Fix the backend-worker coordination logic

* Optimize CPU worker with a callback pattern

* Add a timeout-based broken-ranks detection

* Merge EP module into Mooncake's build system

* Share transfer buffer across all worker instances

* Switch to a more robust approach to detect broken ranks

* Specify CUDA device for test_mooncake_backend.py

* Explicitly stop mooncake worker

* Use transfer engine's notifications to implement collective signals

* Remove the unused `all_reduce_without` API

* Switch to mooncake backend for test_mooncake_ep.py

* Support both IB and RoCE

* Fix EP unit test

* Pass the auto-detected nic_id to EP Buffer

* Fix CMake conditional branches when `PYTORCH_CMAKE_PATH` is not set

* Fix ibgda syncing for RoCE

* Revert "Share transfer buffer across all worker instances"

This reverts commit 964e0a9

* Implement `_reduce_scatter_base`

* Make CPU backends aware of broken ranks

* Fix .typos.toml

* Add a perf test for mooncake backend

* Support more dtypes for reduction

* Revert "Use transfer engine's notifications to implement collective signals"

This reverts commit f20ffb2

* Share worker thread among all process groups

* Share transfer engine among all process groups

* Fix unit tests

* Add a warmup phase for transfer engine

* Fix transfer engine buffer locations

* Fix incorrect calculation of mooncake ep buffer

* Do not use timeout detection in mooncake_ep tests

* Update mooncake backend perf test

* Demangle per-group buffer offset from the shared taskId

* Stop allocating the useless `cuda_counter_buffer` and `cuda_data_buffer`

* Split the task list into a CPU region and a CUDA region

* Add a warmup for test_mooncake_backend_perf.py

* Switch from raw cudaEvent to `torch::Event`

* Fix MooncakeWorkCuda::wait() to make it compatible with cuda graphs

* Add doc

* Fix perf test

* Implement all-gather for perf test

* Move impl of `MooncakeEpBuffer`'s member functions to .cpp

* Change `gathered_experts` to `broken_nodes` to make the API more consistent

* `broken_nodes` should be `broken_ranks`

* API rename

* Fix format

* Enable WITH_EP option in CI

* Try installing torch in advance in CI

* Set `TORCH_CUDA_ARCH_LIST` in CMakeLists.txt

* Install required dependencies in the CI CUDA environment

* [CI] Add the matching PyTorch

* [CI] Add a workaround for missing `CUDA::nvToolsExt`

* Remove unused pybind base class declaration of `MooncakeBackendOptions`

* Support `set_device_filter`

* Remove unused headers for ep_py.cpp

* Build the EP-wheel with setuptools on CI

* [CI] Add the build-with-ep process to release.yaml

* Minor format fix

* Update build guide

* Fix docs

* Only build EP wheel with torch==2.8.0

* Add a torch version assertion for Mooncake Backend

* Fix some python typing

* Use the correct group for EP's initial data sharing

* API: invert `broken_ranks` and change into `active_ranks`

* Followup fix for inverting the API

* Fix format

* Bug-fix in mooncake_ep_kernel.cu

* Mooncake EP has to be built with USE_CUDA on

* Fixed some issues according to the review

* Fix bug
JasonZhang517 pushed a commit to JasonZhang517/Mooncake that referenced this pull request Feb 9, 2026
* Initialize a mooncake backend

* Add pybind

* Fix incorrect backend registration

* Fix wheel building of mooncake_ep

* Add a fake allreduce implementation

* Introduce transfer_engine to mooncake_backend

* Add a basic CPU proxy execution framework

* Implement a seemingly working allgather

* Remove mooncake_ep's dependency on etcd

* Implement `_allgather_base`

* Implement `allreduce`

* Implement `alltoall`

* Use an even-odd pattern for data transfer

* Add a `set_host_ip` method

* Switch to an extended-API implementation of the Mooncake backend

* Implement `broadcast`

* Implement `barrier`

* Extend Mooncake backend to CPU

* Support more operations for reduction

* Fix the backend-worker coordination logic

* Optimize CPU worker with a callback pattern

* Add a timeout-based broken-ranks detection

* Merge EP module into Mooncake's build system

* Share transfer buffer across all worker instances

* Switch to a more robust approach to detect broken ranks

* Specify CUDA device for test_mooncake_backend.py

* Explicitly stop mooncake worker

* Use transfer engine's notifications to implement collective signals

* Remove the unused `all_reduce_without` API

* Switch to mooncake backend for test_mooncake_ep.py

* Support both IB and RoCE

* Fix EP unit test

* Pass the auto-detected nic_id to EP Buffer

* Fix CMake conditional branches when `PYTORCH_CMAKE_PATH` is not set

* Fix ibgda syncing for RoCE

* Revert "Share transfer buffer across all worker instances"

This reverts commit 964e0a9

* Implement `_reduce_scatter_base`

* Make CPU backends aware of broken ranks

* Fix .typos.toml

* Add a perf test for mooncake backend

* Support more dtypes for reduction

* Revert "Use transfer engine's notifications to implement collective signals"

This reverts commit f20ffb2

* Share worker thread among all process groups

* Share transfer engine among all process groups

* Fix unit tests

* Add a warmup phase for transfer engine

* Fix transfer engine buffer locations

* Fix incorrect calculation of mooncake ep buffer

* Do not use timeout detection in mooncake_ep tests

* Update mooncake backend perf test

* Demangle per-group buffer offset from the shared taskId

* Stop allocating the useless `cuda_counter_buffer` and `cuda_data_buffer`

* Split the task list into a CPU region and a CUDA region

* Add a warmup for test_mooncake_backend_perf.py

* Switch from raw cudaEvent to `torch::Event`

* Fix MooncakeWorkCuda::wait() to make it compatible with cuda graphs

* Add doc

* Fix perf test

* Implement all-gather for perf test

* Move impl of `MooncakeEpBuffer`'s member functions to .cpp

* Change `gathered_experts` to `broken_nodes` to make the API more consistent

* `broken_nodes` should be `broken_ranks`

* API rename

* Fix format

* Enable WITH_EP option in CI

* Try installing torch in advance in CI

* Set `TORCH_CUDA_ARCH_LIST` in CMakeLists.txt

* Install required dependencies in the CI CUDA environment

* [CI] Add the matching PyTorch

* [CI] Add a workaround for missing `CUDA::nvToolsExt`

* Remove unused pybind base class declaration of `MooncakeBackendOptions`

* Support `set_device_filter`

* Remove unused headers for ep_py.cpp

* Build the EP-wheel with setuptools on CI

* [CI] Add the build-with-ep process to release.yaml

* Minor format fix

* Update build guide

* Fix docs

* Only build EP wheel with torch==2.8.0

* Add a torch version assertion for Mooncake Backend

* Fix some python typing

* Use the correct group for EP's initial data sharing

* API: invert `broken_ranks` and change into `active_ranks`

* Followup fix for inverting the API

* Fix format

* Bug-fix in mooncake_ep_kernel.cu

* Mooncake EP has to be built with USE_CUDA on

* Fixed some issues according to the review

* Fix bug
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants