-
Notifications
You must be signed in to change notification settings - Fork 27.4k
[RFC] PT2-Friendly Traceable, Functional Collective Communication APIs #93173
Description
🚀 Traceable Collectives!
Collective APIs (e.g. all_reduce, all_gather, ...) are used in distributed PyTorch programs, but do not compose cleanly with compilers.
Specifically, torchDynamo and the AotAutograd pipeline for decompositions and functionalization do not work with the existing c10d collective APIs
- there are not functional variants of these collectives
- ProcessGroup and Work objects interfere with graph tracing and pollute the IR with non-tensor objects
XLA also currently has to implement some workarounds, to marry the XLA collective ops via lazy tensor tracing with the existing PyTorch / C10D side. They have to use a custom ProcessGroup implementation and swizzle PTD PG creation functions.
Goals
- provide collectives that are traceable with the PT2 stack and XLA stack
- provide functional collectives, which are easier for IR transformations to reason about
- support eager and compiled flows with the same API
- use plain data types in the traced API
- allow tracing/compilation without requiring process group init
- support different frontends (DTensors, ProcessGroups, etc)
- support autograd for collective ops
- clean up c10d python bindings and dispatcher registrations
Non-goals
- Introduce multiple stream semantics in inductor
New traceable collectives python API
def collective(input:Tensor, *, group: GROUP_TYPE) -> AsyncTensor
GROUP_TYPE is a Union over List, DeviceMesh, ProcessGroup, etc. It allows flexible usage by different frontends.
AsyncTensor is a Tensor subclass that calls wait() automatically when the tensor is used by another op.
New Dispatcher Collectives
aten::collective(Tensor, *, str tag, int[] ranks, int stride) -> Tensor`
These are the ops that actually get traced into a graph and can be manipulated by compiler passes.
The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.
These ops support meta device (for traceability), and support backwards via derivatives.yaml.
The semantics of these ops are that they return a real tensor, but you aren't allowed to access its data or storage.
c10d.wait(Tensor) -> Tensor
wait() must be called on the output of any collective before its underlying data or storage is accessed.
- It is valid to peek at the size() or stride() (or probably other metadata) of a tensor returned from a collective, but not its data.
- wait() is the only way to make an output from collectives safe to use by other non collective ops
- we are considering whether wait(collective(collective)) can be implemented safely, but by default we assume it is not
The semantics of wait are that you must only access the storage of the tensor returned from wait. You can't think of wait as mutating its input tensor and making it safe to use.
Alternatives
The following style of API has also been considered. Its main disadvantage is in requiring a user to first initialize a processgroup, but it is also opaque and not easily interchangeable with lists of ranks or DTensors. It doesn't allow us to easily represent MPMD collectives.
pg = init_process_group()
pg_id = dist.register_process_group(pg)
collective(tensor, pg_id)
Detailed Proposal
See Traceable Collectives Design
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @chauhang @penguinwu @zou3519 @bdhirsh @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @kwen2501 @XilunWu @ezyang @msaroufim @anijain2305 @soumith @ngimel
