[SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name#159734
Closed
codingwithsurya wants to merge 14 commits intogh/codingwithsurya/18/basefrom
Closed
[SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name#159734codingwithsurya wants to merge 14 commits intogh/codingwithsurya/18/basefrom
codingwithsurya wants to merge 14 commits intogh/codingwithsurya/18/basefrom
Conversation
…m in their name [ghstack-poisoned]
This was referenced Aug 3, 2025
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159734
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 52250d3 with merge base 3daef4d ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
ngimel
approved these changes
Aug 4, 2025
This was referenced Aug 4, 2025
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…have nvshmem in their name" Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Collaborator
|
Starting merge as part of PR stack under #159788 |
2 similar comments
Collaborator
|
Starting merge as part of PR stack under #159788 |
Collaborator
|
Starting merge as part of PR stack under #159788 |
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…tomatic dtype‐based dispatch (#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: #159755 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: #159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot
pushed a commit
that referenced
this pull request
Aug 8, 2025
…on kernels (#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: #159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer
pushed a commit
to hinriksnaer/pytorch
that referenced
this pull request
Aug 8, 2025
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614
pushed a commit
to markc-614/pytorch
that referenced
this pull request
Sep 17, 2025
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes
_nvshmemx_cumodule_init(kernel.module)only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta