xref: /aosp_15_r20/external/pytorch/functorch/benchmarks/chrome_trace_parser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2import argparse
3import logging
4import os
5
6import pandas as pd
7
8from torch._functorch.benchmark_utils import compute_utilization
9
10
11# process the chrome traces output by the pytorch profiler
12# require the json input file's name to be in format {model_name}_chrome_trace_*.json
13# the runtimes file should have format (model_name, runtime)
14
15
16def get_model_name(filename):
17    """
18    Get model name from a file in format {model_name}_chrome_trace_*.json
19    """
20    _, tail = os.path.split(filename)
21    modelname = tail[: tail.find("_chrome_trace")]
22    return modelname
23
24
25def get_total_length(run_times_df, modelname):
26    return float(run_times_df[run_times_df["name"] == modelname]["runtime"])
27
28
29def main():
30    parser = argparse.ArgumentParser()
31    group = parser.add_mutually_exclusive_group(required=True)
32    parser.add_argument(
33        "--runtime", "-runf", help="file name of the runtime file", required=True
34    )
35    group.add_argument(
36        "--filename",
37        "-f",
38        action="append",
39        help="a filename of the json file to process",
40    )
41    group.add_argument("--folder", "-fd", help="a folder of the json files to process")
42    args = parser.parse_args()
43
44    if args.filename:
45        filenames = args.filename
46    elif args.folder:
47        filenames = []
48        directory = args.folder
49        for filename in os.listdir(directory):
50            f = os.path.join(directory, filename)
51            if os.path.isfile(f) and f.endswith(".json"):
52                filenames.append(f)
53    else:
54        print("Please provide a filename or a folder name")
55
56    print("modelname, GPU Utilization, MM and Conv time")
57
58    run_times_df = pd.read_csv(args.runtime)
59    for filename in filenames:
60        try:
61            modelname = get_model_name(filename)
62            total_length = get_total_length(run_times_df, modelname) * 1e6
63            utilization, mm_conv_utilization = compute_utilization(
64                filenames, total_length
65            )
66            print(f"{modelname}, {utilization}, {mm_conv_utilization}")
67        except BaseException:
68            logging.exception("%s, ERROR", filename)
69            print(f"{filename}, ERROR")
70
71
72if __name__ == "__main__":
73    main()
74