[Distributed] all_reduce op and distributed info in graphs#284
[Distributed] all_reduce op and distributed info in graphs#284soodoshll merged 42 commits intohidet-org:mainfrom
Conversation
|
@yaoyaoding this pr is ready for review :) |
Merely assigning environment variables is insufficient for setting up dev environment now. We need to run pip to install hidet package in develop mode. Users still need to build source files written in C++ manually. Consider integrating that into `setup.py` in the future?
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @soodoshll !
I left some suggestions on the data organization and implementation.
python/hidet/cuda/nccl/comm.py
Outdated
| def init_unique_id(unqie_id: NcclUniqueId) -> None: | ||
| if not nccl_available(): | ||
| raise RuntimeError("NCCL is not available") | ||
| nccl_runtime_api.get_unique_id(unqie_id) |
There was a problem hiding this comment.
Can we define init_unique_id(...) as
def create_unique_id() -> NcclUniqueId:
...I feel the current API is not very intuitive.
There was a problem hiding this comment.
The point here is now we need the NcclUniqueId to be shared by all processes. And the current solution is
- Create a shared NcclUniqueId object;
- Launch multiple processes with the shared uniqueid object as one argument;
- Init the shared uniqueid object in process 0, which need the reference to the shared object
If we create the NcclUniqueId in process 0 after processes have been launched, it's not so easy to do the broadcast (if there's an elegant way of broadcasting, please let me know).
A workaround is to 1) create the shared object; 2) launch processes; 3) create a unique id object; 4) copy its value back to the shared object.
python/hidet/graph/flow_graph.py
Outdated
| # For distributed graphs | ||
| self.nrank = nrank | ||
| self.rank = rank | ||
| self.groups = groups | ||
|
|
There was a problem hiding this comment.
Let's define a new class called FlowGraphAttrs and define these attributes in that class. Then add a field in FlowGraph with FlowGraphAttrs type.
There was a problem hiding this comment.
something like
class FlowGraph:
def __init__(..., attrs=None):
...
self.attrs: FlowGraphAttrs = attrs if attrs else FlowGraphAttrs()
python/hidet/graph/flow_graph.py
Outdated
| def is_distributed(self): | ||
| return self.nrank is not None or self.rank is not None | ||
|
|
||
| def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None): | ||
| self.nrank = nrank | ||
| self.rank = rank | ||
| self.groups = groups | ||
|
|
There was a problem hiding this comment.
Let's define thses functions at the module that will use these functionality, instead of defining them as FlowGraph methods.
There was a problem hiding this comment.
I have replaced them with set_attrs
| self.comm_id = comm_id | ||
| self.op = op | ||
|
|
||
| super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) |
There was a problem hiding this comment.
Better also add comm_id and op to attributes, so that the user can see the comm_id and op when compiling the task.
| return f"all_reduce_{self.op}_{self.comm_id}" | ||
|
|
||
| def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: | ||
| # we may need current rank here to avoid duplicated working_dirs |
There was a problem hiding this comment.
Could you clarify the problem here? Thanks.
There was a problem hiding this comment.
If we add the comm_id to attributes, then the op hash would be different.
There was a problem hiding this comment.
if we run the compilation concurrently in multiple processes, for the same op, there might be race conditions in the local filesystem.
| comms_array = comms_to_array(self.nccl_comms) | ||
| runtime_api.set_nccl_comms(comms_array) |
There was a problem hiding this comment.
Let's create this when initialize the dist-related info, to avoid repeating creating the comm Array.
| self.cpu_workspace: Optional[Storage] = None | ||
|
|
||
| # distributed properties | ||
| self.dist_info: Optional[GraphDistributedInfo] = dist_info |
There was a problem hiding this comment.
Better to put this in GraphMetaData.
There was a problem hiding this comment.
I think a better idea is to put the FlowGraphAttr in the GraphMetaData as a whole, instead of reiterating all attributes. But then where should we put FlowGraphAttr? Putting it in flow_graph.py will cause circular import.
|
|
||
| # distributed properties | ||
| self.dist_info: Optional[GraphDistributedInfo] = dist_info | ||
| self.nccl_comms: List[NcclCommunicator] = [] |
There was a problem hiding this comment.
store it as Array of NcclCommunicator directly, to avoid repeating creating the Array in run_async.
There was a problem hiding this comment.
Array of NcclCommunicator cannot be directly passed into C++. C++ needs an array of ncclComm_t, which is basically the handle of NcclCommunicator. And to avoid NcclCommunicators being released by GC, we need to maintain the list of NcclCommunicator. If we also maintain the ncclComm_t array, we will have two redundant arrays which almost save the same value
| def _recursive_find(root: Stmt): | ||
| if isinstance(root, BlackBoxStmt): | ||
| if root.template_string.startswith('nccl'): | ||
| return True | ||
| for child in dir(root): | ||
| if isinstance(child, Stmt): | ||
| if _recursive_find(child): | ||
| return True | ||
| return False | ||
|
|
||
| ret = _recursive_find(func.body) |
There was a problem hiding this comment.
Use hidet.ir.tools.collect to collect all BlackStmt.
python/hidet/transforms/__init__.py
Outdated
| rule_based_simplify_pass(), | ||
| inline_let_stmt_pass(), | ||
| simplify_stmt_pass(), | ||
| include_nccl_pass(), |
There was a problem hiding this comment.
Later, we will use a pass to add the header information. Let's make this pass a general one and give a name like "annotate_headers" or "annotate_include_headers". Or "annotate_header_and_libs".
) Previously, if a primitive function calls a primitive function, the `instantiate_symbols` pass will update the corresponding `hidet.ir.primitives.func.PrimitiveFunctionRegistry.function` in-place (I am not sure exactly how it's done, but this is what I observed), adding symbol variables to its parameters. The primitive function pool is a global variable, therefore this effect is cumulative across tuning candidates. So while candidate 0 will have no problem, candidate 1 will have two extra copies of symbol params, and so on, leading to compile errors. Since primitive functions do not need symbol vars, a quick fix is just to not instantiate any symbols for them.
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @soodoshll !
I left some comments.
python/hidet/distributed/group.py
Outdated
| NCCL_COMMS = [] | ||
| _NCCL_ARRAY = None |
There was a problem hiding this comment.
| NCCL_COMMS = [] | |
| _NCCL_ARRAY = None | |
| NCCL_COMMS: List[NcclCommunicator] = [] | |
| _NCCL_ARRAY: 'Array' = None |
python/hidet/distributed/store.py
Outdated
| self._filename = filename | ||
| self._lock_filename = filename + '.lock' | ||
| self._world_size = world_size | ||
|
|
||
| self._lock = filelock.FileLock(self._lock_filename) | ||
| self._cache = {} | ||
| self._timeout = None |
There was a problem hiding this comment.
Better to add some type annotations to reduce the time of the code reader.
python/hidet/distributed/store.py
Outdated
| key = self.REGULAR_PREFIX + key | ||
| with self._lock: | ||
| with open(self._filename, "ab+") as f: | ||
| f.seek(0) |
python/hidet/distributed/store.py
Outdated
| f.seek(0) | ||
| self._update(f) | ||
| has_key = key in self._cache | ||
| print(has_key, self._cache[key]) |
There was a problem hiding this comment.
| print(has_key, self._cache[key]) |
python/hidet/distributed/store.py
Outdated
| if k is None: | ||
| return | ||
| v = self._read(f) | ||
| k = str(k, encoding='raw_unicode_escape') |
There was a problem hiding this comment.
Can I know why we choose this encoding, instead of encoding like 'utf-8'?
There was a problem hiding this comment.
Also better to add the reason to the comments.
There was a problem hiding this comment.
No special reasons besides pickle uses that. Switching to utf-8 since it is the default value of encoding/decoding.
| @@ -0,0 +1,137 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
There was a problem hiding this comment.
Better to place this test to hidet/tests/distributed/test_file_store.py.
python/hidet/distributed/store.py
Outdated
| manually if required. | ||
|
|
||
| We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert | ||
| more than 32768 bytes for each entry. |
There was a problem hiding this comment.
4 byte integer could represent up to 2^31-1?
There was a problem hiding this comment.
Oops. let me fix it
|
|
||
| Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will | ||
| overwrite the insertion of the given entry when we scanning the file. | ||
| """ |
There was a problem hiding this comment.
Thanks for the comments, now the design is very clear!
|
Thanks @soodoshll ! Looks good to me now. Good job! There seems is a typo in the comment. Feel free to merge this PR by yourself after fixing that. |
all_reduceopall_reduce(relu(x * w))in./examples/distributed/test.py