1#!/usr/bin/env python3 2 3import click 4import numpy as np 5from operator_inp_utils import OperatorInputsLoader 6 7import torch 8from torch._dynamo.backends.cudagraphs import cudagraphs_inner 9from torch._dynamo.testing import same 10from torch._inductor.compile_fx import compile_fx 11from torch._inductor.decomposition import decompositions 12from torch._inductor.lowering import lowerings 13from torch._inductor.runtime.benchmarking import benchmarker 14from torch._inductor.utils import gen_gm_and_inputs 15from torch.utils._pytree import tree_map_only 16 17 18aten = torch.ops.aten 19 20 21def compute_speedups( 22 operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda" 23): 24 expected = models[0](*example_inputs) 25 if accuracy_checking: 26 for model in models[1:]: 27 actual = model(*example_inputs) 28 # change to assert later 29 try: 30 same(actual, expected, cos_similarity=True, equal_nan=True) 31 except AssertionError as e: 32 print(e) 33 print(f"Accuracy check failed: {operator}") 34 print((expected[0] - actual[0]).abs().max()) 35 36 timings = np.zeros((repeats, len(models)), np.float64) 37 for rep in range(repeats): 38 # interleave the runs to handle frequency scaling and load changes 39 for m, model in enumerate(models): 40 if device == "cuda": 41 model(*example_inputs) 42 43 # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time 44 # along with cuda synchronization 45 timings[rep, m] = benchmarker.benchmark_gpu( 46 lambda: model(*example_inputs) 47 ) 48 else: 49 from torch._inductor.utils import timed 50 51 timings[rep, m] = timed(model, example_inputs) 52 return np.median(timings, axis=0) 53 54 55def strip_overloads(gm): 56 """ 57 Modifies the target of graph nodes in :attr:`gm` to strip overloads. 58 Args: 59 gm(fx.GraphModule): The input Fx graph module to be modified 60 """ 61 for node in gm.graph.nodes: 62 if isinstance(node.target, torch._ops.OpOverload): 63 node.target = node.target.overloadpacket 64 gm.recompile() 65 66 67def convert_to_jit(gm, gm_args): 68 strip_overloads(gm) 69 try: 70 return torch.jit.script(gm) 71 except Exception: 72 pass 73 return torch.jit.trace(gm, gm_args) 74 75 76def to_channels_last(ten): 77 return ten if ten.ndim != 4 else ten.to(memory_format=torch.channels_last) 78 79 80def microbenchmark( 81 operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser, device 82): 83 gm, gm_args = gen_gm_and_inputs(operator, args, kwargs) 84 torch.jit._builtins._register_builtin( 85 torch.ops.aten.convolution_backward.default, "aten::convolution_backward" 86 ) 87 if device == "cuda": 88 cudagraphs_eager = cudagraphs_inner( 89 gm, gm_args, copy_outputs=False, copy_inputs=False 90 ) 91 compiled_fn = compile_fx(gm, gm_args) 92 cudagraphs_compiled = cudagraphs_inner( 93 compiled_fn, gm_args, copy_outputs=False, copy_inputs=False 94 ) 95 compiled = [cudagraphs_eager, cudagraphs_compiled] 96 else: 97 compiled_fn = compile_fx(gm, gm_args) 98 compiled = [gm, compiled_fn] 99 if measure_nvfuser: 100 g = convert_to_jit(gm, gm_args) 101 cudagraphs_jit = cudagraphs_inner( 102 g, gm_args, copy_outputs=False, copy_inputs=False 103 ) 104 compiled += [cudagraphs_jit] 105 if accuracy_checking: 106 repeats = 1 107 108 medians = compute_speedups( 109 operator, compiled, gm_args, repeats, accuracy_checking, device 110 ) 111 return medians 112 113 114def skip_operator(operator): 115 nyi_strings = ( 116 "aten.gather.default", 117 "nll_loss", 118 "aten.index", 119 "aten.scatter_", 120 "masked_fill_.Scalar", 121 ) 122 123 if any(nyi_string in str(operator) for nyi_string in nyi_strings): 124 # maybe disable aten.native_layer_norm.default 125 # TODO - inputs cannot be randomly initialized, causes cyda failures 126 print(f"Skipping {operator}, input generator nyi") 127 return True 128 129 # not covered by other non-compute operator heuristics 130 if operator == torch.ops.aten._unsafe_view.default: 131 print(f"Skipping {operator}, non compute operator") 132 return True 133 134 # some of inductor registered to the OpOverload, some registered to OpOverloadPacket 135 op_impls = [operator] 136 if isinstance(operator, torch._ops.OpOverload): 137 op_impls.append(operator.overloadpacket) 138 139 # TODO - skip benchmarking fallbacks. for some ops we have both lowerings and fallbacks 140 # so its not clear just from operator what will be lowered. 141 142 if all(op not in decompositions and op not in lowerings for op in op_impls): 143 print(f"Skipping {operator}, no inductor impl") 144 return True 145 146 if "convolution" in str(operator): 147 return True 148 149 return False 150 151 152@click.command() 153@click.option( 154 "--suite", 155 help="suite to load inps from: options: timm, huggingface, torchbench", 156 default="torchbench", 157) 158@click.option("--op", help="operator overload to benchmark") 159@click.option("--dtype", help="dtype to benchmark") 160@click.option("--max-samples", help="max samples per op", default=15) 161@click.option("--accuracy-checking", help="check accuracy", default=False) 162@click.option( 163 "--repeats", help="how many times to repeat for perf measurement", default=3 164) 165@click.option( 166 "--measure-nvfuser", help="default we only measure inductor", default=False 167) 168@click.option("--device", help="cpu or cuda", default="cuda") 169@click.option("--inp-file", help="use custom input file instead of suite", default=None) 170@click.option("--start-idx", help="specify start index of samples", default=0) 171@click.option( 172 "--channels-last", help="force inputs to channels last", is_flag=True, default=False 173) 174def benchmark( 175 suite, 176 op, 177 dtype, 178 max_samples, 179 accuracy_checking, 180 repeats, 181 measure_nvfuser, 182 device, 183 inp_file, 184 start_idx, 185 channels_last, 186): 187 if inp_file is not None: 188 loader = OperatorInputsLoader(inp_file) 189 else: 190 assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}" 191 if suite == "timm": 192 loader = OperatorInputsLoader.get_timm_loader() 193 elif suite == "huggingface": 194 loader = OperatorInputsLoader.get_huggingface_loader() 195 else: 196 loader = OperatorInputsLoader.get_torchbench_loader() 197 198 assert dtype in ("float16", "float32"), f"got {dtype}" 199 200 if op == "all": 201 filename = f"timings_{suite}_{op.replace('.', '_')}{dtype}.txt" 202 f = open(filename, "a") 203 204 dtype = torch.float16 if dtype == "float16" else torch.float32 205 206 if op == "all": 207 ops = loader.get_all_ops() 208 else: 209 ops = [eval(op)] 210 211 max_samples = max_samples + start_idx 212 for operator in ops: 213 if skip_operator(operator): 214 continue 215 216 print(f"Running {operator}") 217 inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device) 218 timings = [] 219 220 for i in range(min(max_samples, 1000000)): 221 try: 222 inps = next(inp_gen) 223 if inps is None: 224 break 225 if i < start_idx: 226 continue 227 print(f"Iter {i}") 228 args, kwargs = inps 229 if channels_last: 230 args, kwargs = tree_map_only( 231 torch.Tensor, to_channels_last, (args, kwargs) 232 ) 233 234 except StopIteration: 235 break 236 try: 237 # aten, nvfuser, inductor 238 timings.append( 239 microbenchmark( 240 operator, 241 args, 242 kwargs, 243 dtype, 244 accuracy_checking, 245 repeats, 246 measure_nvfuser, 247 device, 248 ) 249 ) 250 except Exception as e: 251 print(f"error {operator}") 252 print(e) 253 # comment out this line to avoid blocking other tests 254 # raise e 255 256 if not timings: 257 continue 258 259 timings = torch.tensor(timings).T 260 q = torch.tensor([0.2, 0.5, 0.8], dtype=torch.float64) 261 output = f"{operator}:\nInductor Speedups : {(torch.quantile(timings[0] / timings[1], q)).tolist()}\n" 262 if measure_nvfuser: 263 output += f"NVFUSER Speedups :{(torch.quantile(timings[0] / timings[2], q)).tolist()}\n" 264 if op == "all": 265 f.write(output) 266 print(output) 267 268 if op == "all": 269 f.close() 270 271 272if __name__ == "__main__": 273 benchmark() 274