[Distributed] add nccl primitives#280
Conversation
yaoyaoding
left a comment
There was a problem hiding this comment.
Hi @soodoshll , thanks!
Leave some comments on how to organize the nccl-related source code. I prefer to putting them under hidet.cuda.nccl submodule. This would make our code structure more clean when we add other vendor libraries like cublas and cudnn.
python/hidet/ffi/ffi.py
Outdated
| def load_nccl_library(): | ||
| global _LIB_NCCL | ||
| library_dirs = get_nccl_library_search_dirs() | ||
| for library_dir in library_dirs: | ||
| lib_nccl_paths = glob.glob(os.path.join(library_dir, 'libnccl.so*')) | ||
| if len(lib_nccl_paths) == 0: | ||
| continue | ||
| _LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0]) | ||
| library_paths['nccl'] = lib_nccl_paths[0] | ||
| break | ||
| if _LIB_NCCL is None: | ||
| raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs)) |
There was a problem hiding this comment.
Let's put this part in "hidet/cuda/nccl/ffi.py", and leave "hidet/ffi/ffi.py" to only contain the hidet runtime library.
In the future, when we want to add other library (e.g., cudnn library), we can put to "hidet/cuda/cudnn" and also its ffi.py.
python/hidet/ffi/runtime_api.py
Outdated
| class NcclUniqueId(Structure): | ||
| """ | ||
| Defined in nccl.h | ||
| """ | ||
| _fields_ = [("internal", c_byte * 128)] | ||
|
|
||
| class NcclCommunicator: | ||
| """ | ||
|
|
||
| """ | ||
| def __init__(self, handle: int): | ||
| """ | ||
| Users should not call this constructor directly. Because there are two ways of creating | ||
| a new communicator: 1) using unique_id and rank ; 2) using split. | ||
| """ | ||
| if not nccl_available(): | ||
| raise RuntimeError("NCCL Library not found.") | ||
| self._handle = handle | ||
|
|
||
| # TODO: how to ensure the following two are identical? | ||
| _comms.append(self) | ||
| runtime_api.add_nccl_comm(self) | ||
|
|
||
| def __del__(self): | ||
| """ | ||
| Should we manage the lifetime of communicator object in Python or C++? | ||
| """ | ||
| nccl_runtime_api.comm_destroy(self) | ||
|
|
||
| def split(self): | ||
| raise NotImplementedError() |
There was a problem hiding this comment.
Consider moving this to "hidet/cuda/nccl.py".
There was a problem hiding this comment.
@yaoyaoding I have some questions related to code organization:
- How should we manage the lifetime of NcclCommunicator objects in Python and ncclComm_t objects in C++? There are two choices:
- in Python as it is now. the resource is released when the python object is released. this means if we need a python api
get_nccl_comm(idx)in Python, we need to maintain a global list_commsin Python to maintain all communicators and make sure it's aligned with the ncclComm_t list in C++. However, It's likely that we don't need this API in Python,NcclCommunicator.get_id()is more important; - in C++. We don't manage the lifetime of
ncclComm_tat all, and only release them when exit.NcclCommunicatorwill only be a wrapper ofncclComm_t.
IMO, we should follow (1) and abandon get_nccl_comm(idx) in Python. No global list _comms is needed. If users think it's necessary, they can create one by themselves.
- Should we import 'nccl.py' in 'nccl/ffi.py' or the opposite? Because NcclCommunicator might depend on some runtime api, therefore we need 'nccl/ffi.py' in 'nccl.py' which might cause circular import.
There was a problem hiding this comment.
Good questions!
How should we manage the lifetime of NcclCommunicator objects in Python and ncclComm_t objects in C++?
Because different flow graphs may have different set of communicators, thus, we should create the communicators for each flow graph. Let's keep the communicators in CompiledGraph. Something like
class CompiledGraph:
def __init__(...):
...
self.nccl_comms = ... # create the communicators,
# let's store it as hidet.ffi.utils.Array(void*, num of comm)
# and create a runtime api called
# "void set_nccl_comms(int num_comm, void** comm_array)"
def run_async(...):
...
runtime_api.set_nccl_comms(len(self.nccl_comms), self.nccl_comms)
...
def __del__(...):
# destroy communicatorsFor FlowGraph.forward(...), we can raise an error for now if we find it is a distributed flow graph.
Should we import 'nccl.py' in 'nccl/ffi.py' or the opposite? Because NcclCommunicator might depend on some runtime api, therefore we need 'nccl/ffi.py' in 'nccl.py' which might cause circular import.
Consider: let the nccl runtime api return integer (the pointer to communicator), and the hidet/cuda/nccl/nccl.py imports hidet/cuda/nccl/ffi.py. The nccl ffi should only be used by nccl.py and we expose the api to the users in nccl.py.
There was a problem hiding this comment.
Thanks for your reply! I agree with the idea of attaching communicators to compiled graphs.
Maybe we can leave the modification of CompiledGraph for another PR (which might include upper-level structures like op, graph) ? Let's focus this PR on primitives, otherwise it will be huge can hard to test.
python/hidet/ffi/runtime_api.py
Outdated
| if nccl_available(): | ||
| class NCCLRuntimeAPI: | ||
| """ | ||
| Runtime APIs regarding NCCL | ||
| TODO: Exception handling | ||
| """ | ||
| _get_version = get_func('ncclGetVersion', [c_void_p], c_int) | ||
| _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int) | ||
| _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int) | ||
| _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int) | ||
|
|
||
| _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int) | ||
| _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int) | ||
|
|
||
| @staticmethod | ||
| def get_version() -> int: | ||
| version = c_int(0) | ||
| NCCLRuntimeAPI._get_version(pointer(version)) | ||
| return version.value | ||
|
|
||
| @staticmethod | ||
| def get_unique_id(comm_id:NcclUniqueId) -> None: | ||
| """ | ||
| In-place initialization of the NcclUniqueId object | ||
| """ | ||
| ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) | ||
| assert ret == 0, ret | ||
|
|
||
| @staticmethod | ||
| def comm_init_rank(ndev:int, comm_id:NcclUniqueId, rank:int) -> int: | ||
| comm = c_void_p() | ||
| ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) | ||
| assert ret == 0, ret | ||
| return comm.value | ||
|
|
||
| @staticmethod | ||
| def comm_destroy(comm:NcclCommunicator) -> None: | ||
| ret = NCCLRuntimeAPI._comm_destroy(comm._handle) | ||
| assert ret == 0 | ||
|
|
||
| nccl_runtime_api = NCCLRuntimeAPI() | ||
| _comms: List[NcclCommunicator] = [] | ||
|
|
||
| def get_nccl_comm(comm_id: int): | ||
| return _comms[comm_id] |
There was a problem hiding this comment.
Also move to hidet.cuda.nccl
python/hidet/libinfo.py
Outdated
|
|
||
| def _get_nccl_dirs(): | ||
| import site | ||
| return [os.path.join(root, 'nvidia', 'nccl') for root in site.getsitepackages()] | ||
|
|
||
| def get_nccl_include_dirs(): | ||
| return [os.path.join(root, 'include') for root in _get_nccl_dirs()] | ||
|
|
||
| def get_nccl_library_search_dirs(): | ||
| return [os.path.join(root, 'lib') for root in _get_nccl_dirs()] No newline at end of file |
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @soodoshll !
Overall LGTM. Left some minor suggestions. We can merge this first and then work on other primitives and graph level operators.
include/hidet/runtime/cuda/context.h
Outdated
| /** | ||
| * Add a NCCL communicator to the context. | ||
| */ | ||
| DLL void set_nccl_comms(void** comm, int num_comms); |
There was a problem hiding this comment.
| DLL void set_nccl_comms(void** comm, int num_comms); | |
| DLL void set_nccl_comms(int num_comms, void** comm); |
Let's keep the order consistent with other APIs in hidet.
| from hidet.cuda.nccl import NcclDataType, NcclRedOp | ||
|
|
||
|
|
||
| def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: NcclDataType, op: NcclRedOp): |
There was a problem hiding this comment.
| def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: NcclDataType, op: NcclRedOp): | |
| def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp): |
Let's pass the hidet.ir.type.DataType to this primitive and convert it into NcclDataType inside the primitive function.
| } | ||
|
|
||
| DLL void* get_nccl_comm(int idx) { | ||
| return CudaContext::global()->nccl_comms[idx]; |
There was a problem hiding this comment.
Better to add some check logic here to make sure idx < num_comms.
|
@yaoyaoding fixed. Please take a look :) |
What we have now:
nvidia-nccl-cu11ornvidia-nccl-cu12./examples/distributed/test.py)