[SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels#159788
[SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels#159788codingwithsurya wants to merge 18 commits intogh/codingwithsurya/21/basefrom
Conversation
…M Triton kernels [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159788
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9534bab with merge base 3daef4d ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…n to NVSHMEM Triton kernels" **I have broadcast and alltoall implemented for this now. Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions. The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have broadcast, alltoall, put, and get implemented for this now. Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions. The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions. The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions. The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have broadcast, alltoall, put, get, barrier, sync, wait_until, quiet, fence implemented for this now. Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small triton.jit wrapper function over our core NVSHMEM extern functions. The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw int64 pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have everything but put_with_signal and signal_wait_until implemented for this now (dealing with nccl hangs when signaling). Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…n to NVSHMEM Triton kernels" **I have everything but put_with_signal and signal_wait_until implemented for this now (dealing with nccl hangs when signaling). Working on the rest but pushing this out now for early feedback!** ------ This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Merge failedReason: Approvers from one of the following sets are needed:
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: PR #159215 has not been reviewed yet |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Approvers from one of the following sets are needed:
|
…VSHMEM Triton kernels"
This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.
-----
**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`
From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer
```
Pointer-Based Version:
Rank 0 → Rank 1:
Local buffer: 0x430300a00 (src)
Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory
Remote signal: 0x2430301600 (sig) ← Rank 1's signal
Rank 1 (waiting):
Local signal: 0x430301600 (waits here)
Tensor-Based Version:
Rank 0 → Rank 1:
Local buffer: 0x430300a00 (src)
Local buffer: 0x430300c00 (dst) ← this is wrong
Local signal: 0x430300e00 (sig) ← this is wrong
Rank 1 (waiting):
Local signal: 0x430300e00 (waits here)
```
Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta
[ghstack-poisoned]
|
|
||
| # Reduction Operation | ||
| @triton.jit # type: ignore[misc] | ||
| def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def] |
There was a problem hiding this comment.
I still think having default args for team and operation would make sense here.
But this is a super nice improvement over previous dtype parsing!
…VSHMEM Triton kernels"
This PR introduces a small `triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).
The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.
-----
**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`
From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer
```
Pointer-Based Version:
Rank 0 → Rank 1:
Local buffer: 0x430300a00 (src)
Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory
Remote signal: 0x2430301600 (sig) ← Rank 1's signal
Rank 1 (waiting):
Local signal: 0x430301600 (waits here)
Tensor-Based Version:
Rank 0 → Rank 1:
Local buffer: 0x430300a00 (src)
Local buffer: 0x430300c00 (dst) ← this is wrong
Local signal: 0x430300e00 (sig) ← this is wrong
Rank 1 (waiting):
Local signal: 0x430300e00 (waits here)
```
Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
This PR introduces a small
@triton.jitwrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw
int64pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.TODO:
This is almost complete. One pending item is tensor-aware implementation of
nvshmem.putmem_signal_blockandnvshmem.signal_wait_untilFrom my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer
Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta