Skip to content

[SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name#159734

Closed
codingwithsurya wants to merge 14 commits intogh/codingwithsurya/18/basefrom
gh/codingwithsurya/18/head
Closed

[SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name#159734
codingwithsurya wants to merge 14 commits intogh/codingwithsurya/18/basefrom
gh/codingwithsurya/18/head

Conversation

@codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Aug 3, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159734

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 52250d3 with merge base 3daef4d (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
…m in their name

ghstack-source-id: 6fc7857
Pull Request resolved: #159734
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
…m in their name

ghstack-source-id: a12322a
Pull Request resolved: #159734
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
…m in their name

ghstack-source-id: ec8ff11
Pull Request resolved: #159734
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…have nvshmem in their name"

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

2 similar comments
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: #159755
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: #159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…on kernels (#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: #159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
@github-actions github-actions bot deleted the gh/codingwithsurya/18/head branch September 8, 2025 02:14
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants