[library] Add registration API for symmetric memory arguments#173513
[library] Add registration API for symmetric memory arguments#173513RohitRathore1 wants to merge 13 commits intopytorch:mainfrom
Conversation
This PR needs a
|
|
Do you think we need a C++ API for this? See the discussion here: #171909 (comment) cc @kwen2501 |
|
Hi @eee4017, thanks for pointing this out! Looking at the discussion in #171909, I see that @kwen2501 mentioned we need both Python and C++ APIs - C++ for torch's internal op development, Python for DSL op development. but currently, my PR only implements the Python API. Should I add the C++ API as well so that the This would co-locate the operator definition with its metadata. Let me know if this is the right direction! |
|
Yes, then we wouldn’t need what you added in |
yes, correct. let me come up with the necessary changes |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
a40e031 to
5a356bc
Compare
torch/library.py
Outdated
|
|
||
| if op_overload is None: | ||
| try: | ||
| if "::" in qualname: |
There was a problem hiding this comment.
Should we use namespace, opname = torch._library.utils.parse_namespace(qualname) ?
There was a problem hiding this comment.
but parse_namespace only parses - we'd still need to write the op lookup. we can do something like this
try:
op_overload = torch._library.utils.lookup_op(qualname)
except (ValueError, AttributeError):
pass
lookup_op uses parse_namespace internally and handles the full lookup
torch/_inductor/comm_lowering.py
Outdated
|
|
||
| result = ir.FallbackKernel.create(op, *args, **kwargs) | ||
|
|
||
| return pytree.tree_map(ir.TensorBox.create, result) |
There was a problem hiding this comment.
Does this handle inplace operation correctly? Should we return a mutated argument (e.g. Tensor(a!) out), but not create a new tensor?
_all_reduce returns the input directly
pytorch/torch/_inductor/comm_lowering.py
Line 242 in 356094e
There was a problem hiding this comment.
valid point. FallbackKernel tracks mutations via handle_aliasing_and_mutation, but mutates_and_returns_first_arg only works for aten ops. For inplace symm_mem ops, we should either verify the current behavior is correct or handle them like _all_reduce_ does. let me investigate it
There was a problem hiding this comment.
I just checked mutable ops returning Tensor(a!) will fail - can_auto_functionalize returns False for non-aten ops with aliased returns. We can use _CollectiveKernel.create_inplace instead. I will test it
torch/_inductor/comm_lowering.py
Outdated
| continue | ||
|
|
||
| try: | ||
| _, op_part = qualname.rsplit("::", 1) |
There was a problem hiding this comment.
Should we also reuse the name parsing logic in torch._library.utils here?
There was a problem hiding this comment.
aah, correct. We should
|
Looks good in general. Hi @zou3519 any comment you may have? |
d6deb3d to
1d76c03
Compare
|
@eellison @tianrengao Can you please have a look at this PR? |
eellison
left a comment
There was a problem hiding this comment.
The changes I am expecting from this pr is that the custom lowerings in #171909 have been removed and ir.FallbackKernel generically handles the sym memory inputs using the api you have added. Would you make those changes in this pr and leave Collective Inplace refactoring for a separate pr ?
cc @zou3519 for custom op library changes and also cc @tianrengao if this is related to your sym mem stack
aten/src/ATen/core/library.cpp
Outdated
| return *this; | ||
| } | ||
|
|
||
| // Static registry for C++ symm_mem args registrations |
There was a problem hiding this comment.
cc @zou3519 mind looking at the library changes ?
torch/_inductor/comm_lowering.py
Outdated
| if isinstance(item, ir.TensorBox) and group_name is not None: | ||
| _maybe_realize_symm_mem(item, group_name) | ||
|
|
||
| if mutated_return is not None: |
There was a problem hiding this comment.
I'm not expecting a CollectiveKernel create_inplace refactoring in this pr. Can we separate that to a separate pr ?
There was a problem hiding this comment.
agreed, I'll remove the _CollectiveKernel.create_inplace usage from this pr entirely. The symm_mem buffer realization now will happen generically inside ir.FallbackKernel.create via _maybe_realize_symm_mem_args, which will check the registry and realizes the registered args as comm buffers. The per-op custom lowerings and the _create_symm_mem_lowering / _get_mutated_return_arg helpers have been removed. I'll handle the CollectiveKernel.create_inplace refactoring in a follow-up PR.
torch/_inductor/comm_lowering.py
Outdated
| for arg_name in symm_mem_args_set: | ||
| arg_value = all_args.get(arg_name) | ||
| if isinstance(arg_value, ir.TensorBox): | ||
| if group_name is not None: | ||
| _maybe_realize_symm_mem(arg_value, group_name) | ||
| elif isinstance(arg_value, (list, tuple)): | ||
| for item in arg_value: | ||
| if isinstance(item, ir.TensorBox) and group_name is not None: | ||
| _maybe_realize_symm_mem(item, group_name) |
There was a problem hiding this comment.
The maybe_realize_symm_memory handling i'm expecting to just go in ir.FallbackKernel. Can you put it there ?
Remove per-op custom lowerings from comm_lowering.py and instead handle symm_mem arg realization generically in FallbackKernel.create via _maybe_realize_symm_mem_args, which checks the registration metadata to determine which args need symmetric memory. Also adds a symm_mem bypass in FallbackKernel.__init__ for mutable ops, with a TODO for follow-up CollectiveKernel.create_inplace refactoring.
…-api # Conflicts: # torch/_inductor/comm_lowering.py
|
@eellison Thanks for the review feedback! I've updated the PR based on your suggestions:
Would appreciate another look when you have time! |
kwen2501
left a comment
There was a problem hiding this comment.
Changes look good to me. A couple minor comments, non-blocking tho.
Can you please fix the lint errors?
torch/_inductor/ir.py
Outdated
| if isinstance(arg_value, TensorBox): | ||
| if can_realize_as_comm_buffer(arg_value, CommBufferType.SYMM_MEM): | ||
| realize_as_comm_buffer( | ||
| arg_value, CommBufferType.SYMM_MEM, group_name | ||
| ) | ||
| else: | ||
| log.warning( | ||
| "Failed to realize %s as a symmetric memory buffer for %s", | ||
| arg_name, | ||
| qualname, | ||
| ) | ||
| elif isinstance(arg_value, (list, tuple)): | ||
| for item in arg_value: | ||
| if isinstance(item, TensorBox): | ||
| if can_realize_as_comm_buffer(item, CommBufferType.SYMM_MEM): | ||
| realize_as_comm_buffer( | ||
| item, CommBufferType.SYMM_MEM, group_name | ||
| ) |
There was a problem hiding this comment.
nit: you can use a tree_map util to recursively do this instead of manual inspect.
| // Static registry for C++ symm_mem args registrations | ||
| namespace { | ||
| struct SymmMemArgsRegistry { | ||
| std::unordered_map<std::string, std::vector<std::string>> registry; |
There was a problem hiding this comment.
Can you please add documentation to these fields?
Use pytree.tree_flatten in _maybe_realize_symm_mem_args instead of manual isinstance checks for nested args. Add documentation to the C++ SymmMemArgsRegistry struct. Replace Optional[X] with X | None per modern Python style.
…-api Resolve conflict in comm_lowering.py: remove new per-op symm_mem lowerings (tile_reduce, multi_root_tile_reduce) from main since FallbackKernel now handles realization generically via the registry. Add the new ops to test_symm_mem_correct_args.
torch/_inductor/ir.py
Outdated
| if "symm_mem" in self.op_overload.name(): | ||
| # symm_mem kernels are mutable custom ops whose mutation is | ||
| # handled by the lowering that created this FallbackKernel. | ||
| # TODO: handle mutable symm_mem ops via _CollectiveKernel.create_inplace | ||
| # instead of bypassing here. See follow-up to #173513. | ||
| return |
There was a problem hiding this comment.
sorry, why do we need this?
There was a problem hiding this comment.
without the bypass, mutable symm_mem ops (e.g., multimem_all_reduce_, two_shot_all_reduce_) hit the NotImplementedError. they have schema.is_mutable = True but can_auto_functionalize returns False for them, and they don't have a Functionalize dispatch key implementation.
There was a problem hiding this comment.
could we fix this before land. cc @zou3519 to comment. i dont know why can_auto_functionalize matters in inductor lowering.
There was a problem hiding this comment.
@eellison the mutable symm_mem ops (e.g., multimem_all_reduce_, two_shot_all_reduce_) return aliased tensors (Tensor(a!) -> Tensor(a!)), so they fail both mutates_and_returns_first_arg (aten-only gate) and can_auto_functionalize (no aliased returns), hitting the NotImplementedError... the proper fix would be routing these through _CollectiveKernel.create_inplace, but you mentioned earlier that refactoring should be a separate PR. Should I go ahead and do it here instead? @zou3519 any thoughts on whether the can_auto_functionalize guard should be relaxed for this case?
There was a problem hiding this comment.
@RohitRathore1 can you change the mutable symm ops to not return the aliased tensor? Instead you can write any code that calls them to just use the input tensor.
Otherwise, we can special case these but we are going to need to work a bit to special case them
There was a problem hiding this comment.
@zou3519 because no callers rely on the return values of the mutable symm_mem ops, so changing them to return () is safe. However, that's a schema change to the ops themselves which feels like it should be a separate PR. would it be ok to keep the bypass for now and do the schema change as a follow-up? or would you prefer it in this PR?
There was a problem hiding this comment.
We should figure it out now, we risk silent incorrectness otherwise. Either we fix the schema, or we are going to need to special case these operators in functionalization. If we special case the operators then we need to make specific changes to special case the operators.
@kwen2501 would you be okay if we changed the schema for these operators? The TL;DR is that custom operators are not allowed to do something like the following:
def op_(x):
mutate(x)
return x
Instead they should be written like the following, to avoid returning the input directly as the output:
def op_(x):
mutate(x)
If there are BC concerns here, then we can try to special case them in compile.
There was a problem hiding this comment.
Thanks @zou3519 . I am fine with breaking the API signatures now.
@RohitRathore1 can you change the APIs for those couple ops? Thanks!
There was a problem hiding this comment.
@RohitRathore1 You can also pull a separate PR and base this PR on top. Either way works.
Custom operators are not allowed to return an alias of a mutated input (e.g. `Tensor(a!) -> Tensor(a!)`), as this pattern is incompatible with torch.compile's functionalization. Change all mutable symm_mem ops to return void (`-> ()`) instead. No callers in the codebase rely on the return values — all use these ops for their in-place side effects. Discussed with zou3519 and kwen2501 on pytorch#173513.
|
Alright. Now I have a clearer idea:
The registration in 2 alone may or may not be enough for torch.compile to work in vLLM:
So I recommend the following strategy:
|
|
@kwen2501 sounds good to me. Keeping just the registration part and deferring the in-place op lowering makes the PR much cleaner. I can trim it down tokeep the SymmMemArgsRegistry + Python lookup, drop the FallbackKernel bypass, and add a test for functional op support... will wait to hear from @zou3519 @eellison @ngimel as well |
|
Sgtm |
…d tests Remove the FallbackKernel bypass block for symm_mem ops and restore the existing per-op lowerings in comm_lowering.py. This keeps the registration API (symm_mem_args registry + _maybe_realize_symm_mem_args) while preserving backward compatibility with the current lowering paths. Add comprehensive test suite (23 tests) covering registry core functionality, Library integration, symm_mem ops registration, and torch.compile integration.
|
PR has been updated |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/24037374529 |
This PR adds a registration API that lets operators declare which of their arguments require symmetric memory allocation.
Operators can register their symm_mem args via
Library::register_symm_mem_args()in C++ orlib.register_symm_mem_args()in Python. The registrations are validated against kernel schemas to make sure only valid argument names are accepted. For example,multimem_all_gather_outregistersoutrather thaninput, because it's the output tensor that gets passed torendezvous().During compilation,
FallbackKernel.createcalls_maybe_realize_symm_mem_args(), which looks up the registry and automatically realizes the registered args as symmetric memory comm buffers. This works alongside the existing per-op lowerings incomm_lowering.py, which are fully preserved.The main benefit is that new symm_mem ops — including those from downstream projects like vLLM — can register their args and get automatic buffer realization through
FallbackKernel, without needing to add manual per-op lowerings.Changes:
SymmMemArgsRegistryandLibrary::register_symm_mem_argsmethodSymmMemArgsHolderinsimple_registry.pyandLibrary.register_symm_mem_args_maybe_realize_symm_mem_argsinFallbackKernelA follow-up PR will refactor the existing per-op lowerings to use the registry and add auto-functionalize support for in-place ops.
Fixes #172345
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo