Skip to content

[SymmMem] Refactor NVSHMEM Reduction API to be more ergonomic with automatic dtype‐based dispatch#159755

Closed
codingwithsurya wants to merge 15 commits intogh/codingwithsurya/19/basefrom
gh/codingwithsurya/19/head
Closed

[SymmMem] Refactor NVSHMEM Reduction API to be more ergonomic with automatic dtype‐based dispatch#159755
codingwithsurya wants to merge 15 commits intogh/codingwithsurya/19/basefrom
gh/codingwithsurya/19/head

Conversation

@codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Aug 4, 2025

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id), that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159755

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit ba1f61a with merge base 3daef4d (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 6, 2025
…and more ergonomic

ghstack-source-id: c539c4f
Pull Request resolved: #159755
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

Comment on lines +1173 to +1187
expected = []
for i in range(nreduce):
# Product across all ranks
product = 1
for r in range(world_size):
if i == 0:
# rank 0,2,4... contributes 1, rank 1,3,5... contributes 2
product *= 1 if r % 2 == 0 else 2 # 2^(world_size//2)
elif i == 1:
# all ranks contribute 1
product *= 1 # result is 1
else:
# rank 0,1 contribute 1, rank 2,3 contribute 2, etc.
product *= 1 if (r // 2) % 2 == 0 else 2
expected.append(product)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
expected = []
for i in range(nreduce):
# Product across all ranks
product = 1
for r in range(world_size):
if i == 0:
# rank 0,2,4... contributes 1, rank 1,3,5... contributes 2
product *= 1 if r % 2 == 0 else 2 # 2^(world_size//2)
elif i == 1:
# all ranks contribute 1
product *= 1 # result is 1
else:
# rank 0,1 contribute 1, rank 2,3 contribute 2, etc.
product *= 1 if (r // 2) % 2 == 0 else 2
expected.append(product)
vals = torch.empty(nreduce, world_size, dtype=dtype)
vals[0, ::2] = 1
vals[0, 1::2] = 2
vals[1] = 1
vals2 = vals[2].view(-1, 2, 2)
vals2[:, 0] = 1
vals2[:, 1] = 2
expected = vals.prod(-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated this on my most recent commit!

src_ptr,
nreduce,
operation: tl.constexpr,
dtype_id: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't call it dtype_id, it's just dtype (torch.dtype)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of this stuff was changed in #159788, I have updated the commit there to have dtype_id renamed to just dtype

Comment on lines +407 to +424
if hasattr(dtype_id, "name"):
# Triton language dtype (e.g., tl.float32)
dtype_name = dtype_id.name
elif isinstance(dtype_id, str):
# Already a plain string name
dtype_name = dtype_id
elif hasattr(dtype_id, "value"):
# Constexpr wrapper around a dtype
inner_value = dtype_id.value
if hasattr(inner_value, "name"):
# Triton dtype inside constexpr
dtype_name = inner_value.name
else:
# PyTorch dtype inside constexpr
dtype_name = str(inner_value).replace("torch.", "")
else:
# PyTorch dtype (e.g., torch.float32)
dtype_name = str(dtype_id).replace("torch.", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(dtype_id, "name"):
# Triton language dtype (e.g., tl.float32)
dtype_name = dtype_id.name
elif isinstance(dtype_id, str):
# Already a plain string name
dtype_name = dtype_id
elif hasattr(dtype_id, "value"):
# Constexpr wrapper around a dtype
inner_value = dtype_id.value
if hasattr(inner_value, "name"):
# Triton dtype inside constexpr
dtype_name = inner_value.name
else:
# PyTorch dtype inside constexpr
dtype_name = str(inner_value).replace("torch.", "")
else:
# PyTorch dtype (e.g., torch.float32)
dtype_name = str(dtype_id).replace("torch.", "")
inner_value = dtype_id.value if hasattr(dtype_id.value) else dtype_id
if hasattr(inner_value, "name"):
# Triton dtype inside constexpr
dtype_name = inner_value.name
elif isinstance(inner_value, str):
# Already a plain string name
dtype_name = inner_value
else:
# PyTorch dtype (e.g., torch.float32)
dtype_name = str(inner_value).replace("torch.", "")

Is there a more reliable way to distinguish between triton dtype and torch dtype, rather than relying on hasattr(..., "name")?

Copy link
Contributor Author

@codingwithsurya codingwithsurya Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #159788 It's much simpler now because we only have triton dtypes coming in. we just have a map now, so no more hasattr checks!

…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…mic with automatic dtype‐based dispatch"

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce  




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: #159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…on kernels (#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: #159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
@github-actions github-actions bot deleted the gh/codingwithsurya/19/head branch September 8, 2025 02:14
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants