[core][rdt][cherry-pick] Reuse previous metadata if transferring the same tensor list with nixl#58309
Conversation
…ist with nixl (ray-project#58263) ## Description For nixl, reuse previous metadata if transferring the same tensor list. This is to avoid repeated `register_memory` before `deregister_memory` --------- Signed-off-by: Dhyey Shah <dhyey2019@gmail.com> Co-authored-by: Dhyey Shah <dhyey2019@gmail.com> Co-authored-by: Stephanie Wang <smwang@cs.washington.edu>
There was a problem hiding this comment.
Code Review
This pull request introduces an optimization for NIXL tensor transport by reusing metadata for repeated transfers of the same tensor list. This is achieved by caching metadata and using tensor data pointers for identification, which also enables support for multiple objects containing the same tensor. The changes span documentation, core object management logic, and tests. While the overall approach is sound, I've identified a couple of critical syntax errors in the initialization of defaultdict that will cause TypeError at runtime. I've also included a medium-severity suggestion to improve code clarity and efficiency in one of the new methods.
| self._tensor_to_object_ids: Dict[int, Set[str]] = defaultdict[int, Set[str]]( | ||
| set | ||
| ) |
There was a problem hiding this comment.
The syntax defaultdict[int, Set[str]](set) is incorrect. defaultdict[int, Set[str]] is a types.GenericAlias which is not callable in Python 3.9+. This will raise a TypeError at runtime. The correct way to initialize a defaultdict is by passing the factory function to its constructor, like defaultdict(set).
self._tensor_to_object_ids: Dict[int, Set[str]] = defaultdict(set)| # Mapping from object ID to the NIXL managed meta. | ||
| self._managed_meta_nixl: Dict[str, Any] = {} | ||
| # Mapping from NIXL managed meta to the number of objects that contain it. | ||
| self._managed_meta_counts_nixl: Dict[Any, int] = defaultdict[Any, int](int) |
There was a problem hiding this comment.
The syntax defaultdict[Any, int](int) is incorrect. defaultdict[Any, int] is a types.GenericAlias which is not callable in Python 3.9+. This will raise a TypeError at runtime. The correct way to initialize a defaultdict is by passing the factory function to its constructor, like defaultdict(int).
self._managed_meta_counts_nixl: Dict[Any, int] = defaultdict(int)| self._managed_meta_counts_nixl[meta] -= 1 | ||
| if self._managed_meta_counts_nixl[meta] == 0: | ||
| self._managed_meta_counts_nixl.pop(meta) | ||
| return self._managed_meta_counts_nixl[meta] |
There was a problem hiding this comment.
This implementation is a bit confusing and inefficient. It accesses self._managed_meta_counts_nixl[meta] multiple times. If the count becomes 0, the key is popped and then re-accessed on the return line, which relies on the defaultdict behavior to create a new entry and return 0. It would be clearer and more efficient to store the count in a local variable before the conditional pop.
self._managed_meta_counts_nixl[meta] -= 1
count = self._managed_meta_counts_nixl[meta]
if count == 0:
self._managed_meta_counts_nixl.pop(meta)
return count…same tensor list with nixl (ray-project#58309) Cherry-picking ray-project#58263 for 2.51.1 release. Signed-off-by: Dhyey Shah <dhyey2019@gmail.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: Stephanie Wang <smwang@cs.washington.edu>
Cherry-picking #58263 for 2.51.1 release.