[SymmMem] Refactor NVSHMEM Reduction API to be more ergonomic with automatic dtype‐based dispatch#159755
[SymmMem] Refactor NVSHMEM Reduction API to be more ergonomic with automatic dtype‐based dispatch#159755codingwithsurya wants to merge 15 commits intogh/codingwithsurya/19/basefrom
Conversation
…and more ergonomic [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159755
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit ba1f61a with merge base 3daef4d ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
|
Starting merge as part of PR stack under #159788 |
1 similar comment
|
Starting merge as part of PR stack under #159788 |
| expected = [] | ||
| for i in range(nreduce): | ||
| # Product across all ranks | ||
| product = 1 | ||
| for r in range(world_size): | ||
| if i == 0: | ||
| # rank 0,2,4... contributes 1, rank 1,3,5... contributes 2 | ||
| product *= 1 if r % 2 == 0 else 2 # 2^(world_size//2) | ||
| elif i == 1: | ||
| # all ranks contribute 1 | ||
| product *= 1 # result is 1 | ||
| else: | ||
| # rank 0,1 contribute 1, rank 2,3 contribute 2, etc. | ||
| product *= 1 if (r // 2) % 2 == 0 else 2 | ||
| expected.append(product) |
There was a problem hiding this comment.
| expected = [] | |
| for i in range(nreduce): | |
| # Product across all ranks | |
| product = 1 | |
| for r in range(world_size): | |
| if i == 0: | |
| # rank 0,2,4... contributes 1, rank 1,3,5... contributes 2 | |
| product *= 1 if r % 2 == 0 else 2 # 2^(world_size//2) | |
| elif i == 1: | |
| # all ranks contribute 1 | |
| product *= 1 # result is 1 | |
| else: | |
| # rank 0,1 contribute 1, rank 2,3 contribute 2, etc. | |
| product *= 1 if (r // 2) % 2 == 0 else 2 | |
| expected.append(product) | |
| vals = torch.empty(nreduce, world_size, dtype=dtype) | |
| vals[0, ::2] = 1 | |
| vals[0, 1::2] = 2 | |
| vals[1] = 1 | |
| vals2 = vals[2].view(-1, 2, 2) | |
| vals2[:, 0] = 1 | |
| vals2[:, 1] = 2 | |
| expected = vals.prod(-1) | |
There was a problem hiding this comment.
updated this on my most recent commit!
| src_ptr, | ||
| nreduce, | ||
| operation: tl.constexpr, | ||
| dtype_id: tl.constexpr, |
There was a problem hiding this comment.
don't call it dtype_id, it's just dtype (torch.dtype)?
There was a problem hiding this comment.
a lot of this stuff was changed in #159788, I have updated the commit there to have dtype_id renamed to just dtype
| if hasattr(dtype_id, "name"): | ||
| # Triton language dtype (e.g., tl.float32) | ||
| dtype_name = dtype_id.name | ||
| elif isinstance(dtype_id, str): | ||
| # Already a plain string name | ||
| dtype_name = dtype_id | ||
| elif hasattr(dtype_id, "value"): | ||
| # Constexpr wrapper around a dtype | ||
| inner_value = dtype_id.value | ||
| if hasattr(inner_value, "name"): | ||
| # Triton dtype inside constexpr | ||
| dtype_name = inner_value.name | ||
| else: | ||
| # PyTorch dtype inside constexpr | ||
| dtype_name = str(inner_value).replace("torch.", "") | ||
| else: | ||
| # PyTorch dtype (e.g., torch.float32) | ||
| dtype_name = str(dtype_id).replace("torch.", "") |
There was a problem hiding this comment.
| if hasattr(dtype_id, "name"): | |
| # Triton language dtype (e.g., tl.float32) | |
| dtype_name = dtype_id.name | |
| elif isinstance(dtype_id, str): | |
| # Already a plain string name | |
| dtype_name = dtype_id | |
| elif hasattr(dtype_id, "value"): | |
| # Constexpr wrapper around a dtype | |
| inner_value = dtype_id.value | |
| if hasattr(inner_value, "name"): | |
| # Triton dtype inside constexpr | |
| dtype_name = inner_value.name | |
| else: | |
| # PyTorch dtype inside constexpr | |
| dtype_name = str(inner_value).replace("torch.", "") | |
| else: | |
| # PyTorch dtype (e.g., torch.float32) | |
| dtype_name = str(dtype_id).replace("torch.", "") | |
| inner_value = dtype_id.value if hasattr(dtype_id.value) else dtype_id | |
| if hasattr(inner_value, "name"): | |
| # Triton dtype inside constexpr | |
| dtype_name = inner_value.name | |
| elif isinstance(inner_value, str): | |
| # Already a plain string name | |
| dtype_name = inner_value | |
| else: | |
| # PyTorch dtype (e.g., torch.float32) | |
| dtype_name = str(inner_value).replace("torch.", "") |
Is there a more reliable way to distinguish between triton dtype and torch dtype, rather than relying on hasattr(..., "name")?
There was a problem hiding this comment.
In #159788 It's much simpler now because we only have triton dtypes coming in. we just have a map now, so no more hasattr checks!
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…mic with automatic dtype‐based dispatch" This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
|
Starting merge as part of PR stack under #159788 |
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: #159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
…on kernels (#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: #159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function,
nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id), that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta