xref: /aosp_15_r20/external/pytorch/torch/hub.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Workerimport errno
4*da0073e9SAndroid Build Coastguard Workerimport hashlib
5*da0073e9SAndroid Build Coastguard Workerimport json
6*da0073e9SAndroid Build Coastguard Workerimport os
7*da0073e9SAndroid Build Coastguard Workerimport re
8*da0073e9SAndroid Build Coastguard Workerimport shutil
9*da0073e9SAndroid Build Coastguard Workerimport sys
10*da0073e9SAndroid Build Coastguard Workerimport tempfile
11*da0073e9SAndroid Build Coastguard Workerimport uuid
12*da0073e9SAndroid Build Coastguard Workerimport warnings
13*da0073e9SAndroid Build Coastguard Workerimport zipfile
14*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional
16*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated
17*da0073e9SAndroid Build Coastguard Workerfrom urllib.error import HTTPError, URLError
18*da0073e9SAndroid Build Coastguard Workerfrom urllib.parse import urlparse  # noqa: F401
19*da0073e9SAndroid Build Coastguard Workerfrom urllib.request import Request, urlopen
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerimport torch
22*da0073e9SAndroid Build Coastguard Workerfrom torch.serialization import MAP_LOCATION
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerclass _Faketqdm:  # type: ignore[no-redef]
26*da0073e9SAndroid Build Coastguard Worker    def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
27*da0073e9SAndroid Build Coastguard Worker        self.total = total
28*da0073e9SAndroid Build Coastguard Worker        self.disable = disable
29*da0073e9SAndroid Build Coastguard Worker        self.n = 0
30*da0073e9SAndroid Build Coastguard Worker        # Ignore all extra *args and **kwargs lest you want to reinvent tqdm
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def update(self, n):
33*da0073e9SAndroid Build Coastguard Worker        if self.disable:
34*da0073e9SAndroid Build Coastguard Worker            return
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        self.n += n
37*da0073e9SAndroid Build Coastguard Worker        if self.total is None:
38*da0073e9SAndroid Build Coastguard Worker            sys.stderr.write(f"\r{self.n:.1f} bytes")
39*da0073e9SAndroid Build Coastguard Worker        else:
40*da0073e9SAndroid Build Coastguard Worker            sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
41*da0073e9SAndroid Build Coastguard Worker        sys.stderr.flush()
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    # Don't bother implementing; use real tqdm if you want
44*da0073e9SAndroid Build Coastguard Worker    def set_description(self, *args, **kwargs):
45*da0073e9SAndroid Build Coastguard Worker        pass
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def write(self, s):
48*da0073e9SAndroid Build Coastguard Worker        sys.stderr.write(f"{s}\n")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    def close(self):
51*da0073e9SAndroid Build Coastguard Worker        self.disable = True
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
54*da0073e9SAndroid Build Coastguard Worker        return self
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
57*da0073e9SAndroid Build Coastguard Worker        if self.disable:
58*da0073e9SAndroid Build Coastguard Worker            return
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        sys.stderr.write("\n")
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workertry:
64*da0073e9SAndroid Build Coastguard Worker    from tqdm import tqdm  # If tqdm is installed use it, otherwise use the fake wrapper
65*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
66*da0073e9SAndroid Build Coastguard Worker    tqdm = _Faketqdm
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker__all__ = [
69*da0073e9SAndroid Build Coastguard Worker    "download_url_to_file",
70*da0073e9SAndroid Build Coastguard Worker    "get_dir",
71*da0073e9SAndroid Build Coastguard Worker    "help",
72*da0073e9SAndroid Build Coastguard Worker    "list",
73*da0073e9SAndroid Build Coastguard Worker    "load",
74*da0073e9SAndroid Build Coastguard Worker    "load_state_dict_from_url",
75*da0073e9SAndroid Build Coastguard Worker    "set_dir",
76*da0073e9SAndroid Build Coastguard Worker]
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker# matches bfd8deac from resnet18-bfd8deac.pth
79*da0073e9SAndroid Build Coastguard WorkerHASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker_TRUSTED_REPO_OWNERS = (
82*da0073e9SAndroid Build Coastguard Worker    "facebookresearch",
83*da0073e9SAndroid Build Coastguard Worker    "facebookincubator",
84*da0073e9SAndroid Build Coastguard Worker    "pytorch",
85*da0073e9SAndroid Build Coastguard Worker    "fairinternal",
86*da0073e9SAndroid Build Coastguard Worker)
87*da0073e9SAndroid Build Coastguard WorkerENV_GITHUB_TOKEN = "GITHUB_TOKEN"
88*da0073e9SAndroid Build Coastguard WorkerENV_TORCH_HOME = "TORCH_HOME"
89*da0073e9SAndroid Build Coastguard WorkerENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
90*da0073e9SAndroid Build Coastguard WorkerDEFAULT_CACHE_DIR = "~/.cache"
91*da0073e9SAndroid Build Coastguard WorkerVAR_DEPENDENCY = "dependencies"
92*da0073e9SAndroid Build Coastguard WorkerMODULE_HUBCONF = "hubconf.py"
93*da0073e9SAndroid Build Coastguard WorkerREAD_DATA_CHUNK = 128 * 1024
94*da0073e9SAndroid Build Coastguard Worker_hub_dir: Optional[str] = None
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
98*da0073e9SAndroid Build Coastguard Workerdef _add_to_sys_path(path):
99*da0073e9SAndroid Build Coastguard Worker    sys.path.insert(0, path)
100*da0073e9SAndroid Build Coastguard Worker    try:
101*da0073e9SAndroid Build Coastguard Worker        yield
102*da0073e9SAndroid Build Coastguard Worker    finally:
103*da0073e9SAndroid Build Coastguard Worker        sys.path.remove(path)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker# Copied from tools/shared/module_loader to be included in torch package
107*da0073e9SAndroid Build Coastguard Workerdef _import_module(name, path):
108*da0073e9SAndroid Build Coastguard Worker    import importlib.util
109*da0073e9SAndroid Build Coastguard Worker    from importlib.abc import Loader
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    spec = importlib.util.spec_from_file_location(name, path)
112*da0073e9SAndroid Build Coastguard Worker    assert spec is not None
113*da0073e9SAndroid Build Coastguard Worker    module = importlib.util.module_from_spec(spec)
114*da0073e9SAndroid Build Coastguard Worker    assert isinstance(spec.loader, Loader)
115*da0073e9SAndroid Build Coastguard Worker    spec.loader.exec_module(module)
116*da0073e9SAndroid Build Coastguard Worker    return module
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerdef _remove_if_exists(path):
120*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(path):
121*da0073e9SAndroid Build Coastguard Worker        if os.path.isfile(path):
122*da0073e9SAndroid Build Coastguard Worker            os.remove(path)
123*da0073e9SAndroid Build Coastguard Worker        else:
124*da0073e9SAndroid Build Coastguard Worker            shutil.rmtree(path)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Workerdef _git_archive_link(repo_owner, repo_name, ref):
128*da0073e9SAndroid Build Coastguard Worker    # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip
129*da0073e9SAndroid Build Coastguard Worker    return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}"
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Workerdef _load_attr_from_module(module, func_name):
133*da0073e9SAndroid Build Coastguard Worker    # Check if callable is defined in the module
134*da0073e9SAndroid Build Coastguard Worker    if func_name not in dir(module):
135*da0073e9SAndroid Build Coastguard Worker        return None
136*da0073e9SAndroid Build Coastguard Worker    return getattr(module, func_name)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerdef _get_torch_home():
140*da0073e9SAndroid Build Coastguard Worker    torch_home = os.path.expanduser(
141*da0073e9SAndroid Build Coastguard Worker        os.getenv(
142*da0073e9SAndroid Build Coastguard Worker            ENV_TORCH_HOME,
143*da0073e9SAndroid Build Coastguard Worker            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
144*da0073e9SAndroid Build Coastguard Worker        )
145*da0073e9SAndroid Build Coastguard Worker    )
146*da0073e9SAndroid Build Coastguard Worker    return torch_home
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef _parse_repo_info(github):
150*da0073e9SAndroid Build Coastguard Worker    if ":" in github:
151*da0073e9SAndroid Build Coastguard Worker        repo_info, ref = github.split(":")
152*da0073e9SAndroid Build Coastguard Worker    else:
153*da0073e9SAndroid Build Coastguard Worker        repo_info, ref = github, None
154*da0073e9SAndroid Build Coastguard Worker    repo_owner, repo_name = repo_info.split("/")
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    if ref is None:
157*da0073e9SAndroid Build Coastguard Worker        # The ref wasn't specified by the user, so we need to figure out the
158*da0073e9SAndroid Build Coastguard Worker        # default branch: main or master. Our assumption is that if main exists
159*da0073e9SAndroid Build Coastguard Worker        # then it's the default branch, otherwise it's master.
160*da0073e9SAndroid Build Coastguard Worker        try:
161*da0073e9SAndroid Build Coastguard Worker            with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
162*da0073e9SAndroid Build Coastguard Worker                ref = "main"
163*da0073e9SAndroid Build Coastguard Worker        except HTTPError as e:
164*da0073e9SAndroid Build Coastguard Worker            if e.code == 404:
165*da0073e9SAndroid Build Coastguard Worker                ref = "master"
166*da0073e9SAndroid Build Coastguard Worker            else:
167*da0073e9SAndroid Build Coastguard Worker                raise
168*da0073e9SAndroid Build Coastguard Worker        except URLError as e:
169*da0073e9SAndroid Build Coastguard Worker            # No internet connection, need to check for cache as last resort
170*da0073e9SAndroid Build Coastguard Worker            for possible_ref in ("main", "master"):
171*da0073e9SAndroid Build Coastguard Worker                if os.path.exists(
172*da0073e9SAndroid Build Coastguard Worker                    f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
173*da0073e9SAndroid Build Coastguard Worker                ):
174*da0073e9SAndroid Build Coastguard Worker                    ref = possible_ref
175*da0073e9SAndroid Build Coastguard Worker                    break
176*da0073e9SAndroid Build Coastguard Worker            if ref is None:
177*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
178*da0073e9SAndroid Build Coastguard Worker                    "It looks like there is no internet connection and the "
179*da0073e9SAndroid Build Coastguard Worker                    f"repo could not be found in the cache ({get_dir()})"
180*da0073e9SAndroid Build Coastguard Worker                ) from e
181*da0073e9SAndroid Build Coastguard Worker    return repo_owner, repo_name, ref
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Workerdef _read_url(url):
185*da0073e9SAndroid Build Coastguard Worker    with urlopen(url) as r:
186*da0073e9SAndroid Build Coastguard Worker        return r.read().decode(r.headers.get_content_charset("utf-8"))
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerdef _validate_not_a_forked_repo(repo_owner, repo_name, ref):
190*da0073e9SAndroid Build Coastguard Worker    # Use urlopen to avoid depending on local git.
191*da0073e9SAndroid Build Coastguard Worker    headers = {"Accept": "application/vnd.github.v3+json"}
192*da0073e9SAndroid Build Coastguard Worker    token = os.environ.get(ENV_GITHUB_TOKEN)
193*da0073e9SAndroid Build Coastguard Worker    if token is not None:
194*da0073e9SAndroid Build Coastguard Worker        headers["Authorization"] = f"token {token}"
195*da0073e9SAndroid Build Coastguard Worker    for url_prefix in (
196*da0073e9SAndroid Build Coastguard Worker        f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
197*da0073e9SAndroid Build Coastguard Worker        f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
198*da0073e9SAndroid Build Coastguard Worker    ):
199*da0073e9SAndroid Build Coastguard Worker        page = 0
200*da0073e9SAndroid Build Coastguard Worker        while True:
201*da0073e9SAndroid Build Coastguard Worker            page += 1
202*da0073e9SAndroid Build Coastguard Worker            url = f"{url_prefix}?per_page=100&page={page}"
203*da0073e9SAndroid Build Coastguard Worker            response = json.loads(_read_url(Request(url, headers=headers)))
204*da0073e9SAndroid Build Coastguard Worker            # Empty response means no more data to process
205*da0073e9SAndroid Build Coastguard Worker            if not response:
206*da0073e9SAndroid Build Coastguard Worker                break
207*da0073e9SAndroid Build Coastguard Worker            for br in response:
208*da0073e9SAndroid Build Coastguard Worker                if br["name"] == ref or br["commit"]["sha"].startswith(ref):
209*da0073e9SAndroid Build Coastguard Worker                    return
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    raise ValueError(
212*da0073e9SAndroid Build Coastguard Worker        f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
213*da0073e9SAndroid Build Coastguard Worker        "If it's a commit from a forked repo, please call hub.load() with forked repo directly."
214*da0073e9SAndroid Build Coastguard Worker    )
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Workerdef _get_cache_or_reload(
218*da0073e9SAndroid Build Coastguard Worker    github,
219*da0073e9SAndroid Build Coastguard Worker    force_reload,
220*da0073e9SAndroid Build Coastguard Worker    trust_repo,
221*da0073e9SAndroid Build Coastguard Worker    calling_fn,
222*da0073e9SAndroid Build Coastguard Worker    verbose=True,
223*da0073e9SAndroid Build Coastguard Worker    skip_validation=False,
224*da0073e9SAndroid Build Coastguard Worker):
225*da0073e9SAndroid Build Coastguard Worker    # Setup hub_dir to save downloaded files
226*da0073e9SAndroid Build Coastguard Worker    hub_dir = get_dir()
227*da0073e9SAndroid Build Coastguard Worker    os.makedirs(hub_dir, exist_ok=True)
228*da0073e9SAndroid Build Coastguard Worker    # Parse github repo information
229*da0073e9SAndroid Build Coastguard Worker    repo_owner, repo_name, ref = _parse_repo_info(github)
230*da0073e9SAndroid Build Coastguard Worker    # Github allows branch name with slash '/',
231*da0073e9SAndroid Build Coastguard Worker    # this causes confusion with path on both Linux and Windows.
232*da0073e9SAndroid Build Coastguard Worker    # Backslash is not allowed in Github branch name so no need to
233*da0073e9SAndroid Build Coastguard Worker    # to worry about it.
234*da0073e9SAndroid Build Coastguard Worker    normalized_br = ref.replace("/", "_")
235*da0073e9SAndroid Build Coastguard Worker    # Github renames folder repo-v1.x.x to repo-1.x.x
236*da0073e9SAndroid Build Coastguard Worker    # We don't know the repo name before downloading the zip file
237*da0073e9SAndroid Build Coastguard Worker    # and inspect name from it.
238*da0073e9SAndroid Build Coastguard Worker    # To check if cached repo exists, we need to normalize folder names.
239*da0073e9SAndroid Build Coastguard Worker    owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
240*da0073e9SAndroid Build Coastguard Worker    repo_dir = os.path.join(hub_dir, owner_name_branch)
241*da0073e9SAndroid Build Coastguard Worker    # Check that the repo is in the trusted list
242*da0073e9SAndroid Build Coastguard Worker    _check_repo_is_trusted(
243*da0073e9SAndroid Build Coastguard Worker        repo_owner,
244*da0073e9SAndroid Build Coastguard Worker        repo_name,
245*da0073e9SAndroid Build Coastguard Worker        owner_name_branch,
246*da0073e9SAndroid Build Coastguard Worker        trust_repo=trust_repo,
247*da0073e9SAndroid Build Coastguard Worker        calling_fn=calling_fn,
248*da0073e9SAndroid Build Coastguard Worker    )
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    use_cache = (not force_reload) and os.path.exists(repo_dir)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    if use_cache:
253*da0073e9SAndroid Build Coastguard Worker        if verbose:
254*da0073e9SAndroid Build Coastguard Worker            sys.stderr.write(f"Using cache found in {repo_dir}\n")
255*da0073e9SAndroid Build Coastguard Worker    else:
256*da0073e9SAndroid Build Coastguard Worker        # Validate the tag/branch is from the original repo instead of a forked repo
257*da0073e9SAndroid Build Coastguard Worker        if not skip_validation:
258*da0073e9SAndroid Build Coastguard Worker            _validate_not_a_forked_repo(repo_owner, repo_name, ref)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        cached_file = os.path.join(hub_dir, normalized_br + ".zip")
261*da0073e9SAndroid Build Coastguard Worker        _remove_if_exists(cached_file)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        try:
264*da0073e9SAndroid Build Coastguard Worker            url = _git_archive_link(repo_owner, repo_name, ref)
265*da0073e9SAndroid Build Coastguard Worker            sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
266*da0073e9SAndroid Build Coastguard Worker            download_url_to_file(url, cached_file, progress=False)
267*da0073e9SAndroid Build Coastguard Worker        except HTTPError as err:
268*da0073e9SAndroid Build Coastguard Worker            if err.code == 300:
269*da0073e9SAndroid Build Coastguard Worker                # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch
270*da0073e9SAndroid Build Coastguard Worker                # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags
271*da0073e9SAndroid Build Coastguard Worker                # See https://git-scm.com/book/en/v2/Git-Internals-Git-References
272*da0073e9SAndroid Build Coastguard Worker                # Here, we do the same as git: we throw a warning, and assume the user wanted the branch
273*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
274*da0073e9SAndroid Build Coastguard Worker                    f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? "
275*da0073e9SAndroid Build Coastguard Worker                    "Torchhub will now assume that it's a branch. "
276*da0073e9SAndroid Build Coastguard Worker                    "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or "
277*da0073e9SAndroid Build Coastguard Worker                    "refs/tags/tag_name as the ref. That might require using skip_validation=True."
278*da0073e9SAndroid Build Coastguard Worker                )
279*da0073e9SAndroid Build Coastguard Worker                disambiguated_branch_ref = f"refs/heads/{ref}"
280*da0073e9SAndroid Build Coastguard Worker                url = _git_archive_link(
281*da0073e9SAndroid Build Coastguard Worker                    repo_owner, repo_name, ref=disambiguated_branch_ref
282*da0073e9SAndroid Build Coastguard Worker                )
283*da0073e9SAndroid Build Coastguard Worker                download_url_to_file(url, cached_file, progress=False)
284*da0073e9SAndroid Build Coastguard Worker            else:
285*da0073e9SAndroid Build Coastguard Worker                raise
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        with zipfile.ZipFile(cached_file) as cached_zipfile:
288*da0073e9SAndroid Build Coastguard Worker            extraced_repo_name = cached_zipfile.infolist()[0].filename
289*da0073e9SAndroid Build Coastguard Worker            extracted_repo = os.path.join(hub_dir, extraced_repo_name)
290*da0073e9SAndroid Build Coastguard Worker            _remove_if_exists(extracted_repo)
291*da0073e9SAndroid Build Coastguard Worker            # Unzip the code and rename the base folder
292*da0073e9SAndroid Build Coastguard Worker            cached_zipfile.extractall(hub_dir)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        _remove_if_exists(cached_file)
295*da0073e9SAndroid Build Coastguard Worker        _remove_if_exists(repo_dir)
296*da0073e9SAndroid Build Coastguard Worker        shutil.move(extracted_repo, repo_dir)  # rename the repo
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    return repo_dir
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Workerdef _check_repo_is_trusted(
302*da0073e9SAndroid Build Coastguard Worker    repo_owner,
303*da0073e9SAndroid Build Coastguard Worker    repo_name,
304*da0073e9SAndroid Build Coastguard Worker    owner_name_branch,
305*da0073e9SAndroid Build Coastguard Worker    trust_repo,
306*da0073e9SAndroid Build Coastguard Worker    calling_fn="load",
307*da0073e9SAndroid Build Coastguard Worker):
308*da0073e9SAndroid Build Coastguard Worker    hub_dir = get_dir()
309*da0073e9SAndroid Build Coastguard Worker    filepath = os.path.join(hub_dir, "trusted_list")
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(filepath):
312*da0073e9SAndroid Build Coastguard Worker        Path(filepath).touch()
313*da0073e9SAndroid Build Coastguard Worker    with open(filepath) as file:
314*da0073e9SAndroid Build Coastguard Worker        trusted_repos = tuple(line.strip() for line in file)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    # To minimize friction of introducing the new trust_repo mechanism, we consider that
317*da0073e9SAndroid Build Coastguard Worker    # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
318*da0073e9SAndroid Build Coastguard Worker    trusted_repos_legacy = next(os.walk(hub_dir))[1]
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    owner_name = "_".join([repo_owner, repo_name])
321*da0073e9SAndroid Build Coastguard Worker    is_trusted = (
322*da0073e9SAndroid Build Coastguard Worker        owner_name in trusted_repos
323*da0073e9SAndroid Build Coastguard Worker        or owner_name_branch in trusted_repos_legacy
324*da0073e9SAndroid Build Coastguard Worker        or repo_owner in _TRUSTED_REPO_OWNERS
325*da0073e9SAndroid Build Coastguard Worker    )
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    # TODO: Remove `None` option in 2.0 and change the default to "check"
328*da0073e9SAndroid Build Coastguard Worker    if trust_repo is None:
329*da0073e9SAndroid Build Coastguard Worker        if not is_trusted:
330*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
331*da0073e9SAndroid Build Coastguard Worker                "You are about to download and run code from an untrusted repository. In a future release, this won't "
332*da0073e9SAndroid Build Coastguard Worker                "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
333*da0073e9SAndroid Build Coastguard Worker                "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
334*da0073e9SAndroid Build Coastguard Worker                f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
335*da0073e9SAndroid Build Coastguard Worker                f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
336*da0073e9SAndroid Build Coastguard Worker                f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
337*da0073e9SAndroid Build Coastguard Worker            )
338*da0073e9SAndroid Build Coastguard Worker        return
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
341*da0073e9SAndroid Build Coastguard Worker        response = input(
342*da0073e9SAndroid Build Coastguard Worker            f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
343*da0073e9SAndroid Build Coastguard Worker            "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
344*da0073e9SAndroid Build Coastguard Worker        )
345*da0073e9SAndroid Build Coastguard Worker        if response.lower() in ("y", "yes"):
346*da0073e9SAndroid Build Coastguard Worker            if is_trusted:
347*da0073e9SAndroid Build Coastguard Worker                print("The repository is already trusted.")
348*da0073e9SAndroid Build Coastguard Worker        elif response.lower() in ("n", "no", ""):
349*da0073e9SAndroid Build Coastguard Worker            raise Exception("Untrusted repository.")  # noqa: TRY002
350*da0073e9SAndroid Build Coastguard Worker        else:
351*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Unrecognized response {response}.")
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    # At this point we're sure that the user trusts the repo (or wants to trust it)
354*da0073e9SAndroid Build Coastguard Worker    if not is_trusted:
355*da0073e9SAndroid Build Coastguard Worker        with open(filepath, "a") as file:
356*da0073e9SAndroid Build Coastguard Worker            file.write(owner_name + "\n")
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerdef _check_module_exists(name):
360*da0073e9SAndroid Build Coastguard Worker    import importlib.util
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    return importlib.util.find_spec(name) is not None
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Workerdef _check_dependencies(m):
366*da0073e9SAndroid Build Coastguard Worker    dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    if dependencies is not None:
369*da0073e9SAndroid Build Coastguard Worker        missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
370*da0073e9SAndroid Build Coastguard Worker        if len(missing_deps):
371*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}")
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Workerdef _load_entry_from_hubconf(m, model):
375*da0073e9SAndroid Build Coastguard Worker    if not isinstance(model, str):
376*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Invalid input: model should be a string of function name")
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    # Note that if a missing dependency is imported at top level of hubconf, it will
379*da0073e9SAndroid Build Coastguard Worker    # throw before this function. It's a chicken and egg situation where we have to
380*da0073e9SAndroid Build Coastguard Worker    # load hubconf to know what're the dependencies, but to import hubconf it requires
381*da0073e9SAndroid Build Coastguard Worker    # a missing package. This is fine, Python will throw proper error message for users.
382*da0073e9SAndroid Build Coastguard Worker    _check_dependencies(m)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    func = _load_attr_from_module(m, model)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    if func is None or not callable(func):
387*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Cannot find callable {model} in hubconf")
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    return func
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Workerdef get_dir():
393*da0073e9SAndroid Build Coastguard Worker    r"""
394*da0073e9SAndroid Build Coastguard Worker    Get the Torch Hub cache directory used for storing downloaded models & weights.
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
397*da0073e9SAndroid Build Coastguard Worker    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
398*da0073e9SAndroid Build Coastguard Worker    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
399*da0073e9SAndroid Build Coastguard Worker    filesystem layout, with a default value ``~/.cache`` if the environment
400*da0073e9SAndroid Build Coastguard Worker    variable is not set.
401*da0073e9SAndroid Build Coastguard Worker    """
402*da0073e9SAndroid Build Coastguard Worker    # Issue warning to move data if old env is set
403*da0073e9SAndroid Build Coastguard Worker    if os.getenv("TORCH_HUB"):
404*da0073e9SAndroid Build Coastguard Worker        warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    if _hub_dir is not None:
407*da0073e9SAndroid Build Coastguard Worker        return _hub_dir
408*da0073e9SAndroid Build Coastguard Worker    return os.path.join(_get_torch_home(), "hub")
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Workerdef set_dir(d):
412*da0073e9SAndroid Build Coastguard Worker    r"""
413*da0073e9SAndroid Build Coastguard Worker    Optionally set the Torch Hub directory used to save downloaded models & weights.
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    Args:
416*da0073e9SAndroid Build Coastguard Worker        d (str): path to a local folder to save downloaded models & weights.
417*da0073e9SAndroid Build Coastguard Worker    """
418*da0073e9SAndroid Build Coastguard Worker    global _hub_dir
419*da0073e9SAndroid Build Coastguard Worker    _hub_dir = os.path.expanduser(d)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Workerdef list(
423*da0073e9SAndroid Build Coastguard Worker    github,
424*da0073e9SAndroid Build Coastguard Worker    force_reload=False,
425*da0073e9SAndroid Build Coastguard Worker    skip_validation=False,
426*da0073e9SAndroid Build Coastguard Worker    trust_repo=None,
427*da0073e9SAndroid Build Coastguard Worker    verbose=True,
428*da0073e9SAndroid Build Coastguard Worker):
429*da0073e9SAndroid Build Coastguard Worker    r"""
430*da0073e9SAndroid Build Coastguard Worker    List all callable entrypoints available in the repo specified by ``github``.
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    Args:
433*da0073e9SAndroid Build Coastguard Worker        github (str): a string with format "repo_owner/repo_name[:ref]" with an optional
434*da0073e9SAndroid Build Coastguard Worker            ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if
435*da0073e9SAndroid Build Coastguard Worker            it exists, and otherwise ``master``.
436*da0073e9SAndroid Build Coastguard Worker            Example: 'pytorch/vision:0.10'
437*da0073e9SAndroid Build Coastguard Worker        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
438*da0073e9SAndroid Build Coastguard Worker            Default is ``False``.
439*da0073e9SAndroid Build Coastguard Worker        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
440*da0073e9SAndroid Build Coastguard Worker            specified by the ``github`` argument properly belongs to the repo owner. This will make
441*da0073e9SAndroid Build Coastguard Worker            requests to the GitHub API; you can specify a non-default GitHub token by setting the
442*da0073e9SAndroid Build Coastguard Worker            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
443*da0073e9SAndroid Build Coastguard Worker        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
444*da0073e9SAndroid Build Coastguard Worker            This parameter was introduced in v1.12 and helps ensuring that users
445*da0073e9SAndroid Build Coastguard Worker            only run code from repos that they trust.
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker            - If ``False``, a prompt will ask the user whether the repo should
448*da0073e9SAndroid Build Coastguard Worker              be trusted.
449*da0073e9SAndroid Build Coastguard Worker            - If ``True``, the repo will be added to the trusted list and loaded
450*da0073e9SAndroid Build Coastguard Worker              without requiring explicit confirmation.
451*da0073e9SAndroid Build Coastguard Worker            - If ``"check"``, the repo will be checked against the list of
452*da0073e9SAndroid Build Coastguard Worker              trusted repos in the cache. If it is not present in that list, the
453*da0073e9SAndroid Build Coastguard Worker              behaviour will fall back onto the ``trust_repo=False`` option.
454*da0073e9SAndroid Build Coastguard Worker            - If ``None``: this will raise a warning, inviting the user to set
455*da0073e9SAndroid Build Coastguard Worker              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
456*da0073e9SAndroid Build Coastguard Worker              is only present for backward compatibility and will be removed in
457*da0073e9SAndroid Build Coastguard Worker              v2.0.
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
460*da0073e9SAndroid Build Coastguard Worker        verbose (bool, optional): If ``False``, mute messages about hitting
461*da0073e9SAndroid Build Coastguard Worker            local caches. Note that the message about first download cannot be
462*da0073e9SAndroid Build Coastguard Worker            muted. Default is ``True``.
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    Returns:
465*da0073e9SAndroid Build Coastguard Worker        list: The available callables entrypoint
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    Example:
468*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
469*da0073e9SAndroid Build Coastguard Worker        >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
470*da0073e9SAndroid Build Coastguard Worker    """
471*da0073e9SAndroid Build Coastguard Worker    repo_dir = _get_cache_or_reload(
472*da0073e9SAndroid Build Coastguard Worker        github,
473*da0073e9SAndroid Build Coastguard Worker        force_reload,
474*da0073e9SAndroid Build Coastguard Worker        trust_repo,
475*da0073e9SAndroid Build Coastguard Worker        "list",
476*da0073e9SAndroid Build Coastguard Worker        verbose=verbose,
477*da0073e9SAndroid Build Coastguard Worker        skip_validation=skip_validation,
478*da0073e9SAndroid Build Coastguard Worker    )
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    with _add_to_sys_path(repo_dir):
481*da0073e9SAndroid Build Coastguard Worker        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
482*da0073e9SAndroid Build Coastguard Worker        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    # We take functions starts with '_' as internal helper functions
485*da0073e9SAndroid Build Coastguard Worker    entrypoints = [
486*da0073e9SAndroid Build Coastguard Worker        f
487*da0073e9SAndroid Build Coastguard Worker        for f in dir(hub_module)
488*da0073e9SAndroid Build Coastguard Worker        if callable(getattr(hub_module, f)) and not f.startswith("_")
489*da0073e9SAndroid Build Coastguard Worker    ]
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    return entrypoints
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Workerdef help(github, model, force_reload=False, skip_validation=False, trust_repo=None):
495*da0073e9SAndroid Build Coastguard Worker    r"""
496*da0073e9SAndroid Build Coastguard Worker    Show the docstring of entrypoint ``model``.
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    Args:
499*da0073e9SAndroid Build Coastguard Worker        github (str): a string with format <repo_owner/repo_name[:ref]> with an optional
500*da0073e9SAndroid Build Coastguard Worker            ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed
501*da0073e9SAndroid Build Coastguard Worker            to be ``main`` if it exists, and otherwise ``master``.
502*da0073e9SAndroid Build Coastguard Worker            Example: 'pytorch/vision:0.10'
503*da0073e9SAndroid Build Coastguard Worker        model (str): a string of entrypoint name defined in repo's ``hubconf.py``
504*da0073e9SAndroid Build Coastguard Worker        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
505*da0073e9SAndroid Build Coastguard Worker            Default is ``False``.
506*da0073e9SAndroid Build Coastguard Worker        skip_validation (bool, optional): if ``False``, torchhub will check that the ref
507*da0073e9SAndroid Build Coastguard Worker            specified by the ``github`` argument properly belongs to the repo owner. This will make
508*da0073e9SAndroid Build Coastguard Worker            requests to the GitHub API; you can specify a non-default GitHub token by setting the
509*da0073e9SAndroid Build Coastguard Worker            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
510*da0073e9SAndroid Build Coastguard Worker        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
511*da0073e9SAndroid Build Coastguard Worker            This parameter was introduced in v1.12 and helps ensuring that users
512*da0073e9SAndroid Build Coastguard Worker            only run code from repos that they trust.
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker            - If ``False``, a prompt will ask the user whether the repo should
515*da0073e9SAndroid Build Coastguard Worker              be trusted.
516*da0073e9SAndroid Build Coastguard Worker            - If ``True``, the repo will be added to the trusted list and loaded
517*da0073e9SAndroid Build Coastguard Worker              without requiring explicit confirmation.
518*da0073e9SAndroid Build Coastguard Worker            - If ``"check"``, the repo will be checked against the list of
519*da0073e9SAndroid Build Coastguard Worker              trusted repos in the cache. If it is not present in that list, the
520*da0073e9SAndroid Build Coastguard Worker              behaviour will fall back onto the ``trust_repo=False`` option.
521*da0073e9SAndroid Build Coastguard Worker            - If ``None``: this will raise a warning, inviting the user to set
522*da0073e9SAndroid Build Coastguard Worker              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
523*da0073e9SAndroid Build Coastguard Worker              is only present for backward compatibility and will be removed in
524*da0073e9SAndroid Build Coastguard Worker              v2.0.
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
527*da0073e9SAndroid Build Coastguard Worker    Example:
528*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
529*da0073e9SAndroid Build Coastguard Worker        >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
530*da0073e9SAndroid Build Coastguard Worker    """
531*da0073e9SAndroid Build Coastguard Worker    repo_dir = _get_cache_or_reload(
532*da0073e9SAndroid Build Coastguard Worker        github,
533*da0073e9SAndroid Build Coastguard Worker        force_reload,
534*da0073e9SAndroid Build Coastguard Worker        trust_repo,
535*da0073e9SAndroid Build Coastguard Worker        "help",
536*da0073e9SAndroid Build Coastguard Worker        verbose=True,
537*da0073e9SAndroid Build Coastguard Worker        skip_validation=skip_validation,
538*da0073e9SAndroid Build Coastguard Worker    )
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    with _add_to_sys_path(repo_dir):
541*da0073e9SAndroid Build Coastguard Worker        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
542*da0073e9SAndroid Build Coastguard Worker        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    entry = _load_entry_from_hubconf(hub_module, model)
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    return entry.__doc__
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Workerdef load(
550*da0073e9SAndroid Build Coastguard Worker    repo_or_dir,
551*da0073e9SAndroid Build Coastguard Worker    model,
552*da0073e9SAndroid Build Coastguard Worker    *args,
553*da0073e9SAndroid Build Coastguard Worker    source="github",
554*da0073e9SAndroid Build Coastguard Worker    trust_repo=None,
555*da0073e9SAndroid Build Coastguard Worker    force_reload=False,
556*da0073e9SAndroid Build Coastguard Worker    verbose=True,
557*da0073e9SAndroid Build Coastguard Worker    skip_validation=False,
558*da0073e9SAndroid Build Coastguard Worker    **kwargs,
559*da0073e9SAndroid Build Coastguard Worker):
560*da0073e9SAndroid Build Coastguard Worker    r"""
561*da0073e9SAndroid Build Coastguard Worker    Load a model from a github repo or a local directory.
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    Note: Loading a model is the typical use case, but this can also be used to
564*da0073e9SAndroid Build Coastguard Worker    for loading other objects such as tokenizers, loss functions, etc.
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    If ``source`` is 'github', ``repo_or_dir`` is expected to be
567*da0073e9SAndroid Build Coastguard Worker    of the form ``repo_owner/repo_name[:ref]`` with an optional
568*da0073e9SAndroid Build Coastguard Worker    ref (a tag or a branch).
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker    If ``source`` is 'local', ``repo_or_dir`` is expected to be a
571*da0073e9SAndroid Build Coastguard Worker    path to a local directory.
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    Args:
574*da0073e9SAndroid Build Coastguard Worker        repo_or_dir (str): If ``source`` is 'github',
575*da0073e9SAndroid Build Coastguard Worker            this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with
576*da0073e9SAndroid Build Coastguard Worker            an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified,
577*da0073e9SAndroid Build Coastguard Worker            the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
578*da0073e9SAndroid Build Coastguard Worker            If ``source`` is 'local'  then it should be a path to a local directory.
579*da0073e9SAndroid Build Coastguard Worker        model (str): the name of a callable (entrypoint) defined in the
580*da0073e9SAndroid Build Coastguard Worker            repo/dir's ``hubconf.py``.
581*da0073e9SAndroid Build Coastguard Worker        *args (optional): the corresponding args for callable ``model``.
582*da0073e9SAndroid Build Coastguard Worker        source (str, optional): 'github' or 'local'. Specifies how
583*da0073e9SAndroid Build Coastguard Worker            ``repo_or_dir`` is to be interpreted. Default is 'github'.
584*da0073e9SAndroid Build Coastguard Worker        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
585*da0073e9SAndroid Build Coastguard Worker            This parameter was introduced in v1.12 and helps ensuring that users
586*da0073e9SAndroid Build Coastguard Worker            only run code from repos that they trust.
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker            - If ``False``, a prompt will ask the user whether the repo should
589*da0073e9SAndroid Build Coastguard Worker              be trusted.
590*da0073e9SAndroid Build Coastguard Worker            - If ``True``, the repo will be added to the trusted list and loaded
591*da0073e9SAndroid Build Coastguard Worker              without requiring explicit confirmation.
592*da0073e9SAndroid Build Coastguard Worker            - If ``"check"``, the repo will be checked against the list of
593*da0073e9SAndroid Build Coastguard Worker              trusted repos in the cache. If it is not present in that list, the
594*da0073e9SAndroid Build Coastguard Worker              behaviour will fall back onto the ``trust_repo=False`` option.
595*da0073e9SAndroid Build Coastguard Worker            - If ``None``: this will raise a warning, inviting the user to set
596*da0073e9SAndroid Build Coastguard Worker              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
597*da0073e9SAndroid Build Coastguard Worker              is only present for backward compatibility and will be removed in
598*da0073e9SAndroid Build Coastguard Worker              v2.0.
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
601*da0073e9SAndroid Build Coastguard Worker        force_reload (bool, optional): whether to force a fresh download of
602*da0073e9SAndroid Build Coastguard Worker            the github repo unconditionally. Does not have any effect if
603*da0073e9SAndroid Build Coastguard Worker            ``source = 'local'``. Default is ``False``.
604*da0073e9SAndroid Build Coastguard Worker        verbose (bool, optional): If ``False``, mute messages about hitting
605*da0073e9SAndroid Build Coastguard Worker            local caches. Note that the message about first download cannot be
606*da0073e9SAndroid Build Coastguard Worker            muted. Does not have any effect if ``source = 'local'``.
607*da0073e9SAndroid Build Coastguard Worker            Default is ``True``.
608*da0073e9SAndroid Build Coastguard Worker        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
609*da0073e9SAndroid Build Coastguard Worker            specified by the ``github`` argument properly belongs to the repo owner. This will make
610*da0073e9SAndroid Build Coastguard Worker            requests to the GitHub API; you can specify a non-default GitHub token by setting the
611*da0073e9SAndroid Build Coastguard Worker            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
612*da0073e9SAndroid Build Coastguard Worker        **kwargs (optional): the corresponding kwargs for callable ``model``.
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker    Returns:
615*da0073e9SAndroid Build Coastguard Worker        The output of the ``model`` callable when called with the given
616*da0073e9SAndroid Build Coastguard Worker        ``*args`` and ``**kwargs``.
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    Example:
619*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
620*da0073e9SAndroid Build Coastguard Worker        >>> # from a github repo
621*da0073e9SAndroid Build Coastguard Worker        >>> repo = "pytorch/vision"
622*da0073e9SAndroid Build Coastguard Worker        >>> model = torch.hub.load(
623*da0073e9SAndroid Build Coastguard Worker        ...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
624*da0073e9SAndroid Build Coastguard Worker        ... )
625*da0073e9SAndroid Build Coastguard Worker        >>> # from a local directory
626*da0073e9SAndroid Build Coastguard Worker        >>> path = "/some/local/path/pytorch/vision"
627*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
628*da0073e9SAndroid Build Coastguard Worker        >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
629*da0073e9SAndroid Build Coastguard Worker    """
630*da0073e9SAndroid Build Coastguard Worker    source = source.lower()
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker    if source not in ("github", "local"):
633*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
634*da0073e9SAndroid Build Coastguard Worker            f'Unknown source: "{source}". Allowed values: "github" | "local".'
635*da0073e9SAndroid Build Coastguard Worker        )
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    if source == "github":
638*da0073e9SAndroid Build Coastguard Worker        repo_or_dir = _get_cache_or_reload(
639*da0073e9SAndroid Build Coastguard Worker            repo_or_dir,
640*da0073e9SAndroid Build Coastguard Worker            force_reload,
641*da0073e9SAndroid Build Coastguard Worker            trust_repo,
642*da0073e9SAndroid Build Coastguard Worker            "load",
643*da0073e9SAndroid Build Coastguard Worker            verbose=verbose,
644*da0073e9SAndroid Build Coastguard Worker            skip_validation=skip_validation,
645*da0073e9SAndroid Build Coastguard Worker        )
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    model = _load_local(repo_or_dir, model, *args, **kwargs)
648*da0073e9SAndroid Build Coastguard Worker    return model
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Workerdef _load_local(hubconf_dir, model, *args, **kwargs):
652*da0073e9SAndroid Build Coastguard Worker    r"""
653*da0073e9SAndroid Build Coastguard Worker    Load a model from a local directory with a ``hubconf.py``.
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    Args:
656*da0073e9SAndroid Build Coastguard Worker        hubconf_dir (str): path to a local directory that contains a
657*da0073e9SAndroid Build Coastguard Worker            ``hubconf.py``.
658*da0073e9SAndroid Build Coastguard Worker        model (str): name of an entrypoint defined in the directory's
659*da0073e9SAndroid Build Coastguard Worker            ``hubconf.py``.
660*da0073e9SAndroid Build Coastguard Worker        *args (optional): the corresponding args for callable ``model``.
661*da0073e9SAndroid Build Coastguard Worker        **kwargs (optional): the corresponding kwargs for callable ``model``.
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    Returns:
664*da0073e9SAndroid Build Coastguard Worker        a single model with corresponding pretrained weights.
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    Example:
667*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("stub local path")
668*da0073e9SAndroid Build Coastguard Worker        >>> path = "/some/local/path/pytorch/vision"
669*da0073e9SAndroid Build Coastguard Worker        >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1")
670*da0073e9SAndroid Build Coastguard Worker    """
671*da0073e9SAndroid Build Coastguard Worker    with _add_to_sys_path(hubconf_dir):
672*da0073e9SAndroid Build Coastguard Worker        hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
673*da0073e9SAndroid Build Coastguard Worker        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        entry = _load_entry_from_hubconf(hub_module, model)
676*da0073e9SAndroid Build Coastguard Worker        model = entry(*args, **kwargs)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    return model
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Workerdef download_url_to_file(
682*da0073e9SAndroid Build Coastguard Worker    url: str,
683*da0073e9SAndroid Build Coastguard Worker    dst: str,
684*da0073e9SAndroid Build Coastguard Worker    hash_prefix: Optional[str] = None,
685*da0073e9SAndroid Build Coastguard Worker    progress: bool = True,
686*da0073e9SAndroid Build Coastguard Worker) -> None:
687*da0073e9SAndroid Build Coastguard Worker    r"""Download object at the given URL to a local path.
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    Args:
690*da0073e9SAndroid Build Coastguard Worker        url (str): URL of the object to download
691*da0073e9SAndroid Build Coastguard Worker        dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
692*da0073e9SAndroid Build Coastguard Worker        hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
693*da0073e9SAndroid Build Coastguard Worker            Default: None
694*da0073e9SAndroid Build Coastguard Worker        progress (bool, optional): whether or not to display a progress bar to stderr
695*da0073e9SAndroid Build Coastguard Worker            Default: True
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker    Example:
698*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
699*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(POSIX)
700*da0073e9SAndroid Build Coastguard Worker        >>> torch.hub.download_url_to_file(
701*da0073e9SAndroid Build Coastguard Worker        ...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
702*da0073e9SAndroid Build Coastguard Worker        ...     "/tmp/temporary_file",
703*da0073e9SAndroid Build Coastguard Worker        ... )
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker    """
706*da0073e9SAndroid Build Coastguard Worker    file_size = None
707*da0073e9SAndroid Build Coastguard Worker    req = Request(url, headers={"User-Agent": "torch.hub"})
708*da0073e9SAndroid Build Coastguard Worker    u = urlopen(req)
709*da0073e9SAndroid Build Coastguard Worker    meta = u.info()
710*da0073e9SAndroid Build Coastguard Worker    if hasattr(meta, "getheaders"):
711*da0073e9SAndroid Build Coastguard Worker        content_length = meta.getheaders("Content-Length")
712*da0073e9SAndroid Build Coastguard Worker    else:
713*da0073e9SAndroid Build Coastguard Worker        content_length = meta.get_all("Content-Length")
714*da0073e9SAndroid Build Coastguard Worker    if content_length is not None and len(content_length) > 0:
715*da0073e9SAndroid Build Coastguard Worker        file_size = int(content_length[0])
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker    # We deliberately save it in a temp file and move it after
718*da0073e9SAndroid Build Coastguard Worker    # download is complete. This prevents a local working checkpoint
719*da0073e9SAndroid Build Coastguard Worker    # being overridden by a broken download.
720*da0073e9SAndroid Build Coastguard Worker    # We deliberately do not use NamedTemporaryFile to avoid restrictive
721*da0073e9SAndroid Build Coastguard Worker    # file permissions being applied to the downloaded file.
722*da0073e9SAndroid Build Coastguard Worker    dst = os.path.expanduser(dst)
723*da0073e9SAndroid Build Coastguard Worker    for seq in range(tempfile.TMP_MAX):
724*da0073e9SAndroid Build Coastguard Worker        tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
725*da0073e9SAndroid Build Coastguard Worker        try:
726*da0073e9SAndroid Build Coastguard Worker            f = open(tmp_dst, "w+b")
727*da0073e9SAndroid Build Coastguard Worker        except FileExistsError:
728*da0073e9SAndroid Build Coastguard Worker            continue
729*da0073e9SAndroid Build Coastguard Worker        break
730*da0073e9SAndroid Build Coastguard Worker    else:
731*da0073e9SAndroid Build Coastguard Worker        raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker    try:
734*da0073e9SAndroid Build Coastguard Worker        if hash_prefix is not None:
735*da0073e9SAndroid Build Coastguard Worker            sha256 = hashlib.sha256()
736*da0073e9SAndroid Build Coastguard Worker        with tqdm(
737*da0073e9SAndroid Build Coastguard Worker            total=file_size,
738*da0073e9SAndroid Build Coastguard Worker            disable=not progress,
739*da0073e9SAndroid Build Coastguard Worker            unit="B",
740*da0073e9SAndroid Build Coastguard Worker            unit_scale=True,
741*da0073e9SAndroid Build Coastguard Worker            unit_divisor=1024,
742*da0073e9SAndroid Build Coastguard Worker        ) as pbar:
743*da0073e9SAndroid Build Coastguard Worker            while True:
744*da0073e9SAndroid Build Coastguard Worker                buffer = u.read(READ_DATA_CHUNK)
745*da0073e9SAndroid Build Coastguard Worker                if len(buffer) == 0:
746*da0073e9SAndroid Build Coastguard Worker                    break
747*da0073e9SAndroid Build Coastguard Worker                f.write(buffer)  # type: ignore[possibly-undefined]
748*da0073e9SAndroid Build Coastguard Worker                if hash_prefix is not None:
749*da0073e9SAndroid Build Coastguard Worker                    sha256.update(buffer)  # type: ignore[possibly-undefined]
750*da0073e9SAndroid Build Coastguard Worker                pbar.update(len(buffer))
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        f.close()
753*da0073e9SAndroid Build Coastguard Worker        if hash_prefix is not None:
754*da0073e9SAndroid Build Coastguard Worker            digest = sha256.hexdigest()  # type: ignore[possibly-undefined]
755*da0073e9SAndroid Build Coastguard Worker            if digest[: len(hash_prefix)] != hash_prefix:
756*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
757*da0073e9SAndroid Build Coastguard Worker                    f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
758*da0073e9SAndroid Build Coastguard Worker                )
759*da0073e9SAndroid Build Coastguard Worker        shutil.move(f.name, dst)
760*da0073e9SAndroid Build Coastguard Worker    finally:
761*da0073e9SAndroid Build Coastguard Worker        f.close()
762*da0073e9SAndroid Build Coastguard Worker        if os.path.exists(f.name):
763*da0073e9SAndroid Build Coastguard Worker            os.remove(f.name)
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker# Hub used to support automatically extracts from zipfile manually compressed by users.
767*da0073e9SAndroid Build Coastguard Worker# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
768*da0073e9SAndroid Build Coastguard Worker# We should remove this support since zipfile is now default zipfile format for torch.save().
769*da0073e9SAndroid Build Coastguard Workerdef _is_legacy_zip_format(filename: str) -> bool:
770*da0073e9SAndroid Build Coastguard Worker    if zipfile.is_zipfile(filename):
771*da0073e9SAndroid Build Coastguard Worker        infolist = zipfile.ZipFile(filename).infolist()
772*da0073e9SAndroid Build Coastguard Worker        return len(infolist) == 1 and not infolist[0].is_dir()
773*da0073e9SAndroid Build Coastguard Worker    return False
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker@deprecated(
777*da0073e9SAndroid Build Coastguard Worker    "Falling back to the old format < 1.6. This support will be "
778*da0073e9SAndroid Build Coastguard Worker    "deprecated in favor of default zipfile format introduced in 1.6. "
779*da0073e9SAndroid Build Coastguard Worker    "Please redo torch.save() to save it in the new zipfile format.",
780*da0073e9SAndroid Build Coastguard Worker    category=FutureWarning,
781*da0073e9SAndroid Build Coastguard Worker)
782*da0073e9SAndroid Build Coastguard Workerdef _legacy_zip_load(
783*da0073e9SAndroid Build Coastguard Worker    filename: str,
784*da0073e9SAndroid Build Coastguard Worker    model_dir: str,
785*da0073e9SAndroid Build Coastguard Worker    map_location: MAP_LOCATION,
786*da0073e9SAndroid Build Coastguard Worker    weights_only: bool,
787*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Any]:
788*da0073e9SAndroid Build Coastguard Worker    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
789*da0073e9SAndroid Build Coastguard Worker    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
790*da0073e9SAndroid Build Coastguard Worker    #       E.g. resnet18-5c106cde.pth which is widely used.
791*da0073e9SAndroid Build Coastguard Worker    with zipfile.ZipFile(filename) as f:
792*da0073e9SAndroid Build Coastguard Worker        members = f.infolist()
793*da0073e9SAndroid Build Coastguard Worker        if len(members) != 1:
794*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
795*da0073e9SAndroid Build Coastguard Worker        f.extractall(model_dir)
796*da0073e9SAndroid Build Coastguard Worker        extraced_name = members[0].filename
797*da0073e9SAndroid Build Coastguard Worker        extracted_file = os.path.join(model_dir, extraced_name)
798*da0073e9SAndroid Build Coastguard Worker    return torch.load(
799*da0073e9SAndroid Build Coastguard Worker        extracted_file, map_location=map_location, weights_only=weights_only
800*da0073e9SAndroid Build Coastguard Worker    )
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Workerdef load_state_dict_from_url(
804*da0073e9SAndroid Build Coastguard Worker    url: str,
805*da0073e9SAndroid Build Coastguard Worker    model_dir: Optional[str] = None,
806*da0073e9SAndroid Build Coastguard Worker    map_location: MAP_LOCATION = None,
807*da0073e9SAndroid Build Coastguard Worker    progress: bool = True,
808*da0073e9SAndroid Build Coastguard Worker    check_hash: bool = False,
809*da0073e9SAndroid Build Coastguard Worker    file_name: Optional[str] = None,
810*da0073e9SAndroid Build Coastguard Worker    weights_only: bool = False,
811*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Any]:
812*da0073e9SAndroid Build Coastguard Worker    r"""Loads the Torch serialized object at the given URL.
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker    If downloaded file is a zip file, it will be automatically
815*da0073e9SAndroid Build Coastguard Worker    decompressed.
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    If the object is already present in `model_dir`, it's deserialized and
818*da0073e9SAndroid Build Coastguard Worker    returned.
819*da0073e9SAndroid Build Coastguard Worker    The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
820*da0073e9SAndroid Build Coastguard Worker    ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    Args:
823*da0073e9SAndroid Build Coastguard Worker        url (str): URL of the object to download
824*da0073e9SAndroid Build Coastguard Worker        model_dir (str, optional): directory in which to save the object
825*da0073e9SAndroid Build Coastguard Worker        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
826*da0073e9SAndroid Build Coastguard Worker        progress (bool, optional): whether or not to display a progress bar to stderr.
827*da0073e9SAndroid Build Coastguard Worker            Default: True
828*da0073e9SAndroid Build Coastguard Worker        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
829*da0073e9SAndroid Build Coastguard Worker            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
830*da0073e9SAndroid Build Coastguard Worker            digits of the SHA256 hash of the contents of the file. The hash is used to
831*da0073e9SAndroid Build Coastguard Worker            ensure unique names and to verify the contents of the file.
832*da0073e9SAndroid Build Coastguard Worker            Default: False
833*da0073e9SAndroid Build Coastguard Worker        file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
834*da0073e9SAndroid Build Coastguard Worker        weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects.
835*da0073e9SAndroid Build Coastguard Worker            Recommended for untrusted sources. See :func:`~torch.load` for more details.
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    Example:
838*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
839*da0073e9SAndroid Build Coastguard Worker        >>> state_dict = torch.hub.load_state_dict_from_url(
840*da0073e9SAndroid Build Coastguard Worker        ...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
841*da0073e9SAndroid Build Coastguard Worker        ... )
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker    """
844*da0073e9SAndroid Build Coastguard Worker    # Issue warning to move data if old env is set
845*da0073e9SAndroid Build Coastguard Worker    if os.getenv("TORCH_MODEL_ZOO"):
846*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
847*da0073e9SAndroid Build Coastguard Worker            "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
848*da0073e9SAndroid Build Coastguard Worker        )
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker    if model_dir is None:
851*da0073e9SAndroid Build Coastguard Worker        hub_dir = get_dir()
852*da0073e9SAndroid Build Coastguard Worker        model_dir = os.path.join(hub_dir, "checkpoints")
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker    os.makedirs(model_dir, exist_ok=True)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    parts = urlparse(url)
857*da0073e9SAndroid Build Coastguard Worker    filename = os.path.basename(parts.path)
858*da0073e9SAndroid Build Coastguard Worker    if file_name is not None:
859*da0073e9SAndroid Build Coastguard Worker        filename = file_name
860*da0073e9SAndroid Build Coastguard Worker    cached_file = os.path.join(model_dir, filename)
861*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(cached_file):
862*da0073e9SAndroid Build Coastguard Worker        sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
863*da0073e9SAndroid Build Coastguard Worker        hash_prefix = None
864*da0073e9SAndroid Build Coastguard Worker        if check_hash:
865*da0073e9SAndroid Build Coastguard Worker            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
866*da0073e9SAndroid Build Coastguard Worker            hash_prefix = r.group(1) if r else None
867*da0073e9SAndroid Build Coastguard Worker        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    if _is_legacy_zip_format(cached_file):
870*da0073e9SAndroid Build Coastguard Worker        return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)
871*da0073e9SAndroid Build Coastguard Worker    return torch.load(cached_file, map_location=map_location, weights_only=weights_only)
872