1import argparse 2import random 3import time 4from abc import abstractmethod 5from typing import Any, Tuple 6 7from tqdm import tqdm # type: ignore[import-untyped] 8 9import torch 10 11 12class BenchmarkRunner: 13 """ 14 BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to 15 collect data with AutoHeuristic. 16 """ 17 18 def __init__(self, name: str) -> None: 19 self.name = name 20 self.parser = argparse.ArgumentParser() 21 self.add_base_arguments() 22 self.args = None 23 24 def add_base_arguments(self) -> None: 25 self.parser.add_argument( 26 "--device", 27 type=int, 28 default=None, 29 help="torch.cuda.set_device(device) will be used", 30 ) 31 self.parser.add_argument( 32 "--use-heuristic", 33 action="store_true", 34 help="Use learned heuristic instead of collecting data.", 35 ) 36 self.parser.add_argument( 37 "-o", 38 type=str, 39 default="ah_data.txt", 40 help="Path to file where AutoHeuristic will log results.", 41 ) 42 self.parser.add_argument( 43 "--num-samples", 44 type=int, 45 default=1000, 46 help="Number of samples to collect.", 47 ) 48 self.parser.add_argument( 49 "--num-reps", 50 type=int, 51 default=3, 52 help="Number of measurements to collect for each input.", 53 ) 54 55 def run(self) -> None: 56 torch.set_default_device("cuda") 57 args = self.parser.parse_args() 58 if args.use_heuristic: 59 torch._inductor.config.autoheuristic_use = self.name 60 torch._inductor.config.autoheuristic_collect = "" 61 else: 62 torch._inductor.config.autoheuristic_use = "" 63 torch._inductor.config.autoheuristic_collect = self.name 64 torch._inductor.config.autoheuristic_log_path = args.o 65 if args.device is not None: 66 torch.cuda.set_device(args.device) 67 random.seed(time.time()) 68 self.main(args.num_samples, args.num_reps) 69 70 @abstractmethod 71 def run_benchmark(self, *args: Any) -> None: 72 ... 73 74 @abstractmethod 75 def create_input(self) -> Tuple[Any, ...]: 76 ... 77 78 def main(self, num_samples: int, num_reps: int) -> None: 79 for _ in tqdm(range(num_samples)): 80 input = self.create_input() 81 for _ in range(num_reps): 82 self.run_benchmark(*input) 83