Skip to content

[library] Add registration API for symmetric memory arguments#173513

Open
RohitRathore1 wants to merge 13 commits intopytorch:mainfrom
RohitRathore1:symm-mem-registration-api
Open

[library] Add registration API for symmetric memory arguments#173513
RohitRathore1 wants to merge 13 commits intopytorch:mainfrom
RohitRathore1:symm-mem-registration-api

Conversation

@RohitRathore1
Copy link
Copy Markdown
Collaborator

@RohitRathore1 RohitRathore1 commented Jan 27, 2026

This PR adds a registration API that lets operators declare which of their arguments require symmetric memory allocation.

Operators can register their symm_mem args via Library::register_symm_mem_args() in C++ or lib.register_symm_mem_args() in Python. The registrations are validated against kernel schemas to make sure only valid argument names are accepted. For example, multimem_all_gather_out registers out rather than input, because it's the output tensor that gets passed to rendezvous().

During compilation, FallbackKernel.create calls _maybe_realize_symm_mem_args(), which looks up the registry and automatically realizes the registered args as symmetric memory comm buffers. This works alongside the existing per-op lowerings in comm_lowering.py, which are fully preserved.

The main benefit is that new symm_mem ops — including those from downstream projects like vLLM — can register their args and get automatic buffer realization through FallbackKernel, without needing to add manual per-op lowerings.

Changes:

  • C++ SymmMemArgsRegistry and Library::register_symm_mem_args method
  • Python SymmMemArgsHolder in simple_registry.py and Library.register_symm_mem_args
  • _maybe_realize_symm_mem_args in FallbackKernel
  • Registrations for all existing symm_mem ops
  • 23 tests covering registry basics, Library integration, symm_mem ops, and torch.compile

A follow-up PR will refactor the existing per-op lowerings to use the registry and add auto-functionalize support for in-place ops.

Fixes #172345

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 27, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173513

Note: Links to docs will display an error until the docs builds have been completed.

❌ 33 New Failures, 13 Unrelated Failures

As of commit 67a06f0 with merge base 495c0a9 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eee4017
Copy link
Copy Markdown
Collaborator

eee4017 commented Jan 30, 2026

Hi @RohitRathore1

Do you think we need a C++ API for this?

See the discussion here: #171909 (comment)

cc @kwen2501

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

Hi @eee4017, thanks for pointing this out! Looking at the discussion in #171909, I see that @kwen2501 mentioned we need both Python and C++ APIs - C++ for torch's internal op development, Python for DSL op development.

but currently, my PR only implements the Python API. Should I add the C++ API as well so that the symm_mem operator definitions in SymmetricMemory.cpp can directly register their symm_mem args, something like:

TORCH_LIBRARY_FRAGMENT(symm_mem, m) {                                                                                                                                                            
  m.def("one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor");                                                                                                           
  m.register_symm_mem_args("one_shot_all_reduce", {"input"});                                                                                                                                    
} 

This would co-locate the operator definition with its metadata. Let me know if this is the right direction!

@eee4017
Copy link
Copy Markdown
Collaborator

eee4017 commented Jan 30, 2026

Yes, then we wouldn’t need what you added in torch/distributed/_symmetric_memory/_register_ops.py, right?

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

RohitRathore1 commented Jan 30, 2026

Yes, then we wouldn’t need what you added in torch/distributed/_symmetric_memory/_register_ops.py, righ

yes, correct.

let me come up with the necessary changes

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Jan 31, 2026
@RohitRathore1 RohitRathore1 changed the title [WIP][library] Add registration API for symmetric memory arguments [library] Add registration API for symmetric memory arguments Jan 31, 2026
@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased symm-mem-registration-api onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout symm-mem-registration-api && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the symm-mem-registration-api branch from a40e031 to 5a356bc Compare February 3, 2026 17:56
torch/library.py Outdated

if op_overload is None:
try:
if "::" in qualname:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use namespace, opname = torch._library.utils.parse_namespace(qualname) ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but parse_namespace only parses - we'd still need to write the op lookup. we can do something like this

try:
    op_overload = torch._library.utils.lookup_op(qualname)
except (ValueError, AttributeError):
    pass

lookup_op uses parse_namespace internally and handles the full lookup


result = ir.FallbackKernel.create(op, *args, **kwargs)

return pytree.tree_map(ir.TensorBox.create, result)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle inplace operation correctly? Should we return a mutated argument (e.g. Tensor(a!) out), but not create a new tensor?

_all_reduce returns the input directly

return inp # type: ignore[return-value]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid point. FallbackKernel tracks mutations via handle_aliasing_and_mutation, but mutates_and_returns_first_arg only works for aten ops. For inplace symm_mem ops, we should either verify the current behavior is correct or handle them like _all_reduce_ does. let me investigate it

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked mutable ops returning Tensor(a!) will fail - can_auto_functionalize returns False for non-aten ops with aliased returns. We can use _CollectiveKernel.create_inplace instead. I will test it

continue

try:
_, op_part = qualname.rsplit("::", 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also reuse the name parsing logic in torch._library.utils here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aah, correct. We should

@RohitRathore1 RohitRathore1 marked this pull request as ready for review February 8, 2026 08:19
@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 9, 2026
@kwen2501
Copy link
Copy Markdown
Collaborator

Looks good in general. Hi @zou3519 any comment you may have?

@kwen2501
Copy link
Copy Markdown
Collaborator

@eellison @tianrengao Can you please have a look at this PR?

Copy link
Copy Markdown
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes I am expecting from this pr is that the custom lowerings in #171909 have been removed and ir.FallbackKernel generically handles the sym memory inputs using the api you have added. Would you make those changes in this pr and leave Collective Inplace refactoring for a separate pr ?

cc @zou3519 for custom op library changes and also cc @tianrengao if this is related to your sym mem stack

return *this;
}

// Static registry for C++ symm_mem args registrations
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 mind looking at the library changes ?

if isinstance(item, ir.TensorBox) and group_name is not None:
_maybe_realize_symm_mem(item, group_name)

if mutated_return is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not expecting a CollectiveKernel create_inplace refactoring in this pr. Can we separate that to a separate pr ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, I'll remove the _CollectiveKernel.create_inplace usage from this pr entirely. The symm_mem buffer realization now will happen generically inside ir.FallbackKernel.create via _maybe_realize_symm_mem_args, which will check the registry and realizes the registered args as comm buffers. The per-op custom lowerings and the _create_symm_mem_lowering / _get_mutated_return_arg helpers have been removed. I'll handle the CollectiveKernel.create_inplace refactoring in a follow-up PR.

Comment on lines +498 to +506
for arg_name in symm_mem_args_set:
arg_value = all_args.get(arg_name)
if isinstance(arg_value, ir.TensorBox):
if group_name is not None:
_maybe_realize_symm_mem(arg_value, group_name)
elif isinstance(arg_value, (list, tuple)):
for item in arg_value:
if isinstance(item, ir.TensorBox) and group_name is not None:
_maybe_realize_symm_mem(item, group_name)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maybe_realize_symm_memory handling i'm expecting to just go in ir.FallbackKernel. Can you put it there ?

Remove per-op custom lowerings from comm_lowering.py and instead
handle symm_mem arg realization generically in FallbackKernel.create
via _maybe_realize_symm_mem_args, which checks the registration
metadata to determine which args need symmetric memory.

Also adds a symm_mem bypass in FallbackKernel.__init__ for mutable
ops, with a TODO for follow-up CollectiveKernel.create_inplace
refactoring.
…-api

# Conflicts:
#	torch/_inductor/comm_lowering.py
@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@eellison Thanks for the review feedback! I've updated the PR based on your suggestions:

  • The symm_mem buffer realization now happens generically inside FallbackKernel.create instead of per-op custom lowerings. Removed register_symm_mem_lowerings() entirely.
  • Left the CollectiveKernel.create_inplace refactoring for a follow-up PR as you suggested.
  • Also addressed @kwen2501's comment about only registering the args that actually need symmetric memory.

Would appreciate another look when you have time!

Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me. A couple minor comments, non-blocking tho.
Can you please fix the lint errors?

Comment on lines +8749 to +8766
if isinstance(arg_value, TensorBox):
if can_realize_as_comm_buffer(arg_value, CommBufferType.SYMM_MEM):
realize_as_comm_buffer(
arg_value, CommBufferType.SYMM_MEM, group_name
)
else:
log.warning(
"Failed to realize %s as a symmetric memory buffer for %s",
arg_name,
qualname,
)
elif isinstance(arg_value, (list, tuple)):
for item in arg_value:
if isinstance(item, TensorBox):
if can_realize_as_comm_buffer(item, CommBufferType.SYMM_MEM):
realize_as_comm_buffer(
item, CommBufferType.SYMM_MEM, group_name
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use a tree_map util to recursively do this instead of manual inspect.

// Static registry for C++ symm_mem args registrations
namespace {
struct SymmMemArgsRegistry {
std::unordered_map<std::string, std::vector<std::string>> registry;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add documentation to these fields?

Use pytree.tree_flatten in _maybe_realize_symm_mem_args instead of
manual isinstance checks for nested args. Add documentation to the
C++ SymmMemArgsRegistry struct. Replace Optional[X] with X | None
per modern Python style.
@RohitRathore1 RohitRathore1 requested a review from kwen2501 March 25, 2026 04:30
…-api

Resolve conflict in comm_lowering.py: remove new per-op symm_mem
lowerings (tile_reduce, multi_root_tile_reduce) from main since
FallbackKernel now handles realization generically via the registry.
Add the new ops to test_symm_mem_correct_args.
Comment on lines +8295 to +8300
if "symm_mem" in self.op_overload.name():
# symm_mem kernels are mutable custom ops whose mutation is
# handled by the lowering that created this FallbackKernel.
# TODO: handle mutable symm_mem ops via _CollectiveKernel.create_inplace
# instead of bypassing here. See follow-up to #173513.
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, why do we need this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the bypass, mutable symm_mem ops (e.g., multimem_all_reduce_, two_shot_all_reduce_) hit the NotImplementedError. they have schema.is_mutable = True but can_auto_functionalize returns False for them, and they don't have a Functionalize dispatch key implementation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we fix this before land. cc @zou3519 to comment. i dont know why can_auto_functionalize matters in inductor lowering.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison the mutable symm_mem ops (e.g., multimem_all_reduce_, two_shot_all_reduce_) return aliased tensors (Tensor(a!) -> Tensor(a!)), so they fail both mutates_and_returns_first_arg (aten-only gate) and can_auto_functionalize (no aliased returns), hitting the NotImplementedError... the proper fix would be routing these through _CollectiveKernel.create_inplace, but you mentioned earlier that refactoring should be a separate PR. Should I go ahead and do it here instead? @zou3519 any thoughts on whether the can_auto_functionalize guard should be relaxed for this case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RohitRathore1 can you change the mutable symm ops to not return the aliased tensor? Instead you can write any code that calls them to just use the input tensor.

Otherwise, we can special case these but we are going to need to work a bit to special case them

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 because no callers rely on the return values of the mutable symm_mem ops, so changing them to return () is safe. However, that's a schema change to the ops themselves which feels like it should be a separate PR. would it be ok to keep the bypass for now and do the schema change as a follow-up? or would you prefer it in this PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should figure it out now, we risk silent incorrectness otherwise. Either we fix the schema, or we are going to need to special case these operators in functionalization. If we special case the operators then we need to make specific changes to special case the operators.

@kwen2501 would you be okay if we changed the schema for these operators? The TL;DR is that custom operators are not allowed to do something like the following:

def op_(x):
    mutate(x)
    return x

Instead they should be written like the following, to avoid returning the input directly as the output:

def op_(x):
    mutate(x)

If there are BC concerns here, then we can try to special case them in compile.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zou3519 . I am fine with breaking the API signatures now.
@RohitRathore1 can you change the APIs for those couple ops? Thanks!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RohitRathore1 You can also pull a separate PR and base this PR on top. Either way works.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 created #179144 with the schema changes... will rebase #173513 on top of it once this lands.

RohitRathore1 added a commit to RohitRathore1/pytorch that referenced this pull request Apr 2, 2026
Custom operators are not allowed to return an alias of a mutated input
(e.g. `Tensor(a!) -> Tensor(a!)`), as this pattern is incompatible with
torch.compile's functionalization. Change all mutable symm_mem ops to
return void (`-> ()`) instead. No callers in the codebase rely on the
return values — all use these ops for their in-place side effects.

Discussed with zou3519 and kwen2501 on pytorch#173513.
@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

Alright. Now I have a clearer idea:

  1. Existing ops in torch.ops.symm_mem does not need this PR for auto-lowering. That's already done in a previous PR, manually, op-by-op.
  2. vLLM just need the library registration part in this PR for their custom ops.

The registration in 2 alone may or may not be enough for torch.compile to work in vLLM:

  • If their op is in functional form, they don't need auto-functionalize support.
  • Otherwise, they do.

So I recommend the following strategy:

  • save only the registration part in this PR, and spin off the "better engineering" part for lowering torch.ops.symm_mem (which can be done in future when auto-functionlize is ready).
  • if some auto-lowering is already solid in this PR, e.g. handling of functional ops (instead of in-place ops), we can keep them in this PR. For in-place ops, throw the NotImplementedError as it would (till auto-functionalize is ready).
  • if functional op can be indeed supported with this PR, let's add a test to verify that.

@RohitRathore1 @zou3519 @eellison @ngimel wdyt?

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@kwen2501 sounds good to me. Keeping just the registration part and deferring the in-place op lowering makes the PR much cleaner. I can trim it down tokeep the SymmMemArgsRegistry + Python lookup, drop the FallbackKernel bypass, and add a test for functional op support... will wait to hear from @zou3519 @eellison @ngimel as well

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Apr 4, 2026

Sgtm

…d tests

Remove the FallbackKernel bypass block for symm_mem ops and restore the
existing per-op lowerings in comm_lowering.py. This keeps the registration
API (symm_mem_args registry + _maybe_realize_symm_mem_args) while preserving
backward compatibility with the current lowering paths.

Add comprehensive test suite (23 tests) covering registry core functionality,
Library integration, symm_mem ops registration, and torch.compile integration.
@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

PR has been updated
cc: @kwen2501

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/173513/head returned non-zero exit code 1

Rebasing (1/11)
Rebasing (2/11)
Auto-merging torch/_inductor/comm_lowering.py
CONFLICT (content): Merge conflict in torch/_inductor/comm_lowering.py
Auto-merging torch/distributed/_symmetric_memory/__init__.py
error: could not apply d5a5be8f53d... [library] Complete RFC #172345: Generic lowering for symm_mem operations
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply d5a5be8f53d... # [library] Complete RFC #172345: Generic lowering for symm_mem operations

Raised by https://github.com/pytorch/pytorch/actions/runs/24037374529

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: inductor open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] m.register_symm_mem_args

9 participants