with @pritamdamania87 @zhaojuanmao @aazzolini @gqchen @pietern @satgera @ezyang @zdevito @suo @manojkris @gchanan @soumith @dzhulgakov @yifuwang @bddppq @joxu-cn @dwarakrajagopal @jspisak
PyTorch currently provides simple APIs for single machine data parallel, distributed data parallel, and single machine model parallel. However, when it comes to distributed model parallel, applications have to build their own scaffold to stitch together local autograd graphs into one global graph. This proposal aims to fill in that gap by providing an RPC-Based distributed model parallel API. In short, applications may run RPC to execute code remotely in the forward pass, and autograd will automatically travel across RPC boundaries in the backward pass.
API
Core Concepts
RRef[T] - (abbreviation ref) A reference to a value of some type T (e.g. Tensor) on a remote worker. This handle keeps the referenced remote tensor value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. It is valid to have a reference to local value as well, and values of type T can be implicitly converted to RRef[T]. This implicit conversion will be critical later to allow the expression of different types of RPC. Think of it like the implicit conversion from std::string to const std::string &. See System Design section for more details about RRef.
ref.owner() # what is the worker this value lives on
v = ref.local_value() # if ref.owner() is local worker, then
# this returns the the underlying value, otherwise error.
# you can create a ref to a local tensor
t = torch.rand(3, 4)
ref2 = torch.RRef(t)
# in TorchScript, T can be automatically converted to RRef[T]
ref3 : RRef[Tensor] = t
Future[T] - (abbreviation fut) a guarantee that at some future point in time the value of type T will be available locally. The action to create T locally is assumed to be scheduled and in-progress. Future is already supported in TorchScript and we are extending this to remote calls.
v = fut.wait() # block the current thread until v is ready
# local cpu task creation returns a future to the computed tensors
fut = torch.fork(lambda x, y: x + y, torch.rand(3, 4), torch.rand(3, 4))
Core Functions
# synchronous
result : T = torch.rpc(on : Worker, remote_callable : Callable, *args)
# asynchronous
result : Future[T] = torch.async_rpc(on : Worker, remote_callable : Callable, *args)
# remote reference
result : RRef[T] = torch.remote(on : Worker, remote_callable : Callable, *args)
Each function above invokes remote_callable on a remote worker. Value types in the args list are copied by value to the remote worker. RRef[T] types in the args list are copied by reference to the remote worker (again see the analogy between std::string and const std::string&).
The synchronous variant copies the result value back, blocking the calling thread until the response occurs. The asynchronous variant returns immediately with a future. The remote knows that the call will expect to receive the value so it will send a message back at some point with the result without further prompting.
The remote reference variant returns immediately with an RRef of the return value. The remote knows that the caller does not expect to receive the result value.
Below shows how these functions are used:
# make some local tensors
a : Tensor = torch.rand(3, 4)
b : Tensor = torch.rand(3, 4)
# define a remote function, visible to all machines.
# type annotations define expected input/output types.
def remote_function(a : Tensor, b : RRef[Tensor]) -> Tensor:
# 'b' in the type signature is a remote reference, so we must copy it here
# to use it locally.
# to_here() is defined later in the syntax sugar section, it synchronously
# copies the tensor to this worker.
b_l : Tensor = b.to_here()
return a + b_l
# run remote_function on a different device.
# a is copied by value since it is a Tensor
# b is copied by reference remote machine due to the RRef[Tensor]
# type annotation in the signature, which causes an implicit conversion to a
# reference type.
# torch.remote always creates an RRef of the result type.
# It does not wait for the remote's response.
# There is no implied copy of the tensor data yet.
c : RRef[Tensor] = torch.remote("worker1", remote_function, a, b)
# we can explicitly request the data to be copied back here:
c_l : Tensor = c.to_here()
# another example:
def remote_function2(a : Tensor, b : Tensor) -> Tensor:
return a + b
# Here we call torch.rpc which returns the value directly without
# creating a remote reference.
# we synchronously wait for remote_function2 to return.
c : Tensor = torch.rpc("worker2", remote_function2, a, b)
# When the RPC call is returning a non-reference type, we need to wait for
# a response from the remote host. To avoid synchronously waiting, use the
# async flag to get a future instead.
c_f : Future[Tensor] = torch.async_rpc("worker2", remote_function2, a, b)
# even before calling wait, the remote knows that the data should be sent back
# to the caller as soon as it is ready.
# force the local thread to wait for the remote's response
c = c_f.wait()
# if you omit type annotations in the remote function, the assumption is that
# arguments are passed without any implicit conversions
def remote_function3(a, b):
# no annotations mean that a, b will be Tensor since there is no conversion
return a + b
c: Tensor = torch.rpc("worker2", remote_function3, a, b)
RRef Forks
Implicit Conversions for RRef Arguments
We allow implicit conversion between T and RRef[T] for arguments of RPC functions. Both the actual and formal parameter can either be a T or an RRef[T], leading to four cases that might occur:
T → T (passing a T to an rpc that accepts a T): the value T is copied by value, and send over the wire as part of the message invoking the RPC
T → RRef[T] (passing a T to an rpc that accepts RRef[T]): The caller constructs a remote reference to the argument, and sends the reference over the wire to the callee. The data is not sent. The callee can then use the reference as a handle to either request the data later or to make further remote calls.
RRef[T] → T (passing an RRef[T] to an rpc that accepts T): The callee expects to get an actual value, so the callee needs to turn the reference into a value. The network behavior depends on where the RRef[T] lives.
- If the
RRef[T] lives on the caller, then the implementation looks up the actual value of T locally and pass it by value along the wire similar to the T → T case.
- If the
RRef[T] lives on the callee, then the implementation just sends the reference and the callee does the lookup locally.
- If the
RRef[T] lives on some third machine, then the caller sends 2 messages. One to the third machine telling it to send the data in the remote reference directly to the callee, and one to the callee telling it to start the RPC and expect this input to be coming from the third machine. This effectively forward value of the RRef[T] to the callee without the caller having to load it or the callee having to request it later
Examples:
def remote_function1() -> Tensor:
return torch.ones(2)
def remote_function2(a : Tensor) → Tensor:
b = a * 2
return b
aref : RRef[Tensor] = remote("worker1", remote_function1)
# this local worker will make two RPC calls: one to tell worker1 to send the
# tensor to worker2, and another one to tell worker2 to expect this Tensor input
# from worker1. remote_function2 will run on worker2 only after it received the
# tensor from worker1.
bref : RRef[Tensor] = remote("worker2", remote_function2, aref)
RRef[T] → RRef[T] (**passing an RRef[T] to an RPC that accepts RRef[T]): **The callee expects an RRef[T], but we must make sure we correctly keep track of references to the value on a remote. So the actual behavior depends on where the RRef[T] lives.
- If
RRef[T] lives on the caller, then we simply pass it to the remote and record that this remote now has a live reference to the value.
- If the
RRef[T] lives on the callee, then we pass it to the remote, and it becomes a local reference on the remote.
- If
RRef[T] lives on some third machine, then we must forward the reference. To do this the caller sends two messages. One to the third machine telling it to create a remote reference and send it to the callee, and one to the callee telling from where to expect the remote. The callee code is not invoked until the remote is transferred to ensure sane reference counting.
Examples:
def remote_function1() -> Tensor:
return torch.ones(2)
def remote_function2(a : RRef[Tensor]) -> Tensor:
int delta = 10
return a.to_here() + delta
aref : RRef[Tensor] = remote("worker1", remote_function1)
# this local worker will make two RPC calls: one to tell worker1 to create a
# remote reference and send it to worker2, and another one to tell worker2 to
# expect this remote reference input from worker1. remote_function2 code will
# not run on worker2 until it receives the remote reference from worker1 to
# ensure proper reference counting.
bref : RRef[Tensor] = remote("worker2", remote_function2, aref)
When an RRef[T] goes dead on machine A, a message is sent to the owner of T telling it that the reference from machine A is dead.
Explicit RRef type for return values
The above implicit RRef argument conversion does not apply to return values. If remote_function returns RRef[T], calling it remotely using torch.remote would return RRef[RRef[T]] instead of RRef[T]. This is because when the return value RRef of torch.remote is first created on the caller who does not know the owner of the real data T. T could be stored on the callee of torch.remote, but it could also be on a different worker as callee may also make another remote call within remote_function and return an RRef[T] owned by a different worker. Moreover, the caller is allowed to share the returned RRef with other workers immediately after torch.remote returns. However, as by then, the caller does not know the real owner of T yet, sharing the RRef would break the reference count algorithm.
Examples:
def remote_function3() -> RRef[Tensor]:
return torch.remote("Worker2", torch.ones, 2, 2)
cref : RRef[RRef[Tensor]] = remote("worker1", remote_function3)
Initialization API
Users may choose communication backend for RPC, and users are responsible for setting up the backend properly before calling the init_rpc method.
# backend: specifies the underlying communication implementation
# init_method: contains the information to initialize/connect a name store to
# resolve names
# name: is a unique identifier for the current worker
torch.distributed.init_rpc(backend="pg", init_method="file:///...", name="worker1")
The init_rpc method will create an RpcAgent under the hood and will make the current worker ready to send and receive RPC calls. If you call init_rpc and use the ProcessGroup (pg) backend, it acts as a global barrier, where all the node names as collectively synchronized before continuing. This is not the case if you use a peer to peer backend (e.g. tensor pipes), where calling init_rpc will register the node name in the specified store and start serving.
Applications don’t need to explicitly register functions for remote execution, but we do assume same functions are defined on both caller and callee. This is often true as all workers can import the same set of libraries or even share the same Python script.
Syntax Sugar
Other operations are now implementable using syntax sugar.
Retrieving Value From RRef
# helper private RPC functions
def _identity(v : Tensor) -> Tensor:
# copy the tensor by value to this remote,
return v
def _to_here(v : RRef[T]) -> T:
# take a reference, send it to the device that owns it
# and have that device return the actual tensor by value
return v.local_value()
class RRef[T]:
...
# copy a remote tensor to the local worker, sync version
def to_here(self) -> T:
return torch.rpc(_to_here, self, on=self.owner())
Builtin Operators
# proxy methods for all builtin functions exist on references for
# existing TorchScript types like Tensors. They always follow a fixed pattern:
def _mm(a : RRef[Tensor], b : RRef[Tensor]) -> RRef[Tensor]:
return a.local_value() + b.local_value()
class RRef[Tensor]:
def mm(self : RRef[Tensor], other : RRef[Tensor]) -> RRef[Tensor]:
on = same_worker(self.owner(), other.owner())
return torch.remote(on, _mm, self, other)
c : Tensor = a.mm(b).to_here()
Callable and RRef
If RRef[T] holds a callable object T, the application may directly call the RRef which will be translated into torch.remote call to the owner of the callable.
# if T is callable for RRef[T], rref(x) will be translated to calling T(x)
# on the owner of the RRef
def _call_rref(v : RRef[T], *args):
return v.local_value()(*args)
class RRef[T]:
def __call__(self, *args):
return torch.remote(self.on(), _call_rref, self, *args)
net = torch.remote("Worker1", Net)
net(inputs)
Optimizer and RRef
As models might have remote sub-modules (i.e., RRef[nn.Module]), we should provide an optimizer sugar to handle it. The optimizer sugar (torch.optim.remote) takes a local optimizer constructor, a distributed model parallel model, and an argument list for the local optimizer constructor. The torch.optim.remote recursively creates a local optimizer on every remote sub-module owner, and exposes the same step API as a local optimizer which recursively calls every local optimizer.
class Net1(nn.Module):
...
class Net2(nn.Module):
...
class DMP(nn.Module):
def __init__(self):
self.net1 = dist.remote("worker1", Net1)
self.net2 = dist.remote("worker2", Net2)
dmp = dist.remote("worker0", DMP)
# dist.optimizer creates an optimizer on all RRef owners
optimizer = dist.optimizer(torch.optim.SGD, dmp, lr=0.1)
with dist.autograd.context():
loss = dmp(inputs)
dist.autograd.backward(loss)
optimizer.step()
Model Parallel Training Examples
Multi-Machine Model Training
# 1. load data
inputs_rref = torch.remote("worker1", load_inputs, path_to_inputs)
labels_rref = torch.remote("worker2", load_labels, path_to_inputs)
# 2. define model
class Net1(nn.Module):
...
class Net2(nn.Module):
...
class DMP(nn.Module):
def __init__(self):
self.net1 = torch.remote("worker1", Net1)
self.net2 = torch.remote("worker2", Net2)
def forward(self, inputs_rref):
# RRef[T].__call__(args) is a sugar that translates to
# dist.remote(T, RRef.on(), args)
outputs1_rref = self.net1(inputs_rref)
outputs2_rref = self.net2(outputs1_rref)
return outputs2_rref
# 3. training, run it where you want to call autograd
def train(inputs_rref, labels_rref):
dmp = DMP()
# torch.optim.remote creates an optimizer on every RRef destination
optimizer = dist.optimizer(torch.optim.SGD, dmp, lr=0.1)
outputs_rref = dmp(inputs_rref)
loss = loss_func(outputs_rref.to_here(), labels_rref.to_here())
autograd_ctx_id = dist.autograd.backward(loss)
optimizer.step(autograd_ctx_id)
dist.rpc(dev2, train, args=(inputs_rref, labels_rref))
Parameter Server Training
class ParameterServer:
def __init__(self):
self.params = torch.zeros(100, 100).to(0)
def get_params(self) -> Tensor:
return self.params
def add_grads(self, grad: Tensor):
return self.params += grad.to(0)
def train(ps)
for _ in range(10):
params = torch.rpc("ps", ParameterServer.get_params, args=(ps, ))
# run forward and backward
torch.rpc("ps", ParameterServer.add_grads, args=(ps, params.grad))
torch.distributed.barrier(group=TRAINER_GROUP)
ps = torch.remote("worker1",ParameterServer)
torch.remote("worker2", train, args=(ps,))
torch.remote("worker3", train, args=(ps,))
System Design
Distributed Autograd
Basic Idea
In the first version, dist.autograd.backward does not support RRef arguments, but RRef can still help build the autograd graph. The overall idea is as follows.
- When calling
torch.rpc or RRef.to_here(), send and recv autograd functions will be inserted to connect local autograd graphs on multiple workers into one distributed autograd graph.
- Every distributed backward pass is assigned a globally unique id (
autograd_context_id), and every participating worker will keep a dedicate context for it.
- When the backward computation reaches a
recv function, it packs the gradient and the autograd_context_id in the message, and pass it to its send counterpart.
- Upon receiving a message for a
send function in the backward pass, it uses the autograd_context_id in the message to identify which backward pass it belongs to, and uses the gradient in the message to continue autograd computation locally.
Send and Recv Autograd Functions
Let’s start with a simple example where there is just one synchronized RPC call and there is only one tensor passed across worker boundaries. Code is on the left and the autograd graph is on the right where AccumulateGrad autograd functions for leaf nodes are omitted for simplicity.
# the add function should be
# defined on both workers
def add() -> Tensor:
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = a + b
return c
# make RPC call from worker0
# to execute add on worker1
c1 = dist.rpc(add, on="worker1")
d = torch.ones_like(c1)
e = c1 * d
e.sum().backward()

The send and recv autograd functions are inserted during the forward pass, which connect two local graphs into one distributed graph. In the backward pass, the gradient will be passed to the recv autograd function on worker0, and the recv autograd function will then transmit the gradient tensor to worker1’s send autograd function. Then, worker1 can kick off the local autograd engine to resume the backward pass. There are a few more details need to be clarified in this simple example:
- On
worker1, how do we keep the autograd graph alive after the RPC call returns?
- In short, the distributed autograd engine on
worker1 will keep a reference to the send function which can keep the graph alive.
- Reasoning: The graph can be kept alive by keeping a reference to either tensor
C or the send autograd function, as both of them hold a reference to the add autograd function. We choose to keep a reference to the send function instead of tensor C, because C as a non-leaf node produced by add is not needed in the backward pass. It should be freed as soon as possible. It is not memory efficient to hold C alive just because we want to have an entrance point to the autograd graph.
- In the backward pass, how does
recv on worker0 find the correct send on worker1 to talk to?
- This can be done by assigning a globally unique ID (worker***_id + local send/recv id***) for each
send / recv function pair.
- When can
worker1 delete its local autograd graph?
send should have the same lifetime as its corresponding recv function. This can be done by sending a message from worker0 to worker1 when recv is destructed on worker0. The recv function is kept alive by the loss tensor. So, conceptually, the global autograd graph will be deleted when the final loss tensor is gone.
Hidden Autograd Path and Circular Dependency
Things can become complicated when an autograd graph contains multiple send/recv pairs. Consider the following example.
# all functions shoud be defined on all workers
def worker0_func(c2: Tensor) -> Tensor:
g = torch.rand(2, 2)
h = g + c2
return h
def worker1_func_top() -> Tensor:
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = a + b
return c
def worker1_func_bottom(c: Tensor, e1: Tensor) -> Tensor:
f = c + e1
return f
def worker2_func(c1: Tensor) -> Tensor:
d = torch.rand(2, 2)
e = c1 + d
return e
# on Worker3
c_ref = torch.remote(worker1_func_top, on="Worker1")
h1 = torch.rpc(worker0_func, c_ref, on="Worker0")
e_ref = torch.remote(worker2_func, c_ref, on="Worker2")
f1 = torch.rpc(worker1_funct_bottom, c_ref, e_ref, on="Worker1")
i = h1 + f1
i.sum().backward()

This example highlights two problems that we need to address:
- Hidden Autograd Path: Existing local autograd engine starts from loss (or all outputs), and do a discovery/marking phase to identify all participating functions before executing the real autograd computation. So that all paths in the autograd graph are known upfront. However, we don’t have this luxury in distributed autograd because some parts of the autograd graph reside on remote workers. For example, when grad arrives at
send5, worker1 cannot tell whether send3 will be in the backward pass if it only looks at local information. More specifically, i.sum().backward() will be the same as f1.sum().backward() from worker1’s perspective, but the former involves send3 and the latter does not.
- To address this problem, we propose to record all globally upstream (upstream in the forward pass, downstream in the autograd graph)
send / recv pairs in the forward pass, so that we know exactly which send / recv to wait for in the backward pass.
- Circular Dependency: there are circular dependencies between worker1 and worker2, i.e., it is impossible to finish autograd computation on one worker before kicking off on another worker. One option is to start autograd computation on
worker1 first, and having an autograd thread blocking there waiting for grads for send1, but this is less ideal.
- To address this problem, we propose to only create the
send autograd function and put it in the ready queue when the grad is received. Note that, when computing dependency count for add1, the autograd engine still takes send1 into account, so that the engine will only start computing grads for add1 after both add2 and send1 finish.
Note that we need to record information in the forward pass and do the discovery in the backward pass because we don’t know which send function will be participating in the autograd computation. However, if the application can guarantee that all send functions will receive grad in the backward pass, we can skip all these complexity and have a more efficient version. Both scenarios are useful, so we propose to have two modes:
- Smart Mode supports running backward on a subgraph of the global autograd graph, but there will be extra overhead in both forward and backward pass.
- Fast Mode skips dependency recording in the forward pass and graph discovery in the backward pass, but the application needs to guarantee that all send autograd function will receive grad in the backward pass.
The two sections below describe the two algorithms in more details.
Distributed Autograd Algorithm Smart mode
Forward pass:
For every send x:
- Find
send functions in x’s lineage, by:
- Finds all locally reachable
recv functions from send x in the autograd graph. In the example above, send2 finds recv1, send4 finds recv3, and send5 finds recv2.
- Use those found
recv functions to find globally reachable recv functions in send x’s lineage. Note that this can be done, because in step 2 we send enough information from send to recv. In the example above send4 knows send3, and send5 knows send1 and send2.
- Then,
send x includes ids of its lineage send functions in the message. Intuitively, it means that if there is a grad received for send x, the backward pass must reach all send functions in its lineage as well. It helps a node to determine whether it should wait for a send grad.
# pseudo code to demonstrate how send works in forward
def find_global_lineage(tensor):
# find local lineage
recvs = find_recvs(tensor.grad_fn)
dep_ids = {recv.id for recv in recvs}
# find global lineage
dep_ids.update({dep_id for recv in recvs for dep_id in recv.dep_ids})
return dep_ids
def send(func, tensors, on):
msg = Message(func)
for tensor in tensors:
lineage = find_global_lineage(tensor)
# connect send to autograd graph
send = SendFunc()
send.next = tensor.grad_fn
# remember the send by its id
RpcAgent.send_map[send.id] = send
# coalesce data
msg.data.append((tensor, send.id, lineage))
send_msg(msg, on)
def recv(func, data, from):
tensors = []
for tensor, send_id, lineage in data:
# use send_id as recv_id, and remember global lineage
recv = RecvFunc(send_id, lineage)
tensor.grad_fn = recv
tensors.append(tensor)
return func(tensors)
Backward pass:
On the node that calls torch.distributed.backward:
- Find all
send functions in the lineage of the loss tensor. In the above example, it will be all 5 send functions. These ids will be propagated to the recv functions and will be passed to the counterpart send functions accordingly.
- Optimizations can be added, e.g., drop unnecessary ids in backward pass to reduce message size.
On every node:
- Upon receiving the first message (be it a dedicated discovery message or grad of a send), record its
autograd_context_id, and retrieve all participating send ids from the message. Compute dependency count from those send functions (and also from loss grad_fn if loss is on this node). Set dependency count for send functions as 1. If there is any autograd function has dependency count 0, put them into the ready queue.
- Upon receiving a
send grad, decrement the dependency count of that send by 1, and add it to the ready queue. Note this is done on an RpcAgent thread, and some autograd engine thread will pick up the autograd function for execution.
# pseudo code to demonstrate backward
graph_tasks = {}
def backward(loss):
global graph_tasks
autograd_context_id = gen_autograd_id()
lineage = find_global_lineage(loss)
# these send will participate in the autograd pass
roots = local_sends.intersection(lineage)
# propagate the autograd_id and deps info to all
# participating workers. This is non-blocking and can
# run concurrently with the real backward computation.
# This step is not absolutely necessary, but can help other
# workers to kick off autograd earlier.
disseminate(autograd_context_id, lineage)
# below is a handwaving impl to show how it works with local autograd engine
graph_task = GraphTask()
graph_tasks[autograd_context_id] = graph_task
roots.append(loss.grad_fn)
# setup dependency count properly
compute_dependencies(GraphRoot(roots), graph_task)
# insert the task to local engine ready queue. Only the FunctionTask
# for loss is inserted now, send FunctionTasks will be inserted later
# when their grad becomes available.
ready_queue.push_back(FunctionTask(graph_task, loss.grad_fn, ...))
return autograd_context_id
def on_grad_send(send_id, grad, autograd_id):
global graph_tasks
graph_task = graph_tasks[autograd_id]
send_func = RpcAgent.send_map[send_id]
ready_queue.push_back(FunctionTask(graph_task, send_func, grad))
Distributed Autograd Algorithm Fast mode
The problem with the above approach is that including ids in send / recv messages incurs overhead, especially when there are a lot of tensors communicated across multiple workers. And this discovery phase is only necessary when running autograd on subgraph. For example, f1.sum().loss() requires the discovery phase to avoid waiting for send3, but it is easier for i.sum().loss() as all send are involved in the backward. So, we propose to have one additional mode for distributed autograd to bypass send / recv dependency discovery in both forward and backward if all send for non-leaf or requires_grad tensors will receive grad in the backward pass. The mode can be toggled when initializing RPC agents:
# all_requires_grad (bool): If True, the application guarantees that all
# send functions on non-leaf or requires_grad tensors will receive grad
# in the backward pass. Hence, we can skip the distributed dependency
# discovery algorithm (fast mode). If False, run smart mode, where
# messages beween send/recv will contain dependency ids in both forward
# and backward pass. (default False)
torch.distributed.init_rpc(name, backend="pg", all_requires_grad=False)
Internally, RpcAgent will create a thread-local driver ID, where a driver is the worker that pieces together the autograd graph. In the above example, Worker3 is the driver. In the forward pass, every send function originated from this driver will be tagged with its thread-local driver ID, and this applies to all downstream (upstream in the autograd graph) send functions as well. This can be done by either propagating this driver ID to RPC calls recursively, or do an active driver ID discovery by walking the autograd graph before sending a tensor. If this information is ambiguous, e.g., one send function traces back to two upstream (downstream in the autograd graph) recv functions from two different drivers, it will throw an error. In the backward pass, the thread-local driver id of the loss will be included in the entire autograd execution to identify participating send functions. Note that, in this mode, the application cannot keep two disjoint autograd graphs alive at the same time, as that would break the assumption that all send (originated from the driver) will receive grad in the backward pass.
Concurrent distributed Backward passes
A = torch.rand(2, 2)
B = torch.rand(2, 2)
# on all workers
def add() -> Tensor:
global A, B
return A + B
# on worker0
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()
# on worker1
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()
In the above example, there are two concurrent backward passes triggered by worker0 and worker1 respectively, and both will reach worker2. To avoid race, the distributed autograd engine will use the globally unique autograd_context_id to create a dedicated context on every participating worker. Later, pass this autograd_context_id to optimizer to apply gradients. More concretely, this would work as follows:
- Compute all the leaf nodes in the autograd graph.
- As part of running distributed backwards, use the outputs parameter of the autograd engine to avoid executing
AccumulateGrad for the leaf nodes we have and instead return the appropriate output_edges to execute for accumulating gradients.
- Store the
output_edges with the autograd_context_id. This would ensure multiple backward passes won't accumulate gradients in the same context.
- This completes the backward pass and gradients are accumulated in the autograd engine per
autograd_context_id.
- Now we run the optimizer on each of the worker nodes and pass the
autograd_context_id to the optimizer.
- The optimizer applies all the gradients to the leaf nodes that we computed originally.
- The context and enclosing gradients should be destroyed when the
autograd_context_id is destructed on the caller of backward().
Some pseudo-code to illustrate this:
optimizer = dist.optimizer(model)
loss = model(inputs)
bw_ctx_id = dist.autograd.backward(loss, timeout=60) # timeout of 60s
optimizer.step(bw_ctx_id)
RRef
(more details are described in #26759)
RRef is an important concept for building a distributed autograd graph. Each RRef is owned by a single worker (i.e., owner) and can be used by multiple users. The owner stores the real data referenced by its RRefs, and keeps track of the global reference counts for its RRefs. Every RRef can be uniquely identified by a global id ref_id, which is assigned at the time it is first created either on a user or on the owner.
The owner only keeps one RRef instance for each data object, while users can fork as many RRef instances as necessary. All usage on the owner should retrieve the RRef instance using the globally unique ref_id. A fork of RRef will be created when it is used as an argument or return value in a RPC call, but users don't need to worry about forking/forwarding and reference counting (RC) RRefs. These will be handled transparently, and every fork will also have its own fork_id, which is guaranteed to be unique across all RRef instances for the same data object.
RRef needs to support fast and scalable RPC. Hence, in the RC design, we avoid using any global master to keep RRef states. Besides, when worker X invokes RPC on worker Y, Y should be able to start immediately after receiving the RPC request, without waiting for any third-party owner Z (unless Y needs to pull real data from Z), even if neither X nor Y owns the RRef. We propose the following algorithm:
- If the owner is the RPC caller, the owner will update RC for the
RRef accordingly.
- If the owner is the RPC callee, the owner will drop the new fork, and use the unique
RRef id in the fork to access its singleton local RRef instance.
- If the RPC is between two users:
- The caller sends an RPC message to the callee, and also notifies the owner on the new fork.
- The owner, upon receiving the notification, updates its local RC and then tells the callee the new fork is now known by the owner.
- The callee can starts executing the RPC as soon as it receives the RPC message from the caller, and does not need to wait for the message from the owner. However, it cannot delete its local
RRef fork until owner's message arrives.
Reference Count
The right time to delete an RRef on owner is when there are no living forks on any user and Python GC also agrees to delete the RRef instance on the owner. The tricky part is to determine if there are any living forks.
A user can get a fork in three situations:
- Receiving a fork from the owner.
- Receiving a fork from another user.
- Creating a new
RRef fork owned by another worker.
#1 is the simplest case where the owner initiates the fork, and hence it can easily increase local RC. The only requirement is that any fork must notify the owner before destruction. Hence, we need the first guarantee:
- G1. The owner will be notified when any fork is deleted.*
Note that the notification might come delayed or out-of-order.
With #2 and #3, it is possible that the owner only partially knows the RRef fork graph or not even knowing it at all. For example, the RRef could be constructed on a user, and before the owner receives the RPC call, the creator user might have already shared the RRef with other users, and those users could further share the RRef. One invariant is that the fork graph of any RRef is a tree rooted at the owner, because forking an RRef always creates a new RRef instance, and hence every RRef has a parent. One nasty detail is that when an RRef is created on a user, technically the owner is not its parent but we still consider it that way and it does not break the argument below.
The owner's view on any node (fork) in the tree has three stages 1) unknown → 2) known → 3) deleted, and the owner's view on the entire tree keeps changing. The owner deletes its RRef instance when it thinks there is no living forks, i.e., all the forks could be either indeed deleted or unknown. Therefore, the dangerous case is when some forks are unknown and others are deleted. We only need a simple guarantee to prevent this situation:
*G2. No fork x can be deleted on a user before the owner knows x’s parent fork.
*
This works because owner's view on x can only change from known to deleted when x's parent is known or deleted. If the parent is known, owner will not delete local RRef. If the parent is deleted, this rule recursively applies to the parent's parent, until it reaches the root (owner). To implement the guarantee, we only need to make the caller include its own fork_id when notifying the owner on a new fork.
G1 and G2 guarantee correct RC, but does not prevent a user deleting before finishes its own prior RPC calls using that RRef fork. This should be OK, because when the caller deserializes the RPC message, it would hold a reference () to that RRef, preventing it from been deleted.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera
with @pritamdamania87 @zhaojuanmao @aazzolini @gqchen @pietern @satgera @ezyang @zdevito @suo @manojkris @gchanan @soumith @dzhulgakov @yifuwang @bddppq @joxu-cn @dwarakrajagopal @jspisak
PyTorch currently provides simple APIs for single machine data parallel, distributed data parallel, and single machine model parallel. However, when it comes to distributed model parallel, applications have to build their own scaffold to stitch together local autograd graphs into one global graph. This proposal aims to fill in that gap by providing an RPC-Based distributed model parallel API. In short, applications may run RPC to execute code remotely in the forward pass, and autograd will automatically travel across RPC boundaries in the backward pass.
API
Core Concepts
RRef[T] - (abbreviation ref) A reference to a value of some type
T(e.g. Tensor) on a remote worker. This handle keeps the referenced remote tensor value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. It is valid to have a reference to local value as well, and values of typeTcan be implicitly converted toRRef[T]. This implicit conversion will be critical later to allow the expression of different types of RPC. Think of it like the implicit conversion fromstd::stringtoconst std::string &. See System Design section for more details aboutRRef.Future[T] - (abbreviation fut) a guarantee that at some future point in time the value of type
Twill be available locally. The action to createTlocally is assumed to be scheduled and in-progress. Future is already supported in TorchScript and we are extending this to remote calls.Core Functions
Each function above invokes
remote_callableon a remote worker. Value types in theargslist are copied by value to the remote worker.RRef[T]types in theargslist are copied by reference to the remote worker (again see the analogy betweenstd::stringandconst std::string&).The synchronous variant copies the result value back, blocking the calling thread until the response occurs. The asynchronous variant returns immediately with a future. The remote knows that the call will expect to receive the value so it will send a message back at some point with the result without further prompting.
The remote reference variant returns immediately with an
RRefof the return value. The remote knows that the caller does not expect to receive the result value.Below shows how these functions are used:
RRef Forks
Implicit Conversions for RRef Arguments
We allow implicit conversion between
TandRRef[T]for arguments of RPC functions. Both the actual and formal parameter can either be aTor anRRef[T], leading to four cases that might occur:T → T (passing a T to an rpc that accepts a T): the value T is copied by value, and send over the wire as part of the message invoking the RPC
T → RRef[T] (passing a T to an rpc that accepts RRef[T]): The caller constructs a remote reference to the argument, and sends the reference over the wire to the callee. The data is not sent. The callee can then use the reference as a handle to either request the data later or to make further remote calls.
RRef[T] → T (passing an RRef[T] to an rpc that accepts T): The callee expects to get an actual value, so the callee needs to turn the reference into a value. The network behavior depends on where the
RRef[T]lives.RRef[T]lives on the caller, then the implementation looks up the actual value ofTlocally and pass it by value along the wire similar to the T → T case.RRef[T]lives on the callee, then the implementation just sends the reference and the callee does the lookup locally.RRef[T]lives on some third machine, then the caller sends 2 messages. One to the third machine telling it to send the data in the remote reference directly to the callee, and one to the callee telling it to start the RPC and expect this input to be coming from the third machine. This effectively forward value of theRRef[T]to the callee without the caller having to load it or the callee having to request it laterExamples:
RRef[T] → RRef[T] (**passing an RRef[T] to an RPC that accepts RRef[T]): **The callee expects an
RRef[T], but we must make sure we correctly keep track of references to the value on a remote. So the actual behavior depends on where theRRef[T]lives.RRef[T]lives on the caller, then we simply pass it to the remote and record that this remote now has a live reference to the value.RRef[T]lives on the callee, then we pass it to the remote, and it becomes a local reference on the remote.RRef[T]lives on some third machine, then we must forward the reference. To do this the caller sends two messages. One to the third machine telling it to create a remote reference and send it to the callee, and one to the callee telling from where to expect the remote. The callee code is not invoked until the remote is transferred to ensure sane reference counting.Examples:
When an
RRef[T]goes dead on machine A, a message is sent to the owner ofTtelling it that the reference from machine A is dead.Explicit RRef type for return values
The above implicit
RRefargument conversion does not apply to return values. Ifremote_functionreturnsRRef[T], calling it remotely usingtorch.remotewould returnRRef[RRef[T]]instead ofRRef[T]. This is because when the return valueRRefoftorch.remoteis first created on the caller who does not know the owner of the real dataT. T could be stored on the callee oftorch.remote, but it could also be on a different worker as callee may also make another remote call withinremote_functionand return anRRef[T]owned by a different worker. Moreover, the caller is allowed to share the returnedRRefwith other workers immediately aftertorch.remotereturns. However, as by then, the caller does not know the real owner ofTyet, sharing theRRefwould break the reference count algorithm.Examples:
Initialization API
Users may choose communication backend for RPC, and users are responsible for setting up the backend properly before calling the
init_rpcmethod.The
init_rpcmethod will create anRpcAgentunder the hood and will make the current worker ready to send and receive RPC calls. If you callinit_rpcand use theProcessGroup(pg) backend, it acts as a global barrier, where all the node names as collectively synchronized before continuing. This is not the case if you use a peer to peer backend (e.g. tensor pipes), where callinginit_rpcwill register the node name in the specified store and start serving.Applications don’t need to explicitly register functions for remote execution, but we do assume same functions are defined on both caller and callee. This is often true as all workers can import the same set of libraries or even share the same Python script.
Syntax Sugar
Other operations are now implementable using syntax sugar.
Retrieving Value From RRef
Builtin Operators
Callable and RRef
If
RRef[T]holds a callable objectT, the application may directly call theRRefwhich will be translated intotorch.remotecall to the owner of the callable.Optimizer and RRef
As models might have remote sub-modules (i.e.,
RRef[nn.Module]), we should provide an optimizer sugar to handle it. The optimizer sugar (torch.optim.remote) takes a local optimizer constructor, a distributed model parallel model, and an argument list for the local optimizer constructor. Thetorch.optim.remoterecursively creates a local optimizer on every remote sub-module owner, and exposes the same step API as a local optimizer which recursively calls every local optimizer.Model Parallel Training Examples
Multi-Machine Model Training
Parameter Server Training
System Design
Distributed Autograd
Basic Idea
In the first version,
dist.autograd.backwarddoes not supportRRefarguments, butRRefcan still help build the autograd graph. The overall idea is as follows.torch.rpcorRRef.to_here(),sendandrecvautograd functions will be inserted to connect local autograd graphs on multiple workers into one distributed autograd graph.autograd_context_id), and every participating worker will keep a dedicate context for it.recvfunction, it packs the gradient and theautograd_context_idin the message, and pass it to itssendcounterpart.sendfunction in the backward pass, it uses theautograd_context_idin the message to identify which backward pass it belongs to, and uses the gradient in the message to continue autograd computation locally.Send and Recv Autograd Functions
Let’s start with a simple example where there is just one synchronized RPC call and there is only one tensor passed across worker boundaries. Code is on the left and the autograd graph is on the right where
AccumulateGradautograd functions for leaf nodes are omitted for simplicity.The
sendandrecvautograd functions are inserted during the forward pass, which connect two local graphs into one distributed graph. In the backward pass, the gradient will be passed to therecvautograd function onworker0, and therecvautograd function will then transmit the gradient tensor toworker1’ssendautograd function. Then,worker1can kick off the local autograd engine to resume the backward pass. There are a few more details need to be clarified in this simple example:worker1, how do we keep the autograd graph alive after the RPC call returns?worker1will keep a reference to thesendfunction which can keep the graph alive.Cor thesendautograd function, as both of them hold a reference to theaddautograd function. We choose to keep a reference to thesendfunction instead of tensorC, becauseCas a non-leaf node produced byaddis not needed in the backward pass. It should be freed as soon as possible. It is not memory efficient to hold C alive just because we want to have an entrance point to the autograd graph.recvonworker0find the correctsendonworker1to talk to?send/recvfunction pair.worker1delete its local autograd graph?sendshould have the same lifetime as its correspondingrecvfunction. This can be done by sending a message fromworker0toworker1whenrecvis destructed onworker0. Therecvfunction is kept alive by thelosstensor. So, conceptually, the global autograd graph will be deleted when the final loss tensor is gone.Hidden Autograd Path and Circular Dependency
Things can become complicated when an autograd graph contains multiple send/recv pairs. Consider the following example.
This example highlights two problems that we need to address:
send5, worker1 cannot tell whethersend3will be in the backward pass if it only looks at local information. More specifically,i.sum().backward()will be the same asf1.sum().backward()from worker1’s perspective, but the former involvessend3and the latter does not.send/recvpairs in the forward pass, so that we know exactly whichsend/recvto wait for in the backward pass.worker1first, and having an autograd thread blocking there waiting for grads forsend1, but this is less ideal.sendautograd function and put it in the ready queue when the grad is received. Note that, when computing dependency count foradd1, the autograd engine still takessend1into account, so that the engine will only start computing grads for add1 after bothadd2andsend1finish.Note that we need to record information in the forward pass and do the discovery in the backward pass because we don’t know which
sendfunction will be participating in the autograd computation. However, if the application can guarantee that allsendfunctions will receive grad in the backward pass, we can skip all these complexity and have a more efficient version. Both scenarios are useful, so we propose to have two modes:The two sections below describe the two algorithms in more details.
Distributed Autograd Algorithm Smart mode
Forward pass:
For every
sendx:sendfunctions in x’s lineage, by:recvfunctions fromsendx in the autograd graph. In the example above,send2findsrecv1,send4findsrecv3, andsend5findsrecv2.recvfunctions to find globally reachablerecvfunctions insendx’s lineage. Note that this can be done, because in step 2 we send enough information fromsendtorecv. In the example abovesend4knowssend3, andsend5knowssend1andsend2.sendx includes ids of its lineagesendfunctions in the message. Intuitively, it means that if there is a grad received forsendx, the backward pass must reach allsendfunctions in its lineage as well. It helps a node to determine whether it should wait for asendgrad.Backward pass:
On the node that calls
torch.distributed.backward:sendfunctions in the lineage of the loss tensor. In the above example, it will be all 5sendfunctions. These ids will be propagated to therecvfunctions and will be passed to the counterpartsendfunctions accordingly.On every node:
autograd_context_id, and retrieve all participatingsendids from the message. Compute dependency count from thosesendfunctions (and also from lossgrad_fnif loss is on this node). Set dependency count forsendfunctions as 1. If there is any autograd function has dependency count 0, put them into the ready queue.sendgrad, decrement the dependency count of thatsendby 1, and add it to the ready queue. Note this is done on anRpcAgentthread, and some autograd engine thread will pick up the autograd function for execution.Distributed Autograd Algorithm Fast mode
The problem with the above approach is that including ids in
send/recvmessages incurs overhead, especially when there are a lot of tensors communicated across multiple workers. And this discovery phase is only necessary when running autograd on subgraph. For example,f1.sum().loss()requires the discovery phase to avoid waiting forsend3, but it is easier fori.sum().loss()as allsendare involved in the backward. So, we propose to have one additional mode for distributed autograd to bypasssend/recvdependency discovery in both forward and backward if all send for non-leaf orrequires_gradtensors will receive grad in the backward pass. The mode can be toggled when initializing RPC agents:Internally,
RpcAgentwill create a thread-local driver ID, where a driver is the worker that pieces together the autograd graph. In the above example,Worker3is the driver. In the forward pass, everysendfunction originated from this driver will be tagged with its thread-local driver ID, and this applies to all downstream (upstream in the autograd graph)sendfunctions as well. This can be done by either propagating this driver ID to RPC calls recursively, or do an active driver ID discovery by walking the autograd graph before sending a tensor. If this information is ambiguous, e.g., onesendfunction traces back to two upstream (downstream in the autograd graph)recvfunctions from two different drivers, it will throw an error. In the backward pass, the thread-local driver id of the loss will be included in the entire autograd execution to identify participatingsendfunctions. Note that, in this mode, the application cannot keep two disjoint autograd graphs alive at the same time, as that would break the assumption that all send (originated from the driver) will receive grad in the backward pass.Concurrent distributed Backward passes
In the above example, there are two concurrent backward passes triggered by
worker0andworker1respectively, and both will reachworker2. To avoid race, the distributed autograd engine will use the globally uniqueautograd_context_idto create a dedicated context on every participating worker. Later, pass thisautograd_context_idto optimizer to apply gradients. More concretely, this would work as follows:AccumulateGradfor the leaf nodes we have and instead return the appropriateoutput_edgesto execute for accumulating gradients.output_edgeswith theautograd_context_id. This would ensure multiple backward passes won't accumulate gradients in the same context.autograd_context_id.autograd_context_idto the optimizer.autograd_context_idis destructed on the caller ofbackward().Some pseudo-code to illustrate this:
RRef
(more details are described in #26759)
RRefis an important concept for building a distributed autograd graph. EachRRefis owned by a single worker (i.e., owner) and can be used by multiple users. The owner stores the real data referenced by itsRRefs, and keeps track of the global reference counts for itsRRefs. EveryRRefcan be uniquely identified by a global idref_id, which is assigned at the time it is first created either on a user or on the owner.The owner only keeps one
RRefinstance for each data object, while users can fork as manyRRefinstances as necessary. All usage on the owner should retrieve theRRefinstance using the globally uniqueref_id. A fork ofRRefwill be created when it is used as an argument or return value in a RPC call, but users don't need to worry about forking/forwarding and reference counting (RC)RRefs. These will be handled transparently, and every fork will also have its ownfork_id, which is guaranteed to be unique across allRRefinstances for the same data object.RRefneeds to support fast and scalable RPC. Hence, in the RC design, we avoid using any global master to keepRRefstates. Besides, when worker X invokes RPC on worker Y, Y should be able to start immediately after receiving the RPC request, without waiting for any third-party owner Z (unless Y needs to pull real data from Z), even if neither X nor Y owns theRRef. We propose the following algorithm:RRefaccordingly.RRefid in the fork to access its singleton localRRefinstance.RReffork until owner's message arrives.Reference Count
The right time to delete an
RRefon owner is when there are no living forks on any user and Python GC also agrees to delete theRRefinstance on the owner. The tricky part is to determine if there are any living forks.A user can get a fork in three situations:
RReffork owned by another worker.#1is the simplest case where the owner initiates the fork, and hence it can easily increase local RC. The only requirement is that any fork must notify the owner before destruction. Hence, we need the first guarantee:Note that the notification might come delayed or out-of-order.
With
#2and#3, it is possible that the owner only partially knows theRReffork graph or not even knowing it at all. For example, theRRefcould be constructed on a user, and before the owner receives the RPC call, the creator user might have already shared theRRefwith other users, and those users could further share theRRef. One invariant is that the fork graph of anyRRefis a tree rooted at the owner, because forking anRRefalways creates a newRRefinstance, and hence everyRRefhas a parent. One nasty detail is that when anRRefis created on a user, technically the owner is not its parent but we still consider it that way and it does not break the argument below.The owner's view on any node (fork) in the tree has three stages 1) unknown → 2) known → 3) deleted, and the owner's view on the entire tree keeps changing. The owner deletes its
RRefinstance when it thinks there is no living forks, i.e., all the forks could be either indeed deleted or unknown. Therefore, the dangerous case is when some forks are unknown and others are deleted. We only need a simple guarantee to prevent this situation:*G2. No fork x can be deleted on a user before the owner knows x’s parent fork.
*
This works because owner's view on x can only change from known to deleted when x's parent is known or deleted. If the parent is known, owner will not delete local
RRef. If the parent is deleted, this rule recursively applies to the parent's parent, until it reaches the root (owner). To implement the guarantee, we only need to make the caller include its ownfork_idwhen notifying the owner on a new fork.G1 and G2 guarantee correct RC, but does not prevent a user deleting before finishes its own prior RPC calls using that
RReffork. This should be OK, because when the caller deserializes the RPC message, it would hold a reference () to thatRRef, preventing it from been deleted.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera