[SymmMem] Standardize NVSHMEM Triton wrappers on byte-based APIs + improve code clarity#159136
Closed
codingwithsurya wants to merge 16 commits intogh/codingwithsurya/15/basefrom
Closed
[SymmMem] Standardize NVSHMEM Triton wrappers on byte-based APIs + improve code clarity#159136codingwithsurya wants to merge 16 commits intogh/codingwithsurya/15/basefrom
codingwithsurya wants to merge 16 commits intogh/codingwithsurya/15/basefrom
Conversation
…prove code clarity [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159136
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit a2086b8 with merge base 3daef4d ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Jul 25, 2025
…d APIs + improve code clarity" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…e-based APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…e-based APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
This was referenced Aug 4, 2025
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…d APIs + improve code clarity" Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Collaborator
|
Starting merge as part of PR stack under #159788 |
2 similar comments
Collaborator
|
Starting merge as part of PR stack under #159788 |
Collaborator
|
Starting merge as part of PR stack under #159788 |
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: #159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…(make device library discoverable + fix peer calculation bug) (#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: #159701 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…m in their name (#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: #159734 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…tomatic dtype‐based dispatch (#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: #159755 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: #159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…on kernels (#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: #159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…prove code clarity (pytorch#159136) Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. Pull Request resolved: pytorch#159136 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: pytorch#159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…prove code clarity (pytorch#159136) Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. Pull Request resolved: pytorch#159136 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: pytorch#159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Quick refactor for consistency and clarity.
We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms.
Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used
nelemsas the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals.To clean this up:
• All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit.
• Typed APIs consistently use nelems for element counts.
• Docstrings were added or updated to clarify expected units.
Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping.
This should make the API more intuitive and reduce friction for developers.
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta