xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/huggingface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport importlib
4*da0073e9SAndroid Build Coastguard Workerimport logging
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport subprocess
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workertry:
13*da0073e9SAndroid Build Coastguard Worker    from .common import (
14*da0073e9SAndroid Build Coastguard Worker        BenchmarkRunner,
15*da0073e9SAndroid Build Coastguard Worker        download_retry_decorator,
16*da0073e9SAndroid Build Coastguard Worker        load_yaml_file,
17*da0073e9SAndroid Build Coastguard Worker        main,
18*da0073e9SAndroid Build Coastguard Worker        reset_rng_state,
19*da0073e9SAndroid Build Coastguard Worker    )
20*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
21*da0073e9SAndroid Build Coastguard Worker    from common import (
22*da0073e9SAndroid Build Coastguard Worker        BenchmarkRunner,
23*da0073e9SAndroid Build Coastguard Worker        download_retry_decorator,
24*da0073e9SAndroid Build Coastguard Worker        load_yaml_file,
25*da0073e9SAndroid Build Coastguard Worker        main,
26*da0073e9SAndroid Build Coastguard Worker        reset_rng_state,
27*da0073e9SAndroid Build Coastguard Worker    )
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerimport torch
30*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results
31*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching
37*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
38*da0073e9SAndroid Build Coastguard Worker    torch._inductor.config.fx_graph_cache = True
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef pip_install(package):
42*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker# Disable the flake warnings for the imports. Flake8 does not provide a way to
46*da0073e9SAndroid Build Coastguard Worker# disable just warning for the entire file. Disabling flake8 entirely.
47*da0073e9SAndroid Build Coastguard Worker# flake8: noqa
48*da0073e9SAndroid Build Coastguard Workerimports = [
49*da0073e9SAndroid Build Coastguard Worker    "AlbertForPreTraining",
50*da0073e9SAndroid Build Coastguard Worker    "AutoConfig",
51*da0073e9SAndroid Build Coastguard Worker    "AutoModelForCausalLM",
52*da0073e9SAndroid Build Coastguard Worker    "AutoModelForMaskedLM",
53*da0073e9SAndroid Build Coastguard Worker    "AutoModelForSeq2SeqLM",
54*da0073e9SAndroid Build Coastguard Worker    "BigBirdConfig",
55*da0073e9SAndroid Build Coastguard Worker    "BlenderbotForConditionalGeneration",
56*da0073e9SAndroid Build Coastguard Worker    "BlenderbotModel",
57*da0073e9SAndroid Build Coastguard Worker    "BlenderbotSmallForConditionalGeneration",
58*da0073e9SAndroid Build Coastguard Worker    "BlenderbotSmallModel",
59*da0073e9SAndroid Build Coastguard Worker    "CLIPModel",
60*da0073e9SAndroid Build Coastguard Worker    "CLIPVisionModel",
61*da0073e9SAndroid Build Coastguard Worker    "ElectraForPreTraining",
62*da0073e9SAndroid Build Coastguard Worker    "GPT2ForSequenceClassification",
63*da0073e9SAndroid Build Coastguard Worker    "GPTJForSequenceClassification",
64*da0073e9SAndroid Build Coastguard Worker    "GPTNeoForSequenceClassification",
65*da0073e9SAndroid Build Coastguard Worker    "HubertForSequenceClassification",
66*da0073e9SAndroid Build Coastguard Worker    "LxmertForPreTraining",
67*da0073e9SAndroid Build Coastguard Worker    "LxmertForQuestionAnswering",
68*da0073e9SAndroid Build Coastguard Worker    "MarianForCausalLM",
69*da0073e9SAndroid Build Coastguard Worker    "MarianModel",
70*da0073e9SAndroid Build Coastguard Worker    "MarianMTModel",
71*da0073e9SAndroid Build Coastguard Worker    "PegasusForConditionalGeneration",
72*da0073e9SAndroid Build Coastguard Worker    "PegasusModel",
73*da0073e9SAndroid Build Coastguard Worker    "ReformerConfig",
74*da0073e9SAndroid Build Coastguard Worker    "ViTForImageClassification",
75*da0073e9SAndroid Build Coastguard Worker    "ViTForMaskedImageModeling",
76*da0073e9SAndroid Build Coastguard Worker    "ViTModel",
77*da0073e9SAndroid Build Coastguard Worker]
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef process_hf_reformer_output(out):
81*da0073e9SAndroid Build Coastguard Worker    assert isinstance(out, list)
82*da0073e9SAndroid Build Coastguard Worker    # second output is unstable
83*da0073e9SAndroid Build Coastguard Worker    return [elem for i, elem in enumerate(out) if i != 1]
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workertry:
87*da0073e9SAndroid Build Coastguard Worker    mod = importlib.import_module("transformers")
88*da0073e9SAndroid Build Coastguard Worker    for cls in imports:
89*da0073e9SAndroid Build Coastguard Worker        if not hasattr(mod, cls):
90*da0073e9SAndroid Build Coastguard Worker            raise ModuleNotFoundError
91*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
92*da0073e9SAndroid Build Coastguard Worker    print("Installing HuggingFace Transformers...")
93*da0073e9SAndroid Build Coastguard Worker    pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers")
94*da0073e9SAndroid Build Coastguard Workerfinally:
95*da0073e9SAndroid Build Coastguard Worker    for cls in imports:
96*da0073e9SAndroid Build Coastguard Worker        exec(f"from transformers import {cls}")
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker# These models contain the models present in huggingface_models_list. It is a
100*da0073e9SAndroid Build Coastguard Worker# combination of models supported by HF Fx parser and some manually supplied
101*da0073e9SAndroid Build Coastguard Worker# models. For these models, we already know the largest batch size that can fit
102*da0073e9SAndroid Build Coastguard Worker# on A100 GPUs - 40 GB.
103*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE_KNOWN_MODELS = {}
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
107*da0073e9SAndroid Build Coastguard Worker# Get the list of models and their batch sizes
108*da0073e9SAndroid Build Coastguard WorkerMODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
109*da0073e9SAndroid Build Coastguard Workerassert os.path.exists(MODELS_FILENAME)
110*da0073e9SAndroid Build Coastguard Workerwith open(MODELS_FILENAME, "r") as fh:
111*da0073e9SAndroid Build Coastguard Worker    lines = fh.readlines()
112*da0073e9SAndroid Build Coastguard Worker    lines = [line.rstrip() for line in lines]
113*da0073e9SAndroid Build Coastguard Worker    for line in lines:
114*da0073e9SAndroid Build Coastguard Worker        model_name, batch_size = line.split(",")
115*da0073e9SAndroid Build Coastguard Worker        batch_size = int(batch_size)
116*da0073e9SAndroid Build Coastguard Worker        BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
117*da0073e9SAndroid Build Coastguard Workerassert len(BATCH_SIZE_KNOWN_MODELS)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerdef get_module_cls_by_model_name(model_cls_name):
121*da0073e9SAndroid Build Coastguard Worker    _module_by_model_name = {
122*da0073e9SAndroid Build Coastguard Worker        "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
123*da0073e9SAndroid Build Coastguard Worker        "TrOCRDecoder": "transformers.models.trocr.modeling_trocr",
124*da0073e9SAndroid Build Coastguard Worker    }
125*da0073e9SAndroid Build Coastguard Worker    module_name = _module_by_model_name.get(model_cls_name, "transformers")
126*da0073e9SAndroid Build Coastguard Worker    module = importlib.import_module(module_name)
127*da0073e9SAndroid Build Coastguard Worker    return getattr(module, model_cls_name)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Workerdef get_sequence_length(model_cls, model_name):
131*da0073e9SAndroid Build Coastguard Worker    if model_name.startswith(("Blenderbot",)):
132*da0073e9SAndroid Build Coastguard Worker        seq_length = 128
133*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")):
134*da0073e9SAndroid Build Coastguard Worker        seq_length = 1024
135*da0073e9SAndroid Build Coastguard Worker    elif model_name in ("AllenaiLongformerBase", "BigBird"):
136*da0073e9SAndroid Build Coastguard Worker        seq_length = 1024
137*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith("OPT"):
138*da0073e9SAndroid Build Coastguard Worker        seq_length = 2048
139*da0073e9SAndroid Build Coastguard Worker    elif "Reformer" in model_name:
140*da0073e9SAndroid Build Coastguard Worker        seq_length = 4096
141*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith(
142*da0073e9SAndroid Build Coastguard Worker        (
143*da0073e9SAndroid Build Coastguard Worker            "Albert",
144*da0073e9SAndroid Build Coastguard Worker            "Deberta",
145*da0073e9SAndroid Build Coastguard Worker            "Layout",
146*da0073e9SAndroid Build Coastguard Worker            "Electra",
147*da0073e9SAndroid Build Coastguard Worker            "XLNet",
148*da0073e9SAndroid Build Coastguard Worker            "MegatronBert",
149*da0073e9SAndroid Build Coastguard Worker            "Bert",
150*da0073e9SAndroid Build Coastguard Worker            "Roberta",
151*da0073e9SAndroid Build Coastguard Worker        )
152*da0073e9SAndroid Build Coastguard Worker    ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"):
153*da0073e9SAndroid Build Coastguard Worker        seq_length = 512
154*da0073e9SAndroid Build Coastguard Worker    elif model_name in ("TrOCRForCausalLM"):
155*da0073e9SAndroid Build Coastguard Worker        seq_length = 256
156*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith("MobileBert"):
157*da0073e9SAndroid Build Coastguard Worker        seq_length = 128
158*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith("Wav2Vec2"):
159*da0073e9SAndroid Build Coastguard Worker        # If too short, will fail with something like
160*da0073e9SAndroid Build Coastguard Worker        # ValueError: `mask_length` has to be smaller than `sequence_length`,
161*da0073e9SAndroid Build Coastguard Worker        # but got `mask_length`: 10 and `sequence_length`: 9`
162*da0073e9SAndroid Build Coastguard Worker        seq_length = 10000  # NB: a more realistic size is 155136
163*da0073e9SAndroid Build Coastguard Worker    else:
164*da0073e9SAndroid Build Coastguard Worker        log.info(
165*da0073e9SAndroid Build Coastguard Worker            f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
166*da0073e9SAndroid Build Coastguard Worker        )
167*da0073e9SAndroid Build Coastguard Worker        seq_length = 128
168*da0073e9SAndroid Build Coastguard Worker    return seq_length
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Workerdef generate_inputs_for_model(
172*da0073e9SAndroid Build Coastguard Worker    model_cls, model, model_name, bs, device, include_loss_args=False
173*da0073e9SAndroid Build Coastguard Worker):
174*da0073e9SAndroid Build Coastguard Worker    # TODO - Check if following values are representative
175*da0073e9SAndroid Build Coastguard Worker    num_choices = 3
176*da0073e9SAndroid Build Coastguard Worker    num_visual_features = 42
177*da0073e9SAndroid Build Coastguard Worker    seq_length = get_sequence_length(model_cls, model_name)
178*da0073e9SAndroid Build Coastguard Worker    vocab_size = model.config.vocab_size
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    if model_name.startswith("Wav2Vec2"):
181*da0073e9SAndroid Build Coastguard Worker        # TODO: If we add more input_values style models, try to work this
182*da0073e9SAndroid Build Coastguard Worker        # into the overall control flow
183*da0073e9SAndroid Build Coastguard Worker        target_length = 100
184*da0073e9SAndroid Build Coastguard Worker        return {
185*da0073e9SAndroid Build Coastguard Worker            "input_values": torch.randn((bs, seq_length), device=device),
186*da0073e9SAndroid Build Coastguard Worker            # Added because that's what the example training script has
187*da0073e9SAndroid Build Coastguard Worker            "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)),
188*da0073e9SAndroid Build Coastguard Worker            "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)),
189*da0073e9SAndroid Build Coastguard Worker        }
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    if model_name.endswith("MultipleChoice"):
192*da0073e9SAndroid Build Coastguard Worker        input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length))
193*da0073e9SAndroid Build Coastguard Worker    elif model_name.startswith("Roberta"):
194*da0073e9SAndroid Build Coastguard Worker        input = rand_int_tensor(device, 0, 1, (bs, seq_length))
195*da0073e9SAndroid Build Coastguard Worker    else:
196*da0073e9SAndroid Build Coastguard Worker        input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length))
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    if "Bart" in model_name:
199*da0073e9SAndroid Build Coastguard Worker        input[:, -1] = model.config.eos_token_id
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    input_dict = {"input_ids": input}
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    if (
204*da0073e9SAndroid Build Coastguard Worker        model_name.startswith("T5")
205*da0073e9SAndroid Build Coastguard Worker        or model_name.startswith("M2M100")
206*da0073e9SAndroid Build Coastguard Worker        or model_name.startswith("MT5")
207*da0073e9SAndroid Build Coastguard Worker        or model_cls
208*da0073e9SAndroid Build Coastguard Worker        in [
209*da0073e9SAndroid Build Coastguard Worker            BlenderbotModel,
210*da0073e9SAndroid Build Coastguard Worker            BlenderbotSmallModel,
211*da0073e9SAndroid Build Coastguard Worker            BlenderbotForConditionalGeneration,
212*da0073e9SAndroid Build Coastguard Worker            BlenderbotSmallForConditionalGeneration,
213*da0073e9SAndroid Build Coastguard Worker            PegasusModel,
214*da0073e9SAndroid Build Coastguard Worker            PegasusForConditionalGeneration,
215*da0073e9SAndroid Build Coastguard Worker            MarianModel,
216*da0073e9SAndroid Build Coastguard Worker            MarianMTModel,
217*da0073e9SAndroid Build Coastguard Worker        ]
218*da0073e9SAndroid Build Coastguard Worker    ):
219*da0073e9SAndroid Build Coastguard Worker        input_dict["decoder_input_ids"] = input
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    if model_name.startswith("Lxmert"):
222*da0073e9SAndroid Build Coastguard Worker        visual_feat_dim, visual_pos_dim = (
223*da0073e9SAndroid Build Coastguard Worker            model.config.visual_feat_dim,
224*da0073e9SAndroid Build Coastguard Worker            model.config.visual_pos_dim,
225*da0073e9SAndroid Build Coastguard Worker        )
226*da0073e9SAndroid Build Coastguard Worker        input_dict["visual_feats"] = torch.randn(
227*da0073e9SAndroid Build Coastguard Worker            bs, num_visual_features, visual_feat_dim
228*da0073e9SAndroid Build Coastguard Worker        )
229*da0073e9SAndroid Build Coastguard Worker        input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    if include_loss_args:
232*da0073e9SAndroid Build Coastguard Worker        if model_name.endswith("PreTraining"):
233*da0073e9SAndroid Build Coastguard Worker            if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:
234*da0073e9SAndroid Build Coastguard Worker                input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length))
235*da0073e9SAndroid Build Coastguard Worker            else:
236*da0073e9SAndroid Build Coastguard Worker                label_name = (
237*da0073e9SAndroid Build Coastguard Worker                    "sentence_order_label"
238*da0073e9SAndroid Build Coastguard Worker                    if model_cls in [AlbertForPreTraining]
239*da0073e9SAndroid Build Coastguard Worker                    else "next_sentence_label"
240*da0073e9SAndroid Build Coastguard Worker                )
241*da0073e9SAndroid Build Coastguard Worker                input_dict["labels"] = (
242*da0073e9SAndroid Build Coastguard Worker                    rand_int_tensor(device, 0, vocab_size, (bs, seq_length)),
243*da0073e9SAndroid Build Coastguard Worker                )
244*da0073e9SAndroid Build Coastguard Worker                input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,))
245*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("QuestionAnswering"):
246*da0073e9SAndroid Build Coastguard Worker            input_dict["start_positions"] = rand_int_tensor(
247*da0073e9SAndroid Build Coastguard Worker                device, 0, seq_length, (bs,)
248*da0073e9SAndroid Build Coastguard Worker            )
249*da0073e9SAndroid Build Coastguard Worker            input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
250*da0073e9SAndroid Build Coastguard Worker        elif (
251*da0073e9SAndroid Build Coastguard Worker            model_name.endswith("MaskedLM")
252*da0073e9SAndroid Build Coastguard Worker            or model_name.endswith("HeadModel")
253*da0073e9SAndroid Build Coastguard Worker            or model_name.endswith("CausalLM")
254*da0073e9SAndroid Build Coastguard Worker            or model_name.endswith("DoubleHeadsModel")
255*da0073e9SAndroid Build Coastguard Worker        ):
256*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(
257*da0073e9SAndroid Build Coastguard Worker                device, 0, vocab_size, (bs, seq_length)
258*da0073e9SAndroid Build Coastguard Worker            )
259*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("TokenClassification"):
260*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(
261*da0073e9SAndroid Build Coastguard Worker                device, 0, model.config.num_labels - 1, (bs, seq_length)
262*da0073e9SAndroid Build Coastguard Worker            )
263*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("MultipleChoice"):
264*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,))
265*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("SequenceClassification"):
266*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(
267*da0073e9SAndroid Build Coastguard Worker                device, 0, model.config.num_labels - 1, (bs,)
268*da0073e9SAndroid Build Coastguard Worker            )
269*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("NextSentencePrediction"):
270*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,))
271*da0073e9SAndroid Build Coastguard Worker        elif model_name.endswith("ForConditionalGeneration"):
272*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(
273*da0073e9SAndroid Build Coastguard Worker                device, 0, vocab_size - 1, (bs, seq_length)
274*da0073e9SAndroid Build Coastguard Worker            )
275*da0073e9SAndroid Build Coastguard Worker        elif model_name in EXTRA_MODELS:
276*da0073e9SAndroid Build Coastguard Worker            input_dict["labels"] = rand_int_tensor(
277*da0073e9SAndroid Build Coastguard Worker                device, 0, vocab_size, (bs, seq_length)
278*da0073e9SAndroid Build Coastguard Worker            )
279*da0073e9SAndroid Build Coastguard Worker        else:
280*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
281*da0073e9SAndroid Build Coastguard Worker                f"Class {model_name} unsupported for training test "
282*da0073e9SAndroid Build Coastguard Worker            )
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    return input_dict
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Workerdef rand_int_tensor(device, low, high, shape):
288*da0073e9SAndroid Build Coastguard Worker    return torch.randint(
289*da0073e9SAndroid Build Coastguard Worker        low,
290*da0073e9SAndroid Build Coastguard Worker        high,
291*da0073e9SAndroid Build Coastguard Worker        shape,
292*da0073e9SAndroid Build Coastguard Worker        device=device,
293*da0073e9SAndroid Build Coastguard Worker        dtype=torch.int64,
294*da0073e9SAndroid Build Coastguard Worker        requires_grad=False,
295*da0073e9SAndroid Build Coastguard Worker    )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard WorkerEXTRA_MODELS = {
299*da0073e9SAndroid Build Coastguard Worker    "AllenaiLongformerBase": (
300*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("allenai/longformer-base-4096"),
301*da0073e9SAndroid Build Coastguard Worker        AutoModelForMaskedLM,
302*da0073e9SAndroid Build Coastguard Worker    ),
303*da0073e9SAndroid Build Coastguard Worker    "Reformer": (
304*da0073e9SAndroid Build Coastguard Worker        ReformerConfig(),
305*da0073e9SAndroid Build Coastguard Worker        AutoModelForMaskedLM,
306*da0073e9SAndroid Build Coastguard Worker    ),
307*da0073e9SAndroid Build Coastguard Worker    "T5Small": (
308*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("t5-small"),
309*da0073e9SAndroid Build Coastguard Worker        AutoModelForSeq2SeqLM,
310*da0073e9SAndroid Build Coastguard Worker    ),
311*da0073e9SAndroid Build Coastguard Worker    # "BigBird": (
312*da0073e9SAndroid Build Coastguard Worker    #     BigBirdConfig(attention_type="block_sparse"),
313*da0073e9SAndroid Build Coastguard Worker    #     AutoModelForMaskedLM,
314*da0073e9SAndroid Build Coastguard Worker    # ),
315*da0073e9SAndroid Build Coastguard Worker    "DistillGPT2": (
316*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("distilgpt2"),
317*da0073e9SAndroid Build Coastguard Worker        AutoModelForCausalLM,
318*da0073e9SAndroid Build Coastguard Worker    ),
319*da0073e9SAndroid Build Coastguard Worker    "GoogleFnet": (
320*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("google/fnet-base"),
321*da0073e9SAndroid Build Coastguard Worker        AutoModelForMaskedLM,
322*da0073e9SAndroid Build Coastguard Worker    ),
323*da0073e9SAndroid Build Coastguard Worker    "YituTechConvBert": (
324*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("YituTech/conv-bert-base"),
325*da0073e9SAndroid Build Coastguard Worker        AutoModelForMaskedLM,
326*da0073e9SAndroid Build Coastguard Worker    ),
327*da0073e9SAndroid Build Coastguard Worker    "CamemBert": (
328*da0073e9SAndroid Build Coastguard Worker        AutoConfig.from_pretrained("camembert-base"),
329*da0073e9SAndroid Build Coastguard Worker        AutoModelForMaskedLM,
330*da0073e9SAndroid Build Coastguard Worker    ),
331*da0073e9SAndroid Build Coastguard Worker}
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Workerclass HuggingfaceRunner(BenchmarkRunner):
335*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
336*da0073e9SAndroid Build Coastguard Worker        super().__init__()
337*da0073e9SAndroid Build Coastguard Worker        self.suite_name = "huggingface"
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker    @property
340*da0073e9SAndroid Build Coastguard Worker    def _config(self):
341*da0073e9SAndroid Build Coastguard Worker        return load_yaml_file("huggingface.yaml")
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    @property
344*da0073e9SAndroid Build Coastguard Worker    def _skip(self):
345*da0073e9SAndroid Build Coastguard Worker        return self._config["skip"]
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    @property
348*da0073e9SAndroid Build Coastguard Worker    def _accuracy(self):
349*da0073e9SAndroid Build Coastguard Worker        return self._config["accuracy"]
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    @property
352*da0073e9SAndroid Build Coastguard Worker    def skip_models(self):
353*da0073e9SAndroid Build Coastguard Worker        return self._skip["all"]
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    @property
356*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_cpu(self):
357*da0073e9SAndroid Build Coastguard Worker        return self._skip["device"]["cpu"]
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    @property
360*da0073e9SAndroid Build Coastguard Worker    def fp32_only_models(self):
361*da0073e9SAndroid Build Coastguard Worker        return self._config["only_fp32"]
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    @property
364*da0073e9SAndroid Build Coastguard Worker    def skip_models_due_to_control_flow(self):
365*da0073e9SAndroid Build Coastguard Worker        return self._skip["control_flow"]
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    def _get_model_cls_and_config(self, model_name):
368*da0073e9SAndroid Build Coastguard Worker        if model_name not in EXTRA_MODELS:
369*da0073e9SAndroid Build Coastguard Worker            model_cls = get_module_cls_by_model_name(model_name)
370*da0073e9SAndroid Build Coastguard Worker            config_cls = model_cls.config_class
371*da0073e9SAndroid Build Coastguard Worker            config = config_cls()
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker            # NB: some models need a pad token defined to handle BS > 1
374*da0073e9SAndroid Build Coastguard Worker            if (
375*da0073e9SAndroid Build Coastguard Worker                model_cls
376*da0073e9SAndroid Build Coastguard Worker                in [
377*da0073e9SAndroid Build Coastguard Worker                    GPT2ForSequenceClassification,
378*da0073e9SAndroid Build Coastguard Worker                    GPTNeoForSequenceClassification,
379*da0073e9SAndroid Build Coastguard Worker                    GPTJForSequenceClassification,
380*da0073e9SAndroid Build Coastguard Worker                ]
381*da0073e9SAndroid Build Coastguard Worker                or model_cls.__name__.startswith("Roberta")
382*da0073e9SAndroid Build Coastguard Worker                or model_cls.__name__.startswith("Marian")
383*da0073e9SAndroid Build Coastguard Worker            ):
384*da0073e9SAndroid Build Coastguard Worker                config.pad_token_id = 0
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        else:
387*da0073e9SAndroid Build Coastguard Worker            config, model_cls = EXTRA_MODELS[model_name]
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        return model_cls, config
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    @download_retry_decorator
392*da0073e9SAndroid Build Coastguard Worker    def _download_model(self, model_name):
393*da0073e9SAndroid Build Coastguard Worker        model_cls, config = self._get_model_cls_and_config(model_name)
394*da0073e9SAndroid Build Coastguard Worker        if "auto" in model_cls.__module__:
395*da0073e9SAndroid Build Coastguard Worker            # Handle auto classes
396*da0073e9SAndroid Build Coastguard Worker            model = model_cls.from_config(config)
397*da0073e9SAndroid Build Coastguard Worker        else:
398*da0073e9SAndroid Build Coastguard Worker            model = model_cls(config)
399*da0073e9SAndroid Build Coastguard Worker        return model
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def load_model(
402*da0073e9SAndroid Build Coastguard Worker        self,
403*da0073e9SAndroid Build Coastguard Worker        device,
404*da0073e9SAndroid Build Coastguard Worker        model_name,
405*da0073e9SAndroid Build Coastguard Worker        batch_size=None,
406*da0073e9SAndroid Build Coastguard Worker        extra_args=None,
407*da0073e9SAndroid Build Coastguard Worker    ):
408*da0073e9SAndroid Build Coastguard Worker        is_training = self.args.training
409*da0073e9SAndroid Build Coastguard Worker        use_eval_mode = self.args.use_eval_mode
410*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
411*da0073e9SAndroid Build Coastguard Worker        reset_rng_state()
412*da0073e9SAndroid Build Coastguard Worker        model_cls, config = self._get_model_cls_and_config(model_name)
413*da0073e9SAndroid Build Coastguard Worker        model = self._download_model(model_name)
414*da0073e9SAndroid Build Coastguard Worker        model = model.to(device, dtype=dtype)
415*da0073e9SAndroid Build Coastguard Worker        if self.args.enable_activation_checkpointing:
416*da0073e9SAndroid Build Coastguard Worker            model.gradient_checkpointing_enable()
417*da0073e9SAndroid Build Coastguard Worker        if model_name in BATCH_SIZE_KNOWN_MODELS:
418*da0073e9SAndroid Build Coastguard Worker            batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
419*da0073e9SAndroid Build Coastguard Worker        elif batch_size is None:
420*da0073e9SAndroid Build Coastguard Worker            batch_size_default = 16
421*da0073e9SAndroid Build Coastguard Worker            log.info(
422*da0073e9SAndroid Build Coastguard Worker                f"Batch size not specified for {model_name}. Setting batch_size=16"
423*da0073e9SAndroid Build Coastguard Worker            )
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        if batch_size is None:
426*da0073e9SAndroid Build Coastguard Worker            batch_size = batch_size_default
427*da0073e9SAndroid Build Coastguard Worker            batch_size_divisors = self._config["batch_size"]["divisors"]
428*da0073e9SAndroid Build Coastguard Worker            if model_name in batch_size_divisors:
429*da0073e9SAndroid Build Coastguard Worker                batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
430*da0073e9SAndroid Build Coastguard Worker                log.info(
431*da0073e9SAndroid Build Coastguard Worker                    f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
432*da0073e9SAndroid Build Coastguard Worker                )
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker        example_inputs = generate_inputs_for_model(
435*da0073e9SAndroid Build Coastguard Worker            model_cls, model, model_name, batch_size, device, include_loss_args=True
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        # So we can check for correct gradients without eliminating the dropout computation
439*da0073e9SAndroid Build Coastguard Worker        for attr in dir(config):
440*da0073e9SAndroid Build Coastguard Worker            if "drop" in attr and isinstance(getattr(config, attr), float):
441*da0073e9SAndroid Build Coastguard Worker                setattr(config, attr, 1e-30)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        if (
444*da0073e9SAndroid Build Coastguard Worker            is_training
445*da0073e9SAndroid Build Coastguard Worker            and not use_eval_mode
446*da0073e9SAndroid Build Coastguard Worker            and not (
447*da0073e9SAndroid Build Coastguard Worker                self.args.accuracy and model_name in self._config["only_inference"]
448*da0073e9SAndroid Build Coastguard Worker            )
449*da0073e9SAndroid Build Coastguard Worker        ):
450*da0073e9SAndroid Build Coastguard Worker            model.train()
451*da0073e9SAndroid Build Coastguard Worker        else:
452*da0073e9SAndroid Build Coastguard Worker            model.eval()
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        self.validate_model(model, example_inputs)
455*da0073e9SAndroid Build Coastguard Worker        return device, model_name, model, example_inputs, batch_size
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker    def iter_model_names(self, args):
458*da0073e9SAndroid Build Coastguard Worker        model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys())
459*da0073e9SAndroid Build Coastguard Worker        model_names = set(model_names)
460*da0073e9SAndroid Build Coastguard Worker        model_names = sorted(model_names)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        start, end = self.get_benchmark_indices(len(model_names))
463*da0073e9SAndroid Build Coastguard Worker        for index, model_name in enumerate(model_names):
464*da0073e9SAndroid Build Coastguard Worker            if index < start or index >= end:
465*da0073e9SAndroid Build Coastguard Worker                continue
466*da0073e9SAndroid Build Coastguard Worker            if (
467*da0073e9SAndroid Build Coastguard Worker                not re.search("|".join(args.filter), model_name, re.I)
468*da0073e9SAndroid Build Coastguard Worker                or re.search("|".join(args.exclude), model_name, re.I)
469*da0073e9SAndroid Build Coastguard Worker                or model_name in args.exclude_exact
470*da0073e9SAndroid Build Coastguard Worker                or model_name in self.skip_models
471*da0073e9SAndroid Build Coastguard Worker            ):
472*da0073e9SAndroid Build Coastguard Worker                continue
473*da0073e9SAndroid Build Coastguard Worker            yield model_name
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    @property
476*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_checks_large_models_dashboard(self):
477*da0073e9SAndroid Build Coastguard Worker        if self.args.dashboard or self.args.accuracy:
478*da0073e9SAndroid Build Coastguard Worker            return self._accuracy["skip"]["large_models"]
479*da0073e9SAndroid Build Coastguard Worker        return set()
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    @property
482*da0073e9SAndroid Build Coastguard Worker    def get_output_amp_train_process_func(self):
483*da0073e9SAndroid Build Coastguard Worker        return {}
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker    def pick_grad(self, name, is_training):
486*da0073e9SAndroid Build Coastguard Worker        if is_training:
487*da0073e9SAndroid Build Coastguard Worker            return torch.enable_grad()
488*da0073e9SAndroid Build Coastguard Worker        else:
489*da0073e9SAndroid Build Coastguard Worker            return torch.no_grad()
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
492*da0073e9SAndroid Build Coastguard Worker        cosine = self.args.cosine
493*da0073e9SAndroid Build Coastguard Worker        if is_training:
494*da0073e9SAndroid Build Coastguard Worker            from torch._inductor import config as inductor_config
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker            if (name in self._config["tolerance"]["higher_training"]) or (
497*da0073e9SAndroid Build Coastguard Worker                inductor_config.max_autotune
498*da0073e9SAndroid Build Coastguard Worker                and name in self._config["tolerance"]["higher_max_autotune_training"]
499*da0073e9SAndroid Build Coastguard Worker            ):
500*da0073e9SAndroid Build Coastguard Worker                return 2e-2, cosine
501*da0073e9SAndroid Build Coastguard Worker            else:
502*da0073e9SAndroid Build Coastguard Worker                return 1e-2, cosine
503*da0073e9SAndroid Build Coastguard Worker        else:
504*da0073e9SAndroid Build Coastguard Worker            if name in self._config["tolerance"]["higher_inference"]:
505*da0073e9SAndroid Build Coastguard Worker                return 4e-3, cosine
506*da0073e9SAndroid Build Coastguard Worker            if (
507*da0073e9SAndroid Build Coastguard Worker                current_device == "cpu"
508*da0073e9SAndroid Build Coastguard Worker                and name in self._config["tolerance"]["higher_inference_cpu"]
509*da0073e9SAndroid Build Coastguard Worker            ):
510*da0073e9SAndroid Build Coastguard Worker                return 4e-3, cosine
511*da0073e9SAndroid Build Coastguard Worker        return 1e-3, cosine
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker    def compute_loss(self, pred):
514*da0073e9SAndroid Build Coastguard Worker        return pred[0]
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker    def forward_pass(self, mod, inputs, collect_outputs=True):
517*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
518*da0073e9SAndroid Build Coastguard Worker            return mod(**inputs)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
521*da0073e9SAndroid Build Coastguard Worker        cloned_inputs = clone_inputs(inputs)
522*da0073e9SAndroid Build Coastguard Worker        self.optimizer_zero_grad(mod)
523*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
524*da0073e9SAndroid Build Coastguard Worker            pred = mod(**cloned_inputs)
525*da0073e9SAndroid Build Coastguard Worker            loss = self.compute_loss(pred)
526*da0073e9SAndroid Build Coastguard Worker        self.grad_scaler.scale(loss).backward()
527*da0073e9SAndroid Build Coastguard Worker        self.optimizer_step()
528*da0073e9SAndroid Build Coastguard Worker        if collect_outputs:
529*da0073e9SAndroid Build Coastguard Worker            return collect_results(mod, pred, loss, cloned_inputs)
530*da0073e9SAndroid Build Coastguard Worker        return None
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Workerdef refresh_model_names_and_batch_sizes():
534*da0073e9SAndroid Build Coastguard Worker    """
535*da0073e9SAndroid Build Coastguard Worker    This function reads the HF Fx tracer supported models and finds the largest
536*da0073e9SAndroid Build Coastguard Worker    batch size that could fit on the GPU with PyTorch eager.
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    The resulting data is written in huggingface_models_list.txt.
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    Note - We only need to run this function if we believe that HF Fx tracer now
541*da0073e9SAndroid Build Coastguard Worker    supports more models.
542*da0073e9SAndroid Build Coastguard Worker    """
543*da0073e9SAndroid Build Coastguard Worker    import transformers.utils.fx as hf_fx
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker    family = {}
546*da0073e9SAndroid Build Coastguard Worker    lm_seen = set()
547*da0073e9SAndroid Build Coastguard Worker    family_seen = set()
548*da0073e9SAndroid Build Coastguard Worker    for cls_name in hf_fx._SUPPORTED_MODELS:
549*da0073e9SAndroid Build Coastguard Worker        if "For" not in cls_name:
550*da0073e9SAndroid Build Coastguard Worker            continue
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        model_cls = get_module_cls_by_model_name(cls_name)
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        # TODO: AttributeError: '*Config' object has no attribute 'vocab_size'
555*da0073e9SAndroid Build Coastguard Worker        if model_cls in [
556*da0073e9SAndroid Build Coastguard Worker            CLIPModel,
557*da0073e9SAndroid Build Coastguard Worker            CLIPVisionModel,
558*da0073e9SAndroid Build Coastguard Worker            # SwinForImageClassification,
559*da0073e9SAndroid Build Coastguard Worker            # SwinForImageClassification,
560*da0073e9SAndroid Build Coastguard Worker            # SwinForMaskedImageModeling,
561*da0073e9SAndroid Build Coastguard Worker            # SwinModel,
562*da0073e9SAndroid Build Coastguard Worker            ViTForImageClassification,
563*da0073e9SAndroid Build Coastguard Worker            ViTForMaskedImageModeling,
564*da0073e9SAndroid Build Coastguard Worker            ViTModel,
565*da0073e9SAndroid Build Coastguard Worker        ]:
566*da0073e9SAndroid Build Coastguard Worker            continue
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        # TODO: AssertionError: Padding_idx must be within num_embeddings
569*da0073e9SAndroid Build Coastguard Worker        if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]:
570*da0073e9SAndroid Build Coastguard Worker            continue
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        # TODO: "model is not supported yet" from HFTracer
573*da0073e9SAndroid Build Coastguard Worker        if model_cls in [HubertForSequenceClassification]:
574*da0073e9SAndroid Build Coastguard Worker            continue
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        # TODO: shape mismatch in loss calculation
577*da0073e9SAndroid Build Coastguard Worker        if model_cls in [LxmertForQuestionAnswering]:
578*da0073e9SAndroid Build Coastguard Worker            continue
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        family_name = cls_name.split("For")[0]
581*da0073e9SAndroid Build Coastguard Worker        if family_name not in family:
582*da0073e9SAndroid Build Coastguard Worker            family[family_name] = []
583*da0073e9SAndroid Build Coastguard Worker        if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen:
584*da0073e9SAndroid Build Coastguard Worker            family[family_name].append(cls_name)
585*da0073e9SAndroid Build Coastguard Worker            lm_seen.add(family_name)
586*da0073e9SAndroid Build Coastguard Worker        elif (
587*da0073e9SAndroid Build Coastguard Worker            cls_name.endswith(
588*da0073e9SAndroid Build Coastguard Worker                ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering")
589*da0073e9SAndroid Build Coastguard Worker            )
590*da0073e9SAndroid Build Coastguard Worker            and family_name not in family_seen
591*da0073e9SAndroid Build Coastguard Worker        ):
592*da0073e9SAndroid Build Coastguard Worker            family[family_name].append(cls_name)
593*da0073e9SAndroid Build Coastguard Worker            family_seen.add(family_name)
594*da0073e9SAndroid Build Coastguard Worker        elif cls_name.endswith("ImageClassification"):
595*da0073e9SAndroid Build Coastguard Worker            family[family_name].append(cls_name)
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    chosen_models = set()
598*da0073e9SAndroid Build Coastguard Worker    for members in family.values():
599*da0073e9SAndroid Build Coastguard Worker        chosen_models.update(set(members))
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    # Add the EXTRA_MODELS
602*da0073e9SAndroid Build Coastguard Worker    chosen_models.update(set(EXTRA_MODELS.keys()))
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    for model_name in sorted(chosen_models):
605*da0073e9SAndroid Build Coastguard Worker        try:
606*da0073e9SAndroid Build Coastguard Worker            subprocess.check_call(
607*da0073e9SAndroid Build Coastguard Worker                [sys.executable]
608*da0073e9SAndroid Build Coastguard Worker                + sys.argv
609*da0073e9SAndroid Build Coastguard Worker                + ["--find-batch-sizes"]
610*da0073e9SAndroid Build Coastguard Worker                + [f"--only={model_name}"]
611*da0073e9SAndroid Build Coastguard Worker                + [f"--output={MODELS_FILENAME}"]
612*da0073e9SAndroid Build Coastguard Worker            )
613*da0073e9SAndroid Build Coastguard Worker        except subprocess.SubprocessError:
614*da0073e9SAndroid Build Coastguard Worker            log.warning(f"Failed to find suitable batch size for {model_name}")
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Workerdef huggingface_main():
618*da0073e9SAndroid Build Coastguard Worker    # Code to refresh model names and batch sizes
619*da0073e9SAndroid Build Coastguard Worker    # if "--find-batch-sizes" not in sys.argv:
620*da0073e9SAndroid Build Coastguard Worker    #     refresh_model_names_and_batch_sizes()
621*da0073e9SAndroid Build Coastguard Worker    logging.basicConfig(level=logging.WARNING)
622*da0073e9SAndroid Build Coastguard Worker    warnings.filterwarnings("ignore")
623*da0073e9SAndroid Build Coastguard Worker    main(HuggingfaceRunner())
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
627*da0073e9SAndroid Build Coastguard Worker    huggingface_main()
628