Skip to content

[RFC] Async User Function for RPC #36071

@mrshenli

Description

@mrshenli

Impact

Currently, every RPC request occupies an RPC thread on the server side until done. However, if there are nested RPC calls or other IO operations, we don't have any feature to allow user functions to yield on the server side. As a result, threads processing those requests would idle wait. The Async User Function feature_ aims to fix this by allowing applications to mark a function as async. This function must return an rpc.Future, and RPC framework would install response processing/handling as a callback to that Future.

Pitch

API

import torch.distributed.rpc as rpc

def my_add_2(a, b):
    return a + b

@rpc.async_function
def my_add_3(a, b, c):
    # final output future of this user function
    future_ret = rpc.Future()
    # nested RPC calls
    future_value = rpc.rpc_async(dst2, my_add_2, args=(a, b))
    # finish processing asynchronously when the nested RPC is done.
    future_value.add_callback(
        lambda ret: future_ret.mark_complete(ret + c)
    )
    return future_ret
    
rpc.rpc_sync(dst1, my_add_3, args=(1, 2, 3))

New Concepts

  1. @rpc.async_function decorator marks a function as an async function.
    1. It gives the server side a hint that this is async user function, which can be done by registering the function name on the receiver, similar to how we handle TorchScript functions.
    2. All async functions must return a future.
  2. rpc.Future() creates a future object which can be marked complete by the user function when the value is ready. This is not exactly a new concept, we just need to expose it and allow applications to create an empty Future.
  3. Future.add_callback this is already available in C++, we need to expose it to Python.

Examples

Nested Async User Functions

import torch.distributed.rpc as rpc

def identity(a):
    return a

@rpc.async_function
def my_add_2(a, b):
    fut_ret = rpc.Future()
    rpc_async(dst, identity, args=(a, )).add_callback(
        lambda ret: fut_ret.mark_complete(ret + b)
    )
    return fut_ret
    

@rpc.async_function
def my_add_3(a, b, c):
    fut_ret = rpc.Future()
    rpc.rpc_async(dst, my_add_2, args=(a, b)).add_callback(
        lambda ret: fut_ret.mark_complete(ret + c)
    )
    return fut_ret

Multiple Async Calls in One User Function

def my_add_2(a, b):
    return a + b

@async_function
def my_add_3(a, b, c):
    fut_ret = rpc.Future()
    def bottom_half(x):
        fut2 = rpc.rpc_async(dst, my_add_2, args=(x, x))
        fut2.add_callback(
            lambda ret: fut_ret.mark_complete(ret)
        )
    
    rpc.rpc_async(dst, my_add_2, args=(a, b)).add_callback(bottom_half)
    return fut_ret

Async RPC Fan Out

def my_add_2(a, b):
    return a + b

@async_function
def my_add_4(a, b, c, d):
    fut_ret = rpc.Future()
    rets = []
    lock = threading.Lock()
    
    def barrier(x):
        flag = False
        with lock:
            rets.append(x)
            if len(rets) == 2:
                flag = True
        if flag:
            fut_ret.mark_complete(rets)
            
    rpc.rpc_async(dst2, my_add_2, args=(a, b)).add_callback(barrier)
    rpc.rpc_async(dst2, my_add_2, args=(c, d)).add_callback(barrier) 
    return fut_ret

Batch RPC Requests

lock = threading.Lock()
a_list = []
b_list = []
fut_rets = []
batch_size = 10

@async_function
def batch_add(a, b):
    flag = False
    fut_ret = rpc.Future()
    with lock:
        a_list.append(a)
        b_list.append(b)
        fut_rets.append(fut_ret)
        
        if len(a_list) == batch_size:
            tmp_a_list, a_list = a_list, []
            tmp_b_list, b_list = b_list, []
            tmp_fut_rets, fut_rets = fut_rets, []
            flag = True
            
    if flag:
        aa = torch.stack(a_list)
        bb = torch.stack(b_list)
        cc = aa + bb
        for i in range(len(fut_rets)):
            fut_rets[i].mark_complete(cc[i])
    
    return fut_ret

Discussion

Q: why not using native Python asyncio API?
A: The asyncio proposal is started in Python 3.3. The async/wait syntax is first added in Python 3.5, and then becomes in Python 3.6. As RPC also aim to support earlier Python 3 releases, we cannot rely on a feature that is only stabilized for 3.6+.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions