Get remote tensors inside @helion.kernel#1122
Conversation
yf225
left a comment
There was a problem hiding this comment.
Thanks! I believe we need to update test_examples_dist.py as well, but Helion distributed CI currently has a bug causing the test error to not surface. I will land a PR to fix the bug and then we can rebase this PR
|
@kwen2501 in case adding lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
lib.define(
"get_remote_tensors(Tensor x, str group_name) -> Tensor[]"
)
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
def _get_remote_tensors_default(
local: torch.Tensor,
group_name: str
):
hdl = torch.distributed._symmetric_memory.rendezvous(local, group_name)
return tuple(
hdl.get_remote_tensor(peer, local.size(), local.dtype) for peer in range(hdl.world_size)
)
@torch.library.impl(lib, "get_remote_tensors", "Meta")
def _get_remote_tensors_meta(
local: torch.Tensor,
group_name: str
):
# TODO: correct world size
world_size = torch.distributed.get_world_size()
return (local,) * world_size |
9ec6b6e to
7fda194
Compare
|
@kwen2501 I'll rebase this PR so that it has the distributed CI error propagation fix. Thanks! |
7fda194 to
47f5294
Compare
|
@yf225 What torch version does CI use? |
yes it uses torch nightly - should be able to pick it up very soon |
To support use case in pytorch/helion#1122, i.e. ``` @helion.kernel def foo( x: Tensor, group_name: str ): x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name) for t in x_remotes: ... ```` Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function: ``` hdl = rendezvous(tensor) remote_tensors = tuple( hdl.get_remote_tensor(peer, ...) for peer in range(world_size) ) ``` The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail. This PR wraps the above functionality as an op: ``` lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") ``` so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous. Pull Request resolved: #167779 Approved by: https://github.com/yf225
f342333 to
1db590a
Compare
To support use case in pytorch/helion#1122, i.e. ``` @helion.kernel def foo( x: Tensor, group_name: str ): x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name) for t in x_remotes: ... ```` Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function: ``` hdl = rendezvous(tensor) remote_tensors = tuple( hdl.get_remote_tensor(peer, ...) for peer in range(world_size) ) ``` The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail. This PR wraps the above functionality as an op: ``` lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") ``` so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous. Pull Request resolved: pytorch#167779 Approved by: https://github.com/yf225
|
Hi @kwen2501! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Instead of user passing a tuple of tensors into the kernel.
We get the tuple of remote tensors by calling
torch.ops.symm_mem.get_remote_tensorsin the CPU part of the Helion function.This op is yet to be upstreamed on PyTorch side. Naively, it is nothing but:
The "Meta" impl is necessary because Helion seems to traces the function in Fake mode.