[SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) #159701
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159701
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 1925eb4 with merge base 3daef4d ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| p = p.strip().replace('$ORIGIN', torch_lib) | ||
| if p and p not in paths: | ||
| paths.append(p) | ||
| except Exception: |
There was a problem hiding this comment.
don't catch generic exception, narrow it down to exception you expect to happen here
| except Exception: | ||
| pass | ||
|
|
||
| for path in paths: |
There was a problem hiding this comment.
we need to decide the order of path search to match what we do during build, cc @kwen2501 . RPATH probably should be the first if we have it?
There was a problem hiding this comment.
I was looking into this, and on second thought, can we just use RPATH? RPATH will always contain the exact paths where NVSHMEM was found during build time. I don't think we need any fallback searches for anything else since the RPATH is the source of truth. Thoughts on having something like this, where enable_triton doesn't take in a lib dir param and we just get it from RPATH? This might be cleaner.
def _find_nvshmem_device_library() -> str:
torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib")
so_path = os.path.join(torch_lib, "libtorch_nvshmem.so")
result = subprocess.run(["readelf", "-d", so_path], capture_output=True, text=True)
for line in result.stdout.splitlines():
if "RPATH" in line and "[" in line:
rpath = line.split("[")[1].split("]")[0]
for path in rpath.split(":"):
path = path.strip().replace("$ORIGIN", torch_lib)
if path:
device_lib = os.path.join(path, "libnvshmem_device.bc")
if os.path.exists(device_lib):
return device_lib
raise RuntimeError("NVSHMEM device library not found")
def enable_triton() -> dict[str, str]:
from torch._C._distributed_c10d import _nvshmemx_cumodule_init
lib_path = _find_nvshmem_device_library()
extern_libs = {"libnvshmem_device": lib_path}
def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def]
key = kwargs["key"]
device = kwargs["compile"]["device"]
jit_function = kwargs["fn"].jit_function
kernel_cache, _, _, _ = jit_function.device_caches[device]
kernel = kernel_cache.get(key, None)
kernel.run
_nvshmemx_cumodule_init(kernel.module)
triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook
return extern_libsThere was a problem hiding this comment.
i have updated the commit with the above, lmk if you have other thoughts on how I should go about implementing this though
There was a problem hiding this comment.
cc @malfet, are we guaranteed to have nvshmem path in RPATH?
There was a problem hiding this comment.
Turns out nightly isn't built with nvshmem support, we need to fix that. For now, it's ok to fix the CI adding whatever needed search paths, we'll need to get back to this after nightly is fixed.
There was a problem hiding this comment.
Ok with @malfet's fix #159907 nightly now builds with nvshmem, and nvshmem libhost is indeed in the RPATH
In [10]: print(result.stdout)
Dynamic section at offset 0x8b0140 contains 34 entries:
Tag Type Name/Value
0x0000000000000001 (NEEDED) Shared library: [libnvshmem_host.so.3]
0x0000000000000001 (NEEDED) Shared library: [librt.so.1]
0x0000000000000001 (NEEDED) Shared library: [libpthread.so.0]
0x0000000000000001 (NEEDED) Shared library: [libdl.so.2]
0x0000000000000001 (NEEDED) Shared library: [libstdc++.so.6]
0x0000000000000001 (NEEDED) Shared library: [libm.so.6]
0x0000000000000001 (NEEDED) Shared library: [libgcc_s.so.1]
0x0000000000000001 (NEEDED) Shared library: [libc.so.6]
0x0000000000000001 (NEEDED) Shared library: [ld-linux-x86-64.so.2]
0x000000000000000e (SONAME) Library soname: [libtorch_nvshmem.so]
0x000000000000000f (RPATH) Library rpath: [$ORIGIN/../../nvidia/cublas/lib:$ORIGIN/../../nvidia/cuda_cupti/lib:$ORIGIN/../../nvidia/cuda_nvrtc/lib:$ORIGIN/../../nvidia/cuda_runtime/lib:$ORIGIN/../../nvidia/cudnn/lib:$ORIGIN/../../nvidia/cufft/lib:$ORIGIN/../../nvidia/curand/lib:$ORIGIN/../../nvidia/cusolver/lib:$ORIGIN/../../nvidia/cusparse/lib:$ORIGIN/../../nvidia/cusparselt/lib:$ORIGIN/../../cusparselt/lib:$ORIGIN/../../nvidia/nccl/lib:$ORIGIN/../../nvidia/nvshmem/lib:$ORIGIN/../../nvidia/nvtx/lib:$ORIGIN/../../nvidia/cufile/lib:$ORIGIN]
There was a problem hiding this comment.
awesome! i can then roll with the simple method to find device library via rpath then!
|
|
||
| # Register the function as a post-compile hook | ||
| JITFunction.compiled_hook = nvshmem_init_hook | ||
| triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook |
There was a problem hiding this comment.
can we check triton version and use compiled_hook or knobs depending on triton version?
@kwen2501 do we plan to enable nvshmem backend in fbcode? fbcode is still triton 3.3
There was a problem hiding this comment.
It's not just the compile hook API name we'd need to change - we also have the _semantic parameter replacing the _barrier parameter across all the extern function calls. After thinking about it, I think we should just support the latest version of pytorch-triton (e.g., Triton 3.4 here). If API names and interfaces on the Triton side keep changing, it'll be difficult to support multiple versions without significant code complexity.
Open to other approaches if you think supporting 3.3 is important. I was thinking we could handle it via a wrapper function that accepts **kwargs and dynamically maps either _semantic or _barrier parameters (based on version detection) to the correct parameter name when calling the underlying Triton APIs
|
looks like there are some issues w/ test hangs in ci, working on a fix right now |
Figured out the issue, it was the peer calculation. Culprit was this change from 3 weeks ago: #158167 (comment) I was on a 2-rank setup earlier so didn’t catch it, but CI runs with 8 ranks. Once I switched to 8 ranks on the dev GPU, it reproduced. The code had peer = (world_size - 1) - rank, which works for 2 ranks (0↔1) but maps rank 0 → 7 with 8 ranks. The test assumes 0↔1 comms, so this broke it. Fixed by switching to peer = 1 - rank to preserve that pattern. Just pushed out a fix here in this PR, but let me know if I should add it in a seperate PR. |
…ix CI Skips" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…ix CI Skips" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…ix CI Skips" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
| ) | ||
| for line in result.stdout.split("\n"): | ||
| if ("RPATH" in line or "RUNPATH" in line) and "[" in line: | ||
| rpath = line.split("[")[1].split("]")[0] |
| p = p.strip().replace("$ORIGIN", torch_lib) | ||
| if p and p not in paths: | ||
| paths.append(p) | ||
| except Exception: |
There was a problem hiding this comment.
Not generic exception please
…nd fix CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Skips** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…nd fix CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Skips** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…nd fix CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Skips** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Skips** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Skips** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to skip NVSHMEM tests on A100s (hardware not comptaible) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) " This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
|
Starting merge as part of PR stack under #159788 |
|
Starting merge as part of PR stack under #159788 |
1 similar comment
|
Starting merge as part of PR stack under #159788 |
…m in their name (#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: #159734 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701
…tomatic dtype‐based dispatch (#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: #159755 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: #159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
…on kernels (#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: #159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.
Triton 3.4 Compatibility
Fix CI Errors
Peer Rank Calculation Fix
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta