Skip to content

[SymmMem] Back symm_mem.emtpy() with implicit pool#172292

Closed
kwen2501 wants to merge 2 commits intogh/kwen2501/308/basefrom
gh/kwen2501/308/head
Closed

[SymmMem] Back symm_mem.emtpy() with implicit pool#172292
kwen2501 wants to merge 2 commits intogh/kwen2501/308/basefrom
gh/kwen2501/308/head

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Jan 13, 2026

Stack from ghstack (oldest at bottom):

Resolves #172050

Two motivations:

  • Give better UX and perf to users who explicitly use symm_mem.empty().
  • Simplify the code generated by Inductor, i.e. symm_mem.empty() would automatically reuse memory, rather than requiring Inductor to bookkeep it.

The MemPool infra for all CUDA backends (CUDA, NVSHMEM, NCCL) has been built previously.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172292

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit fe394be with merge base 8cfe6f1 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Jan 13, 2026
@kwen2501 kwen2501 added the release notes: distributed (symm_mem) release note label for symmetric memory label Jan 13, 2026
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jan 13, 2026
@kwen2501 kwen2501 requested a review from eqy January 13, 2026 01:02
@eqy eqy requested a review from galv January 13, 2026 05:09
Copy link
Copy Markdown
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to use any special custom allocator or is just having another pool OK?

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@eqy The get_mem_pool API points to symm_mem's internal pool:

def get_mem_pool(device: _device) -> torch.cuda.MemPool:
"""
Get the symmetric memory pool for a given device. If not found, create a new
pool.

The pool is backed by symm_mem's allocator, be it CUDA, NVSHMEM or NCCL, depending on user's symm_mem.set_backend(...) setting.

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank yoU!

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 13, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "sorry but it seems your pr failed internal test. error: torch._inductor.exc.InductorError: ImportError: undefined symbol: cuCtxGetCurrent, please reach out intenral staff for further debugging" -c ghfirst

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Jan 14, 2026
This reverts commit 4301818.

Reverted #172292 on behalf of https://github.com/yangw-dev due to sorry but it seems your pr failed internal test. error: torch._inductor.exc.InductorError: ImportError: undefined symbol: cuCtxGetCurrent, please reach out intenral staff for further debugging ([comment](#172292 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@kwen2501 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jan 14, 2026
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 14, 2026

@kwen2501 the error is actually in dynamo tests, e.g. test/dynamo:test_aot_autograd_cache - test_autograd_function (caffe2.test.dynamo.test_aot_autograd_cache.AOTAutogradCacheBundledTests), so should be reproducible. Don't know tbh why is it triggered by this PR

  File "/re_cwd/buck-out/v2/gen/fbcode/3753cce6484443c5/caffe2/test/dynamo/__test_aot_autograd_cache__/test_aot_autograd_cache#link-tree/triton/runtime/driver.py", line 10, in _create_driver
    return active_drivers[0]()
  File "/re_cwd/buck-out/v2/gen/fbcode/3753cce6484443c5/caffe2/test/dynamo/__test_aot_autograd_cache__/test_aot_autograd_cache#link-tree/triton/backends/nvidia/driver.py", line 783, in __init__
    self.utils = CudaUtils()  # TODO: make static
  File "/re_cwd/buck-out/v2/gen/fbcode/3753cce6484443c5/caffe2/test/dynamo/__test_aot_autograd_cache__/test_aot_autograd_cache#link-tree/triton/backends/nvidia/driver.py", line 63, in __init__
    mod = compile_module_from_src(
  File "/re_cwd/buck-out/v2/gen/fbcode/3753cce6484443c5/caffe2/test/dynamo/__test_aot_autograd_cache__/test_aot_autograd_cache#link-tree/triton/runtime/build.py", line 93, in compile_module_from_src
    return _load_module_from_path(name, cache_path)
  File "/re_cwd/buck-out/v2/gen/fbcode/3753cce6484443c5/caffe2/test/dynamo/__test_aot_autograd_cache__/test_aot_autograd_cache#link-tree/triton/runtime/build.py", line 65, in _load_module_from_path
    mod = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 730, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1176, in create_module
  File "<frozen importlib._bootstrap>", line 400, in _call_with_frames_removed
torch._inductor.exc.InductorError: ImportError: /re_tmp/tmpupz8kvm_/triton/YG7MWGMAX4UH3KIF6MQ665T4HJZLYF2T43HWBZOGN6O66MLEQHKA/cuda_utils.cpython-310-fb-010-x86_64.so: undefined symbol: cuCtxGetCurrent

mattteochen pushed a commit to mattteochen/pytorch that referenced this pull request Jan 15, 2026
Resolves pytorch#172050

Two motivations:
- Give better UX and perf to users who explicitly use `symm_mem.empty()`.
- Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it.

The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously.
Pull Request resolved: pytorch#172292
Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba
ghstack dependencies: pytorch#172163
mattteochen pushed a commit to mattteochen/pytorch that referenced this pull request Jan 15, 2026
…72292)"

This reverts commit 4301818.

Reverted pytorch#172292 on behalf of https://github.com/yangw-dev due to sorry but it seems your pr failed internal test. error: torch._inductor.exc.InductorError: ImportError: undefined symbol: cuCtxGetCurrent, please reach out intenral staff for further debugging ([comment](pytorch#172292 (comment)))
@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Jan 16, 2026

Hi @yangw-dev thanks for the heads-up!
I ran the same test locally and it passed for me.
Looking at the error message, it seems to be a Triton issue:

/triton/runtime/driver.py
/re_tmp/tmpupz8kvm_/triton/.../cuda_utils.cpython-310-fb-010-x86_64.so: undefined symbol: cuCtxGetCurrent

My PR did not modify anything related to Triton, nor changing CUDA driver linkage.
Could the error be caused by other internal changes or some flaky tmp folder failing to be cleaned?

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/21079055965

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Jan 16, 2026

Hmm, I hope it's not related to #171116 ...

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -f "Triaged CI error; seems unrelated"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

DustyL pushed a commit to DustyL/pytorch that referenced this pull request Jan 17, 2026
Cherry-picked from upstream main:

- [SymmMem] Back symm_mem.empty() with implicit pool (pytorch#172292)
  Automatic memory reuse for symmetric memory allocations

- [SymmMem] Add multimem support for NCCL and NVSHMEM (pytorch#172185)
  Enhanced multi-GPU memory support

- [inductor] Basic Comm Buffer Reuse for Symmetric Memory (pytorch#171909)
  Memory optimization for torch.compile with symmetric buffers

- [BE] Don't print 12 `triton not found` on import (pytorch#172614)
  QoL fix for flop_counter imports

- [inductor] Use custom triton kernel subclass when available (pytorch#167456)
  Enables custom backend heuristics for Triton kernels

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@github-actions github-actions bot deleted the gh/kwen2501/308/head branch February 16, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (symm_mem) release note label for symmetric memory Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants