Port AllGather/ReduceScatter from TensorRT-LLM#1145
Port AllGather/ReduceScatter from TensorRT-LLM#1145wenscarl wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @wenscarl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request ports essential collective communication primitives, specifically AllGather and ReduceScatter, along with their underlying MPI and NCCL utility infrastructure from TensorRT-LLM. This work lays the foundation for enabling various forms of distributed parallelism (like Tensor Parallelism) within FlashInfer, allowing operations to scale across multiple GPUs and nodes. A new Python Mapping class is introduced to manage these complex parallelism configurations.
Highlights
- Ported Collective Operations: Introduced AllGather and ReduceScatter collective communication operations, ported from TensorRT-LLM, enabling multi-GPU/multi-node capabilities for relevant parts of the codebase.
- MPI/NCCL Utilities: Ported core MPI and NCCL utility functions and classes (like
MpiComm,MpiRequest, and NCCL communicator management) necessary for distributed operations, guarded by theENABLE_MULTI_DEVICEflag. - Parallelism Mapping: Added a new
Mappingclass in Python to represent and manage various parallelism configurations (Tensor Parallelism, Pipeline Parallelism, Context Parallelism, MoE Parallelism, etc.) and provide rank/group information. - PyTorch Custom Operations: Implemented PyTorch custom CUDA operations (
trtllm_allgather,trtllm_reducescatter) that wrap the underlying NCCL calls, making these collectives available in the Python frontend. - Python API Wrappers: Added user-friendly Python functions (
all_gather,reduce_scatter) inflashinfer.commthat utilize the new custom operations and theMappingobject to handle distributed tensors. - Integration and Testing: Integrated the new C++/CUDA files into the JIT compilation process and added multi-GPU tests using
mpi4pyto verify the correctness of the ported collective operations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request ports AllGather and ReduceScatter functionalities from TensorRT-LLM, introducing new C++/CUDA custom operations for PyTorch, along with MPI/NCCL utilities and Python bindings. A new Mapping class is added to manage distributed configurations, and tests for these collective operations are included. Key areas for improvement include addressing file duplication, enhancing portability of build configurations (MPI/NCCL path discovery), fixing potential runtime errors, and ensuring code correctness in the new Mapping class and C++ utilities.
| MPICHECK(MPI_Type_size(dtype, &elementSize)); | ||
| elementSize = std::min<int>(elementSize, alignment); | ||
|
|
||
| // We cap at max alignment-bytes chunks that can be sent at once. | ||
| auto const step = maxP1 - (alignment / elementSize); |
There was a problem hiding this comment.
There's a potential division by zero if elementSize is zero. MPI_Type_size could return 0 for certain types (e.g., MPI_DATATYPE_NULL or if an error occurs, though MPICHECK should catch MPI errors). If elementSize becomes 0, the calculation alignment / elementSize in step = maxP1 - (alignment / elementSize) would lead to a runtime error. Consider adding a check for elementSize <= 0 before this line.
| jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", | ||
| mpi_include_path, | ||
| ], | ||
| extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi", "-L/usr/lib/aarch64-linux-gnu/ -lnccl"], |
There was a problem hiding this comment.
The NCCL library path -L/usr/lib/aarch64-linux-gnu/ is hardcoded. This will cause build failures on systems with different architectures (e.g., x86_64) or where NCCL is installed in a non-standard location. NCCL paths should be discovered dynamically (e.g., from NCCL_HOME environment variable or by searching standard library paths) or configurable by the user.
extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi"], # NCCL path needs to be discovered or configurable
# Example: Add discovered NCCL lib path and -lnccl here| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include "tensorrt_llm/common/opUtils.h" | ||
| #include "tensorrt_llm/runtime/utils/mpiTags.h" | ||
| #include "tensorrt_llm/runtime/utils/mpiUtils.h" | ||
|
|
||
| #include "cuda.h" | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_fp8.h> | ||
|
|
||
| #include <functional> | ||
| #include <mutex> | ||
| #include <thread> | ||
|
|
||
| #if ENABLE_MULTI_DEVICE | ||
|
|
||
| std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap() | ||
| { | ||
| static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = { | ||
| {nvinfer1::DataType::kFLOAT, ncclFloat32}, | ||
| {nvinfer1::DataType::kHALF, ncclFloat16}, | ||
| {nvinfer1::DataType::kBF16, ncclBfloat16}, | ||
| {nvinfer1::DataType::kFP8, ncclInt8}, | ||
| {nvinfer1::DataType::kBOOL, ncclInt8}, | ||
| {nvinfer1::DataType::kINT32, ncclInt32}, | ||
| {nvinfer1::DataType::kINT64, ncclInt64}, | ||
| {nvinfer1::DataType::kUINT8, ncclUint8}, | ||
| {nvinfer1::DataType::kINT8, ncclInt8}, | ||
| }; | ||
| return &dtypeMap; | ||
| } | ||
|
|
||
| namespace | ||
| { | ||
|
|
||
| // Get NCCL unique ID for a group of ranks. | ||
| ncclUniqueId getUniqueId(std::set<int> const& group) | ||
| { | ||
| auto const rank = COMM_SESSION.getRank(); | ||
| TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
| ncclUniqueId id; | ||
| if (rank == *group.begin()) | ||
| { | ||
| NCCLCHECK_THROW(ncclGetUniqueId(&id)); | ||
| for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) | ||
| { | ||
| COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); | ||
| } | ||
| TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
| return id; | ||
| } | ||
| } // namespace | ||
|
|
||
| std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group) | ||
| { | ||
| auto const rank = COMM_SESSION.getRank(); | ||
| TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
| static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap; | ||
| static std::mutex mutex; | ||
| std::lock_guard<std::mutex> lock(mutex); | ||
| std::ostringstream oss; | ||
| int index = 0; | ||
| for (auto const& rank : group) | ||
| { | ||
| if (index != 0) | ||
| { | ||
| oss << ","; | ||
| } | ||
| oss << rank; | ||
| index++; | ||
| } | ||
| auto groupStr = oss.str(); | ||
| auto it = commMap.find(group); | ||
| if (it != commMap.end()) | ||
| { | ||
| auto ncclComm = it->second; | ||
| TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); | ||
| return ncclComm; | ||
| } | ||
|
|
||
| TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); | ||
| ncclUniqueId id = getUniqueId(group); | ||
| int groupRank = 0; | ||
| for (auto const& currentRank : group) | ||
| { | ||
| if (rank == currentRank) | ||
| break; | ||
| ++groupRank; | ||
| } | ||
| TLLM_CHECK(static_cast<size_t>(groupRank) < group.size()); | ||
| std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t, | ||
| [](ncclComm_t* comm) | ||
| { | ||
| ncclCommDestroy(*comm); | ||
| delete comm; | ||
| }); | ||
| // Need static connection initialization for accurate KV cache size estimation | ||
| #if defined(_WIN32) | ||
| if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) | ||
| _putenv_s("NCCL_RUNTIME_CONNECT", "0"); | ||
| #else | ||
| setenv("NCCL_RUNTIME_CONNECT", "0", 0); | ||
| #endif // _WIN32 | ||
| NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); | ||
| commMap[group] = ncclComm; | ||
| TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
| return ncclComm; | ||
| } | ||
| #endif // ENABLE_MULTI_DEVICE No newline at end of file |
There was a problem hiding this comment.
This file appears to be identical to csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp. Code duplication can lead to maintenance issues. Please consolidate them into a single file. Based on the include paths used in flashinfer/comm.py (nv_internal/cpp/common/opUtils.cpp), this location seems more appropriate if tensorrt_llm/common/opUtils.h is the intended header.
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include "tensorrt_llm/common/opUtils.h" | ||
| #include "tensorrt_llm/runtime/utils/mpiTags.h" | ||
| #include "tensorrt_llm/runtime/utils/mpiUtils.h" | ||
|
|
||
| #include "cuda.h" | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_fp8.h> | ||
|
|
||
| #include <functional> | ||
| #include <mutex> | ||
| #include <thread> | ||
|
|
||
| #if ENABLE_MULTI_DEVICE | ||
|
|
||
| std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap() | ||
| { | ||
| static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = { | ||
| {nvinfer1::DataType::kFLOAT, ncclFloat32}, | ||
| {nvinfer1::DataType::kHALF, ncclFloat16}, | ||
| {nvinfer1::DataType::kBF16, ncclBfloat16}, | ||
| {nvinfer1::DataType::kFP8, ncclInt8}, | ||
| {nvinfer1::DataType::kBOOL, ncclInt8}, | ||
| {nvinfer1::DataType::kINT32, ncclInt32}, | ||
| {nvinfer1::DataType::kINT64, ncclInt64}, | ||
| {nvinfer1::DataType::kUINT8, ncclUint8}, | ||
| {nvinfer1::DataType::kINT8, ncclInt8}, | ||
| }; | ||
| return &dtypeMap; | ||
| } | ||
|
|
||
| namespace | ||
| { | ||
|
|
||
| // Get NCCL unique ID for a group of ranks. | ||
| ncclUniqueId getUniqueId(std::set<int> const& group) | ||
| { | ||
| auto const rank = COMM_SESSION.getRank(); | ||
| TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
| ncclUniqueId id; | ||
| if (rank == *group.begin()) | ||
| { | ||
| NCCLCHECK_THROW(ncclGetUniqueId(&id)); | ||
| for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) | ||
| { | ||
| COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); | ||
| } | ||
| TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
| return id; | ||
| } | ||
| } // namespace | ||
|
|
||
| std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group) | ||
| { | ||
| auto const rank = COMM_SESSION.getRank(); | ||
| TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
| static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap; | ||
| static std::mutex mutex; | ||
| std::lock_guard<std::mutex> lock(mutex); | ||
| std::ostringstream oss; | ||
| int index = 0; | ||
| for (auto const& rank : group) | ||
| { | ||
| if (index != 0) | ||
| { | ||
| oss << ","; | ||
| } | ||
| oss << rank; | ||
| index++; | ||
| } | ||
| auto groupStr = oss.str(); | ||
| auto it = commMap.find(group); | ||
| if (it != commMap.end()) | ||
| { | ||
| auto ncclComm = it->second; | ||
| TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); | ||
| return ncclComm; | ||
| } | ||
|
|
||
| TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); | ||
| ncclUniqueId id = getUniqueId(group); | ||
| int groupRank = 0; | ||
| for (auto const& currentRank : group) | ||
| { | ||
| if (rank == currentRank) | ||
| break; | ||
| ++groupRank; | ||
| } | ||
| TLLM_CHECK(static_cast<size_t>(groupRank) < group.size()); | ||
| std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t, | ||
| [](ncclComm_t* comm) | ||
| { | ||
| ncclCommDestroy(*comm); | ||
| delete comm; | ||
| }); | ||
| // Need static connection initialization for accurate KV cache size estimation | ||
| #if defined(_WIN32) | ||
| if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) | ||
| _putenv_s("NCCL_RUNTIME_CONNECT", "0"); | ||
| #else | ||
| setenv("NCCL_RUNTIME_CONNECT", "0", 0); | ||
| #endif // _WIN32 | ||
| NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); | ||
| commMap[group] = ncclComm; | ||
| TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
| return ncclComm; | ||
| } | ||
| #endif // ENABLE_MULTI_DEVICE No newline at end of file |
| import ctypes | ||
| import functools | ||
| import math | ||
| from dataclasses import dataclass |
| def gen_comm_module() -> JitSpec: | ||
| mpi_include_path, mpi_lib_path = get_mpi_include_lib_path() | ||
| mpi_lib_path = str(mpi_lib_path[0]) | ||
| print(mpi_include_path, mpi_lib_path) |
There was a problem hiding this comment.
| TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); | ||
| } | ||
|
|
||
| } // namespace tensorrt_llm::mpi No newline at end of file |
| if tp_size != 1 or pp_size != 1 or tp_size != 1: | ||
| raise ValueError( | ||
| f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | ||
| ) |
There was a problem hiding this comment.
The condition if tp_size != 1 or pp_size != 1 or tp_size != 1: has a duplicated check for tp_size != 1. It seems the intention might have been to check cp_size != 1 as well.
| if tp_size != 1 or pp_size != 1 or tp_size != 1: | |
| raise ValueError( | |
| f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | |
| ) | |
| if tp_size != 1 or pp_size != 1 or cp_size != 1: # Assuming cp_size was intended for the last check | |
| raise ValueError( | |
| f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | |
| ) |
| TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
| return ncclComm; | ||
| } | ||
| #endif // ENABLE_MULTI_DEVICE No newline at end of file |
| print(f"Running test for world_size={world_size}") | ||
| _run_reduce_scatter_worker(rank, world_size, dtype) | ||
|
|
||
| print(f"reduce_scatter with tp = {world_size}: OK") No newline at end of file |
|
The CI error is because of mpi is missing from our dockerfile. We can add mpi installation to https://github.com/flashinfer-ai/flashinfer/blob/main/docker/Dockerfile.ci_gpu by: conda install mpi4pyafter the updated dockerfile is merged, we can manually trigger to update the docker image on dockerhub. |
<!-- .github/pull_request_template.md --> ## 📌 Description Install the python packages for CI docker: mpi4py, pynvml. They will be used for the comm ops. ## 🔍 Related Issues #1145, #1134 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
📌 Port AllGather/ReduceScatter from TensorRT-LLM
🔍 Related Issues
This PR introduces dependency on mpi4py.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes