#!/usr/bin/env python3 import importlib import logging import os import re import subprocess import sys import warnings try: from .common import BenchmarkRunner, download_retry_decorator, main except ImportError: from common import BenchmarkRunner, download_retry_decorator, main import torch from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs # Enable FX graph caching if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: torch._inductor.config.fx_graph_cache = True def pip_install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) try: importlib.import_module("timm") except ModuleNotFoundError: print("Installing PyTorch Image Models...") pip_install("git+https://github.com/rwightman/pytorch-image-models") finally: from timm import __version__ as timmversion from timm.data import resolve_data_config from timm.models import create_model TIMM_MODELS = {} filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") with open(filename) as fh: lines = fh.readlines() lines = [line.rstrip() for line in lines] for line in lines: model_name, batch_size = line.split(" ") TIMM_MODELS[model_name] = int(batch_size) # TODO - Figure out the reason of cold start memory spike BATCH_SIZE_DIVISORS = { "beit_base_patch16_224": 2, "convit_base": 2, "convmixer_768_32": 2, "convnext_base": 2, "cspdarknet53": 2, "deit_base_distilled_patch16_224": 2, "gluon_xception65": 2, "mobilevit_s": 2, "pnasnet5large": 2, "poolformer_m36": 2, "resnest101e": 2, "swin_base_patch4_window7_224": 2, "swsl_resnext101_32x16d": 2, "vit_base_patch16_224": 2, "volo_d1_224": 2, "jx_nest_base": 4, } REQUIRE_HIGHER_TOLERANCE = { "fbnetv3_b", "gmixer_24_224", "hrnet_w18", "inception_v3", "mixer_b16_224", "mobilenetv3_large_100", "sebotnet33ts_256", "selecsls42b", } REQUIRE_EVEN_HIGHER_TOLERANCE = { "levit_128", "sebotnet33ts_256", "beit_base_patch16_224", "cspdarknet53", } # These models need higher tolerance in MaxAutotune mode REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = { "gluon_inception_v3", } REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { "adv_inception_v3", "botnet26t_256", "gluon_inception_v3", "selecsls42b", "swsl_resnext101_32x16d", } SCALED_COMPUTE_LOSS = { "ese_vovnet19b_dw", "fbnetc_100", "mnasnet_100", "mobilevit_s", "sebotnet33ts_256", } FORCE_AMP_FOR_FP16_BF16_MODELS = { "convit_base", "xcit_large_24_p8_224", } SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = { "xcit_large_24_p8_224", } REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = { "inception_v3", "mobilenetv3_large_100", "cspdarknet53", } def refresh_model_names(): import glob from timm.models import list_models def read_models_from_docs(): models = set() # TODO - set the path to pytorch-image-models repo for fn in glob.glob("../pytorch-image-models/docs/models/*.md"): with open(fn) as f: while True: line = f.readline() if not line: break if not line.startswith("model = timm.create_model("): continue model = line.split("'")[1] # print(model) models.add(model) return models def get_family_name(name): known_families = [ "darknet", "densenet", "dla", "dpn", "ecaresnet", "halo", "regnet", "efficientnet", "deit", "mobilevit", "mnasnet", "convnext", "resnet", "resnest", "resnext", "selecsls", "vgg", "xception", ] for known_family in known_families: if known_family in name: return known_family if name.startswith("gluon_"): return "gluon_" + name.split("_")[1] return name.split("_")[0] def populate_family(models): family = {} for model_name in models: family_name = get_family_name(model_name) if family_name not in family: family[family_name] = [] family[family_name].append(model_name) return family docs_models = read_models_from_docs() all_models = list_models(pretrained=True, exclude_filters=["*in21k"]) all_models_family = populate_family(all_models) docs_models_family = populate_family(docs_models) for key in docs_models_family: del all_models_family[key] chosen_models = set() chosen_models.update(value[0] for value in docs_models_family.values()) chosen_models.update(value[0] for key, value in all_models_family.items()) filename = "timm_models_list.txt" if os.path.exists("benchmarks"): filename = "benchmarks/" + filename with open(filename, "w") as fw: for model_name in sorted(chosen_models): fw.write(model_name + "\n") class TimmRunner(BenchmarkRunner): def __init__(self): super().__init__() self.suite_name = "timm_models" @property def force_amp_for_fp16_bf16_models(self): return FORCE_AMP_FOR_FP16_BF16_MODELS @property def force_fp16_for_bf16_models(self): return set() @property def get_output_amp_train_process_func(self): return {} @property def skip_accuracy_check_as_eager_non_deterministic(self): if self.args.accuracy and self.args.training: return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS return set() @property def guard_on_nn_module_models(self): return { "convit_base", } @property def inline_inbuilt_nn_modules_models(self): return { "lcnet_050", } @download_retry_decorator def _download_model(self, model_name): model = create_model( model_name, in_chans=3, scriptable=False, num_classes=None, drop_rate=0.0, drop_path_rate=None, drop_block_rate=None, pretrained=True, ) return model def load_model( self, device, model_name, batch_size=None, extra_args=None, ): if self.args.enable_activation_checkpointing: raise NotImplementedError( "Activation checkpointing not implemented for Timm models" ) is_training = self.args.training use_eval_mode = self.args.use_eval_mode channels_last = self._args.channels_last model = self._download_model(model_name) if model is None: raise RuntimeError(f"Failed to load model '{model_name}'") model.to( device=device, memory_format=torch.channels_last if channels_last else None, ) self.num_classes = model.num_classes data_config = resolve_data_config( vars(self._args) if timmversion >= "0.8.0" else self._args, model=model, use_test_size=not is_training, ) input_size = data_config["input_size"] recorded_batch_size = TIMM_MODELS[model_name] if model_name in BATCH_SIZE_DIVISORS: recorded_batch_size = max( int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 ) batch_size = batch_size or recorded_batch_size torch.manual_seed(1337) input_tensor = torch.randint( 256, size=(batch_size,) + input_size, device=device ).to(dtype=torch.float32) mean = torch.mean(input_tensor) std_dev = torch.std(input_tensor) example_inputs = (input_tensor - mean) / std_dev if channels_last: example_inputs = example_inputs.contiguous( memory_format=torch.channels_last ) example_inputs = [ example_inputs, ] self.target = self._gen_target(batch_size, device) self.loss = torch.nn.CrossEntropyLoss().to(device) if model_name in SCALED_COMPUTE_LOSS: self.compute_loss = self.scaled_compute_loss if is_training and not use_eval_mode: model.train() else: model.eval() self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size def iter_model_names(self, args): # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]): model_names = sorted(TIMM_MODELS.keys()) start, end = self.get_benchmark_indices(len(model_names)) for index, model_name in enumerate(model_names): if index < start or index >= end: continue if ( not re.search("|".join(args.filter), model_name, re.IGNORECASE) or re.search("|".join(args.exclude), model_name, re.IGNORECASE) or model_name in args.exclude_exact or model_name in self.skip_models ): continue yield model_name def pick_grad(self, name, is_training): if is_training: return torch.enable_grad() else: return torch.no_grad() def use_larger_multiplier_for_smaller_tensor(self, name): return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR def get_tolerance_and_cosine_flag(self, is_training, current_device, name): cosine = self.args.cosine tolerance = 1e-3 if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING: # the conv-batchnorm fusion used under freezing may cause relatively # large numerical difference. We need are larger tolerance. # Check https://github.com/pytorch/pytorch/issues/120545 for context tolerance = 8 * 1e-2 if is_training: from torch._inductor import config as inductor_config if name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( inductor_config.max_autotune and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE ): tolerance = 8 * 1e-2 elif name in REQUIRE_HIGHER_TOLERANCE: tolerance = 4 * 1e-2 else: tolerance = 1e-2 return tolerance, cosine def _gen_target(self, batch_size, device): return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_( self.num_classes ) def compute_loss(self, pred): # High loss values make gradient checking harder, as small changes in # accumulation order upsets accuracy checks. return reduce_to_scalar_loss(pred) def scaled_compute_loss(self, pred): # Loss values need zoom out further. return reduce_to_scalar_loss(pred) / 1000.0 def forward_pass(self, mod, inputs, collect_outputs=True): with self.autocast(**self.autocast_arg): return mod(*inputs) def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) self.optimizer_zero_grad(mod) with self.autocast(**self.autocast_arg): pred = mod(*cloned_inputs) if isinstance(pred, tuple): pred = pred[0] loss = self.compute_loss(pred) self.grad_scaler.scale(loss).backward() self.optimizer_step() if collect_outputs: return collect_results(mod, pred, loss, cloned_inputs) return None def timm_main(): logging.basicConfig(level=logging.WARNING) warnings.filterwarnings("ignore") main(TimmRunner()) if __name__ == "__main__": timm_main()