xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/benchmark_runner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import random
3import time
4from abc import abstractmethod
5from typing import Any, Tuple
6
7from tqdm import tqdm  # type: ignore[import-untyped]
8
9import torch
10
11
12class BenchmarkRunner:
13    """
14    BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to
15    collect data with AutoHeuristic.
16    """
17
18    def __init__(self, name: str) -> None:
19        self.name = name
20        self.parser = argparse.ArgumentParser()
21        self.add_base_arguments()
22        self.args = None
23
24    def add_base_arguments(self) -> None:
25        self.parser.add_argument(
26            "--device",
27            type=int,
28            default=None,
29            help="torch.cuda.set_device(device) will be used",
30        )
31        self.parser.add_argument(
32            "--use-heuristic",
33            action="store_true",
34            help="Use learned heuristic instead of collecting data.",
35        )
36        self.parser.add_argument(
37            "-o",
38            type=str,
39            default="ah_data.txt",
40            help="Path to file where AutoHeuristic will log results.",
41        )
42        self.parser.add_argument(
43            "--num-samples",
44            type=int,
45            default=1000,
46            help="Number of samples to collect.",
47        )
48        self.parser.add_argument(
49            "--num-reps",
50            type=int,
51            default=3,
52            help="Number of measurements to collect for each input.",
53        )
54
55    def run(self) -> None:
56        torch.set_default_device("cuda")
57        args = self.parser.parse_args()
58        if args.use_heuristic:
59            torch._inductor.config.autoheuristic_use = self.name
60            torch._inductor.config.autoheuristic_collect = ""
61        else:
62            torch._inductor.config.autoheuristic_use = ""
63            torch._inductor.config.autoheuristic_collect = self.name
64        torch._inductor.config.autoheuristic_log_path = args.o
65        if args.device is not None:
66            torch.cuda.set_device(args.device)
67        random.seed(time.time())
68        self.main(args.num_samples, args.num_reps)
69
70    @abstractmethod
71    def run_benchmark(self, *args: Any) -> None:
72        ...
73
74    @abstractmethod
75    def create_input(self) -> Tuple[Any, ...]:
76        ...
77
78    def main(self, num_samples: int, num_reps: int) -> None:
79        for _ in tqdm(range(num_samples)):
80            input = self.create_input()
81            for _ in range(num_reps):
82                self.run_benchmark(*input)
83