1 # mypy: allow-untyped-defs 2 """Async API. 3 4 This module contains the API for parallelism in TorchScript, notably: 5 * torch.jit.fork 6 * torch.jit.wait 7 8 This is not intended to be imported directly; please use the exposed 9 functionalities in `torch.jit`. 10 """ 11 12 import torch 13 from torch._jit_internal import Future 14 from torch.jit._builtins import _register_builtin 15 from torch.utils import set_module 16 17 18 set_module(Future, "torch.jit") 19 20 21 def fork(func, *args, **kwargs): 22 r""" 23 Create an asynchronous task executing `func` and a reference to the value of the result of this execution. 24 25 `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion 26 of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked 27 with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily 28 nested, and may be invoked with positional and keyword arguments. 29 Asynchronous execution will only occur when run in TorchScript. If run in pure python, 30 `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked 31 while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph. 32 33 .. warning:: 34 `fork` tasks will execute non-deterministically. We recommend only spawning 35 parallel fork tasks for pure functions that do not modify their inputs, 36 module attributes, or global state. 37 38 Args: 39 func (callable or torch.nn.Module): A Python function or `torch.nn.Module` 40 that will be invoked. If executed in TorchScript, it will execute asynchronously, 41 otherwise it will not. Traced invocations of fork will be captured in the IR. 42 ``*args``, ``**kwargs``: arguments to invoke `func` with. 43 Returns: 44 `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T` 45 can only be accessed by forcing completion of `func` through `torch.jit.wait`. 46 47 Example (fork a free function): 48 49 .. code-block:: python 50 51 import torch 52 from torch import Tensor 53 def foo(a : Tensor, b : int) -> Tensor: 54 return a + b 55 def bar(a): 56 fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) 57 return torch.jit.wait(fut) 58 script_bar = torch.jit.script(bar) 59 input = torch.tensor(2) 60 # only the scripted version executes asynchronously 61 assert script_bar(input) == bar(input) 62 # trace is not run asynchronously, but fork is captured in IR 63 graph = torch.jit.trace(bar, (input,)).graph 64 assert "fork" in str(graph) 65 66 Example (fork a module method): 67 68 .. code-block:: python 69 70 import torch 71 from torch import Tensor 72 class AddMod(torch.nn.Module): 73 def forward(self, a: Tensor, b : int): 74 return a + b 75 class Mod(torch.nn.Module): 76 def __init__(self) -> None: 77 super(self).__init__() 78 self.mod = AddMod() 79 def forward(self, input): 80 fut = torch.jit.fork(self.mod, a, b=2) 81 return torch.jit.wait(fut) 82 input = torch.tensor(2) 83 mod = Mod() 84 assert mod(input) == torch.jit.script(mod).forward(input) 85 """ 86 return torch._C.fork(func, *args, **kwargs) 87 88 89 def wait(future): 90 r""" 91 Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. 92 93 See :func:`~fork` for docs and examples. 94 Args: 95 future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork` 96 Returns: 97 `T`: the return value of the completed task 98 """ 99 return torch._C.wait(future) 100 101 102 _register_builtin(wait, "aten::wait") 103