Change mutable symm_mem ops to return void instead of aliased tensors#179144
Change mutable symm_mem ops to return void instead of aliased tensors#179144RohitRathore1 wants to merge 3 commits intopytorch:mainfrom
Conversation
Custom operators are not allowed to return an alias of a mutated input (e.g. `Tensor(a!) -> Tensor(a!)`), as this pattern is incompatible with torch.compile's functionalization. Change all mutable symm_mem ops to return void (`-> ()`) instead. No callers in the codebase rely on the return values — all use these ops for their in-place side effects. Discussed with zou3519 and kwen2501 on pytorch#173513.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179144
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit a1fd65e with merge base d386e0b ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| TORCH_LIBRARY_FRAGMENT(symm_mem, m) { | ||
| m.def( | ||
| "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); | ||
| "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> ()"); |
There was a problem hiding this comment.
@kwen2501 an alternative to BC-breaking this is to define a new operator that has no return and have the old operator call it. So something like:
m.def(
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
m.def(
"multimem_all_reduce_noreturn_(Tensor(a!) input, str reduce_op, str group_name) -> ()");
TORCH_LIBRARY_IMPL(symm_mem, m, CompositeImplicitAutograd) {
m.impl("multimem_all_reduce_", &multimem_all_reduce_);
}
Tensor multimem_all_reduce_(...) {
multimem_all_reduce_noreturn_(...)
return ...
}
I'm not sure how worth it this is. I looked at the docs and the APIs are "alpha" so BC-break seemes reasonable to do
There was a problem hiding this comment.
ummm, since these APIs are marked alpha and no callers in the codebase (or downstream in vLLM) use the return values, the BC-break seems like the simpler path... happy to add the _noreturn_ wrapper approach if you'd prefer though..
cc: @kwen2501
|
Thanks @RohitRathore1 @zou3519 , I am checking with team members to sign off on this. |
kwen2501
left a comment
There was a problem hiding this comment.
I tend to approve this change.
Looking at the impact radius:
- Most ops changed are of "_out" form. For these ops, user would pass in
outtensor and not use the return anyway. - Four ops are of in-place form "_". It should have been also clear that the
inputis modified in place.
|
@RohitRathore1 stream_write_value32 (lines 1244–1248): memset32 (lines 1257–1261): Both lambdas also use return, which will fail to compile once the op returns void. The fix is changing at::Tensor → void in the .typed<>() calls and dropping the return. |
|
Claude checked vLLM repo, here is the verdict:
|
|
Claude check on SGLang repo, here is the verdict: SGLang does use two of the changed ops, both in Both calls discard the return value, so the return type change from Tensor(a!) to () does not break SGLang. The ops are used purely for their in-place side effects. |
|
@kwen2501 thanks for providing all these verdicts! |
|
Also, when you search "codebase", what codebase did you mean? Kraken is a public repo and it will be broken after this change. Kraken is fine as we own it and it is just a benchmark repo. But just want to understand the "codebase" you referred to. |
when i said, my mean in favor of vllm.. let me check more thoroughly |
| std::string group_name, | ||
| at::Tensor out) { | ||
| return one_shot_all_reduce_out_impl( | ||
| one_shot_all_reduce_out_impl( |
There was a problem hiding this comment.
Other general nit: a lot of these strings should be std::moved
There was a problem hiding this comment.
thanks for the suggestion! that was a pre-existing pattern.. i missed it :(
|
What kraken uses: The return value is not used. But we'd need to fix the binding in init.cpp in torch. @RohitRathore1 |
@kwen2501 allready fixed in my earlier commit to update the |
|
I chatted a bit with @ngimel on this. The current thinking is:
The general motivation is that (1) is something we should do and that we don't want to BC-break people now and also BC-break them later when we get (1) to work. Fixing (1) will take me or someone ~2 weeks, but I don't have bandwidth to do this for another couple of weeks. I will get to it sometime in the medium term. Thoughts @kwen2501 @RohitRathore1 ? |
albanD
left a comment
There was a problem hiding this comment.
I'm very confused here.
Why would we ever consider this to be a good way to go?
This is an antipattern we don't want to do as there is no way to get autograd to work.
And I dont see how this fixes the functionalization problem? It just doesn't trigger the particular error we put there to avoid silent correctness issue. But it doesn't make it right?
|
@albanD We are in the same boat of trying to figure out what's a good practice :) And add backward formula for it. Functional collectives in PyTorch do exactly this. They define functional forms that return a new tensor and register proper backward implementations via Internally, they just call the in-place |
|
It doesn't matter whether these ops are meant to be autogradable or not, inplace ops return their output, out ops return their outputs, that's the current convention and we shouldn't be breaking it for unclear reasons. |
|
I would expect that the pattern we want is just to have one op that does what we need it to do. And not have 3 differnet ops wrapping each other? I would argue that the functional collectives for distributed should NOT do weird wrapping like they do and just be one op each. The fact that there are so many layers of wrapping (including ops that are silently wrong) is a BAD thing, looking back at it we should never have done that. |
Custom operators are not allowed to return an alias of a mutated input (e.g.
Tensor(a!) -> Tensor(a!)), as this pattern is incompatible with torch.compile's functionalization. Change all mutable symm_mem ops to return void (-> ()) instead. No callers in the codebase rely on the return values — all use these ops for their in-place side effects.Discussed with @zou3519 and @kwen2501 on #173513.