@@ -789,6 +789,101 @@ def wrap_check_inputs(check_inputs):
789789
790790 return [{'forward' : c } for c in check_inputs ]
791791
792+
793+ def fork (func , * args , ** kwargs ):
794+ """
795+ Creates an asynchronous task executing `func` and a reference to the value
796+ of the result of this execution. `fork` will return immediately,
797+ so the return value of `func` may not have been computed yet. To force completion
798+ of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
799+ with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
800+ nested, and may be invoked with positional and keyword arguments.
801+
802+ Asynchronous execution will only occur when run in TorchScript. If run in pure python,
803+ `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
804+ while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
805+
806+ Warning:
807+ `fork` tasks will execute non-deterministicly. We recommend only spawning
808+ parallel fork tasks for pure functions that do not modify their inputs,
809+ module attributes, or global state.
810+
811+ Arguments:
812+ func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
813+ that will be invoked. If executed in TorchScript, it will execute asynchronously,
814+ otherwise it will not. Traced invocations of fork will be captured in the IR.
815+
816+ *args, **kwargs: arguments to invoke `func` with.
817+
818+ Returns:
819+ `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
820+ can only be accessed by forcing completion of `func` through `torch.jit.wait`.
821+
822+ Example (fork a free function):
823+
824+ .. testcode::
825+
826+ import torch
827+ from torch import Tensor
828+
829+ def foo(a : Tensor, b : int) -> Tensor:
830+ return a + b
831+
832+ def bar(a):
833+ fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
834+ return torch.jit.wait(fut)
835+
836+ script_bar = torch.jit.script(bar)
837+ input = torch.tensor(2)
838+
839+ # only the scripted version executes asynchronously
840+ assert script_bar(input) == bar(input)
841+
842+ # trace is not run asynchronously, but fork is captured in IR
843+ graph = torch.jit.trace(bar, (input,)).graph
844+ assert "fork" in str(graph)
845+
846+ Example (fork a module method):
847+
848+ .. testcode::
849+
850+ import torch
851+ from torch import Tensor
852+
853+ class SubMod(torch.nn.Module):
854+ def forward(self, a: Tensor, b : int):
855+ return a + b
856+
857+ class Mod(torch.nn.Module):
858+ def __init__(self):
859+ super(self).__init__()
860+ self.mod = SubMod()
861+
862+ def forward(self, input):
863+ fut = torch.jit.fork(self.mod, a, b=2)
864+ return torch.jit.wait(fut)
865+
866+ input = torch.tensor(2)
867+ mod = Mod()
868+
869+ assert mod(input) == torch.jit.script(mod).forward(input)
870+ """
871+ return torch ._C .fork (func , * args , ** kwargs )
872+
873+ def wait (future ):
874+ """
875+ Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the
876+ result of the task. See :func:`~fork` for docs and examples.
877+
878+ Arguments:
879+ func (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
880+
881+ Returns:
882+ `T`: the return value of the the completed task
883+ """
884+ return torch ._C .wait (future )
885+
886+
792887def trace (func ,
793888 example_inputs ,
794889 optimize = None ,
0 commit comments