Skip to content

[SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels#159788

Closed
codingwithsurya wants to merge 18 commits intogh/codingwithsurya/21/basefrom
gh/codingwithsurya/21/head
Closed

[SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels#159788
codingwithsurya wants to merge 18 commits intogh/codingwithsurya/21/basefrom
gh/codingwithsurya/21/head

Conversation

@codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Aug 4, 2025

This PR introduces a small @triton.jit wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.


TODO:
This is almost complete. One pending item is tensor-aware implementation of nvshmem.putmem_signal_block and nvshmem.signal_wait_until

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong 
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159788

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9534bab with merge base 3daef4d (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…n to NVSHMEM Triton kernels"

**I have broadcast and alltoall implemented for this now. Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions.
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 4, 2025
…M Triton kernels

ghstack-source-id: 9b3cdd0
Pull Request resolved: #159788
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

…n to NVSHMEM Triton kernels"

**I have broadcast, alltoall, put, and get implemented for this now. Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions.
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 5, 2025
…M Triton kernels

ghstack-source-id: cb5cf46
Pull Request resolved: #159788
…n to NVSHMEM Triton kernels"

**I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions.
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 5, 2025
…M Triton kernels

ghstack-source-id: 92fc01d
Pull Request resolved: #159788
…n to NVSHMEM Triton kernels"

**I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions.
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 5, 2025
…M Triton kernels

ghstack-source-id: 3a07067
Pull Request resolved: #159788
…n to NVSHMEM Triton kernels"

**I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions.
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 6, 2025
…M Triton kernels

ghstack-source-id: 8d0465c
Pull Request resolved: #159788
…n to NVSHMEM Triton kernels"

**I have everything but put_with_signal and signal_wait_until implemented for this now (dealing with nccl hangs when signaling). Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 6, 2025
…M Triton kernels

ghstack-source-id: cf2c454
Pull Request resolved: #159788
…n to NVSHMEM Triton kernels"

**I have everything but put_with_signal and signal_wait_until implemented for this now (dealing with nccl hangs when signaling). Working on the rest but pushing this out now for early feedback!**
------

This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • Distributed (wconstab, mrshenli, pritamdamania87, zhaojuanmao, rohan-varma, ...)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@codingwithsurya
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: PR #159215 has not been reviewed yet

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@codingwithsurya
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • Distributed (wconstab, mrshenli, pritamdamania87, zhaojuanmao, rohan-varma, ...)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

…VSHMEM Triton kernels"

This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong 
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 8, 2025
…M Triton kernels

ghstack-source-id: 51b15f3
Pull Request resolved: #159788

# Reduction Operation
@triton.jit # type: ignore[misc]
def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think having default args for team and operation would make sense here.
But this is a super nice improvement over previous dtype parsing!

…VSHMEM Triton kernels"

This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. 

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong 
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 8, 2025
…M Triton kernels

ghstack-source-id: 51b15f3
Pull Request resolved: #159788
@codingwithsurya
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
@github-actions github-actions bot deleted the gh/codingwithsurya/21/head branch September 8, 2025 02:14
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants