Skip to content

[RFC] PT2-Friendly Traceable, Functional Collective Communication APIs #93173

@wconstab

Description

@wconstab

🚀 Traceable Collectives!

Collective APIs (e.g. all_reduce, all_gather, ...) are used in distributed PyTorch programs, but do not compose cleanly with compilers.

Specifically, torchDynamo and the AotAutograd pipeline for decompositions and functionalization do not work with the existing c10d collective APIs

  • there are not functional variants of these collectives
  • ProcessGroup and Work objects interfere with graph tracing and pollute the IR with non-tensor objects

XLA also currently has to implement some workarounds, to marry the XLA collective ops via lazy tensor tracing with the existing PyTorch / C10D side. They have to use a custom ProcessGroup implementation and swizzle PTD PG creation functions.

Goals

  1. provide collectives that are traceable with the PT2 stack and XLA stack
  2. provide functional collectives, which are easier for IR transformations to reason about
  3. support eager and compiled flows with the same API
  4. use plain data types in the traced API
  5. allow tracing/compilation without requiring process group init
  6. support different frontends (DTensors, ProcessGroups, etc)
  7. support autograd for collective ops
  8. clean up c10d python bindings and dispatcher registrations

Non-goals

  1. Introduce multiple stream semantics in inductor

image

New traceable collectives python API

def collective(input:Tensor, *, group: GROUP_TYPE) -> AsyncTensor

GROUP_TYPE is a Union over List, DeviceMesh, ProcessGroup, etc. It allows flexible usage by different frontends.

AsyncTensor is a Tensor subclass that calls wait() automatically when the tensor is used by another op.

New Dispatcher Collectives

aten::collective(Tensor, *, str tag, int[] ranks, int stride) -> Tensor`

These are the ops that actually get traced into a graph and can be manipulated by compiler passes.

The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.

These ops support meta device (for traceability), and support backwards via derivatives.yaml.

The semantics of these ops are that they return a real tensor, but you aren't allowed to access its data or storage.

c10d.wait(Tensor) -> Tensor

wait() must be called on the output of any collective before its underlying data or storage is accessed.

  • It is valid to peek at the size() or stride() (or probably other metadata) of a tensor returned from a collective, but not its data.
  • wait() is the only way to make an output from collectives safe to use by other non collective ops
  • we are considering whether wait(collective(collective)) can be implemented safely, but by default we assume it is not
    The semantics of wait are that you must only access the storage of the tensor returned from wait. You can't think of wait as mutating its input tensor and making it safe to use.

Alternatives

The following style of API has also been considered. Its main disadvantage is in requiring a user to first initialize a processgroup, but it is also opaque and not easily interchangeable with lists of ranks or DTensors. It doesn't allow us to easily represent MPMD collectives.

pg = init_process_group()
pg_id = dist.register_process_group(pg)
collective(tensor, pg_id)

Detailed Proposal

See Traceable Collectives Design

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @chauhang @penguinwu @zou3519 @bdhirsh @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @kwen2501 @XilunWu @ezyang @msaroufim @anijain2305 @soumith @ngimel

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: ProxyTensormake_fx and relatedmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,no-scrubExclude from "scrubbing" exercises, e.g., for fundamental issues that don’t need periodic check-in.oncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions