xref: /aosp_15_r20/external/pytorch/torch/jit/_fuser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 import contextlib
3 from typing import List, Tuple
4 
5 import torch
6 
7 
8 @contextlib.contextmanager
9 def optimized_execution(should_optimize):
10     """Context manager that controls whether the JIT's executor will run optimizations before executing a function."""
11     stored_flag = torch._C._get_graph_executor_optimize()
12     torch._C._set_graph_executor_optimize(should_optimize)
13     try:
14         yield
15     finally:
16         torch._C._set_graph_executor_optimize(stored_flag)
17 
18 
19 @contextlib.contextmanager
20 def fuser(name):
21     """Context manager that facilitates switching between backend fusers.
22 
23     Valid names:
24     * ``fuser0`` - enables only legacy fuser
25     * ``fuser1`` - enables only NNC
26     * ``fuser2`` - enables only nvFuser
27     * ``fuser3`` - enables oneDNN Graph
28     """
29     old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
30     old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
31     old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
32     old_nvfuser_state = torch._C._jit_nvfuser_enabled()
33     old_llga_state = torch._C._jit_llga_enabled()
34     if name == "fuser0":  # legacy fuser
35         torch._C._jit_override_can_fuse_on_cpu(True)
36         torch._C._jit_override_can_fuse_on_gpu(True)
37         torch._C._jit_set_texpr_fuser_enabled(False)
38         torch._C._jit_set_nvfuser_enabled(False)
39         torch._C._jit_set_llga_enabled(False)
40     elif name == "fuser1":  # NNC
41         old_profiling_executor = torch._C._jit_set_profiling_executor(True)
42         old_profiling_mode = torch._C._get_graph_executor_optimize(True)
43         torch._C._jit_override_can_fuse_on_cpu(True)
44         torch._C._jit_override_can_fuse_on_gpu(True)
45         torch._C._jit_set_texpr_fuser_enabled(True)
46         torch._C._jit_set_nvfuser_enabled(False)
47         torch._C._jit_set_llga_enabled(False)
48     elif name == "fuser2":  # nvFuser
49         torch._C._jit_override_can_fuse_on_cpu(False)
50         torch._C._jit_override_can_fuse_on_gpu(False)
51         torch._C._jit_set_texpr_fuser_enabled(False)
52         torch._C._jit_set_nvfuser_enabled(True)
53         torch._C._jit_set_llga_enabled(False)
54     elif name == "fuser3":  # oneDNN Graph
55         old_profiling_executor = torch._C._jit_set_profiling_executor(True)
56         old_profiling_mode = torch._C._get_graph_executor_optimize(True)
57         torch._C._jit_override_can_fuse_on_cpu(True)
58         torch._C._jit_override_can_fuse_on_gpu(False)
59         torch._C._jit_set_texpr_fuser_enabled(True)
60         torch._C._jit_set_nvfuser_enabled(False)
61         torch._C._jit_set_llga_enabled(True)
62     elif name == "none":  # Turn Pytorch fuser off
63         torch._C._jit_override_can_fuse_on_cpu(False)
64         torch._C._jit_override_can_fuse_on_gpu(False)
65         torch._C._jit_set_texpr_fuser_enabled(False)
66         torch._C._jit_set_nvfuser_enabled(False)
67         torch._C._jit_set_llga_enabled(False)
68     else:
69         raise Exception(f"unrecognized fuser option (name: {name})")  # noqa: TRY002
70     try:
71         yield
72     finally:
73         if name in ["fuser1", "fuser3"]:  # NNC or oneDNN Graph
74             torch._C._jit_set_profiling_executor(old_profiling_executor)  # type: ignore[possibly-undefined]
75             torch._C._get_graph_executor_optimize(old_profiling_mode)  # type: ignore[possibly-undefined]
76         # recover the previous values
77         torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
78         torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
79         torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
80         torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
81         torch._C._jit_set_llga_enabled(old_llga_state)
82 
83 
84 last_executed_optimized_graph = torch._C._last_executed_optimized_graph
85 
86 
87 def _get_differentiable_graph_node(node, diff_node):
88     if node.kind() == "prim::DifferentiableGraph":
89         diff_node.append(node)
90     else:
91         for block in node.blocks():
92             for n in block.nodes():
93                 _get_differentiable_graph_node(n, diff_node)
94 
95 
96 def _graph_for(self, *args, **kwargs):
97     return _script_method_graph_for(self, self, *args, **kwargs)
98 
99 
100 def _script_method_graph_for(self, parent, *args, **kwargs):
101     try:
102         dbs = parent.get_debug_state()
103         eps = list(dbs.execution_plans.values())
104         assert len(eps) == 1
105         graph = eps[0].graph.copy()
106 
107         # graph_executor_states for differentiable node
108         fw_states = eps[0].code.differentiable_op_executor_states()
109         diff_nodes: List[torch._C.Node] = []
110         for n in graph.nodes():
111             _get_differentiable_graph_node(n, diff_nodes)
112 
113         assert len(fw_states) == len(diff_nodes)
114         # swap each differentiable graph with optimized graph in their execution plan
115         for n, state in zip(diff_nodes, fw_states):
116             fw_execution_plans = list(state.execution_plans.values())
117             # we can only update the subgraph when there's a unique execution
118             # plan. Avoid assert here so we would skip the ones that can't be
119             # updated while try the best effort to update other nodes.
120             if len(fw_execution_plans) == 1:
121                 n.g_("Subgraph", fw_execution_plans[0].graph)
122 
123         return graph
124     except Exception:
125         # fallback approach, we just ran the graph and return the recorded optimized
126         # graph
127         self(*args, **kwargs)
128         return last_executed_optimized_graph()
129 
130 
131 def set_fusion_strategy(strategy: List[Tuple[str, int]]):
132     """Set the type and number of specializations that can occur during fusion.
133 
134     Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
135     and depth is an integer.
136 
137     Behavior - static vs dynamic:
138         In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
139         based on some initial profiling runs.
140         In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
141         shapes are possible.
142 
143     In both cases, we also recompile on new striding behavior, device, or dtype.
144 
145     Behavior - fallback functions & depth:
146         When an input doesn't match the format required by the specialized compiled op, it will run
147         a fallback function. Fallback functions are recursively be compiled and specialized based
148         on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
149         limit the number of specializations that can be compiled, before giving up on recompiling and
150         falling back to a completely un-fused, un-specialized implementation.
151 
152     The list of (type, depth) pairs controls the type of specializations and the number of
153     specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first
154     two specializations will use static fusions, the following two specializations will use
155     dynamic fusion, and any inputs that satisfy none of the 4 options will run an
156     unfused implementation.
157 
158     NB: in the future, if more as more fusion backends are added there may be more granular
159     apis for specific fusers.
160     """
161     return torch._C._jit_set_fusion_strategy(strategy)
162