Skip to content

[SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) #159701

Closed
codingwithsurya wants to merge 17 commits intogh/codingwithsurya/17/basefrom
gh/codingwithsurya/17/head
Closed

[SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) #159701
codingwithsurya wants to merge 17 commits intogh/codingwithsurya/17/basefrom
gh/codingwithsurya/17/head

Conversation

@codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Aug 2, 2025

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

Triton 3.4 Compatibility

  • The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
  • The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

Fix CI Errors

  • The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
  • Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

Peer Rank Calculation Fix

  • The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
    Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159701

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 1925eb4 with merge base 3daef4d (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 2, 2025
codingwithsurya added a commit that referenced this pull request Aug 2, 2025
ghstack-source-id: a4ba304
Pull Request resolved: #159701
@codingwithsurya codingwithsurya added the release notes: distributed (symm_mem) release note label for symmetric memory label Aug 2, 2025
@codingwithsurya codingwithsurya changed the title support triton 3.4, fix ci skips [wip] [symmmem] support triton 3.4, fix ci skips Aug 2, 2025
@codingwithsurya codingwithsurya changed the title [wip] [symmmem] support triton 3.4, fix ci skips [wip] [SymmMem] Make NVSHMEM Triton support Triton 3.4 + Fix CI Skips Aug 2, 2025
p = p.strip().replace('$ORIGIN', torch_lib)
if p and p not in paths:
paths.append(p)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't catch generic exception, narrow it down to exception you expect to happen here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

except Exception:
pass

for path in paths:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to decide the order of path search to match what we do during build, cc @kwen2501 . RPATH probably should be the first if we have it?

Copy link
Contributor Author

@codingwithsurya codingwithsurya Aug 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking into this, and on second thought, can we just use RPATH? RPATH will always contain the exact paths where NVSHMEM was found during build time. I don't think we need any fallback searches for anything else since the RPATH is the source of truth. Thoughts on having something like this, where enable_triton doesn't take in a lib dir param and we just get it from RPATH? This might be cleaner.

def _find_nvshmem_device_library() -> str:

    torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib")
    so_path = os.path.join(torch_lib, "libtorch_nvshmem.so")

    result = subprocess.run(["readelf", "-d", so_path], capture_output=True, text=True)

    for line in result.stdout.splitlines():
        if "RPATH" in line and "[" in line:
            rpath = line.split("[")[1].split("]")[0]
            for path in rpath.split(":"):
                path = path.strip().replace("$ORIGIN", torch_lib)
                if path:
                    device_lib = os.path.join(path, "libnvshmem_device.bc")
                    if os.path.exists(device_lib):
                        return device_lib

    raise RuntimeError("NVSHMEM device library not found")


def enable_triton() -> dict[str, str]:
    from torch._C._distributed_c10d import _nvshmemx_cumodule_init

    lib_path = _find_nvshmem_device_library()
    extern_libs = {"libnvshmem_device": lib_path}
    def nvshmem_init_hook(*args, **kwargs) -> None:  # type: ignore[no-untyped-def]
        key = kwargs["key"]
        device = kwargs["compile"]["device"]
        jit_function = kwargs["fn"].jit_function
        kernel_cache, _, _, _ = jit_function.device_caches[device]
        kernel = kernel_cache.get(key, None)
        kernel.run
        _nvshmemx_cumodule_init(kernel.module)

    triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook
    return extern_libs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have updated the commit with the above, lmk if you have other thoughts on how I should go about implementing this though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @malfet, are we guaranteed to have nvshmem path in RPATH?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel I have no idea. If it's integrated correctly, than yes, it should be the case.
But I'm not sure how it's currently done, suspect that it's not, as nvshmem is not part of CUDA toolkit
cc: @kwen2501

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out nightly isn't built with nvshmem support, we need to fix that. For now, it's ok to fix the CI adding whatever needed search paths, we'll need to get back to this after nightly is fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with @malfet's fix #159907 nightly now builds with nvshmem, and nvshmem libhost is indeed in the RPATH

In [10]: print(result.stdout)

Dynamic section at offset 0x8b0140 contains 34 entries:
  Tag        Type                         Name/Value
 0x0000000000000001 (NEEDED)             Shared library: [libnvshmem_host.so.3]
 0x0000000000000001 (NEEDED)             Shared library: [librt.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libpthread.so.0]
 0x0000000000000001 (NEEDED)             Shared library: [libdl.so.2]
 0x0000000000000001 (NEEDED)             Shared library: [libstdc++.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libm.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libgcc_s.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libc.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [ld-linux-x86-64.so.2]
 0x000000000000000e (SONAME)             Library soname: [libtorch_nvshmem.so]
 0x000000000000000f (RPATH)              Library rpath: [$ORIGIN/../../nvidia/cublas/lib:$ORIGIN/../../nvidia/cuda_cupti/lib:$ORIGIN/../../nvidia/cuda_nvrtc/lib:$ORIGIN/../../nvidia/cuda_runtime/lib:$ORIGIN/../../nvidia/cudnn/lib:$ORIGIN/../../nvidia/cufft/lib:$ORIGIN/../../nvidia/curand/lib:$ORIGIN/../../nvidia/cusolver/lib:$ORIGIN/../../nvidia/cusparse/lib:$ORIGIN/../../nvidia/cusparselt/lib:$ORIGIN/../../cusparselt/lib:$ORIGIN/../../nvidia/nccl/lib:$ORIGIN/../../nvidia/nvshmem/lib:$ORIGIN/../../nvidia/nvtx/lib:$ORIGIN/../../nvidia/cufile/lib:$ORIGIN]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! i can then roll with the simple method to find device library via rpath then!


# Register the function as a post-compile hook
JITFunction.compiled_hook = nvshmem_init_hook
triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check triton version and use compiled_hook or knobs depending on triton version?
@kwen2501 do we plan to enable nvshmem backend in fbcode? fbcode is still triton 3.3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just the compile hook API name we'd need to change - we also have the _semantic parameter replacing the _barrier parameter across all the extern function calls. After thinking about it, I think we should just support the latest version of pytorch-triton (e.g., Triton 3.4 here). If API names and interfaces on the Triton side keep changing, it'll be difficult to support multiple versions without significant code complexity.

Open to other approaches if you think supporting 3.3 is important. I was thinking we could handle it via a wrapper function that accepts **kwargs and dynamically maps either _semantic or _barrier parameters (based on version detection) to the correct parameter name when calling the underlying Triton APIs

@codingwithsurya
Copy link
Contributor Author

looks like there are some issues w/ test hangs in ci, working on a fix right now

@codingwithsurya
Copy link
Contributor Author

codingwithsurya commented Aug 2, 2025

looks like there are some issues w/ test hangs in ci, working on a fix right now

Figured out the issue, it was the peer calculation. Culprit was this change from 3 weeks ago: #158167 (comment)

I was on a 2-rank setup earlier so didn’t catch it, but CI runs with 8 ranks. Once I switched to 8 ranks on the dev GPU, it reproduced. The code had peer = (world_size - 1) - rank, which works for 2 ranks (0↔1) but maps rank 0 → 7 with 8 ranks. The test assumes 0↔1 comms, so this broke it. Fixed by switching to peer = 1 - rank to preserve that pattern. Just pushed out a fix here in this PR, but let me know if I should add it in a seperate PR.

…ix CI Skips"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
ghstack-source-id: 807077a
Pull Request resolved: #159701
…ix CI Skips"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
ghstack-source-id: 380fbda
Pull Request resolved: #159701
…ix CI Skips"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
codingwithsurya added a commit that referenced this pull request Aug 3, 2025
@codingwithsurya codingwithsurya changed the title [wip] [SymmMem] Make NVSHMEM Triton support Triton 3.4 + Fix CI Skips [SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) Aug 3, 2025
)
for line in result.stdout.split("\n"):
if ("RPATH" in line or "RUNPATH" in line) and "[" in line:
rpath = line.split("[")[1].split("]")[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use maxsplit arg here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

p = p.strip().replace("$ORIGIN", torch_lib)
if p and p not in paths:
paths.append(p)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not generic exception please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

@codingwithsurya codingwithsurya changed the title [SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) [wip] [SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) Aug 3, 2025
@codingwithsurya codingwithsurya self-assigned this Aug 3, 2025
…nd fix CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Skips**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…nd fix CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Skips**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…nd fix CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Skips**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Skips**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Skips**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to skip NVSHMEM tests on A100s (hardware not comptaible)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…x CI tests (make device library discoverable + fix peer calculation bug) "

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. 

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to skip NVSHMEM tests on A100s (hardware not compatible)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…m in their name (#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: #159734
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: #159755
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: #159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…on kernels (#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: #159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701)

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: pytorch#159701
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
@github-actions github-actions bot deleted the gh/codingwithsurya/17/head branch September 8, 2025 02:14
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701)

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: pytorch#159701
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants