Skip to content

Change mutable symm_mem ops to return void instead of aliased tensors#179144

Open
RohitRathore1 wants to merge 3 commits intopytorch:mainfrom
RohitRathore1:symm-mem-void-schema
Open

Change mutable symm_mem ops to return void instead of aliased tensors#179144
RohitRathore1 wants to merge 3 commits intopytorch:mainfrom
RohitRathore1:symm-mem-void-schema

Conversation

@RohitRathore1
Copy link
Copy Markdown
Collaborator

Custom operators are not allowed to return an alias of a mutated input (e.g. Tensor(a!) -> Tensor(a!)), as this pattern is incompatible with torch.compile's functionalization. Change all mutable symm_mem ops to return void (-> ()) instead. No callers in the codebase rely on the return values — all use these ops for their in-place side effects.

Discussed with @zou3519 and @kwen2501 on #173513.

Custom operators are not allowed to return an alias of a mutated input
(e.g. `Tensor(a!) -> Tensor(a!)`), as this pattern is incompatible with
torch.compile's functionalization. Change all mutable symm_mem ops to
return void (`-> ()`) instead. No callers in the codebase rely on the
return values — all use these ops for their in-place side effects.

Discussed with zou3519 and kwen2501 on pytorch#173513.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179144

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit a1fd65e with merge base d386e0b (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> ()");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 an alternative to BC-breaking this is to define a new operator that has no return and have the old operator call it. So something like:

m.def(
      "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
m.def(
      "multimem_all_reduce_noreturn_(Tensor(a!) input, str reduce_op, str group_name) -> ()");

TORCH_LIBRARY_IMPL(symm_mem, m, CompositeImplicitAutograd) {
    m.impl("multimem_all_reduce_", &multimem_all_reduce_);
}

Tensor multimem_all_reduce_(...) {
    multimem_all_reduce_noreturn_(...)
    return ...
}

I'm not sure how worth it this is. I looked at the docs and the APIs are "alpha" so BC-break seemes reasonable to do

Copy link
Copy Markdown
Collaborator Author

@RohitRathore1 RohitRathore1 Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ummm, since these APIs are marked alpha and no callers in the codebase (or downstream in vLLM) use the return values, the BC-break seems like the simpler path... happy to add the _noreturn_ wrapper approach if you'd prefer though..
cc: @kwen2501

@kwen2501 kwen2501 added release notes: distributed (symm_mem) release note label for symmetric memory module: symm_mem Issues and PRs of Symmetric Memory topic: bc breaking topic category suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) suppress-api-compatibility-check Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) labels Apr 2, 2026
@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

Thanks @RohitRathore1 @zou3519 , I am checking with team members to sign off on this.

Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to approve this change.
Looking at the impact radius:

  • Most ops changed are of "_out" form. For these ops, user would pass in out tensor and not use the return anyway.
  • Four ops are of in-place form "_". It should have been also clear that the input is modified in place.

@kwen2501 kwen2501 requested review from fegin, kwen2501 and ngimel April 2, 2026 16:17
@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

@RohitRathore1
Please check binding site at init.cpp lines 1247 and 1259-1260. The .typed<at::Tensor(...)>() calls for stream_write_value32_ and memset32_ need to be updated to .typed<void(...)>() to match the new void return type.

stream_write_value32 (lines 1244–1248):

  auto op =                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
      c10::Dispatcher::singleton()                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
          .findSchemaOrThrow("symm_mem::stream_write_value32_", "")                                                                                                                                                                                                                                                                                                                                                                                                                                     
          .typed<at::Tensor(at::Tensor&, int64_t, int64_t)>();  // ← wrong return type
  return op.call(input, offset, val);  // ← returns Tensor, but op now returns void                                                                                                                                                                                                                                                                                                                                                                                                                     

memset32 (lines 1257–1261):

  auto op = c10::Dispatcher::singleton()                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                .findSchemaOrThrow("symm_mem::memset32_", "")                                                                                                                                                                                                                                                                                                                                                                                                                                           
                .typed<at::Tensor(                          // ← wrong return type
                    at::Tensor&, int64_t, int64_t, int64_t)>();                   
  return op.call(input, offset, val, count);  // ← same problem       

Both lambdas also use return, which will fail to compile once the op returns void. The fix is changing at::Tensor → void in the .typed<>() calls and dropping the return.

@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

Claude checked vLLM repo, here is the verdict:

vllm is not impacted.

@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

Claude check on SGLang repo, here is the verdict:

SGLang does use two of the changed ops, both in /python/sglang/srt/distributed/device_communicators/torch_symm_mem.py:

  # multimem path                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
  torch.ops.symm_mem.multimem_all_reduce_(                                                                                                                                                                                                                                                                                                                                                                                                                                                              
      self.buffer[: inp.numel()], "sum", self.group.group_name                                                                                                                                                                                                                                                                                                                                                                                                                                          
  )                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
  # two-shot fallback path                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
  torch.ops.symm_mem.two_shot_all_reduce_(                                                                                                                                                                                                                                                                                                                                                                                                                                                              
      self.buffer[: inp.numel()], "sum", self.group.group_name                                                                                                                                                                                                                                                                                                                                                                                                                                          
  )                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     

Both calls discard the return value, so the return type change from Tensor(a!) to () does not break SGLang. The ops are used purely for their in-place side effects.

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@kwen2501 thanks for providing all these verdicts!

@fegin
Copy link
Copy Markdown
Contributor

fegin commented Apr 2, 2026

@RohitRathore1

Can you check if torch/_inductor/comm_lowering.py is going to be affected? More specifically, will this change break the torch.compiler + symmetry memory use cases due to Inductor expecting a TensorBox? If it will, you will need to add a similar wrapping in comm_lowering.py.

Also, when you search "codebase", what codebase did you mean? Kraken is a public repo and it will be broken after this change. Kraken is fine as we own it and it is just a benchmark repo. But just want to understand the "codebase" you referred to.

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@RohitRathore1

Can you check if torch/_inductor/comm_lowering.py is going to be affected? More specifically, will this change break the torch.compiler + symmetry memory use cases due to Inductor expecting a TensorBox? If it will, you will need to add a similar wrapping in comm_lowering.py.

Also, when you search "codebase", what codebase did you mean? Kraken is a public repo and it will be broken after this change. Kraken is fine as we own it and it is just a benchmark repo. But just want to understand the "codebase" you referred to.

when i said, my mean in favor of vllm.. let me check more thoroughly

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

@fegin yes, this change does affect the torch.compile path but follow-up PR #173513 is designed to handle this

std::string group_name,
at::Tensor out) {
return one_shot_all_reduce_out_impl(
one_shot_all_reduce_out_impl(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other general nit: a lot of these strings should be std::moved

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! that was a pre-existing pattern.. i missed it :(

@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

What kraken uses:
symm_mem_hdl.stream_write_value32(...) in kraken/comm/copy_engine_all_gather.py:

The return value is not used.

But we'd need to fix the binding in init.cpp in torch. @RohitRathore1

@RohitRathore1
Copy link
Copy Markdown
Collaborator Author

What kraken uses: symm_mem_hdl.stream_write_value32(...) in kraken/comm/copy_engine_all_gather.py:

The return value is not used.

But we'd need to fix the binding in init.cpp in torch. @RohitRathore1

@kwen2501 allready fixed in my earlier commit to update the .typed<>() calls and removed the return statements for both stream_write_value32_ and memset32_ in init.cpp.

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Apr 2, 2026

I chatted a bit with @ngimel on this. The current thinking is:

  1. we should fix the problem where torch.compile does not like the original symm_mem custom ops.
  2. we can have temporary variants of the symm_mem custom ops that do support torch.compile while we wait for a fix for (1). I think ideally we are able to support these with torch.compile in pytorch 2.12, unless we think there are other blockers to this, at which point we should just fix (1).

The general motivation is that (1) is something we should do and that we don't want to BC-break people now and also BC-break them later when we get (1) to work.

Fixing (1) will take me or someone ~2 weeks, but I don't have bandwidth to do this for another couple of weeks. I will get to it sometime in the medium term.

Thoughts @kwen2501 @RohitRathore1 ?

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very confused here.
Why would we ever consider this to be a good way to go?

This is an antipattern we don't want to do as there is no way to get autograd to work.
And I dont see how this fixes the functionalization problem? It just doesn't trigger the particular error we put there to avoid silent correctness issue. But it doesn't make it right?

@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Apr 2, 2026

@albanD We are in the same boat of trying to figure out what's a good practice :)
My two cents re autograd:
The ops here (referring to ops.symm_mem) are the bare-minimum of a collective implementation.
They don't mean to be autograd'able.
If someone wants an autograd'able form, they can create a functional form wrapping these bare-minimum implementations, such as:

def foo(x) -> Tensor:
  y = torch.empty(...)
  ops.symm_mem.foo(x, y)
  return y

And add backward formula for it.

Functional collectives in PyTorch do exactly this. They define functional forms that return a new tensor and register proper backward implementations via torch.library.register_autograd. The pattern in torch/distributed/_functional_collectives.py:

  # Forward: y = all_reduce(x)  →  backward: all_reduce the grad too                                                                                                                                                                                                                                                                                                                                                                                                                                    
  torch.library.register_autograd(                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
      "_c10d_functional::all_reduce",                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
      all_reduce_backward,          # does all_reduce on grad_output                                                                                                                                                                                                                                                                                                                                                                                                                                    
      setup_context=all_reduce_setup_context,                                                                                                                                                                                                                                                                                                                                                                                                                                                           
  )
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
  # Forward: y = all_gather(x)  →  backward: reduce_scatter the grad                                                                                                                                                                                                                                                                                                                                                                                                                                    
  torch.library.register_autograd(
      "_c10d_functional::all_gather_into_tensor",                                                                                                                                                                                                                                                                                                                                                                                                                                                       
      all_gather_into_tensor_backward,  # does reduce_scatter on grad_output
      ...                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  )
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
  # Forward: y = reduce_scatter(x)  →  backward: all_gather the grad                                                                                                                                                                                                                                                                                                                                                                                                                                    
  torch.library.register_autograd(
      "_c10d_functional::reduce_scatter_tensor",                                                                                                                                                                                                                                                                                                                                                                                                                                                        
      reduce_scatter_tensor_backward,  # does all_gather on grad_output
      ...
  )    

Internally, they just call the in-place dist.reduce_scatter(x, y) version. But that's not visible to autograd.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Apr 2, 2026

It doesn't matter whether these ops are meant to be autogradable or not, inplace ops return their output, out ops return their outputs, that's the current convention and we shouldn't be breaking it for unclear reasons.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Apr 3, 2026

I would expect that the pattern we want is just to have one op that does what we need it to do. And not have 3 differnet ops wrapping each other?

I would argue that the functional collectives for distributed should NOT do weird wrapping like they do and just be one op each. The fact that there are so many layers of wrapping (including ops that are silently wrong) is a BAD thing, looking back at it we should never have done that.
If you want an inplace op, make it a proper op. If it is just an implementation detail of your out-of-place op, then it doesn't need to be an op at all (to avoid confusions like the one here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: symm_mem Issues and PRs of Symmetric Memory open source release notes: distributed (c10d) release notes category release notes: distributed (symm_mem) release note label for symmetric memory suppress-api-compatibility-check Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants