1#!/usr/bin/env python3 2import argparse 3import logging 4import os 5 6import pandas as pd 7 8from torch._functorch.benchmark_utils import compute_utilization 9 10 11# process the chrome traces output by the pytorch profiler 12# require the json input file's name to be in format {model_name}_chrome_trace_*.json 13# the runtimes file should have format (model_name, runtime) 14 15 16def get_model_name(filename): 17 """ 18 Get model name from a file in format {model_name}_chrome_trace_*.json 19 """ 20 _, tail = os.path.split(filename) 21 modelname = tail[: tail.find("_chrome_trace")] 22 return modelname 23 24 25def get_total_length(run_times_df, modelname): 26 return float(run_times_df[run_times_df["name"] == modelname]["runtime"]) 27 28 29def main(): 30 parser = argparse.ArgumentParser() 31 group = parser.add_mutually_exclusive_group(required=True) 32 parser.add_argument( 33 "--runtime", "-runf", help="file name of the runtime file", required=True 34 ) 35 group.add_argument( 36 "--filename", 37 "-f", 38 action="append", 39 help="a filename of the json file to process", 40 ) 41 group.add_argument("--folder", "-fd", help="a folder of the json files to process") 42 args = parser.parse_args() 43 44 if args.filename: 45 filenames = args.filename 46 elif args.folder: 47 filenames = [] 48 directory = args.folder 49 for filename in os.listdir(directory): 50 f = os.path.join(directory, filename) 51 if os.path.isfile(f) and f.endswith(".json"): 52 filenames.append(f) 53 else: 54 print("Please provide a filename or a folder name") 55 56 print("modelname, GPU Utilization, MM and Conv time") 57 58 run_times_df = pd.read_csv(args.runtime) 59 for filename in filenames: 60 try: 61 modelname = get_model_name(filename) 62 total_length = get_total_length(run_times_df, modelname) * 1e6 63 utilization, mm_conv_utilization = compute_utilization( 64 filenames, total_length 65 ) 66 print(f"{modelname}, {utilization}, {mm_conv_utilization}") 67 except BaseException: 68 logging.exception("%s, ERROR", filename) 69 print(f"{filename}, ERROR") 70 71 72if __name__ == "__main__": 73 main() 74