README.md
1# NVFuser - A Fusion Code Generator for NVIDIA GPUs
2_NVFuser is integrated as a backend for TorchScript's Profiling Graph Executor. NVFuser is the default fuser for NVIDIA GPUs._
3
4## Simple knobs to change fusion behavior
5
61. Allow single node fusion `torch._C._jit_set_nvfuser_single_node_mode(True)`
7Fusion group is only created when two or more compatible ops are grouped together. Turn on single node fusion would allow fusion pass to create fusion group with a single node, this is very handy for testing and could be useful when single node generated kernel out-performs native cuda kernels in framework.
8
92. Allow horizontal fusion `torch._C._jit_set_nvfuser_horizontal_mode(True)`
10Fusion pass fuses producer to consumer, horizontal mode allows sibling nodes that shared tensor input to be fused together. This could save input memory bandwidth.
11
123. Turn off guard for fusion `torch._C._jit_set_nvfuser_guard_mode(False)`
13This disables the runtime check on fusion group pre-assumptions (tensor meta information / constant inputs / profiled constants), this really is only used for testing as we want to ensure generated kernels are indeed tested and you should avoid using this in training scripts.
14
154. Turn off fusion for certain node kinds `torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True)`
16This disables fusion for certain nodes, but allows other nodes to continue being fused. The first parameter is the node kind, and the second parameter is whether to toggle the node on or off in fusion.
17
18## Fusion Debugging
19
20Given the following script as an example
21
22```
23import torch
24
25def forward(x):
26 o = x + 1.0
27 o = o.relu()
28 return o
29
30shape = (2, 32, 128, 512)
31input = torch.rand(*shape).cuda()
32t = torch.jit.script(forward)
33
34with torch.jit.fuser("fuser2"):
35 for k in range(4):
36 o = t(input)
37```
38
39### TorchScript Based Debugging
40
41#### 1. TorchScript IR Graph
42
43##### Usage
44
45Two easy ways to checkout fusion for graph: The first one is to print out graph in python script after a few runs (for optimization to kick in).
46
47`print(t.graph_for(input))`
48
49The second way is to turn on graph dumping in profiling executor via command line below:
50
51```
52PYTORCH_JIT_LOG_LEVEL="profiling_graph_executor_impl" python <your pytorch script>
53```
54
55##### Example Output
56
57Graph print out is straight forward and you should look for `prim::CudaFusionGroup_X` for fused kernels. While profiling executor dumps many things, but the most important part is `Optimized Graph`. In this example, it shows a Fusion Group, which is an indication that fusion is happening and you should be expecting fused kernel!
58
59```
60 Optimized Graph:
61 graph(%x.1 : Tensor):
62 %12 : bool = prim::CudaFusionGuard[types=[Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)]](%x.1)
63 %11 : Tensor = prim::If(%12)
64 block0():
65 %o.8 : Tensor = prim::CudaFusionGroup_0[cache_id=0](%x.1)
66 -> (%o.8)
67 block1():
68 %18 : Function = prim::Constant[name="fallback_function", fallback=1]()
69 %19 : (Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)) = prim::CallFunction(%18, %x.1)
70 %20 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = prim::TupleUnpack(%19)
71 -> (%20)
72 return (%11)
73 with prim::CudaFusionGroup_0 = graph(%2 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)):
74 %4 : int = prim::Constant[value=1]()
75 %3 : float = prim::Constant[value=1.]() # test.py:6:12
76 %o.1 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::add(%2, %3, %4) # test.py:6:8
77 %o.5 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::relu(%o.1) # test.py:7:8
78 return (%o.5)
79```
80
81Note that one thing that could prevents fusion when you are running training is autodiff. Fusion pass only runs within `prim::DifferentiableGraph`, so the first thing you should check is to that targetted ops are within differentiable graph subgraphs.
82Graph dump could be quite confusing to look at, since it naively dumps all graphs executed by profiling executor and differentiable graphs are executed via a nested graph executor. So for each graph, you might see a few segmented `Optimized Graph` where each corresponds to a differentiable node in the original graph.
83
84#### 2. Cuda Fusion Graphs
85
86##### Usage
87
88Cuda fusion dump gives the input and output graph to fusion pass. This is a good place to check fusion pass logic.
89
90```
91PYTORCH_JIT_LOG_LEVEL="graph_fuser" python <your pytorch script>
92```
93
94##### Example Output
95
96Running the same script above, in the log, you should be looking for two graphs `Before Fusion` shows the subgraph where fusion pass runs on; `Before Compilation` shows the graph sent to codegen backend, where each `CudaFusionGroup` will trigger codegen runtime system to generate kernel(s) to execute the subgraph.
97
98```
99 Before Fusion:
100 graph(%x.1 : Tensor):
101 %2 : float = prim::Constant[value=1.]()
102 %1 : int = prim::Constant[value=1]()
103 %3 : Tensor = prim::profile[profiled_type=Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)](%x.1)
104 %o.10 : Tensor = aten::add(%3, %2, %1) # test.py:6:8
105 %5 : Tensor = prim::profile[profiled_type=Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)](%o.10)
106 %o.7 : Tensor = aten::relu(%5) # test.py:7:8
107 %7 : Tensor = prim::profile[profiled_type=Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)](%o.7)
108 %8 : Tensor = prim::profile[profiled_type=Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)](%o.7)
109 return (%7, %8)
110
111 Before Compilation:
112 graph(%x.1 : Tensor):
113 %13 : bool = prim::CudaFusionGuard[types=[Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)]](%x.1)
114 %12 : Tensor = prim::If(%13)
115 block0():
116 %o.11 : Tensor = prim::CudaFusionGroup_0(%x.1)
117 -> (%o.11)
118 block1():
119 %o.7 : Tensor = prim::FallbackGraph_1(%x.1)
120 -> (%o.7)
121 return (%12, %12)
122 with prim::CudaFusionGroup_0 = graph(%2 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)):
123 %4 : int = prim::Constant[value=1]()
124 %3 : float = prim::Constant[value=1.]()
125 %o.10 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::add(%2, %3, %4) # test.py:6:8
126 %o.7 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::relu(%o.10) # test.py:7:8
127 return (%o.7)
128 with prim::FallbackGraph_1 = graph(%x.1 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0)):
129 %1 : int = prim::Constant[value=1]()
130 %2 : float = prim::Constant[value=1.]()
131 %o.10 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::add(%x.1, %2, %1) # test.py:6:8
132 %o.7 : Float(2, 32, 128, 512, strides=[2097152, 65536, 512, 1], requires_grad=0, device=cuda:0) = aten::relu(%o.10) # test.py:7:8
133 return (%o.7)
134```
135
136### General ideas of debug no-fusion
137
138Currently there we have a few consumers that utilizes nvfuser via lowering computations to TorchScript and executing that through a ProfilingExecutor.
139
140Without going into too much details about how the integration is done, a few notes on debugging no-fusion on ProfilingExecutor:
141
1421. Run TorchScript module multiple times (5 could be a lucky number) to enable fusion.
143 Because ProfilingExecutor takes the first (few) runs for profiling, later optimization (including the fusion pass the enables nvfuser) relies on profiling information to run, so your initial runs are not going to trigger fused kernels.
144 Note that the number of profiling runs is dependent on your model.
145
1462. Fused kernel should show up in TorchScript IR as `prim::CudaFusionGroup`. You can look at your TorchScript optimized graph to see if fusion is happening `jit_model.graph_for(*inputs)`.
147
1483. If your scripted model has inputs requiring gradient, fusion is only happening for graphs inside `prim::DifferentiableGraph`.
149 There are many reasons why your graph is not autodiff-able. Take a look at `/torch/csrc/jit/runtime/symbolic_scripts.cpp`, which lists all autodiff-able ops (note that this is a different list from autograd-supported ops). There's also a threshold where tiny autodiff graph are inlined/reverted, which could be disabled via `torch._C._debug_set_autodiff_subgraph_inlining(False)`.
150
151### General ideas of debug nvfuser mal-functioning
152
153Assuming we have ProfilingExecutor things worked out properly, that is, you see a region that's supposed to be fused but did not ended up in a fused kernel, here's ways to dig deeper:
154
1551. Dump fusion pass result:
156 `PYTORCH_JIT_LOG_LEVEL=graph_fuser python your_script.py &> log`
157
158 Looks for graph dumped with `Before Fusion` & `Before Compilation`, which shows the portion of graph where fusion pass runs on and the result of fusion (`CudaFusionGroup`).
159
1602. Check out which ops are not fused and roughly why:
161 `PYTORCH_JIT_LOG_LEVEL=">partition:graph_fuser" python your_script.py &> log`
162
163 Enabling GRAPH_UPDATE from partition.cpp dumps a log when a given node is rejected by fusion.
164
1653. Disabling FALLBACK path:
166 If you see a warning where a FALLBACK path has been taken while executing your model with nvfuser enabled, it's indicating that either codegen or fusion pass has failed unexpectedly. This is likely to cause regression on model performance, even though it's still functionally correct. We recommend to disable FALLBACK path, so error would be reported properly to open an informative issue.
167
168 `PYTORCH_NVFUSER_DISABLE=fallback python your_script.py &> log`
169
1704. Pin point kernel/fusion pattern that's causing error:
171 With a larger model that includes multiple fusion patterns, it could be tricky to figure out which exact fusion is causing FALLBACK and build up a minimal python repro.
172 One quick thing to try is to run the example with a few knobs turned on:
173
174 ```
175 PYTORCH_NVFUSER_DISABLE=fallback \
176 PYTORCH_JIT_LOG_LEVEL=">partition:graph_fuser:>>kernel_cache" \
177 python your_script.py &> log
178 ```
179
180 This logs all TorchScript IR parsed to codegen IR as well as kernel generated and executed by nvfuser. Since fallback path is disabled, it's likely that the last log would indicate the failing fusion.
181
182 Hint: look for last `Before Compilation:` that indicates a parsing failure, or `running GraphCache: xxxxx`, which indicates jit compilation/execution failure (also search for the GraphCache address, which would should have dumped a TorchScript IR earlier.
183
184### Query nvfuser codegen kernels
185
186There're a few debug dump that could be turned on via environment variables. Look for `PYTORCH_NVFUSER_DUMP` inside `[pytorch_source_path]/torch/csrc/jit/codegen/cuda/utils.cpp`. A few useful ones are:
1871. `dump_eff_bandwidth`: print out effective bandwidth of each generated kernel. This naively measure the kernel time divided by I/O buffer size and is a good/simple metric of performance for bandwidth bound kernels
1882. `cuda_kernel`: print out generated cuda kernels
1893. `launch_param`: print out launch config of generated kernels
1904. `kernel_args`: print out input/output/buffer tensors of all executed codegen kernels, note that for buffers, we indicate whether they are zero-initialized, which hints on an extra kernel to fill the tensor before codegen kernels.
191
192### FAQs
193
1941. There's regression after turning on nvfuser.
195
196First thing is to check that you have fusion kernel running properly. Try to run your model with fallback disabled to see if you hit any errors that caused fallback via `export PYTORCH_NVFUSER_DISABLE=fallback`.
197
198If turning on NVFuser produces unexpected outputs, set the `PYTORCH_NVFUSER_DISABLE` environment variable to disable some of the optional features, e.g.:
199- `fma`: disable using FMA instructions
200- `index_hoist`: disable optimization to hoist common index expressions
201- `predicate_elimination`: disable optimization to eliminate redundant predicates
202- `unroll_with_rng`: disable unrolling when RNG is used
203
204For example, `export PYTORCH_NVFUSER_DISABLE=fma,index_hoist` would disable FMA and index hoisting.
205
2062. I didn't see any speedup with nvfuser.
207
208Check if there is fusion in your script model. Run your script with `PYTORCH_JIT_LOG_LEVEL="graph_fuser"`, you should see some log dump of before/after graph regarding fusion pass. If nothing shows up in the log, that means something in TorchScript is not right and fusion pass are not executed. Check [General ideals of debug no-fusion] for more details.
209
2103. I ran into codegen issues with nvfuser, how do I disable nvfuser?
211
212There are three ways to disable nvfuser. Listed below with descending priorities:
213
214- Force using NNC instead of nvfuser for GPU fusion with env variable `export PYTORCH_JIT_USE_NNC_NOT_NVFUSER=1`.
215- Disabling nvfuser with torch API `torch._C._jit_set_nvfuser_enabled(False)`.
216- Disable nvfuser with env variable `export PYTORCH_JIT_ENABLE_NVFUSER=0`.
217
2184. Is there any more knobs to tune nvfuser fusion?
219
220Some opt-out features in nvfuser are exposed via env var `PYTORCH_NVFUSER_DISABLE`. e.g. `fallback` to disable aten fallback during compilation failure and `fma` to disable fused multiply-add, you would set `export PYTORCH_NVFUSER_DISABLE="fallback,fma"`. Note that disabling fma would usually regress on performance so we strongly encourage to not disable it.
221
222There's also opt-in features via env var `PYTORCH_NVFUSER_ENABLE`.
223- `complex` would enable complex floating type support in nvfuser (currently experimental and turned off by default to avoid functional regression);
224- `linear_decomposition` enables decomposition of the bias add in linear layer. Similarly, `conv_decomposition` enables decomposition of the bias add in conv layer. In some small benchmark models, we noticed that such decompositions added more overhead in compilation that out-weighs the benefit of faster kernel. Hence we decided to change these to be opt-in instead.
225