xref: /aosp_15_r20/external/pytorch/benchmarks/framework_overhead_benchmark/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import time
2from collections import namedtuple
3
4from torch.utils import ThroughputBenchmark
5
6
7NUM_LOOP_ITERS = 1000
8BenchmarkConfig = namedtuple("BenchmarkConfig", "num_warmup_iters num_iters")
9ModuleConfig = namedtuple("ModuleConfig", "pt_fn c2_op num_params graph_mode")
10
11
12def ms_to_us(time_ms):
13    return time_ms * 1e3
14
15
16def secs_to_us(time_s):
17    return time_s * 1e6
18
19
20def secs_to_ms(time_s):
21    return time_s * 1e3
22
23
24def benchmark_using_throughput_benchmark(config, module):
25    print("Benchmarking via ThroughputBenchmark")
26    bench = ThroughputBenchmark(module.module)
27    bench.add_input(*module.tensor_inputs)
28    stats = bench.benchmark(1, config.num_warmup_iters, config.num_iters)
29    return stats.latency_avg_ms / NUM_LOOP_ITERS
30
31
32def benchmark_module(config, module, use_throughput_benchmark=False):
33    if use_throughput_benchmark:
34        return benchmark_using_throughput_benchmark(config, module)
35    module.forward(config.num_warmup_iters)
36    print(f"Running module for {config.num_iters} iterations")
37    start = time.time()
38    module.forward(config.num_iters)
39    end = time.time()
40    time_elapsed_s = end - start
41    return secs_to_ms(time_elapsed_s) / config.num_iters / NUM_LOOP_ITERS
42