[None][feat] Add DWDP (Distributed Weight Data Parallelism) support for MoE inference#12136
Conversation
syuoni
left a comment
There was a problem hiding this comment.
This is a nice feature, great to see it's going to production.
Regarding the CuTeDSL MoE interface, overall I would suggest:
- Use two different ops for single-b and multi-b cases.
- Above the op level, we use two different code paths; this avoids confusion to users -- most users would use the single-b op only.
- Below the op level, we unify them by multi-b implementation; this simplifies implementation and avoids code duplication.
Thanks!
1d3e509 to
3532338
Compare
…or MoE inference Core DWDP runtime (dwdp.py): - DwdpManager: IPC handle exchange across MPI ranks - DwdpHandleCollector: per-layer weight/scale/bias handle collection - Expert weight prefetching with double-buffering MoE integration (configurable_moe.py, fused_moe_cute_dsl.py, interface.py): - DWDP support in ConfigurableMoE with CuteDSL+NVFP4 backend - NvFp4WeightView for DWDP weight access patterns - Contiguous gather/scatter grouped GEMM kernels CuteDSL kernel extensions: - Blockscaled contiguous gather grouped GEMM with SwiGLU fusion - Blockscaled contiguous grouped GEMM finalize fusion Executor integration (py_executor.py, py_executor_creator.py, llm_args.py): - DwdpConfig dataclass for YAML-based configuration - DwdpManager initialization and per-step prefetching Disaggregated serving scripts: - start_worker_dwdp.sh for MPI-based worker launch - submit.py DWDP configuration support CI test: - DWDP accuracy test with DeepSeek-V3-Lite (NVFP4, 4 GPUs, GSM8K) Co-authored-by: wanqian-nv <221923321+wanqian-nv@users.noreply.github.com> Co-authored-by: zongfeijing <20381269+zongfeijing@users.noreply.github.com> Signed-off-by: tianyuz-nv <tianyuz@nvidia.com>
3532338 to
bbe48fa
Compare
|
/bot run |
|
/bot run --disable-fail-fast |
|
PR_Github #39672 [ run ] triggered by Bot. Commit: |
|
PR_Github #39672 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
Signed-off-by: tianyuz-nv <tianyuz@nvidia.com>
Signed-off-by: tianyuz-nv <tianyuz@nvidia.com>
…matting - Remove DwdpConfig.from_dict() to comply with Pydantic best practices test - Initialize GroupedGemmInputsHelper in test_nvfp4_gather_grouped_gemm_swiglu_blackwell - Apply pre-commit formatting (isort, yapf, ruff, autoflake, trailing whitespace) Signed-off-by: Tianyu Zhang <tianyuz@nvidia.com> Signed-off-by: tianyuz-nv <tianyuz@nvidia.com> Made-with: Cursor
|
/bot run --disable-fail-fast |
|
PR_Github #39925 [ run ] triggered by Bot. Commit: |
|
PR_Github #39925 [ run ] completed with state
|
|
PR_Github #40992 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41218 [ run ] triggered by Bot. Commit: |
|
PR_Github #41218 [ run ] completed with state
|
|
/bot run --stage-list "DGX_B300-4_GPUs-PyTorch-1" |
|
PR_Github #41281 [ run ] triggered by Bot. Commit: |
|
PR_Github #41281 [ run ] completed with state
|
|
/bot run --stage-list "DGX_B300-4_GPUs-PyTorch-1" --disable-fail-fast |
|
PR_Github #41300 [ run ] triggered by Bot. Commit: |
|
PR_Github #41300 [ run ] completed with state |
|
/LLM/main/L0_MergeRequest_PR pipeline #32181 and /LLM/main/L0_MergeRequest_PR pipeline #32254 (Partly Tested) have covered all CI tests, and all success. |
|
/bot skip |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot skip --comment '/LLM/main/L0_MergeRequest_PR pipeline #32181 and /LLM/main/L0_MergeRequest_PR pipeline #32254 (Partly Tested) have covered all CI tests, and all success.' |
|
PR_Github #41320 Bot args parsing error: Failed to parse bot args |
|
/bot skip --comment "L0 MergeRequest PR pipelines 32181 and 32254 (partly tested) already covered CI; all success." |
|
PR_Github #41325 [ skip ] triggered by Bot. Commit: |
|
PR_Github #41325 [ skip ] completed with state |
…nfigurableMoE load_weights PR NVIDIA#12136 (DWDP) added a load_weights override in CuteDslFusedMoE that dropped the allow_partial_loading parameter from the base class signature. ConfigurableMoE.load_weights also lacked this parameter. This causes TypeError when qwen2_moe_weight_mapper calls module.load_weights(weights=..., allow_partial_loading=...) on models using the CuteDSL or ConfigurableMoE backend (e.g., Qwen3 MoE). Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
…or MoE inference (NVIDIA#12136) Signed-off-by: tianyuz-nv <tianyuz@nvidia.com> Signed-off-by: Tianyu Zhang <tianyuz@nvidia.com> Signed-off-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com> Co-authored-by: wanqian-nv <221923321+wanqian-nv@users.noreply.github.com> Co-authored-by: zongfeijing <20381269+zongfeijing@users.noreply.github.com> Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
… unused by VA path) Commit 3 of the DWDP IPC->VA refactor. Three files (custom op wrapper + two blackwell kernels) are restored verbatim to their pre-PR-NVIDIA#12136 state. These multi-B paths were introduced purely to support DWDP's IPC scheme, which passed N peer expert shards as separate B tensors into each kernel call. The VA pipeline swaps param.data to a single [num_experts, ...] tensor via cuMemMap, so the standard single-B kernel path handles every case — the multi-B parameters, MAX_B_TENSORS branches, and tuple-ified b/sfb/alpha signatures become dead code. Files reverted to the commit before e92ee4f (PR NVIDIA#12136): * _torch/custom_ops/cute_dsl_custom_ops.py - Removes *_multi_b custom op registrations - Restores GatherGroupedGemmInputsHelper to single-tensor layout * _torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py - Removes b_tensor_l_sizes param, MAX_B_TENSORS, num_b_tensors const_expr branches for 1/2/3/4 B tensors, _make_tma_b helper, kernel-side gB_nkl_0..3 / gSFB_nkl_0..3 expansions, tuple'd signatures * _torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py - Same pattern Verified no commits touched these files between PR NVIDIA#12136 and HEAD, so the revert is surgical and does not risk clobbering unrelated work: $ git log --oneline e92ee4f..HEAD -- <file> # empty for all 3 Net change: +371 / -1519 = -1148 lines in the three files. Smoke tested: 78 pytest passes (DWDP units + full api_stability), plus runtime import checks confirm both kernel modules and the custom ops module load cleanly after the revert. Co-Authored-By: dongxuy04 <dongxuy@nvidia.com> Signed-off-by: tianyuz-nv <tianyuz@nvidia.com>
`num_groups` has been a dead schema field in DwdpConfig from PR NVIDIA#12136 (the original IPC implementation) through commit 1fbc0d4: read into DwdpManager but never consumed by the runtime. Mis-configured YAMLs where the user's declared topology disagrees with the launch were left to fail mysteriously deeper in MPI sub-communicator creation, or to silently accept a launch that didn't match the schema. Convert the field into runtime validation in DwdpManager.__init__: 1. num_groups must be positive. 2. This rank's computed group_id (`rank // dwdp_size`) must be less than num_groups, so the launch hasn't started more CTX workers than the declared topology can hold. 3. `num_groups * dwdp_size <= MPI world size`, so the world is large enough to fit all declared groups. The three checks together catch over-subscription, world-size under-allocation, and obviously bogus values, while remaining local (no extra inter-rank communication required since DwdpConfig is identical on every rank). Verification: - Unit tests: 61 passed + 4 subtests (4 new num_groups cases in tests/unittest/_torch/executor/test_dwdp_manager.py) - Accuracy (DSv3-Lite + dwdp=2 + GSM8K): PASS, above the 61.5% threshold (3min 52s) - Perf DEP baseline: 12.95 req/s - Perf DWDP=4: 14.26 req/s (+10.1% over DEP) - Perf DWDP=8 cross-tray: 27.21 req/s (1.91x DWDP=4 scaling) Signed-off-by: tianyuz-nv <tianyuz@nvidia.com>
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
Description
In the context phase of LLM inference, workload imbalances and communication bottlenecks often lead to excessive synchronization overhead, limiting GPU utilization. To resolve this, we propose Distributed Weight Data Parallelism (DWDP), a strategy that leverages Data Parallelism combined with NVLink-based weight offloading to enable fully asynchronous execution across ranks.
Key properties of DWDP:
A detailed technical report on DWDP internals and optimizations will be published separately. We welcome discussions and feedback.
Changes
Core DWDP runtime (
tensorrt_llm/_torch/pyexecutor/dwdp.py— new file):DwdpManager: Orchestrates IPC handle exchange across MPI ranks for zero-copy expert weight sharingDwdpHandleCollector: Per-layer collector that gathers CUDA IPC handles for weight/scale/bias tensorsMoE integration (
configurable_moe.py,fused_moe_cute_dsl.py,interface.py):ConfigurableMoEwith automatic detection of compatible backends (CuteDSL + NVFP4)NvFp4WeightViewdataclass for clean separation of DWDP vs non-DWDP weight access patternsCuteDSL kernel extensions (
cute_dsl_custom_ops.py,blockscaled_contiguous_*_fusion.py):Executor integration (
py_executor.py,py_executor_creator.py,_util.py,llm_args.py):DwdpConfigdataclass toLlmArgsfor YAML-based DWDP configurationDwdpManagerduring executor creation when DWDP is enabledDisaggregated serving scripts (
examples/disaggregated/slurm/benchmark/):start_worker_dwdp.shfor launching DWDP workers viampirunsubmit.pywith DWDP-specific configuration (dwdp_size, num_group, experts_per_worker, etc.)Test Coverage
tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py: Unit tests for DWDP-specific CuteDSL MoE kernels (contiguous gather grouped GEMM with SwiGLU fusion, finalize fusion)tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py: Standalone kernel correctness tests for DWDP gather GEMMtests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_grouped_gemm_finalize_fusion.py: Standalone kernel correctness tests for DWDP finalize GEMMtests/integration/defs/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_dwdp_accuracy: End-to-end DWDP disaggregated serving accuracy test with DeepSeek-V3-Lite (NVFP4, 4 GPUs, GSM8K)PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.