xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/execution/runner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Run benchmarks while handling parallelism, isolation, and fault tolerance."""
2# mypy: ignore-errors
3import math
4import multiprocessing
5import subprocess
6import textwrap
7import threading
8import time
9from typing import Dict, List, Optional, Set, Tuple, Union
10
11from worker.main import WorkerFailure, WorkerOutput
12
13from execution.work import InProgress, PYTHON_CMD, SHELL, WorkOrder
14
15
16CPU_COUNT: int = multiprocessing.cpu_count()
17
18
19class WorkerFailed(Exception):
20    """Raised in the main process when a worker failure is detected."""
21
22    def __init__(self, cmd: str, wrapped_trace: Optional[str] = None) -> None:
23        self.cmd: str = cmd
24        self.wrapped_trace: Optional[str] = wrapped_trace
25        super().__init__()
26
27
28class CorePool:
29    """Allocator style helper class to assign individual tasks to a core range.
30
31    Pinning tasks to separate cores (or core ranges if `num_threads` > 1)
32    serves two purposes. First, it prevents the machine from being overloaded,
33    which can result in OOMs or Callgrind crashes. Second, it helps reduce
34    noise in the wall times, which are collected as a secondary metric. For
35    multi-threaded workloads, adjacency is important. Often pairs of cores
36    share silicon (e.g. cache), while far away cores may lie on separate NUMA
37    nodes. For this reason, CorePool will only allocate contiguous core ranges.
38    This falls short of full architecture awareness, and instead tries to find
39    a balance between rigor and engineering complexity.
40    """
41
42    def __init__(self, min_core_id: int, max_core_id: int) -> None:
43        assert min_core_id >= 0
44        assert max_core_id >= min_core_id
45        assert max_core_id < CPU_COUNT
46
47        self._min_core_id: int = min_core_id
48        self._max_core_id: int = max_core_id
49        self._num_cores = max_core_id - min_core_id + 1
50        print(f"Core pool created: cores {self._min_core_id}-{self._max_core_id}")
51
52        self._available: List[bool] = [
53            True for _ in range(min_core_id, min_core_id + self._num_cores)
54        ]
55
56        self._reservations: Dict[str, Tuple[int, ...]] = {}
57        self._lock = threading.Lock()
58
59    def reserve(self, n: int) -> Optional[str]:
60        """Simple first-fit policy.
61
62        If successful, return a string for `taskset`. Otherwise, return None.
63        """
64        with self._lock:
65            for lower_index in range(self._num_cores - n + 1):
66                indices = tuple(range(lower_index, lower_index + n))
67                if all(self._available[i] for i in indices):
68                    for i in indices:
69                        self._available[i] = False
70
71                    lower_core = indices[0] + self._min_core_id
72                    upper_core = indices[-1] + self._min_core_id
73                    key = f"{lower_core}-{upper_core}" if n > 1 else f"{lower_core}"
74                    self._reservations[key] = indices
75                    return key
76        return None
77
78    def release(self, key: str) -> None:
79        with self._lock:
80            for i in self._reservations[key]:
81                self._available[i] = True
82            self._reservations.pop(key)
83
84
85class Runner:
86    def __init__(
87        self,
88        work_items: Tuple[WorkOrder, ...],
89        core_pool: Optional[CorePool] = None,
90        cadence: float = 1.0,
91    ) -> None:
92        self._work_items: Tuple[WorkOrder, ...] = work_items
93        self._core_pool: CorePool = core_pool or CorePool(0, CPU_COUNT - 4)
94        self._cadence: float = cadence
95
96        # Working state.
97        self._work_queue: List[WorkOrder] = list(work_items)
98        self._active_jobs: List[InProgress] = []
99        self._results: Dict[WorkOrder, WorkerOutput] = {}
100
101        # Debug information for ETA and error messages.
102        self._start_time: float = -1
103        self._durations: Dict[WorkOrder, float] = {}
104        self._currently_processed: Optional[WorkOrder] = None
105
106        if len(work_items) != len(set(work_items)):
107            raise ValueError("Duplicate work items.")
108
109    def run(self) -> Dict[WorkOrder, WorkerOutput]:
110        try:
111            return self._run()
112
113        except KeyboardInterrupt:
114            print("\n\nKeyboardInterrupt (ctrl-c) detected. Shutting down children.")
115            self._force_shutdown(verbose=False)
116            raise
117
118        except subprocess.TimeoutExpired:
119            print("\n\nJob timed out. Shutting down children.")
120            self._force_shutdown(verbose=True)
121            raise
122
123        except WorkerFailed as e:
124            print("Shutting down all outstanding jobs before re-raising.")
125            self._force_shutdown(verbose=True)
126            print(f"Cmd: {e.cmd}")
127            if e.wrapped_trace:
128                print(e.wrapped_trace)
129            else:
130                print("Unknown failure. (Worker did not report exception contents.)")
131            raise
132
133        except BaseException:
134            print("\n\nUnknown exception. Shutting down jobs before re-raising.")
135            self._force_shutdown(verbose=True)
136            raise
137
138    def _run(self) -> Dict[WorkOrder, WorkerOutput]:
139        self._start_time = time.time()
140        self._canary_import()
141        while self._work_queue or self._active_jobs:
142            t0 = time.time()
143            self._update_active_jobs()
144            self._enqueue_new_jobs()
145            self._print_progress()
146            time.sleep(max(self._cadence - (time.time() - t0), 0.0))
147        print(f"\nTotal time: {time.time() - self._start_time:.0f} seconds")
148        return self._results.copy()
149
150    def _update_active_jobs(self) -> None:
151        active_jobs: List[InProgress] = []
152        for job in self._active_jobs:
153            self._currently_processed = job.work_order
154            if not job.check_finished():
155                active_jobs.append(job)
156                continue
157
158            result: Union[WorkerOutput, WorkerFailure] = job.result
159            if isinstance(result, WorkerOutput):
160                self._results[job.work_order] = result
161                assert job.cpu_list is not None
162                self._core_pool.release(job.cpu_list)
163                self._durations[job.work_order] = job.duration
164
165            else:
166                assert isinstance(result, WorkerFailure)
167                raise WorkerFailed(cmd=job.proc.cmd, wrapped_trace=result.failure_trace)
168        self._currently_processed = None
169        self._active_jobs.clear()
170        self._active_jobs.extend(active_jobs)
171
172    def _enqueue_new_jobs(self) -> None:
173        work_queue: List[WorkOrder] = []
174        for i, work_order in enumerate(self._work_queue):
175            self._currently_processed = work_order
176            cpu_list = self._core_pool.reserve(work_order.timer_args.num_threads)
177
178            if cpu_list is None:
179                work_queue.append(work_order)
180            else:
181                self._active_jobs.append(InProgress(work_order, cpu_list))
182
183                # Stagger creation. This helps with contention.
184                time.sleep(0.5)
185        self._currently_processed = None
186        self._work_queue.clear()
187        self._work_queue.extend(work_queue)
188
189    def _print_progress(self) -> None:
190        fraction = f"{len(self._results)} / {len(self._work_items)}"
191        elapsed = f"{time.time() - self._start_time:.0f} seconds"
192        if len(self._results) < 5:
193            eta = "Unknown"
194        else:
195            remaining = len(self._work_items) - len(self._results)
196            iters_remaining = math.ceil(remaining / self._core_pool._num_cores)
197            mean_time = sum(self._durations.values()) / len(self._durations)
198            eta_minutes = math.ceil(iters_remaining * mean_time / 60)
199            eta = f"~{eta_minutes:.0f} minute{'s' if eta_minutes > 1 else ''}"
200        print(f"\r{fraction} ({elapsed}), ETA: {eta}", end="")
201
202    def _force_shutdown(self, verbose: bool = False) -> None:
203        """Try to interrupt jobs, and kill if need be.
204        We would prefer to softly terminate jobs so that they have a chance to
205        clean up before shutting down.
206        """
207        for job in self._active_jobs:
208            job.proc.interrupt()
209
210        if verbose and self._currently_processed is not None:
211            print(
212                textwrap.dedent(
213                    f"""
214                Failed when processing the following Job:
215                  Label:      {self._currently_processed.label}
216                  AutoLabels: {self._currently_processed.autolabels}
217                  Source cmd: {self._currently_processed.source_cmd}
218            """
219                ).strip()
220                + "\n"
221            )
222
223        if self._active_jobs:
224            time.sleep(0.5)
225
226        remaining_jobs = [j for j in self._active_jobs if j.proc.poll() is None]
227        if remaining_jobs:
228            print(
229                f"SIGINT sent to {len(self._active_jobs)} jobs, "
230                f"{len(remaining_jobs)} have not yet exited.\n"
231                "Entering short cleanup loop, after which stragglers will "
232                "be forcibly terminated."
233            )
234
235            for _ in range(5):
236                time.sleep(2.0)
237                remaining_jobs = [j for j in remaining_jobs if j.proc.poll() is None]
238                if remaining_jobs:
239                    print(f"{len(remaining_jobs)} still remain.")
240                else:
241                    print("All remaining jobs have gracefully terminated.")
242                    return
243
244            print(f"{len(remaining_jobs)} jobs refused to exit. Forcibly terminating.")
245            for j in remaining_jobs:
246                j.proc.terminate()
247
248    def _canary_import(self) -> None:
249        """Make sure we can import torch before launching a slew of workers."""
250        source_cmds: Set[str] = set()
251        for w in self._work_items:
252            if w.source_cmd is not None:
253                source_cmds.add(f"{w.source_cmd} && ")
254
255        for source_cmd in source_cmds or {""}:
256            cmd = f'{source_cmd}{PYTHON_CMD} -c "import torch"'
257            proc = subprocess.run(
258                cmd,
259                shell=True,
260                stdout=subprocess.PIPE,
261                stderr=subprocess.STDOUT,
262                encoding="utf-8",
263                executable=SHELL,
264            )
265
266            if proc.returncode:
267                raise ImportError(
268                    f"Failed to import torch in subprocess: {cmd}\n{proc.stdout}"
269                )
270