xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/ddp/diff.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 *da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2 *da0073e9SAndroid Build Coastguard Worker#
3 *da0073e9SAndroid Build Coastguard Worker# Computes difference between measurements produced by ./benchmark.py.
4 *da0073e9SAndroid Build Coastguard Worker#
5 *da0073e9SAndroid Build Coastguard Worker
6 *da0073e9SAndroid Build Coastguard Workerimport argparse
7 *da0073e9SAndroid Build Coastguard Workerimport json
8 *da0073e9SAndroid Build Coastguard Worker
9 *da0073e9SAndroid Build Coastguard Workerimport numpy as np
10 *da0073e9SAndroid Build Coastguard Worker
11 *da0073e9SAndroid Build Coastguard Worker
12 *da0073e9SAndroid Build Coastguard Workerdef load(path):
13 *da0073e9SAndroid Build Coastguard Worker    with open(path) as f:
14 *da0073e9SAndroid Build Coastguard Worker        return json.load(f)
15 *da0073e9SAndroid Build Coastguard Worker
16 *da0073e9SAndroid Build Coastguard Worker
17 *da0073e9SAndroid Build Coastguard Workerdef main():
18 *da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff")
19 *da0073e9SAndroid Build Coastguard Worker    parser.add_argument("file", nargs=2)
20 *da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
21 *da0073e9SAndroid Build Coastguard Worker
22 *da0073e9SAndroid Build Coastguard Worker    if len(args.file) != 2:
23 *da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Must specify 2 files to diff")
24 *da0073e9SAndroid Build Coastguard Worker
25 *da0073e9SAndroid Build Coastguard Worker    ja = load(args.file[0])
26 *da0073e9SAndroid Build Coastguard Worker    jb = load(args.file[1])
27 *da0073e9SAndroid Build Coastguard Worker
28 *da0073e9SAndroid Build Coastguard Worker    keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
29 *da0073e9SAndroid Build Coastguard Worker    print(f"{'':20s} {'baseline':>20s}      {'test':>20s}")
30 *da0073e9SAndroid Build Coastguard Worker    print(f"{'':20s} {'-' * 20:>20s}      {'-' * 20:>20s}")
31 *da0073e9SAndroid Build Coastguard Worker    for key in sorted(keys):
32 *da0073e9SAndroid Build Coastguard Worker        va = str(ja.get(key, "-"))
33 *da0073e9SAndroid Build Coastguard Worker        vb = str(jb.get(key, "-"))
34 *da0073e9SAndroid Build Coastguard Worker        print(f"{key + ':':20s} {va:>20s}  vs  {vb:>20s}")
35 *da0073e9SAndroid Build Coastguard Worker    print()
36 *da0073e9SAndroid Build Coastguard Worker
37 *da0073e9SAndroid Build Coastguard Worker    ba = ja["benchmark_results"]
38 *da0073e9SAndroid Build Coastguard Worker    bb = jb["benchmark_results"]
39 *da0073e9SAndroid Build Coastguard Worker    for ra, rb in zip(ba, bb):
40 *da0073e9SAndroid Build Coastguard Worker        if ra["model"] != rb["model"]:
41 *da0073e9SAndroid Build Coastguard Worker            continue
42 *da0073e9SAndroid Build Coastguard Worker        if ra["batch_size"] != rb["batch_size"]:
43 *da0073e9SAndroid Build Coastguard Worker            continue
44 *da0073e9SAndroid Build Coastguard Worker
45 *da0073e9SAndroid Build Coastguard Worker        model = ra["model"]
46 *da0073e9SAndroid Build Coastguard Worker        batch_size = int(ra["batch_size"])
47 *da0073e9SAndroid Build Coastguard Worker        name = f"{model} with batch size {batch_size}"
48 *da0073e9SAndroid Build Coastguard Worker        print(f"Benchmark: {name}")
49 *da0073e9SAndroid Build Coastguard Worker
50 *da0073e9SAndroid Build Coastguard Worker        # Print header
51 *da0073e9SAndroid Build Coastguard Worker        print()
52 *da0073e9SAndroid Build Coastguard Worker        print(f"{'':>10s}", end="")  # noqa: E999
53 *da0073e9SAndroid Build Coastguard Worker        for _ in [75, 95]:
54 *da0073e9SAndroid Build Coastguard Worker            print(
55 *da0073e9SAndroid Build Coastguard Worker                f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
56 *da0073e9SAndroid Build Coastguard Worker            )  # noqa: E999
57 *da0073e9SAndroid Build Coastguard Worker        print()
58 *da0073e9SAndroid Build Coastguard Worker
59 *da0073e9SAndroid Build Coastguard Worker        # Print measurements
60 *da0073e9SAndroid Build Coastguard Worker        for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
61 *da0073e9SAndroid Build Coastguard Worker            # Ignore round without ddp
62 *da0073e9SAndroid Build Coastguard Worker            if i == 0:
63 *da0073e9SAndroid Build Coastguard Worker                continue
64 *da0073e9SAndroid Build Coastguard Worker            # Sanity check: ignore if number of ranks is not equal
65 *da0073e9SAndroid Build Coastguard Worker            if len(xa["ranks"]) != len(xb["ranks"]):
66 *da0073e9SAndroid Build Coastguard Worker                continue
67 *da0073e9SAndroid Build Coastguard Worker
68 *da0073e9SAndroid Build Coastguard Worker            ngpus = len(xa["ranks"])
69 *da0073e9SAndroid Build Coastguard Worker            ma = sorted(xa["measurements"])
70 *da0073e9SAndroid Build Coastguard Worker            mb = sorted(xb["measurements"])
71 *da0073e9SAndroid Build Coastguard Worker            print(f"{ngpus:>4d} GPUs:", end="")  # noqa: E999
72 *da0073e9SAndroid Build Coastguard Worker            for p in [75, 95]:
73 *da0073e9SAndroid Build Coastguard Worker                va = np.percentile(ma, p)
74 *da0073e9SAndroid Build Coastguard Worker                vb = np.percentile(mb, p)
75 *da0073e9SAndroid Build Coastguard Worker                # We're measuring time, so lower is better (hence the negation)
76 *da0073e9SAndroid Build Coastguard Worker                delta = -100 * ((vb - va) / va)
77 *da0073e9SAndroid Build Coastguard Worker                print(
78 *da0073e9SAndroid Build Coastguard Worker                    f"  p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
79 *da0073e9SAndroid Build Coastguard Worker                    end="",
80 *da0073e9SAndroid Build Coastguard Worker                )  # noqa: E999
81 *da0073e9SAndroid Build Coastguard Worker            print()
82 *da0073e9SAndroid Build Coastguard Worker        print()
83 *da0073e9SAndroid Build Coastguard Worker
84 *da0073e9SAndroid Build Coastguard Worker
85 *da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
86 *da0073e9SAndroid Build Coastguard Worker    main()
87