Skip to content

Commit 0d06085

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
[JIT] Fork/Join inline docs (#39952)
Summary: Pull Request resolved: #39952 Differential Revision: D22165477 Pulled By: eellison fbshipit-source-id: 93132cd6987fdd2484112a57ef17912b8fcc5fab
1 parent 13a8ec3 commit 0d06085

2 files changed

Lines changed: 97 additions & 2 deletions

File tree

test/jit/test_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def foo(x):
2020
return torch.neg(x)
2121

2222
x = torch.rand(3, 4)
23-
fut = torch.jit._fork(foo, x)
23+
fut = torch.jit.fork(foo, x)
2424
y_hat = foo(x)
2525
y = torch.jit._wait(fut)
2626
# assert nothing; only to make sure the fake python path works
@@ -32,7 +32,7 @@ def foo(inp):
3232
futures.append(torch.jit._fork(lambda x: x, inp))
3333
all_outputs = []
3434
for future in futures:
35-
all_outputs.append(torch.jit._wait(future))
35+
all_outputs.append(torch.jit.wait(future))
3636
return all_outputs
3737

3838
# assert nothing, just to make sure python type parsing works

torch/jit/__init__.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,101 @@ def wrap_check_inputs(check_inputs):
789789

790790
return [{'forward' : c} for c in check_inputs]
791791

792+
793+
def fork(func, *args, **kwargs):
794+
"""
795+
Creates an asynchronous task executing `func` and a reference to the value
796+
of the result of this execution. `fork` will return immediately,
797+
so the return value of `func` may not have been computed yet. To force completion
798+
of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
799+
with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
800+
nested, and may be invoked with positional and keyword arguments.
801+
802+
Asynchronous execution will only occur when run in TorchScript. If run in pure python,
803+
`fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
804+
while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
805+
806+
Warning:
807+
`fork` tasks will execute non-deterministicly. We recommend only spawning
808+
parallel fork tasks for pure functions that do not modify their inputs,
809+
module attributes, or global state.
810+
811+
Arguments:
812+
func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
813+
that will be invoked. If executed in TorchScript, it will execute asynchronously,
814+
otherwise it will not. Traced invocations of fork will be captured in the IR.
815+
816+
*args, **kwargs: arguments to invoke `func` with.
817+
818+
Returns:
819+
`torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
820+
can only be accessed by forcing completion of `func` through `torch.jit.wait`.
821+
822+
Example (fork a free function):
823+
824+
.. testcode::
825+
826+
import torch
827+
from torch import Tensor
828+
829+
def foo(a : Tensor, b : int) -> Tensor:
830+
return a + b
831+
832+
def bar(a):
833+
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
834+
return torch.jit.wait(fut)
835+
836+
script_bar = torch.jit.script(bar)
837+
input = torch.tensor(2)
838+
839+
# only the scripted version executes asynchronously
840+
assert script_bar(input) == bar(input)
841+
842+
# trace is not run asynchronously, but fork is captured in IR
843+
graph = torch.jit.trace(bar, (input,)).graph
844+
assert "fork" in str(graph)
845+
846+
Example (fork a module method):
847+
848+
.. testcode::
849+
850+
import torch
851+
from torch import Tensor
852+
853+
class SubMod(torch.nn.Module):
854+
def forward(self, a: Tensor, b : int):
855+
return a + b
856+
857+
class Mod(torch.nn.Module):
858+
def __init__(self):
859+
super(self).__init__()
860+
self.mod = SubMod()
861+
862+
def forward(self, input):
863+
fut = torch.jit.fork(self.mod, a, b=2)
864+
return torch.jit.wait(fut)
865+
866+
input = torch.tensor(2)
867+
mod = Mod()
868+
869+
assert mod(input) == torch.jit.script(mod).forward(input)
870+
"""
871+
return torch._C.fork(func, *args, **kwargs)
872+
873+
def wait(future):
874+
"""
875+
Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the
876+
result of the task. See :func:`~fork` for docs and examples.
877+
878+
Arguments:
879+
func (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
880+
881+
Returns:
882+
`T`: the return value of the the completed task
883+
"""
884+
return torch._C.wait(future)
885+
886+
792887
def trace(func,
793888
example_inputs,
794889
optimize=None,

0 commit comments

Comments
 (0)