Skip to content

[RFC] support symmetric memory in torch.compile #162859

@zou3519

Description

@zou3519

The proposal originally came up in vLLM-compile sync with @ProExpertProg, @Chillee, and @Amir-19 and was also discussed with @ngimel @kwen2501. Recording it here to make sure we're all on the same page.

Pitch

For any collective operator (built-in or custom), a user can specify which input must have symmetric memory.

torch.compile (Inductor) will figure out where the input is coming from and ensure that it is allocated with symmetric memory.

There are two cases for what type of operator produced the input.

  1. built-in operator. Inductor might already preallocate the buffer that is the output of the operator (via memory planning) and it just needs to allocate it with symmetric memory.
requires_symmetric_memory(collective, input=0)

def user_code(x):
    y = x.sin()
    z = y.cos()
    return collective(z)

def inductor_generated_code(x):
    with symmetric_memory():
        buffer = torch.empty()
    triton_inplace_sin_cos_fused(buffer)
    return collective(buffer)
  1. custom operator. Inductor just needs to run the custom operator underneath the symmetric memory context manager. The main risk of this is that more buffers than are needed get allocated with symmetric memory (all tensors produced by the custom op get allocated with symmetric memory), but the user can just re-write their custom op to optimize this
requires_symmetric_memory(collective, input=0)

def user_code(x):
    y = custom_op(x)
    return collective(z)

def inductor_generated_code(x):
    with symmetric_memory():
        y = custom_op(x)
    return collective(y)

What about eager-mode?

The API to specify which input needs symmetric memory only applies to torch.compile. So a user would end up writing code that looks like:

requires_symmetric_memory(collective, input=0)

def user_code(x):
    if torch.compiler.is_compiling():
        with symmetric_memory():
            y = custom_op(x)
    else:
        y = custom_op(x)
    return collective(z)

What is the API to specify which input needs symmetric memory?

@kwen2501 noted that the choice of which input needs symmetric memory is specific to the collective operator. So one design is just during operator registration, specify that the input needs symmetric memory.

  1. torch.library.define("my_collective(SymmMemTensor x) -> Tensor")
  2. torch.library.define("my_collective(Tensor x) -> Tensor", symm_mem_hint="x")

Another design is a torch.compiler API:

torch.compiler.specify_symmetric_memory(my_collective, "x").

If we think the choice is actually dynamic (or that some collectives may accept both symmetric and non-symmetric memory?) then this could instead be a context manager:

@torch.compile
def user_code(y):
    x = custom_op(y)
    with torch.compiler.specify_symmetric_memory(my_collective, "x"):
        my_collective(x)

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @coconutruben

Metadata

Metadata

Labels

Type

No type

Projects

Status

In Progress

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions