xref: /aosp_15_r20/external/pytorch/torch/jit/_async.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 """Async API.
3 
4 This module contains the API for parallelism in TorchScript, notably:
5     * torch.jit.fork
6     * torch.jit.wait
7 
8 This is not intended to be imported directly; please use the exposed
9 functionalities in `torch.jit`.
10 """
11 
12 import torch
13 from torch._jit_internal import Future
14 from torch.jit._builtins import _register_builtin
15 from torch.utils import set_module
16 
17 
18 set_module(Future, "torch.jit")
19 
20 
21 def fork(func, *args, **kwargs):
22     r"""
23     Create an asynchronous task executing `func` and a reference to the value of the result of this execution.
24 
25     `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion
26     of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
27     with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
28     nested, and may be invoked with positional and keyword arguments.
29     Asynchronous execution will only occur when run in TorchScript. If run in pure python,
30     `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
31     while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
32 
33     .. warning::
34         `fork` tasks will execute non-deterministically. We recommend only spawning
35         parallel fork tasks for pure functions that do not modify their inputs,
36         module attributes, or global state.
37 
38     Args:
39         func (callable or torch.nn.Module):  A Python function or `torch.nn.Module`
40             that will be invoked. If executed in TorchScript, it will execute asynchronously,
41             otherwise it will not. Traced invocations of fork will be captured in the IR.
42         ``*args``, ``**kwargs``: arguments to invoke `func` with.
43     Returns:
44         `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
45         can only be accessed by forcing completion of `func` through `torch.jit.wait`.
46 
47     Example (fork a free function):
48 
49     .. code-block:: python
50 
51         import torch
52         from torch import Tensor
53         def foo(a : Tensor, b : int) -> Tensor:
54             return a + b
55         def bar(a):
56             fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
57             return torch.jit.wait(fut)
58         script_bar = torch.jit.script(bar)
59         input = torch.tensor(2)
60         # only the scripted version executes asynchronously
61         assert script_bar(input) == bar(input)
62         # trace is not run asynchronously, but fork is captured in IR
63         graph = torch.jit.trace(bar, (input,)).graph
64         assert "fork" in str(graph)
65 
66     Example (fork a module method):
67 
68     .. code-block:: python
69 
70         import torch
71         from torch import Tensor
72         class AddMod(torch.nn.Module):
73             def forward(self, a: Tensor, b : int):
74                 return a + b
75         class Mod(torch.nn.Module):
76             def __init__(self) -> None:
77                 super(self).__init__()
78                 self.mod = AddMod()
79             def forward(self, input):
80                 fut = torch.jit.fork(self.mod, a, b=2)
81                 return torch.jit.wait(fut)
82         input = torch.tensor(2)
83         mod = Mod()
84         assert mod(input) == torch.jit.script(mod).forward(input)
85     """
86     return torch._C.fork(func, *args, **kwargs)
87 
88 
89 def wait(future):
90     r"""
91     Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task.
92 
93     See :func:`~fork` for docs and examples.
94     Args:
95         future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
96     Returns:
97         `T`: the return value of the completed task
98     """
99     return torch._C.wait(future)
100 
101 
102 _register_builtin(wait, "aten::wait")
103