1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Workerimport errno 4*da0073e9SAndroid Build Coastguard Workerimport hashlib 5*da0073e9SAndroid Build Coastguard Workerimport json 6*da0073e9SAndroid Build Coastguard Workerimport os 7*da0073e9SAndroid Build Coastguard Workerimport re 8*da0073e9SAndroid Build Coastguard Workerimport shutil 9*da0073e9SAndroid Build Coastguard Workerimport sys 10*da0073e9SAndroid Build Coastguard Workerimport tempfile 11*da0073e9SAndroid Build Coastguard Workerimport uuid 12*da0073e9SAndroid Build Coastguard Workerimport warnings 13*da0073e9SAndroid Build Coastguard Workerimport zipfile 14*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional 16*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated 17*da0073e9SAndroid Build Coastguard Workerfrom urllib.error import HTTPError, URLError 18*da0073e9SAndroid Build Coastguard Workerfrom urllib.parse import urlparse # noqa: F401 19*da0073e9SAndroid Build Coastguard Workerfrom urllib.request import Request, urlopen 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerimport torch 22*da0073e9SAndroid Build Coastguard Workerfrom torch.serialization import MAP_LOCATION 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerclass _Faketqdm: # type: ignore[no-redef] 26*da0073e9SAndroid Build Coastguard Worker def __init__(self, total=None, disable=False, unit=None, *args, **kwargs): 27*da0073e9SAndroid Build Coastguard Worker self.total = total 28*da0073e9SAndroid Build Coastguard Worker self.disable = disable 29*da0073e9SAndroid Build Coastguard Worker self.n = 0 30*da0073e9SAndroid Build Coastguard Worker # Ignore all extra *args and **kwargs lest you want to reinvent tqdm 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def update(self, n): 33*da0073e9SAndroid Build Coastguard Worker if self.disable: 34*da0073e9SAndroid Build Coastguard Worker return 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker self.n += n 37*da0073e9SAndroid Build Coastguard Worker if self.total is None: 38*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f"\r{self.n:.1f} bytes") 39*da0073e9SAndroid Build Coastguard Worker else: 40*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") 41*da0073e9SAndroid Build Coastguard Worker sys.stderr.flush() 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker # Don't bother implementing; use real tqdm if you want 44*da0073e9SAndroid Build Coastguard Worker def set_description(self, *args, **kwargs): 45*da0073e9SAndroid Build Coastguard Worker pass 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def write(self, s): 48*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f"{s}\n") 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def close(self): 51*da0073e9SAndroid Build Coastguard Worker self.disable = True 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 54*da0073e9SAndroid Build Coastguard Worker return self 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 57*da0073e9SAndroid Build Coastguard Worker if self.disable: 58*da0073e9SAndroid Build Coastguard Worker return 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker sys.stderr.write("\n") 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workertry: 64*da0073e9SAndroid Build Coastguard Worker from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper 65*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 66*da0073e9SAndroid Build Coastguard Worker tqdm = _Faketqdm 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker__all__ = [ 69*da0073e9SAndroid Build Coastguard Worker "download_url_to_file", 70*da0073e9SAndroid Build Coastguard Worker "get_dir", 71*da0073e9SAndroid Build Coastguard Worker "help", 72*da0073e9SAndroid Build Coastguard Worker "list", 73*da0073e9SAndroid Build Coastguard Worker "load", 74*da0073e9SAndroid Build Coastguard Worker "load_state_dict_from_url", 75*da0073e9SAndroid Build Coastguard Worker "set_dir", 76*da0073e9SAndroid Build Coastguard Worker] 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker# matches bfd8deac from resnet18-bfd8deac.pth 79*da0073e9SAndroid Build Coastguard WorkerHASH_REGEX = re.compile(r"-([a-f0-9]*)\.") 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker_TRUSTED_REPO_OWNERS = ( 82*da0073e9SAndroid Build Coastguard Worker "facebookresearch", 83*da0073e9SAndroid Build Coastguard Worker "facebookincubator", 84*da0073e9SAndroid Build Coastguard Worker "pytorch", 85*da0073e9SAndroid Build Coastguard Worker "fairinternal", 86*da0073e9SAndroid Build Coastguard Worker) 87*da0073e9SAndroid Build Coastguard WorkerENV_GITHUB_TOKEN = "GITHUB_TOKEN" 88*da0073e9SAndroid Build Coastguard WorkerENV_TORCH_HOME = "TORCH_HOME" 89*da0073e9SAndroid Build Coastguard WorkerENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 90*da0073e9SAndroid Build Coastguard WorkerDEFAULT_CACHE_DIR = "~/.cache" 91*da0073e9SAndroid Build Coastguard WorkerVAR_DEPENDENCY = "dependencies" 92*da0073e9SAndroid Build Coastguard WorkerMODULE_HUBCONF = "hubconf.py" 93*da0073e9SAndroid Build Coastguard WorkerREAD_DATA_CHUNK = 128 * 1024 94*da0073e9SAndroid Build Coastguard Worker_hub_dir: Optional[str] = None 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 98*da0073e9SAndroid Build Coastguard Workerdef _add_to_sys_path(path): 99*da0073e9SAndroid Build Coastguard Worker sys.path.insert(0, path) 100*da0073e9SAndroid Build Coastguard Worker try: 101*da0073e9SAndroid Build Coastguard Worker yield 102*da0073e9SAndroid Build Coastguard Worker finally: 103*da0073e9SAndroid Build Coastguard Worker sys.path.remove(path) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker# Copied from tools/shared/module_loader to be included in torch package 107*da0073e9SAndroid Build Coastguard Workerdef _import_module(name, path): 108*da0073e9SAndroid Build Coastguard Worker import importlib.util 109*da0073e9SAndroid Build Coastguard Worker from importlib.abc import Loader 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker spec = importlib.util.spec_from_file_location(name, path) 112*da0073e9SAndroid Build Coastguard Worker assert spec is not None 113*da0073e9SAndroid Build Coastguard Worker module = importlib.util.module_from_spec(spec) 114*da0073e9SAndroid Build Coastguard Worker assert isinstance(spec.loader, Loader) 115*da0073e9SAndroid Build Coastguard Worker spec.loader.exec_module(module) 116*da0073e9SAndroid Build Coastguard Worker return module 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Workerdef _remove_if_exists(path): 120*da0073e9SAndroid Build Coastguard Worker if os.path.exists(path): 121*da0073e9SAndroid Build Coastguard Worker if os.path.isfile(path): 122*da0073e9SAndroid Build Coastguard Worker os.remove(path) 123*da0073e9SAndroid Build Coastguard Worker else: 124*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(path) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Workerdef _git_archive_link(repo_owner, repo_name, ref): 128*da0073e9SAndroid Build Coastguard Worker # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip 129*da0073e9SAndroid Build Coastguard Worker return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Workerdef _load_attr_from_module(module, func_name): 133*da0073e9SAndroid Build Coastguard Worker # Check if callable is defined in the module 134*da0073e9SAndroid Build Coastguard Worker if func_name not in dir(module): 135*da0073e9SAndroid Build Coastguard Worker return None 136*da0073e9SAndroid Build Coastguard Worker return getattr(module, func_name) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerdef _get_torch_home(): 140*da0073e9SAndroid Build Coastguard Worker torch_home = os.path.expanduser( 141*da0073e9SAndroid Build Coastguard Worker os.getenv( 142*da0073e9SAndroid Build Coastguard Worker ENV_TORCH_HOME, 143*da0073e9SAndroid Build Coastguard Worker os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker ) 146*da0073e9SAndroid Build Coastguard Worker return torch_home 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef _parse_repo_info(github): 150*da0073e9SAndroid Build Coastguard Worker if ":" in github: 151*da0073e9SAndroid Build Coastguard Worker repo_info, ref = github.split(":") 152*da0073e9SAndroid Build Coastguard Worker else: 153*da0073e9SAndroid Build Coastguard Worker repo_info, ref = github, None 154*da0073e9SAndroid Build Coastguard Worker repo_owner, repo_name = repo_info.split("/") 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker if ref is None: 157*da0073e9SAndroid Build Coastguard Worker # The ref wasn't specified by the user, so we need to figure out the 158*da0073e9SAndroid Build Coastguard Worker # default branch: main or master. Our assumption is that if main exists 159*da0073e9SAndroid Build Coastguard Worker # then it's the default branch, otherwise it's master. 160*da0073e9SAndroid Build Coastguard Worker try: 161*da0073e9SAndroid Build Coastguard Worker with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): 162*da0073e9SAndroid Build Coastguard Worker ref = "main" 163*da0073e9SAndroid Build Coastguard Worker except HTTPError as e: 164*da0073e9SAndroid Build Coastguard Worker if e.code == 404: 165*da0073e9SAndroid Build Coastguard Worker ref = "master" 166*da0073e9SAndroid Build Coastguard Worker else: 167*da0073e9SAndroid Build Coastguard Worker raise 168*da0073e9SAndroid Build Coastguard Worker except URLError as e: 169*da0073e9SAndroid Build Coastguard Worker # No internet connection, need to check for cache as last resort 170*da0073e9SAndroid Build Coastguard Worker for possible_ref in ("main", "master"): 171*da0073e9SAndroid Build Coastguard Worker if os.path.exists( 172*da0073e9SAndroid Build Coastguard Worker f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}" 173*da0073e9SAndroid Build Coastguard Worker ): 174*da0073e9SAndroid Build Coastguard Worker ref = possible_ref 175*da0073e9SAndroid Build Coastguard Worker break 176*da0073e9SAndroid Build Coastguard Worker if ref is None: 177*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 178*da0073e9SAndroid Build Coastguard Worker "It looks like there is no internet connection and the " 179*da0073e9SAndroid Build Coastguard Worker f"repo could not be found in the cache ({get_dir()})" 180*da0073e9SAndroid Build Coastguard Worker ) from e 181*da0073e9SAndroid Build Coastguard Worker return repo_owner, repo_name, ref 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Workerdef _read_url(url): 185*da0073e9SAndroid Build Coastguard Worker with urlopen(url) as r: 186*da0073e9SAndroid Build Coastguard Worker return r.read().decode(r.headers.get_content_charset("utf-8")) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Workerdef _validate_not_a_forked_repo(repo_owner, repo_name, ref): 190*da0073e9SAndroid Build Coastguard Worker # Use urlopen to avoid depending on local git. 191*da0073e9SAndroid Build Coastguard Worker headers = {"Accept": "application/vnd.github.v3+json"} 192*da0073e9SAndroid Build Coastguard Worker token = os.environ.get(ENV_GITHUB_TOKEN) 193*da0073e9SAndroid Build Coastguard Worker if token is not None: 194*da0073e9SAndroid Build Coastguard Worker headers["Authorization"] = f"token {token}" 195*da0073e9SAndroid Build Coastguard Worker for url_prefix in ( 196*da0073e9SAndroid Build Coastguard Worker f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", 197*da0073e9SAndroid Build Coastguard Worker f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", 198*da0073e9SAndroid Build Coastguard Worker ): 199*da0073e9SAndroid Build Coastguard Worker page = 0 200*da0073e9SAndroid Build Coastguard Worker while True: 201*da0073e9SAndroid Build Coastguard Worker page += 1 202*da0073e9SAndroid Build Coastguard Worker url = f"{url_prefix}?per_page=100&page={page}" 203*da0073e9SAndroid Build Coastguard Worker response = json.loads(_read_url(Request(url, headers=headers))) 204*da0073e9SAndroid Build Coastguard Worker # Empty response means no more data to process 205*da0073e9SAndroid Build Coastguard Worker if not response: 206*da0073e9SAndroid Build Coastguard Worker break 207*da0073e9SAndroid Build Coastguard Worker for br in response: 208*da0073e9SAndroid Build Coastguard Worker if br["name"] == ref or br["commit"]["sha"].startswith(ref): 209*da0073e9SAndroid Build Coastguard Worker return 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker raise ValueError( 212*da0073e9SAndroid Build Coastguard Worker f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " 213*da0073e9SAndroid Build Coastguard Worker "If it's a commit from a forked repo, please call hub.load() with forked repo directly." 214*da0073e9SAndroid Build Coastguard Worker ) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Workerdef _get_cache_or_reload( 218*da0073e9SAndroid Build Coastguard Worker github, 219*da0073e9SAndroid Build Coastguard Worker force_reload, 220*da0073e9SAndroid Build Coastguard Worker trust_repo, 221*da0073e9SAndroid Build Coastguard Worker calling_fn, 222*da0073e9SAndroid Build Coastguard Worker verbose=True, 223*da0073e9SAndroid Build Coastguard Worker skip_validation=False, 224*da0073e9SAndroid Build Coastguard Worker): 225*da0073e9SAndroid Build Coastguard Worker # Setup hub_dir to save downloaded files 226*da0073e9SAndroid Build Coastguard Worker hub_dir = get_dir() 227*da0073e9SAndroid Build Coastguard Worker os.makedirs(hub_dir, exist_ok=True) 228*da0073e9SAndroid Build Coastguard Worker # Parse github repo information 229*da0073e9SAndroid Build Coastguard Worker repo_owner, repo_name, ref = _parse_repo_info(github) 230*da0073e9SAndroid Build Coastguard Worker # Github allows branch name with slash '/', 231*da0073e9SAndroid Build Coastguard Worker # this causes confusion with path on both Linux and Windows. 232*da0073e9SAndroid Build Coastguard Worker # Backslash is not allowed in Github branch name so no need to 233*da0073e9SAndroid Build Coastguard Worker # to worry about it. 234*da0073e9SAndroid Build Coastguard Worker normalized_br = ref.replace("/", "_") 235*da0073e9SAndroid Build Coastguard Worker # Github renames folder repo-v1.x.x to repo-1.x.x 236*da0073e9SAndroid Build Coastguard Worker # We don't know the repo name before downloading the zip file 237*da0073e9SAndroid Build Coastguard Worker # and inspect name from it. 238*da0073e9SAndroid Build Coastguard Worker # To check if cached repo exists, we need to normalize folder names. 239*da0073e9SAndroid Build Coastguard Worker owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) 240*da0073e9SAndroid Build Coastguard Worker repo_dir = os.path.join(hub_dir, owner_name_branch) 241*da0073e9SAndroid Build Coastguard Worker # Check that the repo is in the trusted list 242*da0073e9SAndroid Build Coastguard Worker _check_repo_is_trusted( 243*da0073e9SAndroid Build Coastguard Worker repo_owner, 244*da0073e9SAndroid Build Coastguard Worker repo_name, 245*da0073e9SAndroid Build Coastguard Worker owner_name_branch, 246*da0073e9SAndroid Build Coastguard Worker trust_repo=trust_repo, 247*da0073e9SAndroid Build Coastguard Worker calling_fn=calling_fn, 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker use_cache = (not force_reload) and os.path.exists(repo_dir) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker if use_cache: 253*da0073e9SAndroid Build Coastguard Worker if verbose: 254*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f"Using cache found in {repo_dir}\n") 255*da0073e9SAndroid Build Coastguard Worker else: 256*da0073e9SAndroid Build Coastguard Worker # Validate the tag/branch is from the original repo instead of a forked repo 257*da0073e9SAndroid Build Coastguard Worker if not skip_validation: 258*da0073e9SAndroid Build Coastguard Worker _validate_not_a_forked_repo(repo_owner, repo_name, ref) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker cached_file = os.path.join(hub_dir, normalized_br + ".zip") 261*da0073e9SAndroid Build Coastguard Worker _remove_if_exists(cached_file) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker try: 264*da0073e9SAndroid Build Coastguard Worker url = _git_archive_link(repo_owner, repo_name, ref) 265*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') 266*da0073e9SAndroid Build Coastguard Worker download_url_to_file(url, cached_file, progress=False) 267*da0073e9SAndroid Build Coastguard Worker except HTTPError as err: 268*da0073e9SAndroid Build Coastguard Worker if err.code == 300: 269*da0073e9SAndroid Build Coastguard Worker # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch 270*da0073e9SAndroid Build Coastguard Worker # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags 271*da0073e9SAndroid Build Coastguard Worker # See https://git-scm.com/book/en/v2/Git-Internals-Git-References 272*da0073e9SAndroid Build Coastguard Worker # Here, we do the same as git: we throw a warning, and assume the user wanted the branch 273*da0073e9SAndroid Build Coastguard Worker warnings.warn( 274*da0073e9SAndroid Build Coastguard Worker f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " 275*da0073e9SAndroid Build Coastguard Worker "Torchhub will now assume that it's a branch. " 276*da0073e9SAndroid Build Coastguard Worker "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " 277*da0073e9SAndroid Build Coastguard Worker "refs/tags/tag_name as the ref. That might require using skip_validation=True." 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker disambiguated_branch_ref = f"refs/heads/{ref}" 280*da0073e9SAndroid Build Coastguard Worker url = _git_archive_link( 281*da0073e9SAndroid Build Coastguard Worker repo_owner, repo_name, ref=disambiguated_branch_ref 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker download_url_to_file(url, cached_file, progress=False) 284*da0073e9SAndroid Build Coastguard Worker else: 285*da0073e9SAndroid Build Coastguard Worker raise 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker with zipfile.ZipFile(cached_file) as cached_zipfile: 288*da0073e9SAndroid Build Coastguard Worker extraced_repo_name = cached_zipfile.infolist()[0].filename 289*da0073e9SAndroid Build Coastguard Worker extracted_repo = os.path.join(hub_dir, extraced_repo_name) 290*da0073e9SAndroid Build Coastguard Worker _remove_if_exists(extracted_repo) 291*da0073e9SAndroid Build Coastguard Worker # Unzip the code and rename the base folder 292*da0073e9SAndroid Build Coastguard Worker cached_zipfile.extractall(hub_dir) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker _remove_if_exists(cached_file) 295*da0073e9SAndroid Build Coastguard Worker _remove_if_exists(repo_dir) 296*da0073e9SAndroid Build Coastguard Worker shutil.move(extracted_repo, repo_dir) # rename the repo 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker return repo_dir 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Workerdef _check_repo_is_trusted( 302*da0073e9SAndroid Build Coastguard Worker repo_owner, 303*da0073e9SAndroid Build Coastguard Worker repo_name, 304*da0073e9SAndroid Build Coastguard Worker owner_name_branch, 305*da0073e9SAndroid Build Coastguard Worker trust_repo, 306*da0073e9SAndroid Build Coastguard Worker calling_fn="load", 307*da0073e9SAndroid Build Coastguard Worker): 308*da0073e9SAndroid Build Coastguard Worker hub_dir = get_dir() 309*da0073e9SAndroid Build Coastguard Worker filepath = os.path.join(hub_dir, "trusted_list") 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(filepath): 312*da0073e9SAndroid Build Coastguard Worker Path(filepath).touch() 313*da0073e9SAndroid Build Coastguard Worker with open(filepath) as file: 314*da0073e9SAndroid Build Coastguard Worker trusted_repos = tuple(line.strip() for line in file) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker # To minimize friction of introducing the new trust_repo mechanism, we consider that 317*da0073e9SAndroid Build Coastguard Worker # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) 318*da0073e9SAndroid Build Coastguard Worker trusted_repos_legacy = next(os.walk(hub_dir))[1] 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker owner_name = "_".join([repo_owner, repo_name]) 321*da0073e9SAndroid Build Coastguard Worker is_trusted = ( 322*da0073e9SAndroid Build Coastguard Worker owner_name in trusted_repos 323*da0073e9SAndroid Build Coastguard Worker or owner_name_branch in trusted_repos_legacy 324*da0073e9SAndroid Build Coastguard Worker or repo_owner in _TRUSTED_REPO_OWNERS 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker # TODO: Remove `None` option in 2.0 and change the default to "check" 328*da0073e9SAndroid Build Coastguard Worker if trust_repo is None: 329*da0073e9SAndroid Build Coastguard Worker if not is_trusted: 330*da0073e9SAndroid Build Coastguard Worker warnings.warn( 331*da0073e9SAndroid Build Coastguard Worker "You are about to download and run code from an untrusted repository. In a future release, this won't " 332*da0073e9SAndroid Build Coastguard Worker "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " 333*da0073e9SAndroid Build Coastguard Worker "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " 334*da0073e9SAndroid Build Coastguard Worker f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " 335*da0073e9SAndroid Build Coastguard Worker f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " 336*da0073e9SAndroid Build Coastguard Worker f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker return 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker if (trust_repo is False) or (trust_repo == "check" and not is_trusted): 341*da0073e9SAndroid Build Coastguard Worker response = input( 342*da0073e9SAndroid Build Coastguard Worker f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " 343*da0073e9SAndroid Build Coastguard Worker "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" 344*da0073e9SAndroid Build Coastguard Worker ) 345*da0073e9SAndroid Build Coastguard Worker if response.lower() in ("y", "yes"): 346*da0073e9SAndroid Build Coastguard Worker if is_trusted: 347*da0073e9SAndroid Build Coastguard Worker print("The repository is already trusted.") 348*da0073e9SAndroid Build Coastguard Worker elif response.lower() in ("n", "no", ""): 349*da0073e9SAndroid Build Coastguard Worker raise Exception("Untrusted repository.") # noqa: TRY002 350*da0073e9SAndroid Build Coastguard Worker else: 351*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unrecognized response {response}.") 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker # At this point we're sure that the user trusts the repo (or wants to trust it) 354*da0073e9SAndroid Build Coastguard Worker if not is_trusted: 355*da0073e9SAndroid Build Coastguard Worker with open(filepath, "a") as file: 356*da0073e9SAndroid Build Coastguard Worker file.write(owner_name + "\n") 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerdef _check_module_exists(name): 360*da0073e9SAndroid Build Coastguard Worker import importlib.util 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker return importlib.util.find_spec(name) is not None 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Workerdef _check_dependencies(m): 366*da0073e9SAndroid Build Coastguard Worker dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker if dependencies is not None: 369*da0073e9SAndroid Build Coastguard Worker missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] 370*da0073e9SAndroid Build Coastguard Worker if len(missing_deps): 371*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Workerdef _load_entry_from_hubconf(m, model): 375*da0073e9SAndroid Build Coastguard Worker if not isinstance(model, str): 376*da0073e9SAndroid Build Coastguard Worker raise ValueError("Invalid input: model should be a string of function name") 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker # Note that if a missing dependency is imported at top level of hubconf, it will 379*da0073e9SAndroid Build Coastguard Worker # throw before this function. It's a chicken and egg situation where we have to 380*da0073e9SAndroid Build Coastguard Worker # load hubconf to know what're the dependencies, but to import hubconf it requires 381*da0073e9SAndroid Build Coastguard Worker # a missing package. This is fine, Python will throw proper error message for users. 382*da0073e9SAndroid Build Coastguard Worker _check_dependencies(m) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker func = _load_attr_from_module(m, model) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker if func is None or not callable(func): 387*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Cannot find callable {model} in hubconf") 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker return func 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Workerdef get_dir(): 393*da0073e9SAndroid Build Coastguard Worker r""" 394*da0073e9SAndroid Build Coastguard Worker Get the Torch Hub cache directory used for storing downloaded models & weights. 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where 397*da0073e9SAndroid Build Coastguard Worker environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. 398*da0073e9SAndroid Build Coastguard Worker ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux 399*da0073e9SAndroid Build Coastguard Worker filesystem layout, with a default value ``~/.cache`` if the environment 400*da0073e9SAndroid Build Coastguard Worker variable is not set. 401*da0073e9SAndroid Build Coastguard Worker """ 402*da0073e9SAndroid Build Coastguard Worker # Issue warning to move data if old env is set 403*da0073e9SAndroid Build Coastguard Worker if os.getenv("TORCH_HUB"): 404*da0073e9SAndroid Build Coastguard Worker warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker if _hub_dir is not None: 407*da0073e9SAndroid Build Coastguard Worker return _hub_dir 408*da0073e9SAndroid Build Coastguard Worker return os.path.join(_get_torch_home(), "hub") 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Workerdef set_dir(d): 412*da0073e9SAndroid Build Coastguard Worker r""" 413*da0073e9SAndroid Build Coastguard Worker Optionally set the Torch Hub directory used to save downloaded models & weights. 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker Args: 416*da0073e9SAndroid Build Coastguard Worker d (str): path to a local folder to save downloaded models & weights. 417*da0073e9SAndroid Build Coastguard Worker """ 418*da0073e9SAndroid Build Coastguard Worker global _hub_dir 419*da0073e9SAndroid Build Coastguard Worker _hub_dir = os.path.expanduser(d) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Workerdef list( 423*da0073e9SAndroid Build Coastguard Worker github, 424*da0073e9SAndroid Build Coastguard Worker force_reload=False, 425*da0073e9SAndroid Build Coastguard Worker skip_validation=False, 426*da0073e9SAndroid Build Coastguard Worker trust_repo=None, 427*da0073e9SAndroid Build Coastguard Worker verbose=True, 428*da0073e9SAndroid Build Coastguard Worker): 429*da0073e9SAndroid Build Coastguard Worker r""" 430*da0073e9SAndroid Build Coastguard Worker List all callable entrypoints available in the repo specified by ``github``. 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker Args: 433*da0073e9SAndroid Build Coastguard Worker github (str): a string with format "repo_owner/repo_name[:ref]" with an optional 434*da0073e9SAndroid Build Coastguard Worker ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if 435*da0073e9SAndroid Build Coastguard Worker it exists, and otherwise ``master``. 436*da0073e9SAndroid Build Coastguard Worker Example: 'pytorch/vision:0.10' 437*da0073e9SAndroid Build Coastguard Worker force_reload (bool, optional): whether to discard the existing cache and force a fresh download. 438*da0073e9SAndroid Build Coastguard Worker Default is ``False``. 439*da0073e9SAndroid Build Coastguard Worker skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit 440*da0073e9SAndroid Build Coastguard Worker specified by the ``github`` argument properly belongs to the repo owner. This will make 441*da0073e9SAndroid Build Coastguard Worker requests to the GitHub API; you can specify a non-default GitHub token by setting the 442*da0073e9SAndroid Build Coastguard Worker ``GITHUB_TOKEN`` environment variable. Default is ``False``. 443*da0073e9SAndroid Build Coastguard Worker trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 444*da0073e9SAndroid Build Coastguard Worker This parameter was introduced in v1.12 and helps ensuring that users 445*da0073e9SAndroid Build Coastguard Worker only run code from repos that they trust. 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker - If ``False``, a prompt will ask the user whether the repo should 448*da0073e9SAndroid Build Coastguard Worker be trusted. 449*da0073e9SAndroid Build Coastguard Worker - If ``True``, the repo will be added to the trusted list and loaded 450*da0073e9SAndroid Build Coastguard Worker without requiring explicit confirmation. 451*da0073e9SAndroid Build Coastguard Worker - If ``"check"``, the repo will be checked against the list of 452*da0073e9SAndroid Build Coastguard Worker trusted repos in the cache. If it is not present in that list, the 453*da0073e9SAndroid Build Coastguard Worker behaviour will fall back onto the ``trust_repo=False`` option. 454*da0073e9SAndroid Build Coastguard Worker - If ``None``: this will raise a warning, inviting the user to set 455*da0073e9SAndroid Build Coastguard Worker ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 456*da0073e9SAndroid Build Coastguard Worker is only present for backward compatibility and will be removed in 457*da0073e9SAndroid Build Coastguard Worker v2.0. 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker Default is ``None`` and will eventually change to ``"check"`` in v2.0. 460*da0073e9SAndroid Build Coastguard Worker verbose (bool, optional): If ``False``, mute messages about hitting 461*da0073e9SAndroid Build Coastguard Worker local caches. Note that the message about first download cannot be 462*da0073e9SAndroid Build Coastguard Worker muted. Default is ``True``. 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker Returns: 465*da0073e9SAndroid Build Coastguard Worker list: The available callables entrypoint 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker Example: 468*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 469*da0073e9SAndroid Build Coastguard Worker >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True) 470*da0073e9SAndroid Build Coastguard Worker """ 471*da0073e9SAndroid Build Coastguard Worker repo_dir = _get_cache_or_reload( 472*da0073e9SAndroid Build Coastguard Worker github, 473*da0073e9SAndroid Build Coastguard Worker force_reload, 474*da0073e9SAndroid Build Coastguard Worker trust_repo, 475*da0073e9SAndroid Build Coastguard Worker "list", 476*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 477*da0073e9SAndroid Build Coastguard Worker skip_validation=skip_validation, 478*da0073e9SAndroid Build Coastguard Worker ) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker with _add_to_sys_path(repo_dir): 481*da0073e9SAndroid Build Coastguard Worker hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) 482*da0073e9SAndroid Build Coastguard Worker hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker # We take functions starts with '_' as internal helper functions 485*da0073e9SAndroid Build Coastguard Worker entrypoints = [ 486*da0073e9SAndroid Build Coastguard Worker f 487*da0073e9SAndroid Build Coastguard Worker for f in dir(hub_module) 488*da0073e9SAndroid Build Coastguard Worker if callable(getattr(hub_module, f)) and not f.startswith("_") 489*da0073e9SAndroid Build Coastguard Worker ] 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker return entrypoints 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Workerdef help(github, model, force_reload=False, skip_validation=False, trust_repo=None): 495*da0073e9SAndroid Build Coastguard Worker r""" 496*da0073e9SAndroid Build Coastguard Worker Show the docstring of entrypoint ``model``. 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker Args: 499*da0073e9SAndroid Build Coastguard Worker github (str): a string with format <repo_owner/repo_name[:ref]> with an optional 500*da0073e9SAndroid Build Coastguard Worker ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed 501*da0073e9SAndroid Build Coastguard Worker to be ``main`` if it exists, and otherwise ``master``. 502*da0073e9SAndroid Build Coastguard Worker Example: 'pytorch/vision:0.10' 503*da0073e9SAndroid Build Coastguard Worker model (str): a string of entrypoint name defined in repo's ``hubconf.py`` 504*da0073e9SAndroid Build Coastguard Worker force_reload (bool, optional): whether to discard the existing cache and force a fresh download. 505*da0073e9SAndroid Build Coastguard Worker Default is ``False``. 506*da0073e9SAndroid Build Coastguard Worker skip_validation (bool, optional): if ``False``, torchhub will check that the ref 507*da0073e9SAndroid Build Coastguard Worker specified by the ``github`` argument properly belongs to the repo owner. This will make 508*da0073e9SAndroid Build Coastguard Worker requests to the GitHub API; you can specify a non-default GitHub token by setting the 509*da0073e9SAndroid Build Coastguard Worker ``GITHUB_TOKEN`` environment variable. Default is ``False``. 510*da0073e9SAndroid Build Coastguard Worker trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 511*da0073e9SAndroid Build Coastguard Worker This parameter was introduced in v1.12 and helps ensuring that users 512*da0073e9SAndroid Build Coastguard Worker only run code from repos that they trust. 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker - If ``False``, a prompt will ask the user whether the repo should 515*da0073e9SAndroid Build Coastguard Worker be trusted. 516*da0073e9SAndroid Build Coastguard Worker - If ``True``, the repo will be added to the trusted list and loaded 517*da0073e9SAndroid Build Coastguard Worker without requiring explicit confirmation. 518*da0073e9SAndroid Build Coastguard Worker - If ``"check"``, the repo will be checked against the list of 519*da0073e9SAndroid Build Coastguard Worker trusted repos in the cache. If it is not present in that list, the 520*da0073e9SAndroid Build Coastguard Worker behaviour will fall back onto the ``trust_repo=False`` option. 521*da0073e9SAndroid Build Coastguard Worker - If ``None``: this will raise a warning, inviting the user to set 522*da0073e9SAndroid Build Coastguard Worker ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 523*da0073e9SAndroid Build Coastguard Worker is only present for backward compatibility and will be removed in 524*da0073e9SAndroid Build Coastguard Worker v2.0. 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker Default is ``None`` and will eventually change to ``"check"`` in v2.0. 527*da0073e9SAndroid Build Coastguard Worker Example: 528*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 529*da0073e9SAndroid Build Coastguard Worker >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True)) 530*da0073e9SAndroid Build Coastguard Worker """ 531*da0073e9SAndroid Build Coastguard Worker repo_dir = _get_cache_or_reload( 532*da0073e9SAndroid Build Coastguard Worker github, 533*da0073e9SAndroid Build Coastguard Worker force_reload, 534*da0073e9SAndroid Build Coastguard Worker trust_repo, 535*da0073e9SAndroid Build Coastguard Worker "help", 536*da0073e9SAndroid Build Coastguard Worker verbose=True, 537*da0073e9SAndroid Build Coastguard Worker skip_validation=skip_validation, 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker with _add_to_sys_path(repo_dir): 541*da0073e9SAndroid Build Coastguard Worker hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) 542*da0073e9SAndroid Build Coastguard Worker hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker entry = _load_entry_from_hubconf(hub_module, model) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker return entry.__doc__ 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Workerdef load( 550*da0073e9SAndroid Build Coastguard Worker repo_or_dir, 551*da0073e9SAndroid Build Coastguard Worker model, 552*da0073e9SAndroid Build Coastguard Worker *args, 553*da0073e9SAndroid Build Coastguard Worker source="github", 554*da0073e9SAndroid Build Coastguard Worker trust_repo=None, 555*da0073e9SAndroid Build Coastguard Worker force_reload=False, 556*da0073e9SAndroid Build Coastguard Worker verbose=True, 557*da0073e9SAndroid Build Coastguard Worker skip_validation=False, 558*da0073e9SAndroid Build Coastguard Worker **kwargs, 559*da0073e9SAndroid Build Coastguard Worker): 560*da0073e9SAndroid Build Coastguard Worker r""" 561*da0073e9SAndroid Build Coastguard Worker Load a model from a github repo or a local directory. 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker Note: Loading a model is the typical use case, but this can also be used to 564*da0073e9SAndroid Build Coastguard Worker for loading other objects such as tokenizers, loss functions, etc. 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker If ``source`` is 'github', ``repo_or_dir`` is expected to be 567*da0073e9SAndroid Build Coastguard Worker of the form ``repo_owner/repo_name[:ref]`` with an optional 568*da0073e9SAndroid Build Coastguard Worker ref (a tag or a branch). 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker If ``source`` is 'local', ``repo_or_dir`` is expected to be a 571*da0073e9SAndroid Build Coastguard Worker path to a local directory. 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker Args: 574*da0073e9SAndroid Build Coastguard Worker repo_or_dir (str): If ``source`` is 'github', 575*da0073e9SAndroid Build Coastguard Worker this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with 576*da0073e9SAndroid Build Coastguard Worker an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, 577*da0073e9SAndroid Build Coastguard Worker the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. 578*da0073e9SAndroid Build Coastguard Worker If ``source`` is 'local' then it should be a path to a local directory. 579*da0073e9SAndroid Build Coastguard Worker model (str): the name of a callable (entrypoint) defined in the 580*da0073e9SAndroid Build Coastguard Worker repo/dir's ``hubconf.py``. 581*da0073e9SAndroid Build Coastguard Worker *args (optional): the corresponding args for callable ``model``. 582*da0073e9SAndroid Build Coastguard Worker source (str, optional): 'github' or 'local'. Specifies how 583*da0073e9SAndroid Build Coastguard Worker ``repo_or_dir`` is to be interpreted. Default is 'github'. 584*da0073e9SAndroid Build Coastguard Worker trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 585*da0073e9SAndroid Build Coastguard Worker This parameter was introduced in v1.12 and helps ensuring that users 586*da0073e9SAndroid Build Coastguard Worker only run code from repos that they trust. 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker - If ``False``, a prompt will ask the user whether the repo should 589*da0073e9SAndroid Build Coastguard Worker be trusted. 590*da0073e9SAndroid Build Coastguard Worker - If ``True``, the repo will be added to the trusted list and loaded 591*da0073e9SAndroid Build Coastguard Worker without requiring explicit confirmation. 592*da0073e9SAndroid Build Coastguard Worker - If ``"check"``, the repo will be checked against the list of 593*da0073e9SAndroid Build Coastguard Worker trusted repos in the cache. If it is not present in that list, the 594*da0073e9SAndroid Build Coastguard Worker behaviour will fall back onto the ``trust_repo=False`` option. 595*da0073e9SAndroid Build Coastguard Worker - If ``None``: this will raise a warning, inviting the user to set 596*da0073e9SAndroid Build Coastguard Worker ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 597*da0073e9SAndroid Build Coastguard Worker is only present for backward compatibility and will be removed in 598*da0073e9SAndroid Build Coastguard Worker v2.0. 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker Default is ``None`` and will eventually change to ``"check"`` in v2.0. 601*da0073e9SAndroid Build Coastguard Worker force_reload (bool, optional): whether to force a fresh download of 602*da0073e9SAndroid Build Coastguard Worker the github repo unconditionally. Does not have any effect if 603*da0073e9SAndroid Build Coastguard Worker ``source = 'local'``. Default is ``False``. 604*da0073e9SAndroid Build Coastguard Worker verbose (bool, optional): If ``False``, mute messages about hitting 605*da0073e9SAndroid Build Coastguard Worker local caches. Note that the message about first download cannot be 606*da0073e9SAndroid Build Coastguard Worker muted. Does not have any effect if ``source = 'local'``. 607*da0073e9SAndroid Build Coastguard Worker Default is ``True``. 608*da0073e9SAndroid Build Coastguard Worker skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit 609*da0073e9SAndroid Build Coastguard Worker specified by the ``github`` argument properly belongs to the repo owner. This will make 610*da0073e9SAndroid Build Coastguard Worker requests to the GitHub API; you can specify a non-default GitHub token by setting the 611*da0073e9SAndroid Build Coastguard Worker ``GITHUB_TOKEN`` environment variable. Default is ``False``. 612*da0073e9SAndroid Build Coastguard Worker **kwargs (optional): the corresponding kwargs for callable ``model``. 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker Returns: 615*da0073e9SAndroid Build Coastguard Worker The output of the ``model`` callable when called with the given 616*da0073e9SAndroid Build Coastguard Worker ``*args`` and ``**kwargs``. 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker Example: 619*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 620*da0073e9SAndroid Build Coastguard Worker >>> # from a github repo 621*da0073e9SAndroid Build Coastguard Worker >>> repo = "pytorch/vision" 622*da0073e9SAndroid Build Coastguard Worker >>> model = torch.hub.load( 623*da0073e9SAndroid Build Coastguard Worker ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" 624*da0073e9SAndroid Build Coastguard Worker ... ) 625*da0073e9SAndroid Build Coastguard Worker >>> # from a local directory 626*da0073e9SAndroid Build Coastguard Worker >>> path = "/some/local/path/pytorch/vision" 627*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 628*da0073e9SAndroid Build Coastguard Worker >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT") 629*da0073e9SAndroid Build Coastguard Worker """ 630*da0073e9SAndroid Build Coastguard Worker source = source.lower() 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker if source not in ("github", "local"): 633*da0073e9SAndroid Build Coastguard Worker raise ValueError( 634*da0073e9SAndroid Build Coastguard Worker f'Unknown source: "{source}". Allowed values: "github" | "local".' 635*da0073e9SAndroid Build Coastguard Worker ) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker if source == "github": 638*da0073e9SAndroid Build Coastguard Worker repo_or_dir = _get_cache_or_reload( 639*da0073e9SAndroid Build Coastguard Worker repo_or_dir, 640*da0073e9SAndroid Build Coastguard Worker force_reload, 641*da0073e9SAndroid Build Coastguard Worker trust_repo, 642*da0073e9SAndroid Build Coastguard Worker "load", 643*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 644*da0073e9SAndroid Build Coastguard Worker skip_validation=skip_validation, 645*da0073e9SAndroid Build Coastguard Worker ) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker model = _load_local(repo_or_dir, model, *args, **kwargs) 648*da0073e9SAndroid Build Coastguard Worker return model 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Workerdef _load_local(hubconf_dir, model, *args, **kwargs): 652*da0073e9SAndroid Build Coastguard Worker r""" 653*da0073e9SAndroid Build Coastguard Worker Load a model from a local directory with a ``hubconf.py``. 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker Args: 656*da0073e9SAndroid Build Coastguard Worker hubconf_dir (str): path to a local directory that contains a 657*da0073e9SAndroid Build Coastguard Worker ``hubconf.py``. 658*da0073e9SAndroid Build Coastguard Worker model (str): name of an entrypoint defined in the directory's 659*da0073e9SAndroid Build Coastguard Worker ``hubconf.py``. 660*da0073e9SAndroid Build Coastguard Worker *args (optional): the corresponding args for callable ``model``. 661*da0073e9SAndroid Build Coastguard Worker **kwargs (optional): the corresponding kwargs for callable ``model``. 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker Returns: 664*da0073e9SAndroid Build Coastguard Worker a single model with corresponding pretrained weights. 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker Example: 667*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("stub local path") 668*da0073e9SAndroid Build Coastguard Worker >>> path = "/some/local/path/pytorch/vision" 669*da0073e9SAndroid Build Coastguard Worker >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1") 670*da0073e9SAndroid Build Coastguard Worker """ 671*da0073e9SAndroid Build Coastguard Worker with _add_to_sys_path(hubconf_dir): 672*da0073e9SAndroid Build Coastguard Worker hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) 673*da0073e9SAndroid Build Coastguard Worker hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker entry = _load_entry_from_hubconf(hub_module, model) 676*da0073e9SAndroid Build Coastguard Worker model = entry(*args, **kwargs) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker return model 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Workerdef download_url_to_file( 682*da0073e9SAndroid Build Coastguard Worker url: str, 683*da0073e9SAndroid Build Coastguard Worker dst: str, 684*da0073e9SAndroid Build Coastguard Worker hash_prefix: Optional[str] = None, 685*da0073e9SAndroid Build Coastguard Worker progress: bool = True, 686*da0073e9SAndroid Build Coastguard Worker) -> None: 687*da0073e9SAndroid Build Coastguard Worker r"""Download object at the given URL to a local path. 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker Args: 690*da0073e9SAndroid Build Coastguard Worker url (str): URL of the object to download 691*da0073e9SAndroid Build Coastguard Worker dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` 692*da0073e9SAndroid Build Coastguard Worker hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. 693*da0073e9SAndroid Build Coastguard Worker Default: None 694*da0073e9SAndroid Build Coastguard Worker progress (bool, optional): whether or not to display a progress bar to stderr 695*da0073e9SAndroid Build Coastguard Worker Default: True 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker Example: 698*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 699*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(POSIX) 700*da0073e9SAndroid Build Coastguard Worker >>> torch.hub.download_url_to_file( 701*da0073e9SAndroid Build Coastguard Worker ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", 702*da0073e9SAndroid Build Coastguard Worker ... "/tmp/temporary_file", 703*da0073e9SAndroid Build Coastguard Worker ... ) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker """ 706*da0073e9SAndroid Build Coastguard Worker file_size = None 707*da0073e9SAndroid Build Coastguard Worker req = Request(url, headers={"User-Agent": "torch.hub"}) 708*da0073e9SAndroid Build Coastguard Worker u = urlopen(req) 709*da0073e9SAndroid Build Coastguard Worker meta = u.info() 710*da0073e9SAndroid Build Coastguard Worker if hasattr(meta, "getheaders"): 711*da0073e9SAndroid Build Coastguard Worker content_length = meta.getheaders("Content-Length") 712*da0073e9SAndroid Build Coastguard Worker else: 713*da0073e9SAndroid Build Coastguard Worker content_length = meta.get_all("Content-Length") 714*da0073e9SAndroid Build Coastguard Worker if content_length is not None and len(content_length) > 0: 715*da0073e9SAndroid Build Coastguard Worker file_size = int(content_length[0]) 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker # We deliberately save it in a temp file and move it after 718*da0073e9SAndroid Build Coastguard Worker # download is complete. This prevents a local working checkpoint 719*da0073e9SAndroid Build Coastguard Worker # being overridden by a broken download. 720*da0073e9SAndroid Build Coastguard Worker # We deliberately do not use NamedTemporaryFile to avoid restrictive 721*da0073e9SAndroid Build Coastguard Worker # file permissions being applied to the downloaded file. 722*da0073e9SAndroid Build Coastguard Worker dst = os.path.expanduser(dst) 723*da0073e9SAndroid Build Coastguard Worker for seq in range(tempfile.TMP_MAX): 724*da0073e9SAndroid Build Coastguard Worker tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" 725*da0073e9SAndroid Build Coastguard Worker try: 726*da0073e9SAndroid Build Coastguard Worker f = open(tmp_dst, "w+b") 727*da0073e9SAndroid Build Coastguard Worker except FileExistsError: 728*da0073e9SAndroid Build Coastguard Worker continue 729*da0073e9SAndroid Build Coastguard Worker break 730*da0073e9SAndroid Build Coastguard Worker else: 731*da0073e9SAndroid Build Coastguard Worker raise FileExistsError(errno.EEXIST, "No usable temporary file name found") 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker try: 734*da0073e9SAndroid Build Coastguard Worker if hash_prefix is not None: 735*da0073e9SAndroid Build Coastguard Worker sha256 = hashlib.sha256() 736*da0073e9SAndroid Build Coastguard Worker with tqdm( 737*da0073e9SAndroid Build Coastguard Worker total=file_size, 738*da0073e9SAndroid Build Coastguard Worker disable=not progress, 739*da0073e9SAndroid Build Coastguard Worker unit="B", 740*da0073e9SAndroid Build Coastguard Worker unit_scale=True, 741*da0073e9SAndroid Build Coastguard Worker unit_divisor=1024, 742*da0073e9SAndroid Build Coastguard Worker ) as pbar: 743*da0073e9SAndroid Build Coastguard Worker while True: 744*da0073e9SAndroid Build Coastguard Worker buffer = u.read(READ_DATA_CHUNK) 745*da0073e9SAndroid Build Coastguard Worker if len(buffer) == 0: 746*da0073e9SAndroid Build Coastguard Worker break 747*da0073e9SAndroid Build Coastguard Worker f.write(buffer) # type: ignore[possibly-undefined] 748*da0073e9SAndroid Build Coastguard Worker if hash_prefix is not None: 749*da0073e9SAndroid Build Coastguard Worker sha256.update(buffer) # type: ignore[possibly-undefined] 750*da0073e9SAndroid Build Coastguard Worker pbar.update(len(buffer)) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker f.close() 753*da0073e9SAndroid Build Coastguard Worker if hash_prefix is not None: 754*da0073e9SAndroid Build Coastguard Worker digest = sha256.hexdigest() # type: ignore[possibly-undefined] 755*da0073e9SAndroid Build Coastguard Worker if digest[: len(hash_prefix)] != hash_prefix: 756*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 757*da0073e9SAndroid Build Coastguard Worker f'invalid hash value (expected "{hash_prefix}", got "{digest}")' 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker shutil.move(f.name, dst) 760*da0073e9SAndroid Build Coastguard Worker finally: 761*da0073e9SAndroid Build Coastguard Worker f.close() 762*da0073e9SAndroid Build Coastguard Worker if os.path.exists(f.name): 763*da0073e9SAndroid Build Coastguard Worker os.remove(f.name) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker# Hub used to support automatically extracts from zipfile manually compressed by users. 767*da0073e9SAndroid Build Coastguard Worker# The legacy zip format expects only one file from torch.save() < 1.6 in the zip. 768*da0073e9SAndroid Build Coastguard Worker# We should remove this support since zipfile is now default zipfile format for torch.save(). 769*da0073e9SAndroid Build Coastguard Workerdef _is_legacy_zip_format(filename: str) -> bool: 770*da0073e9SAndroid Build Coastguard Worker if zipfile.is_zipfile(filename): 771*da0073e9SAndroid Build Coastguard Worker infolist = zipfile.ZipFile(filename).infolist() 772*da0073e9SAndroid Build Coastguard Worker return len(infolist) == 1 and not infolist[0].is_dir() 773*da0073e9SAndroid Build Coastguard Worker return False 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker@deprecated( 777*da0073e9SAndroid Build Coastguard Worker "Falling back to the old format < 1.6. This support will be " 778*da0073e9SAndroid Build Coastguard Worker "deprecated in favor of default zipfile format introduced in 1.6. " 779*da0073e9SAndroid Build Coastguard Worker "Please redo torch.save() to save it in the new zipfile format.", 780*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 781*da0073e9SAndroid Build Coastguard Worker) 782*da0073e9SAndroid Build Coastguard Workerdef _legacy_zip_load( 783*da0073e9SAndroid Build Coastguard Worker filename: str, 784*da0073e9SAndroid Build Coastguard Worker model_dir: str, 785*da0073e9SAndroid Build Coastguard Worker map_location: MAP_LOCATION, 786*da0073e9SAndroid Build Coastguard Worker weights_only: bool, 787*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Any]: 788*da0073e9SAndroid Build Coastguard Worker # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. 789*da0073e9SAndroid Build Coastguard Worker # We deliberately don't handle tarfile here since our legacy serialization format was in tar. 790*da0073e9SAndroid Build Coastguard Worker # E.g. resnet18-5c106cde.pth which is widely used. 791*da0073e9SAndroid Build Coastguard Worker with zipfile.ZipFile(filename) as f: 792*da0073e9SAndroid Build Coastguard Worker members = f.infolist() 793*da0073e9SAndroid Build Coastguard Worker if len(members) != 1: 794*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Only one file(not dir) is allowed in the zipfile") 795*da0073e9SAndroid Build Coastguard Worker f.extractall(model_dir) 796*da0073e9SAndroid Build Coastguard Worker extraced_name = members[0].filename 797*da0073e9SAndroid Build Coastguard Worker extracted_file = os.path.join(model_dir, extraced_name) 798*da0073e9SAndroid Build Coastguard Worker return torch.load( 799*da0073e9SAndroid Build Coastguard Worker extracted_file, map_location=map_location, weights_only=weights_only 800*da0073e9SAndroid Build Coastguard Worker ) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Workerdef load_state_dict_from_url( 804*da0073e9SAndroid Build Coastguard Worker url: str, 805*da0073e9SAndroid Build Coastguard Worker model_dir: Optional[str] = None, 806*da0073e9SAndroid Build Coastguard Worker map_location: MAP_LOCATION = None, 807*da0073e9SAndroid Build Coastguard Worker progress: bool = True, 808*da0073e9SAndroid Build Coastguard Worker check_hash: bool = False, 809*da0073e9SAndroid Build Coastguard Worker file_name: Optional[str] = None, 810*da0073e9SAndroid Build Coastguard Worker weights_only: bool = False, 811*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Any]: 812*da0073e9SAndroid Build Coastguard Worker r"""Loads the Torch serialized object at the given URL. 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker If downloaded file is a zip file, it will be automatically 815*da0073e9SAndroid Build Coastguard Worker decompressed. 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker If the object is already present in `model_dir`, it's deserialized and 818*da0073e9SAndroid Build Coastguard Worker returned. 819*da0073e9SAndroid Build Coastguard Worker The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where 820*da0073e9SAndroid Build Coastguard Worker ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker Args: 823*da0073e9SAndroid Build Coastguard Worker url (str): URL of the object to download 824*da0073e9SAndroid Build Coastguard Worker model_dir (str, optional): directory in which to save the object 825*da0073e9SAndroid Build Coastguard Worker map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) 826*da0073e9SAndroid Build Coastguard Worker progress (bool, optional): whether or not to display a progress bar to stderr. 827*da0073e9SAndroid Build Coastguard Worker Default: True 828*da0073e9SAndroid Build Coastguard Worker check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention 829*da0073e9SAndroid Build Coastguard Worker ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more 830*da0073e9SAndroid Build Coastguard Worker digits of the SHA256 hash of the contents of the file. The hash is used to 831*da0073e9SAndroid Build Coastguard Worker ensure unique names and to verify the contents of the file. 832*da0073e9SAndroid Build Coastguard Worker Default: False 833*da0073e9SAndroid Build Coastguard Worker file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. 834*da0073e9SAndroid Build Coastguard Worker weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. 835*da0073e9SAndroid Build Coastguard Worker Recommended for untrusted sources. See :func:`~torch.load` for more details. 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker Example: 838*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 839*da0073e9SAndroid Build Coastguard Worker >>> state_dict = torch.hub.load_state_dict_from_url( 840*da0073e9SAndroid Build Coastguard Worker ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" 841*da0073e9SAndroid Build Coastguard Worker ... ) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker """ 844*da0073e9SAndroid Build Coastguard Worker # Issue warning to move data if old env is set 845*da0073e9SAndroid Build Coastguard Worker if os.getenv("TORCH_MODEL_ZOO"): 846*da0073e9SAndroid Build Coastguard Worker warnings.warn( 847*da0073e9SAndroid Build Coastguard Worker "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" 848*da0073e9SAndroid Build Coastguard Worker ) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker if model_dir is None: 851*da0073e9SAndroid Build Coastguard Worker hub_dir = get_dir() 852*da0073e9SAndroid Build Coastguard Worker model_dir = os.path.join(hub_dir, "checkpoints") 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker os.makedirs(model_dir, exist_ok=True) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker parts = urlparse(url) 857*da0073e9SAndroid Build Coastguard Worker filename = os.path.basename(parts.path) 858*da0073e9SAndroid Build Coastguard Worker if file_name is not None: 859*da0073e9SAndroid Build Coastguard Worker filename = file_name 860*da0073e9SAndroid Build Coastguard Worker cached_file = os.path.join(model_dir, filename) 861*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(cached_file): 862*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') 863*da0073e9SAndroid Build Coastguard Worker hash_prefix = None 864*da0073e9SAndroid Build Coastguard Worker if check_hash: 865*da0073e9SAndroid Build Coastguard Worker r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 866*da0073e9SAndroid Build Coastguard Worker hash_prefix = r.group(1) if r else None 867*da0073e9SAndroid Build Coastguard Worker download_url_to_file(url, cached_file, hash_prefix, progress=progress) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker if _is_legacy_zip_format(cached_file): 870*da0073e9SAndroid Build Coastguard Worker return _legacy_zip_load(cached_file, model_dir, map_location, weights_only) 871*da0073e9SAndroid Build Coastguard Worker return torch.load(cached_file, map_location=map_location, weights_only=weights_only) 872