1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport importlib 4*da0073e9SAndroid Build Coastguard Workerimport logging 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport subprocess 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workertry: 13*da0073e9SAndroid Build Coastguard Worker from .common import ( 14*da0073e9SAndroid Build Coastguard Worker BenchmarkRunner, 15*da0073e9SAndroid Build Coastguard Worker download_retry_decorator, 16*da0073e9SAndroid Build Coastguard Worker load_yaml_file, 17*da0073e9SAndroid Build Coastguard Worker main, 18*da0073e9SAndroid Build Coastguard Worker reset_rng_state, 19*da0073e9SAndroid Build Coastguard Worker ) 20*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 21*da0073e9SAndroid Build Coastguard Worker from common import ( 22*da0073e9SAndroid Build Coastguard Worker BenchmarkRunner, 23*da0073e9SAndroid Build Coastguard Worker download_retry_decorator, 24*da0073e9SAndroid Build Coastguard Worker load_yaml_file, 25*da0073e9SAndroid Build Coastguard Worker main, 26*da0073e9SAndroid Build Coastguard Worker reset_rng_state, 27*da0073e9SAndroid Build Coastguard Worker ) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerimport torch 30*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results 31*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching 37*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: 38*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.fx_graph_cache = True 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef pip_install(package): 42*da0073e9SAndroid Build Coastguard Worker subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker# Disable the flake warnings for the imports. Flake8 does not provide a way to 46*da0073e9SAndroid Build Coastguard Worker# disable just warning for the entire file. Disabling flake8 entirely. 47*da0073e9SAndroid Build Coastguard Worker# flake8: noqa 48*da0073e9SAndroid Build Coastguard Workerimports = [ 49*da0073e9SAndroid Build Coastguard Worker "AlbertForPreTraining", 50*da0073e9SAndroid Build Coastguard Worker "AutoConfig", 51*da0073e9SAndroid Build Coastguard Worker "AutoModelForCausalLM", 52*da0073e9SAndroid Build Coastguard Worker "AutoModelForMaskedLM", 53*da0073e9SAndroid Build Coastguard Worker "AutoModelForSeq2SeqLM", 54*da0073e9SAndroid Build Coastguard Worker "BigBirdConfig", 55*da0073e9SAndroid Build Coastguard Worker "BlenderbotForConditionalGeneration", 56*da0073e9SAndroid Build Coastguard Worker "BlenderbotModel", 57*da0073e9SAndroid Build Coastguard Worker "BlenderbotSmallForConditionalGeneration", 58*da0073e9SAndroid Build Coastguard Worker "BlenderbotSmallModel", 59*da0073e9SAndroid Build Coastguard Worker "CLIPModel", 60*da0073e9SAndroid Build Coastguard Worker "CLIPVisionModel", 61*da0073e9SAndroid Build Coastguard Worker "ElectraForPreTraining", 62*da0073e9SAndroid Build Coastguard Worker "GPT2ForSequenceClassification", 63*da0073e9SAndroid Build Coastguard Worker "GPTJForSequenceClassification", 64*da0073e9SAndroid Build Coastguard Worker "GPTNeoForSequenceClassification", 65*da0073e9SAndroid Build Coastguard Worker "HubertForSequenceClassification", 66*da0073e9SAndroid Build Coastguard Worker "LxmertForPreTraining", 67*da0073e9SAndroid Build Coastguard Worker "LxmertForQuestionAnswering", 68*da0073e9SAndroid Build Coastguard Worker "MarianForCausalLM", 69*da0073e9SAndroid Build Coastguard Worker "MarianModel", 70*da0073e9SAndroid Build Coastguard Worker "MarianMTModel", 71*da0073e9SAndroid Build Coastguard Worker "PegasusForConditionalGeneration", 72*da0073e9SAndroid Build Coastguard Worker "PegasusModel", 73*da0073e9SAndroid Build Coastguard Worker "ReformerConfig", 74*da0073e9SAndroid Build Coastguard Worker "ViTForImageClassification", 75*da0073e9SAndroid Build Coastguard Worker "ViTForMaskedImageModeling", 76*da0073e9SAndroid Build Coastguard Worker "ViTModel", 77*da0073e9SAndroid Build Coastguard Worker] 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef process_hf_reformer_output(out): 81*da0073e9SAndroid Build Coastguard Worker assert isinstance(out, list) 82*da0073e9SAndroid Build Coastguard Worker # second output is unstable 83*da0073e9SAndroid Build Coastguard Worker return [elem for i, elem in enumerate(out) if i != 1] 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workertry: 87*da0073e9SAndroid Build Coastguard Worker mod = importlib.import_module("transformers") 88*da0073e9SAndroid Build Coastguard Worker for cls in imports: 89*da0073e9SAndroid Build Coastguard Worker if not hasattr(mod, cls): 90*da0073e9SAndroid Build Coastguard Worker raise ModuleNotFoundError 91*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 92*da0073e9SAndroid Build Coastguard Worker print("Installing HuggingFace Transformers...") 93*da0073e9SAndroid Build Coastguard Worker pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers") 94*da0073e9SAndroid Build Coastguard Workerfinally: 95*da0073e9SAndroid Build Coastguard Worker for cls in imports: 96*da0073e9SAndroid Build Coastguard Worker exec(f"from transformers import {cls}") 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker# These models contain the models present in huggingface_models_list. It is a 100*da0073e9SAndroid Build Coastguard Worker# combination of models supported by HF Fx parser and some manually supplied 101*da0073e9SAndroid Build Coastguard Worker# models. For these models, we already know the largest batch size that can fit 102*da0073e9SAndroid Build Coastguard Worker# on A100 GPUs - 40 GB. 103*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE_KNOWN_MODELS = {} 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py 107*da0073e9SAndroid Build Coastguard Worker# Get the list of models and their batch sizes 108*da0073e9SAndroid Build Coastguard WorkerMODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") 109*da0073e9SAndroid Build Coastguard Workerassert os.path.exists(MODELS_FILENAME) 110*da0073e9SAndroid Build Coastguard Workerwith open(MODELS_FILENAME, "r") as fh: 111*da0073e9SAndroid Build Coastguard Worker lines = fh.readlines() 112*da0073e9SAndroid Build Coastguard Worker lines = [line.rstrip() for line in lines] 113*da0073e9SAndroid Build Coastguard Worker for line in lines: 114*da0073e9SAndroid Build Coastguard Worker model_name, batch_size = line.split(",") 115*da0073e9SAndroid Build Coastguard Worker batch_size = int(batch_size) 116*da0073e9SAndroid Build Coastguard Worker BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size 117*da0073e9SAndroid Build Coastguard Workerassert len(BATCH_SIZE_KNOWN_MODELS) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workerdef get_module_cls_by_model_name(model_cls_name): 121*da0073e9SAndroid Build Coastguard Worker _module_by_model_name = { 122*da0073e9SAndroid Build Coastguard Worker "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2", 123*da0073e9SAndroid Build Coastguard Worker "TrOCRDecoder": "transformers.models.trocr.modeling_trocr", 124*da0073e9SAndroid Build Coastguard Worker } 125*da0073e9SAndroid Build Coastguard Worker module_name = _module_by_model_name.get(model_cls_name, "transformers") 126*da0073e9SAndroid Build Coastguard Worker module = importlib.import_module(module_name) 127*da0073e9SAndroid Build Coastguard Worker return getattr(module, model_cls_name) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Workerdef get_sequence_length(model_cls, model_name): 131*da0073e9SAndroid Build Coastguard Worker if model_name.startswith(("Blenderbot",)): 132*da0073e9SAndroid Build Coastguard Worker seq_length = 128 133*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): 134*da0073e9SAndroid Build Coastguard Worker seq_length = 1024 135*da0073e9SAndroid Build Coastguard Worker elif model_name in ("AllenaiLongformerBase", "BigBird"): 136*da0073e9SAndroid Build Coastguard Worker seq_length = 1024 137*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith("OPT"): 138*da0073e9SAndroid Build Coastguard Worker seq_length = 2048 139*da0073e9SAndroid Build Coastguard Worker elif "Reformer" in model_name: 140*da0073e9SAndroid Build Coastguard Worker seq_length = 4096 141*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith( 142*da0073e9SAndroid Build Coastguard Worker ( 143*da0073e9SAndroid Build Coastguard Worker "Albert", 144*da0073e9SAndroid Build Coastguard Worker "Deberta", 145*da0073e9SAndroid Build Coastguard Worker "Layout", 146*da0073e9SAndroid Build Coastguard Worker "Electra", 147*da0073e9SAndroid Build Coastguard Worker "XLNet", 148*da0073e9SAndroid Build Coastguard Worker "MegatronBert", 149*da0073e9SAndroid Build Coastguard Worker "Bert", 150*da0073e9SAndroid Build Coastguard Worker "Roberta", 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): 153*da0073e9SAndroid Build Coastguard Worker seq_length = 512 154*da0073e9SAndroid Build Coastguard Worker elif model_name in ("TrOCRForCausalLM"): 155*da0073e9SAndroid Build Coastguard Worker seq_length = 256 156*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith("MobileBert"): 157*da0073e9SAndroid Build Coastguard Worker seq_length = 128 158*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith("Wav2Vec2"): 159*da0073e9SAndroid Build Coastguard Worker # If too short, will fail with something like 160*da0073e9SAndroid Build Coastguard Worker # ValueError: `mask_length` has to be smaller than `sequence_length`, 161*da0073e9SAndroid Build Coastguard Worker # but got `mask_length`: 10 and `sequence_length`: 9` 162*da0073e9SAndroid Build Coastguard Worker seq_length = 10000 # NB: a more realistic size is 155136 163*da0073e9SAndroid Build Coastguard Worker else: 164*da0073e9SAndroid Build Coastguard Worker log.info( 165*da0073e9SAndroid Build Coastguard Worker f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker seq_length = 128 168*da0073e9SAndroid Build Coastguard Worker return seq_length 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Workerdef generate_inputs_for_model( 172*da0073e9SAndroid Build Coastguard Worker model_cls, model, model_name, bs, device, include_loss_args=False 173*da0073e9SAndroid Build Coastguard Worker): 174*da0073e9SAndroid Build Coastguard Worker # TODO - Check if following values are representative 175*da0073e9SAndroid Build Coastguard Worker num_choices = 3 176*da0073e9SAndroid Build Coastguard Worker num_visual_features = 42 177*da0073e9SAndroid Build Coastguard Worker seq_length = get_sequence_length(model_cls, model_name) 178*da0073e9SAndroid Build Coastguard Worker vocab_size = model.config.vocab_size 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker if model_name.startswith("Wav2Vec2"): 181*da0073e9SAndroid Build Coastguard Worker # TODO: If we add more input_values style models, try to work this 182*da0073e9SAndroid Build Coastguard Worker # into the overall control flow 183*da0073e9SAndroid Build Coastguard Worker target_length = 100 184*da0073e9SAndroid Build Coastguard Worker return { 185*da0073e9SAndroid Build Coastguard Worker "input_values": torch.randn((bs, seq_length), device=device), 186*da0073e9SAndroid Build Coastguard Worker # Added because that's what the example training script has 187*da0073e9SAndroid Build Coastguard Worker "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)), 188*da0073e9SAndroid Build Coastguard Worker "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)), 189*da0073e9SAndroid Build Coastguard Worker } 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker if model_name.endswith("MultipleChoice"): 192*da0073e9SAndroid Build Coastguard Worker input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length)) 193*da0073e9SAndroid Build Coastguard Worker elif model_name.startswith("Roberta"): 194*da0073e9SAndroid Build Coastguard Worker input = rand_int_tensor(device, 0, 1, (bs, seq_length)) 195*da0073e9SAndroid Build Coastguard Worker else: 196*da0073e9SAndroid Build Coastguard Worker input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length)) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker if "Bart" in model_name: 199*da0073e9SAndroid Build Coastguard Worker input[:, -1] = model.config.eos_token_id 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker input_dict = {"input_ids": input} 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker if ( 204*da0073e9SAndroid Build Coastguard Worker model_name.startswith("T5") 205*da0073e9SAndroid Build Coastguard Worker or model_name.startswith("M2M100") 206*da0073e9SAndroid Build Coastguard Worker or model_name.startswith("MT5") 207*da0073e9SAndroid Build Coastguard Worker or model_cls 208*da0073e9SAndroid Build Coastguard Worker in [ 209*da0073e9SAndroid Build Coastguard Worker BlenderbotModel, 210*da0073e9SAndroid Build Coastguard Worker BlenderbotSmallModel, 211*da0073e9SAndroid Build Coastguard Worker BlenderbotForConditionalGeneration, 212*da0073e9SAndroid Build Coastguard Worker BlenderbotSmallForConditionalGeneration, 213*da0073e9SAndroid Build Coastguard Worker PegasusModel, 214*da0073e9SAndroid Build Coastguard Worker PegasusForConditionalGeneration, 215*da0073e9SAndroid Build Coastguard Worker MarianModel, 216*da0073e9SAndroid Build Coastguard Worker MarianMTModel, 217*da0073e9SAndroid Build Coastguard Worker ] 218*da0073e9SAndroid Build Coastguard Worker ): 219*da0073e9SAndroid Build Coastguard Worker input_dict["decoder_input_ids"] = input 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker if model_name.startswith("Lxmert"): 222*da0073e9SAndroid Build Coastguard Worker visual_feat_dim, visual_pos_dim = ( 223*da0073e9SAndroid Build Coastguard Worker model.config.visual_feat_dim, 224*da0073e9SAndroid Build Coastguard Worker model.config.visual_pos_dim, 225*da0073e9SAndroid Build Coastguard Worker ) 226*da0073e9SAndroid Build Coastguard Worker input_dict["visual_feats"] = torch.randn( 227*da0073e9SAndroid Build Coastguard Worker bs, num_visual_features, visual_feat_dim 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker if include_loss_args: 232*da0073e9SAndroid Build Coastguard Worker if model_name.endswith("PreTraining"): 233*da0073e9SAndroid Build Coastguard Worker if model_cls in [ElectraForPreTraining, LxmertForPreTraining]: 234*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length)) 235*da0073e9SAndroid Build Coastguard Worker else: 236*da0073e9SAndroid Build Coastguard Worker label_name = ( 237*da0073e9SAndroid Build Coastguard Worker "sentence_order_label" 238*da0073e9SAndroid Build Coastguard Worker if model_cls in [AlbertForPreTraining] 239*da0073e9SAndroid Build Coastguard Worker else "next_sentence_label" 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = ( 242*da0073e9SAndroid Build Coastguard Worker rand_int_tensor(device, 0, vocab_size, (bs, seq_length)), 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,)) 245*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("QuestionAnswering"): 246*da0073e9SAndroid Build Coastguard Worker input_dict["start_positions"] = rand_int_tensor( 247*da0073e9SAndroid Build Coastguard Worker device, 0, seq_length, (bs,) 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,)) 250*da0073e9SAndroid Build Coastguard Worker elif ( 251*da0073e9SAndroid Build Coastguard Worker model_name.endswith("MaskedLM") 252*da0073e9SAndroid Build Coastguard Worker or model_name.endswith("HeadModel") 253*da0073e9SAndroid Build Coastguard Worker or model_name.endswith("CausalLM") 254*da0073e9SAndroid Build Coastguard Worker or model_name.endswith("DoubleHeadsModel") 255*da0073e9SAndroid Build Coastguard Worker ): 256*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor( 257*da0073e9SAndroid Build Coastguard Worker device, 0, vocab_size, (bs, seq_length) 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("TokenClassification"): 260*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor( 261*da0073e9SAndroid Build Coastguard Worker device, 0, model.config.num_labels - 1, (bs, seq_length) 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("MultipleChoice"): 264*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,)) 265*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("SequenceClassification"): 266*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor( 267*da0073e9SAndroid Build Coastguard Worker device, 0, model.config.num_labels - 1, (bs,) 268*da0073e9SAndroid Build Coastguard Worker ) 269*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("NextSentencePrediction"): 270*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,)) 271*da0073e9SAndroid Build Coastguard Worker elif model_name.endswith("ForConditionalGeneration"): 272*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor( 273*da0073e9SAndroid Build Coastguard Worker device, 0, vocab_size - 1, (bs, seq_length) 274*da0073e9SAndroid Build Coastguard Worker ) 275*da0073e9SAndroid Build Coastguard Worker elif model_name in EXTRA_MODELS: 276*da0073e9SAndroid Build Coastguard Worker input_dict["labels"] = rand_int_tensor( 277*da0073e9SAndroid Build Coastguard Worker device, 0, vocab_size, (bs, seq_length) 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker else: 280*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 281*da0073e9SAndroid Build Coastguard Worker f"Class {model_name} unsupported for training test " 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker return input_dict 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Workerdef rand_int_tensor(device, low, high, shape): 288*da0073e9SAndroid Build Coastguard Worker return torch.randint( 289*da0073e9SAndroid Build Coastguard Worker low, 290*da0073e9SAndroid Build Coastguard Worker high, 291*da0073e9SAndroid Build Coastguard Worker shape, 292*da0073e9SAndroid Build Coastguard Worker device=device, 293*da0073e9SAndroid Build Coastguard Worker dtype=torch.int64, 294*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard WorkerEXTRA_MODELS = { 299*da0073e9SAndroid Build Coastguard Worker "AllenaiLongformerBase": ( 300*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("allenai/longformer-base-4096"), 301*da0073e9SAndroid Build Coastguard Worker AutoModelForMaskedLM, 302*da0073e9SAndroid Build Coastguard Worker ), 303*da0073e9SAndroid Build Coastguard Worker "Reformer": ( 304*da0073e9SAndroid Build Coastguard Worker ReformerConfig(), 305*da0073e9SAndroid Build Coastguard Worker AutoModelForMaskedLM, 306*da0073e9SAndroid Build Coastguard Worker ), 307*da0073e9SAndroid Build Coastguard Worker "T5Small": ( 308*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("t5-small"), 309*da0073e9SAndroid Build Coastguard Worker AutoModelForSeq2SeqLM, 310*da0073e9SAndroid Build Coastguard Worker ), 311*da0073e9SAndroid Build Coastguard Worker # "BigBird": ( 312*da0073e9SAndroid Build Coastguard Worker # BigBirdConfig(attention_type="block_sparse"), 313*da0073e9SAndroid Build Coastguard Worker # AutoModelForMaskedLM, 314*da0073e9SAndroid Build Coastguard Worker # ), 315*da0073e9SAndroid Build Coastguard Worker "DistillGPT2": ( 316*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("distilgpt2"), 317*da0073e9SAndroid Build Coastguard Worker AutoModelForCausalLM, 318*da0073e9SAndroid Build Coastguard Worker ), 319*da0073e9SAndroid Build Coastguard Worker "GoogleFnet": ( 320*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("google/fnet-base"), 321*da0073e9SAndroid Build Coastguard Worker AutoModelForMaskedLM, 322*da0073e9SAndroid Build Coastguard Worker ), 323*da0073e9SAndroid Build Coastguard Worker "YituTechConvBert": ( 324*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("YituTech/conv-bert-base"), 325*da0073e9SAndroid Build Coastguard Worker AutoModelForMaskedLM, 326*da0073e9SAndroid Build Coastguard Worker ), 327*da0073e9SAndroid Build Coastguard Worker "CamemBert": ( 328*da0073e9SAndroid Build Coastguard Worker AutoConfig.from_pretrained("camembert-base"), 329*da0073e9SAndroid Build Coastguard Worker AutoModelForMaskedLM, 330*da0073e9SAndroid Build Coastguard Worker ), 331*da0073e9SAndroid Build Coastguard Worker} 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Workerclass HuggingfaceRunner(BenchmarkRunner): 335*da0073e9SAndroid Build Coastguard Worker def __init__(self): 336*da0073e9SAndroid Build Coastguard Worker super().__init__() 337*da0073e9SAndroid Build Coastguard Worker self.suite_name = "huggingface" 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker @property 340*da0073e9SAndroid Build Coastguard Worker def _config(self): 341*da0073e9SAndroid Build Coastguard Worker return load_yaml_file("huggingface.yaml") 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker @property 344*da0073e9SAndroid Build Coastguard Worker def _skip(self): 345*da0073e9SAndroid Build Coastguard Worker return self._config["skip"] 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker @property 348*da0073e9SAndroid Build Coastguard Worker def _accuracy(self): 349*da0073e9SAndroid Build Coastguard Worker return self._config["accuracy"] 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker @property 352*da0073e9SAndroid Build Coastguard Worker def skip_models(self): 353*da0073e9SAndroid Build Coastguard Worker return self._skip["all"] 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker @property 356*da0073e9SAndroid Build Coastguard Worker def skip_models_for_cpu(self): 357*da0073e9SAndroid Build Coastguard Worker return self._skip["device"]["cpu"] 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker @property 360*da0073e9SAndroid Build Coastguard Worker def fp32_only_models(self): 361*da0073e9SAndroid Build Coastguard Worker return self._config["only_fp32"] 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker @property 364*da0073e9SAndroid Build Coastguard Worker def skip_models_due_to_control_flow(self): 365*da0073e9SAndroid Build Coastguard Worker return self._skip["control_flow"] 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def _get_model_cls_and_config(self, model_name): 368*da0073e9SAndroid Build Coastguard Worker if model_name not in EXTRA_MODELS: 369*da0073e9SAndroid Build Coastguard Worker model_cls = get_module_cls_by_model_name(model_name) 370*da0073e9SAndroid Build Coastguard Worker config_cls = model_cls.config_class 371*da0073e9SAndroid Build Coastguard Worker config = config_cls() 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # NB: some models need a pad token defined to handle BS > 1 374*da0073e9SAndroid Build Coastguard Worker if ( 375*da0073e9SAndroid Build Coastguard Worker model_cls 376*da0073e9SAndroid Build Coastguard Worker in [ 377*da0073e9SAndroid Build Coastguard Worker GPT2ForSequenceClassification, 378*da0073e9SAndroid Build Coastguard Worker GPTNeoForSequenceClassification, 379*da0073e9SAndroid Build Coastguard Worker GPTJForSequenceClassification, 380*da0073e9SAndroid Build Coastguard Worker ] 381*da0073e9SAndroid Build Coastguard Worker or model_cls.__name__.startswith("Roberta") 382*da0073e9SAndroid Build Coastguard Worker or model_cls.__name__.startswith("Marian") 383*da0073e9SAndroid Build Coastguard Worker ): 384*da0073e9SAndroid Build Coastguard Worker config.pad_token_id = 0 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker else: 387*da0073e9SAndroid Build Coastguard Worker config, model_cls = EXTRA_MODELS[model_name] 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker return model_cls, config 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker @download_retry_decorator 392*da0073e9SAndroid Build Coastguard Worker def _download_model(self, model_name): 393*da0073e9SAndroid Build Coastguard Worker model_cls, config = self._get_model_cls_and_config(model_name) 394*da0073e9SAndroid Build Coastguard Worker if "auto" in model_cls.__module__: 395*da0073e9SAndroid Build Coastguard Worker # Handle auto classes 396*da0073e9SAndroid Build Coastguard Worker model = model_cls.from_config(config) 397*da0073e9SAndroid Build Coastguard Worker else: 398*da0073e9SAndroid Build Coastguard Worker model = model_cls(config) 399*da0073e9SAndroid Build Coastguard Worker return model 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def load_model( 402*da0073e9SAndroid Build Coastguard Worker self, 403*da0073e9SAndroid Build Coastguard Worker device, 404*da0073e9SAndroid Build Coastguard Worker model_name, 405*da0073e9SAndroid Build Coastguard Worker batch_size=None, 406*da0073e9SAndroid Build Coastguard Worker extra_args=None, 407*da0073e9SAndroid Build Coastguard Worker ): 408*da0073e9SAndroid Build Coastguard Worker is_training = self.args.training 409*da0073e9SAndroid Build Coastguard Worker use_eval_mode = self.args.use_eval_mode 410*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 411*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 412*da0073e9SAndroid Build Coastguard Worker model_cls, config = self._get_model_cls_and_config(model_name) 413*da0073e9SAndroid Build Coastguard Worker model = self._download_model(model_name) 414*da0073e9SAndroid Build Coastguard Worker model = model.to(device, dtype=dtype) 415*da0073e9SAndroid Build Coastguard Worker if self.args.enable_activation_checkpointing: 416*da0073e9SAndroid Build Coastguard Worker model.gradient_checkpointing_enable() 417*da0073e9SAndroid Build Coastguard Worker if model_name in BATCH_SIZE_KNOWN_MODELS: 418*da0073e9SAndroid Build Coastguard Worker batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name] 419*da0073e9SAndroid Build Coastguard Worker elif batch_size is None: 420*da0073e9SAndroid Build Coastguard Worker batch_size_default = 16 421*da0073e9SAndroid Build Coastguard Worker log.info( 422*da0073e9SAndroid Build Coastguard Worker f"Batch size not specified for {model_name}. Setting batch_size=16" 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker if batch_size is None: 426*da0073e9SAndroid Build Coastguard Worker batch_size = batch_size_default 427*da0073e9SAndroid Build Coastguard Worker batch_size_divisors = self._config["batch_size"]["divisors"] 428*da0073e9SAndroid Build Coastguard Worker if model_name in batch_size_divisors: 429*da0073e9SAndroid Build Coastguard Worker batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1) 430*da0073e9SAndroid Build Coastguard Worker log.info( 431*da0073e9SAndroid Build Coastguard Worker f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker example_inputs = generate_inputs_for_model( 435*da0073e9SAndroid Build Coastguard Worker model_cls, model, model_name, batch_size, device, include_loss_args=True 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker # So we can check for correct gradients without eliminating the dropout computation 439*da0073e9SAndroid Build Coastguard Worker for attr in dir(config): 440*da0073e9SAndroid Build Coastguard Worker if "drop" in attr and isinstance(getattr(config, attr), float): 441*da0073e9SAndroid Build Coastguard Worker setattr(config, attr, 1e-30) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker if ( 444*da0073e9SAndroid Build Coastguard Worker is_training 445*da0073e9SAndroid Build Coastguard Worker and not use_eval_mode 446*da0073e9SAndroid Build Coastguard Worker and not ( 447*da0073e9SAndroid Build Coastguard Worker self.args.accuracy and model_name in self._config["only_inference"] 448*da0073e9SAndroid Build Coastguard Worker ) 449*da0073e9SAndroid Build Coastguard Worker ): 450*da0073e9SAndroid Build Coastguard Worker model.train() 451*da0073e9SAndroid Build Coastguard Worker else: 452*da0073e9SAndroid Build Coastguard Worker model.eval() 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker self.validate_model(model, example_inputs) 455*da0073e9SAndroid Build Coastguard Worker return device, model_name, model, example_inputs, batch_size 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker def iter_model_names(self, args): 458*da0073e9SAndroid Build Coastguard Worker model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys()) 459*da0073e9SAndroid Build Coastguard Worker model_names = set(model_names) 460*da0073e9SAndroid Build Coastguard Worker model_names = sorted(model_names) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker start, end = self.get_benchmark_indices(len(model_names)) 463*da0073e9SAndroid Build Coastguard Worker for index, model_name in enumerate(model_names): 464*da0073e9SAndroid Build Coastguard Worker if index < start or index >= end: 465*da0073e9SAndroid Build Coastguard Worker continue 466*da0073e9SAndroid Build Coastguard Worker if ( 467*da0073e9SAndroid Build Coastguard Worker not re.search("|".join(args.filter), model_name, re.I) 468*da0073e9SAndroid Build Coastguard Worker or re.search("|".join(args.exclude), model_name, re.I) 469*da0073e9SAndroid Build Coastguard Worker or model_name in args.exclude_exact 470*da0073e9SAndroid Build Coastguard Worker or model_name in self.skip_models 471*da0073e9SAndroid Build Coastguard Worker ): 472*da0073e9SAndroid Build Coastguard Worker continue 473*da0073e9SAndroid Build Coastguard Worker yield model_name 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @property 476*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_checks_large_models_dashboard(self): 477*da0073e9SAndroid Build Coastguard Worker if self.args.dashboard or self.args.accuracy: 478*da0073e9SAndroid Build Coastguard Worker return self._accuracy["skip"]["large_models"] 479*da0073e9SAndroid Build Coastguard Worker return set() 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker @property 482*da0073e9SAndroid Build Coastguard Worker def get_output_amp_train_process_func(self): 483*da0073e9SAndroid Build Coastguard Worker return {} 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker def pick_grad(self, name, is_training): 486*da0073e9SAndroid Build Coastguard Worker if is_training: 487*da0073e9SAndroid Build Coastguard Worker return torch.enable_grad() 488*da0073e9SAndroid Build Coastguard Worker else: 489*da0073e9SAndroid Build Coastguard Worker return torch.no_grad() 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 492*da0073e9SAndroid Build Coastguard Worker cosine = self.args.cosine 493*da0073e9SAndroid Build Coastguard Worker if is_training: 494*da0073e9SAndroid Build Coastguard Worker from torch._inductor import config as inductor_config 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker if (name in self._config["tolerance"]["higher_training"]) or ( 497*da0073e9SAndroid Build Coastguard Worker inductor_config.max_autotune 498*da0073e9SAndroid Build Coastguard Worker and name in self._config["tolerance"]["higher_max_autotune_training"] 499*da0073e9SAndroid Build Coastguard Worker ): 500*da0073e9SAndroid Build Coastguard Worker return 2e-2, cosine 501*da0073e9SAndroid Build Coastguard Worker else: 502*da0073e9SAndroid Build Coastguard Worker return 1e-2, cosine 503*da0073e9SAndroid Build Coastguard Worker else: 504*da0073e9SAndroid Build Coastguard Worker if name in self._config["tolerance"]["higher_inference"]: 505*da0073e9SAndroid Build Coastguard Worker return 4e-3, cosine 506*da0073e9SAndroid Build Coastguard Worker if ( 507*da0073e9SAndroid Build Coastguard Worker current_device == "cpu" 508*da0073e9SAndroid Build Coastguard Worker and name in self._config["tolerance"]["higher_inference_cpu"] 509*da0073e9SAndroid Build Coastguard Worker ): 510*da0073e9SAndroid Build Coastguard Worker return 4e-3, cosine 511*da0073e9SAndroid Build Coastguard Worker return 1e-3, cosine 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker def compute_loss(self, pred): 514*da0073e9SAndroid Build Coastguard Worker return pred[0] 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker def forward_pass(self, mod, inputs, collect_outputs=True): 517*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 518*da0073e9SAndroid Build Coastguard Worker return mod(**inputs) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): 521*da0073e9SAndroid Build Coastguard Worker cloned_inputs = clone_inputs(inputs) 522*da0073e9SAndroid Build Coastguard Worker self.optimizer_zero_grad(mod) 523*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 524*da0073e9SAndroid Build Coastguard Worker pred = mod(**cloned_inputs) 525*da0073e9SAndroid Build Coastguard Worker loss = self.compute_loss(pred) 526*da0073e9SAndroid Build Coastguard Worker self.grad_scaler.scale(loss).backward() 527*da0073e9SAndroid Build Coastguard Worker self.optimizer_step() 528*da0073e9SAndroid Build Coastguard Worker if collect_outputs: 529*da0073e9SAndroid Build Coastguard Worker return collect_results(mod, pred, loss, cloned_inputs) 530*da0073e9SAndroid Build Coastguard Worker return None 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Workerdef refresh_model_names_and_batch_sizes(): 534*da0073e9SAndroid Build Coastguard Worker """ 535*da0073e9SAndroid Build Coastguard Worker This function reads the HF Fx tracer supported models and finds the largest 536*da0073e9SAndroid Build Coastguard Worker batch size that could fit on the GPU with PyTorch eager. 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker The resulting data is written in huggingface_models_list.txt. 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker Note - We only need to run this function if we believe that HF Fx tracer now 541*da0073e9SAndroid Build Coastguard Worker supports more models. 542*da0073e9SAndroid Build Coastguard Worker """ 543*da0073e9SAndroid Build Coastguard Worker import transformers.utils.fx as hf_fx 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker family = {} 546*da0073e9SAndroid Build Coastguard Worker lm_seen = set() 547*da0073e9SAndroid Build Coastguard Worker family_seen = set() 548*da0073e9SAndroid Build Coastguard Worker for cls_name in hf_fx._SUPPORTED_MODELS: 549*da0073e9SAndroid Build Coastguard Worker if "For" not in cls_name: 550*da0073e9SAndroid Build Coastguard Worker continue 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker model_cls = get_module_cls_by_model_name(cls_name) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker # TODO: AttributeError: '*Config' object has no attribute 'vocab_size' 555*da0073e9SAndroid Build Coastguard Worker if model_cls in [ 556*da0073e9SAndroid Build Coastguard Worker CLIPModel, 557*da0073e9SAndroid Build Coastguard Worker CLIPVisionModel, 558*da0073e9SAndroid Build Coastguard Worker # SwinForImageClassification, 559*da0073e9SAndroid Build Coastguard Worker # SwinForImageClassification, 560*da0073e9SAndroid Build Coastguard Worker # SwinForMaskedImageModeling, 561*da0073e9SAndroid Build Coastguard Worker # SwinModel, 562*da0073e9SAndroid Build Coastguard Worker ViTForImageClassification, 563*da0073e9SAndroid Build Coastguard Worker ViTForMaskedImageModeling, 564*da0073e9SAndroid Build Coastguard Worker ViTModel, 565*da0073e9SAndroid Build Coastguard Worker ]: 566*da0073e9SAndroid Build Coastguard Worker continue 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker # TODO: AssertionError: Padding_idx must be within num_embeddings 569*da0073e9SAndroid Build Coastguard Worker if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]: 570*da0073e9SAndroid Build Coastguard Worker continue 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker # TODO: "model is not supported yet" from HFTracer 573*da0073e9SAndroid Build Coastguard Worker if model_cls in [HubertForSequenceClassification]: 574*da0073e9SAndroid Build Coastguard Worker continue 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker # TODO: shape mismatch in loss calculation 577*da0073e9SAndroid Build Coastguard Worker if model_cls in [LxmertForQuestionAnswering]: 578*da0073e9SAndroid Build Coastguard Worker continue 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker family_name = cls_name.split("For")[0] 581*da0073e9SAndroid Build Coastguard Worker if family_name not in family: 582*da0073e9SAndroid Build Coastguard Worker family[family_name] = [] 583*da0073e9SAndroid Build Coastguard Worker if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen: 584*da0073e9SAndroid Build Coastguard Worker family[family_name].append(cls_name) 585*da0073e9SAndroid Build Coastguard Worker lm_seen.add(family_name) 586*da0073e9SAndroid Build Coastguard Worker elif ( 587*da0073e9SAndroid Build Coastguard Worker cls_name.endswith( 588*da0073e9SAndroid Build Coastguard Worker ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering") 589*da0073e9SAndroid Build Coastguard Worker ) 590*da0073e9SAndroid Build Coastguard Worker and family_name not in family_seen 591*da0073e9SAndroid Build Coastguard Worker ): 592*da0073e9SAndroid Build Coastguard Worker family[family_name].append(cls_name) 593*da0073e9SAndroid Build Coastguard Worker family_seen.add(family_name) 594*da0073e9SAndroid Build Coastguard Worker elif cls_name.endswith("ImageClassification"): 595*da0073e9SAndroid Build Coastguard Worker family[family_name].append(cls_name) 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker chosen_models = set() 598*da0073e9SAndroid Build Coastguard Worker for members in family.values(): 599*da0073e9SAndroid Build Coastguard Worker chosen_models.update(set(members)) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker # Add the EXTRA_MODELS 602*da0073e9SAndroid Build Coastguard Worker chosen_models.update(set(EXTRA_MODELS.keys())) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker for model_name in sorted(chosen_models): 605*da0073e9SAndroid Build Coastguard Worker try: 606*da0073e9SAndroid Build Coastguard Worker subprocess.check_call( 607*da0073e9SAndroid Build Coastguard Worker [sys.executable] 608*da0073e9SAndroid Build Coastguard Worker + sys.argv 609*da0073e9SAndroid Build Coastguard Worker + ["--find-batch-sizes"] 610*da0073e9SAndroid Build Coastguard Worker + [f"--only={model_name}"] 611*da0073e9SAndroid Build Coastguard Worker + [f"--output={MODELS_FILENAME}"] 612*da0073e9SAndroid Build Coastguard Worker ) 613*da0073e9SAndroid Build Coastguard Worker except subprocess.SubprocessError: 614*da0073e9SAndroid Build Coastguard Worker log.warning(f"Failed to find suitable batch size for {model_name}") 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Workerdef huggingface_main(): 618*da0073e9SAndroid Build Coastguard Worker # Code to refresh model names and batch sizes 619*da0073e9SAndroid Build Coastguard Worker # if "--find-batch-sizes" not in sys.argv: 620*da0073e9SAndroid Build Coastguard Worker # refresh_model_names_and_batch_sizes() 621*da0073e9SAndroid Build Coastguard Worker logging.basicConfig(level=logging.WARNING) 622*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore") 623*da0073e9SAndroid Build Coastguard Worker main(HuggingfaceRunner()) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 627*da0073e9SAndroid Build Coastguard Worker huggingface_main() 628