1 *da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2 *da0073e9SAndroid Build Coastguard Worker# 3 *da0073e9SAndroid Build Coastguard Worker# Computes difference between measurements produced by ./benchmark.py. 4 *da0073e9SAndroid Build Coastguard Worker# 5 *da0073e9SAndroid Build Coastguard Worker 6 *da0073e9SAndroid Build Coastguard Workerimport argparse 7 *da0073e9SAndroid Build Coastguard Workerimport json 8 *da0073e9SAndroid Build Coastguard Worker 9 *da0073e9SAndroid Build Coastguard Workerimport numpy as np 10 *da0073e9SAndroid Build Coastguard Worker 11 *da0073e9SAndroid Build Coastguard Worker 12 *da0073e9SAndroid Build Coastguard Workerdef load(path): 13 *da0073e9SAndroid Build Coastguard Worker with open(path) as f: 14 *da0073e9SAndroid Build Coastguard Worker return json.load(f) 15 *da0073e9SAndroid Build Coastguard Worker 16 *da0073e9SAndroid Build Coastguard Worker 17 *da0073e9SAndroid Build Coastguard Workerdef main(): 18 *da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff") 19 *da0073e9SAndroid Build Coastguard Worker parser.add_argument("file", nargs=2) 20 *da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 21 *da0073e9SAndroid Build Coastguard Worker 22 *da0073e9SAndroid Build Coastguard Worker if len(args.file) != 2: 23 *da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Must specify 2 files to diff") 24 *da0073e9SAndroid Build Coastguard Worker 25 *da0073e9SAndroid Build Coastguard Worker ja = load(args.file[0]) 26 *da0073e9SAndroid Build Coastguard Worker jb = load(args.file[1]) 27 *da0073e9SAndroid Build Coastguard Worker 28 *da0073e9SAndroid Build Coastguard Worker keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"} 29 *da0073e9SAndroid Build Coastguard Worker print(f"{'':20s} {'baseline':>20s} {'test':>20s}") 30 *da0073e9SAndroid Build Coastguard Worker print(f"{'':20s} {'-' * 20:>20s} {'-' * 20:>20s}") 31 *da0073e9SAndroid Build Coastguard Worker for key in sorted(keys): 32 *da0073e9SAndroid Build Coastguard Worker va = str(ja.get(key, "-")) 33 *da0073e9SAndroid Build Coastguard Worker vb = str(jb.get(key, "-")) 34 *da0073e9SAndroid Build Coastguard Worker print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}") 35 *da0073e9SAndroid Build Coastguard Worker print() 36 *da0073e9SAndroid Build Coastguard Worker 37 *da0073e9SAndroid Build Coastguard Worker ba = ja["benchmark_results"] 38 *da0073e9SAndroid Build Coastguard Worker bb = jb["benchmark_results"] 39 *da0073e9SAndroid Build Coastguard Worker for ra, rb in zip(ba, bb): 40 *da0073e9SAndroid Build Coastguard Worker if ra["model"] != rb["model"]: 41 *da0073e9SAndroid Build Coastguard Worker continue 42 *da0073e9SAndroid Build Coastguard Worker if ra["batch_size"] != rb["batch_size"]: 43 *da0073e9SAndroid Build Coastguard Worker continue 44 *da0073e9SAndroid Build Coastguard Worker 45 *da0073e9SAndroid Build Coastguard Worker model = ra["model"] 46 *da0073e9SAndroid Build Coastguard Worker batch_size = int(ra["batch_size"]) 47 *da0073e9SAndroid Build Coastguard Worker name = f"{model} with batch size {batch_size}" 48 *da0073e9SAndroid Build Coastguard Worker print(f"Benchmark: {name}") 49 *da0073e9SAndroid Build Coastguard Worker 50 *da0073e9SAndroid Build Coastguard Worker # Print header 51 *da0073e9SAndroid Build Coastguard Worker print() 52 *da0073e9SAndroid Build Coastguard Worker print(f"{'':>10s}", end="") # noqa: E999 53 *da0073e9SAndroid Build Coastguard Worker for _ in [75, 95]: 54 *da0073e9SAndroid Build Coastguard Worker print( 55 *da0073e9SAndroid Build Coastguard Worker f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="" 56 *da0073e9SAndroid Build Coastguard Worker ) # noqa: E999 57 *da0073e9SAndroid Build Coastguard Worker print() 58 *da0073e9SAndroid Build Coastguard Worker 59 *da0073e9SAndroid Build Coastguard Worker # Print measurements 60 *da0073e9SAndroid Build Coastguard Worker for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])): 61 *da0073e9SAndroid Build Coastguard Worker # Ignore round without ddp 62 *da0073e9SAndroid Build Coastguard Worker if i == 0: 63 *da0073e9SAndroid Build Coastguard Worker continue 64 *da0073e9SAndroid Build Coastguard Worker # Sanity check: ignore if number of ranks is not equal 65 *da0073e9SAndroid Build Coastguard Worker if len(xa["ranks"]) != len(xb["ranks"]): 66 *da0073e9SAndroid Build Coastguard Worker continue 67 *da0073e9SAndroid Build Coastguard Worker 68 *da0073e9SAndroid Build Coastguard Worker ngpus = len(xa["ranks"]) 69 *da0073e9SAndroid Build Coastguard Worker ma = sorted(xa["measurements"]) 70 *da0073e9SAndroid Build Coastguard Worker mb = sorted(xb["measurements"]) 71 *da0073e9SAndroid Build Coastguard Worker print(f"{ngpus:>4d} GPUs:", end="") # noqa: E999 72 *da0073e9SAndroid Build Coastguard Worker for p in [75, 95]: 73 *da0073e9SAndroid Build Coastguard Worker va = np.percentile(ma, p) 74 *da0073e9SAndroid Build Coastguard Worker vb = np.percentile(mb, p) 75 *da0073e9SAndroid Build Coastguard Worker # We're measuring time, so lower is better (hence the negation) 76 *da0073e9SAndroid Build Coastguard Worker delta = -100 * ((vb - va) / va) 77 *da0073e9SAndroid Build Coastguard Worker print( 78 *da0073e9SAndroid Build Coastguard Worker f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%", 79 *da0073e9SAndroid Build Coastguard Worker end="", 80 *da0073e9SAndroid Build Coastguard Worker ) # noqa: E999 81 *da0073e9SAndroid Build Coastguard Worker print() 82 *da0073e9SAndroid Build Coastguard Worker print() 83 *da0073e9SAndroid Build Coastguard Worker 84 *da0073e9SAndroid Build Coastguard Worker 85 *da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 86 *da0073e9SAndroid Build Coastguard Worker main() 87