xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/operatorbench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import click
4import numpy as np
5from operator_inp_utils import OperatorInputsLoader
6
7import torch
8from torch._dynamo.backends.cudagraphs import cudagraphs_inner
9from torch._dynamo.testing import same
10from torch._inductor.compile_fx import compile_fx
11from torch._inductor.decomposition import decompositions
12from torch._inductor.lowering import lowerings
13from torch._inductor.runtime.benchmarking import benchmarker
14from torch._inductor.utils import gen_gm_and_inputs
15from torch.utils._pytree import tree_map_only
16
17
18aten = torch.ops.aten
19
20
21def compute_speedups(
22    operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda"
23):
24    expected = models[0](*example_inputs)
25    if accuracy_checking:
26        for model in models[1:]:
27            actual = model(*example_inputs)
28            # change to assert later
29            try:
30                same(actual, expected, cos_similarity=True, equal_nan=True)
31            except AssertionError as e:
32                print(e)
33                print(f"Accuracy check failed: {operator}")
34                print((expected[0] - actual[0]).abs().max())
35
36    timings = np.zeros((repeats, len(models)), np.float64)
37    for rep in range(repeats):
38        # interleave the runs to handle frequency scaling and load changes
39        for m, model in enumerate(models):
40            if device == "cuda":
41                model(*example_inputs)
42
43                # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time
44                # along with cuda synchronization
45                timings[rep, m] = benchmarker.benchmark_gpu(
46                    lambda: model(*example_inputs)
47                )
48            else:
49                from torch._inductor.utils import timed
50
51                timings[rep, m] = timed(model, example_inputs)
52    return np.median(timings, axis=0)
53
54
55def strip_overloads(gm):
56    """
57    Modifies the target of graph nodes in :attr:`gm` to strip overloads.
58    Args:
59        gm(fx.GraphModule): The input Fx graph module to be modified
60    """
61    for node in gm.graph.nodes:
62        if isinstance(node.target, torch._ops.OpOverload):
63            node.target = node.target.overloadpacket
64    gm.recompile()
65
66
67def convert_to_jit(gm, gm_args):
68    strip_overloads(gm)
69    try:
70        return torch.jit.script(gm)
71    except Exception:
72        pass
73    return torch.jit.trace(gm, gm_args)
74
75
76def to_channels_last(ten):
77    return ten if ten.ndim != 4 else ten.to(memory_format=torch.channels_last)
78
79
80def microbenchmark(
81    operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser, device
82):
83    gm, gm_args = gen_gm_and_inputs(operator, args, kwargs)
84    torch.jit._builtins._register_builtin(
85        torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
86    )
87    if device == "cuda":
88        cudagraphs_eager = cudagraphs_inner(
89            gm, gm_args, copy_outputs=False, copy_inputs=False
90        )
91        compiled_fn = compile_fx(gm, gm_args)
92        cudagraphs_compiled = cudagraphs_inner(
93            compiled_fn, gm_args, copy_outputs=False, copy_inputs=False
94        )
95        compiled = [cudagraphs_eager, cudagraphs_compiled]
96    else:
97        compiled_fn = compile_fx(gm, gm_args)
98        compiled = [gm, compiled_fn]
99    if measure_nvfuser:
100        g = convert_to_jit(gm, gm_args)
101        cudagraphs_jit = cudagraphs_inner(
102            g, gm_args, copy_outputs=False, copy_inputs=False
103        )
104        compiled += [cudagraphs_jit]
105    if accuracy_checking:
106        repeats = 1
107
108    medians = compute_speedups(
109        operator, compiled, gm_args, repeats, accuracy_checking, device
110    )
111    return medians
112
113
114def skip_operator(operator):
115    nyi_strings = (
116        "aten.gather.default",
117        "nll_loss",
118        "aten.index",
119        "aten.scatter_",
120        "masked_fill_.Scalar",
121    )
122
123    if any(nyi_string in str(operator) for nyi_string in nyi_strings):
124        # maybe disable aten.native_layer_norm.default
125        # TODO - inputs cannot be randomly initialized, causes cyda failures
126        print(f"Skipping {operator}, input generator nyi")
127        return True
128
129    # not covered by other non-compute operator heuristics
130    if operator == torch.ops.aten._unsafe_view.default:
131        print(f"Skipping {operator}, non compute operator")
132        return True
133
134    # some of inductor registered to the OpOverload, some registered to OpOverloadPacket
135    op_impls = [operator]
136    if isinstance(operator, torch._ops.OpOverload):
137        op_impls.append(operator.overloadpacket)
138
139    # TODO - skip benchmarking fallbacks. for some ops we have both lowerings and fallbacks
140    # so its not clear just from operator what will be lowered.
141
142    if all(op not in decompositions and op not in lowerings for op in op_impls):
143        print(f"Skipping {operator}, no inductor impl")
144        return True
145
146    if "convolution" in str(operator):
147        return True
148
149    return False
150
151
152@click.command()
153@click.option(
154    "--suite",
155    help="suite to load inps from: options: timm, huggingface, torchbench",
156    default="torchbench",
157)
158@click.option("--op", help="operator overload to benchmark")
159@click.option("--dtype", help="dtype to benchmark")
160@click.option("--max-samples", help="max samples per op", default=15)
161@click.option("--accuracy-checking", help="check accuracy", default=False)
162@click.option(
163    "--repeats", help="how many times to repeat for perf measurement", default=3
164)
165@click.option(
166    "--measure-nvfuser", help="default we only measure inductor", default=False
167)
168@click.option("--device", help="cpu or cuda", default="cuda")
169@click.option("--inp-file", help="use custom input file instead of suite", default=None)
170@click.option("--start-idx", help="specify start index of samples", default=0)
171@click.option(
172    "--channels-last", help="force inputs to channels last", is_flag=True, default=False
173)
174def benchmark(
175    suite,
176    op,
177    dtype,
178    max_samples,
179    accuracy_checking,
180    repeats,
181    measure_nvfuser,
182    device,
183    inp_file,
184    start_idx,
185    channels_last,
186):
187    if inp_file is not None:
188        loader = OperatorInputsLoader(inp_file)
189    else:
190        assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}"
191        if suite == "timm":
192            loader = OperatorInputsLoader.get_timm_loader()
193        elif suite == "huggingface":
194            loader = OperatorInputsLoader.get_huggingface_loader()
195        else:
196            loader = OperatorInputsLoader.get_torchbench_loader()
197
198    assert dtype in ("float16", "float32"), f"got {dtype}"
199
200    if op == "all":
201        filename = f"timings_{suite}_{op.replace('.', '_')}{dtype}.txt"
202        f = open(filename, "a")
203
204    dtype = torch.float16 if dtype == "float16" else torch.float32
205
206    if op == "all":
207        ops = loader.get_all_ops()
208    else:
209        ops = [eval(op)]
210
211    max_samples = max_samples + start_idx
212    for operator in ops:
213        if skip_operator(operator):
214            continue
215
216        print(f"Running {operator}")
217        inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device)
218        timings = []
219
220        for i in range(min(max_samples, 1000000)):
221            try:
222                inps = next(inp_gen)
223                if inps is None:
224                    break
225                if i < start_idx:
226                    continue
227                print(f"Iter {i}")
228                args, kwargs = inps
229                if channels_last:
230                    args, kwargs = tree_map_only(
231                        torch.Tensor, to_channels_last, (args, kwargs)
232                    )
233
234            except StopIteration:
235                break
236            try:
237                # aten, nvfuser, inductor
238                timings.append(
239                    microbenchmark(
240                        operator,
241                        args,
242                        kwargs,
243                        dtype,
244                        accuracy_checking,
245                        repeats,
246                        measure_nvfuser,
247                        device,
248                    )
249                )
250            except Exception as e:
251                print(f"error {operator}")
252                print(e)
253                # comment out this line to avoid blocking other tests
254                # raise e
255
256        if not timings:
257            continue
258
259        timings = torch.tensor(timings).T
260        q = torch.tensor([0.2, 0.5, 0.8], dtype=torch.float64)
261        output = f"{operator}:\nInductor Speedups : {(torch.quantile(timings[0] / timings[1], q)).tolist()}\n"
262        if measure_nvfuser:
263            output += f"NVFUSER Speedups :{(torch.quantile(timings[0] / timings[2], q)).tolist()}\n"
264        if op == "all":
265            f.write(output)
266        print(output)
267
268    if op == "all":
269        f.close()
270
271
272if __name__ == "__main__":
273    benchmark()
274