xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/ddp/diff.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #!/usr/bin/env python3
2 #
3 # Computes difference between measurements produced by ./benchmark.py.
4 #
5 
6 import argparse
7 import json
8 
9 import numpy as np
10 
11 
12 def load(path):
13     with open(path) as f:
14         return json.load(f)
15 
16 
17 def main():
18     parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff")
19     parser.add_argument("file", nargs=2)
20     args = parser.parse_args()
21 
22     if len(args.file) != 2:
23         raise RuntimeError("Must specify 2 files to diff")
24 
25     ja = load(args.file[0])
26     jb = load(args.file[1])
27 
28     keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
29     print(f"{'':20s} {'baseline':>20s}      {'test':>20s}")
30     print(f"{'':20s} {'-' * 20:>20s}      {'-' * 20:>20s}")
31     for key in sorted(keys):
32         va = str(ja.get(key, "-"))
33         vb = str(jb.get(key, "-"))
34         print(f"{key + ':':20s} {va:>20s}  vs  {vb:>20s}")
35     print()
36 
37     ba = ja["benchmark_results"]
38     bb = jb["benchmark_results"]
39     for ra, rb in zip(ba, bb):
40         if ra["model"] != rb["model"]:
41             continue
42         if ra["batch_size"] != rb["batch_size"]:
43             continue
44 
45         model = ra["model"]
46         batch_size = int(ra["batch_size"])
47         name = f"{model} with batch size {batch_size}"
48         print(f"Benchmark: {name}")
49 
50         # Print header
51         print()
52         print(f"{'':>10s}", end="")  # noqa: E999
53         for _ in [75, 95]:
54             print(
55                 f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
56             )  # noqa: E999
57         print()
58 
59         # Print measurements
60         for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
61             # Ignore round without ddp
62             if i == 0:
63                 continue
64             # Sanity check: ignore if number of ranks is not equal
65             if len(xa["ranks"]) != len(xb["ranks"]):
66                 continue
67 
68             ngpus = len(xa["ranks"])
69             ma = sorted(xa["measurements"])
70             mb = sorted(xb["measurements"])
71             print(f"{ngpus:>4d} GPUs:", end="")  # noqa: E999
72             for p in [75, 95]:
73                 va = np.percentile(ma, p)
74                 vb = np.percentile(mb, p)
75                 # We're measuring time, so lower is better (hence the negation)
76                 delta = -100 * ((vb - va) / va)
77                 print(
78                     f"  p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
79                     end="",
80                 )  # noqa: E999
81             print()
82         print()
83 
84 
85 if __name__ == "__main__":
86     main()
87