xref: /aosp_15_r20/external/pytorch/torch/serialization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport difflib
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport io
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport shutil
8*da0073e9SAndroid Build Coastguard Workerimport struct
9*da0073e9SAndroid Build Coastguard Workerimport sys
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport tarfile
12*da0073e9SAndroid Build Coastguard Workerimport tempfile
13*da0073e9SAndroid Build Coastguard Workerimport warnings
14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import closing, contextmanager
15*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
16*da0073e9SAndroid Build Coastguard Workerfrom ._utils import _import_dotted_name
17*da0073e9SAndroid Build Coastguard Workerfrom torch._sources import get_source_lines_and_file
18*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Storage
19*da0073e9SAndroid Build Coastguard Workerfrom torch.storage import _get_dtype_from_pickle_storage_type
20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
21*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias, TypeGuard  # Python 3.10+
22*da0073e9SAndroid Build Coastguard Workerimport copyreg
23*da0073e9SAndroid Build Coastguard Workerimport pickle
24*da0073e9SAndroid Build Coastguard Workerimport torch._weights_only_unpickler as _weights_only_unpickler
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerDEFAULT_PROTOCOL = 2
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard WorkerLONG_SIZE = struct.Struct('=l').size
29*da0073e9SAndroid Build Coastguard WorkerINT_SIZE = struct.Struct('=i').size
30*da0073e9SAndroid Build Coastguard WorkerSHORT_SIZE = struct.Struct('=h').size
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard WorkerMAGIC_NUMBER = 0x1950a86a20f9469cfc6c
33*da0073e9SAndroid Build Coastguard WorkerPROTOCOL_VERSION = 1001
34*da0073e9SAndroid Build Coastguard WorkerSTORAGE_KEY_SEPARATOR = ','
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard WorkerFILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
37*da0073e9SAndroid Build Coastguard WorkerMAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
38*da0073e9SAndroid Build Coastguard WorkerSTORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard WorkerIS_WINDOWS = sys.platform == "win32"
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerif not IS_WINDOWS:
43*da0073e9SAndroid Build Coastguard Worker    from mmap import MAP_SHARED, MAP_PRIVATE
44*da0073e9SAndroid Build Coastguard Workerelse:
45*da0073e9SAndroid Build Coastguard Worker    MAP_SHARED, MAP_PRIVATE = None, None  # type: ignore[assignment]
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker__all__ = [
48*da0073e9SAndroid Build Coastguard Worker    'SourceChangeWarning',
49*da0073e9SAndroid Build Coastguard Worker    'mkdtemp',
50*da0073e9SAndroid Build Coastguard Worker    'register_package',
51*da0073e9SAndroid Build Coastguard Worker    'check_module_version_greater_or_equal',
52*da0073e9SAndroid Build Coastguard Worker    'validate_cuda_device',
53*da0073e9SAndroid Build Coastguard Worker    'validate_hpu_device',
54*da0073e9SAndroid Build Coastguard Worker    'location_tag',
55*da0073e9SAndroid Build Coastguard Worker    'default_restore_location',
56*da0073e9SAndroid Build Coastguard Worker    'normalize_storage_type',
57*da0073e9SAndroid Build Coastguard Worker    'storage_to_tensor_type',
58*da0073e9SAndroid Build Coastguard Worker    'save',
59*da0073e9SAndroid Build Coastguard Worker    'load',
60*da0073e9SAndroid Build Coastguard Worker    'StorageType',
61*da0073e9SAndroid Build Coastguard Worker    'LoadEndianness',
62*da0073e9SAndroid Build Coastguard Worker    'get_default_load_endianness',
63*da0073e9SAndroid Build Coastguard Worker    'set_default_load_endianness',
64*da0073e9SAndroid Build Coastguard Worker    'clear_safe_globals',
65*da0073e9SAndroid Build Coastguard Worker    'get_safe_globals',
66*da0073e9SAndroid Build Coastguard Worker    'add_safe_globals',
67*da0073e9SAndroid Build Coastguard Worker]
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerclass SourceChangeWarning(Warning):
71*da0073e9SAndroid Build Coastguard Worker    pass
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker@contextmanager
75*da0073e9SAndroid Build Coastguard Workerdef mkdtemp():
76*da0073e9SAndroid Build Coastguard Worker    path = tempfile.mkdtemp()
77*da0073e9SAndroid Build Coastguard Worker    try:
78*da0073e9SAndroid Build Coastguard Worker        yield path
79*da0073e9SAndroid Build Coastguard Worker    finally:
80*da0073e9SAndroid Build Coastguard Worker        shutil.rmtree(path)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerclass LoadEndianness(Enum):
86*da0073e9SAndroid Build Coastguard Worker    NATIVE = 1
87*da0073e9SAndroid Build Coastguard Worker    LITTLE = 2
88*da0073e9SAndroid Build Coastguard Worker    BIG = 3
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker_default_load_endian: Optional[LoadEndianness] = None
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerdef get_default_load_endianness() -> Optional[LoadEndianness]:
93*da0073e9SAndroid Build Coastguard Worker    '''
94*da0073e9SAndroid Build Coastguard Worker    Get fallback byte order for loading files
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    If byteorder mark is not present in saved checkpoint,
97*da0073e9SAndroid Build Coastguard Worker    this byte order is used as fallback.
98*da0073e9SAndroid Build Coastguard Worker    By default, it's "native" byte order.
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    Returns:
101*da0073e9SAndroid Build Coastguard Worker        default_load_endian: Optional[LoadEndianness]
102*da0073e9SAndroid Build Coastguard Worker    '''
103*da0073e9SAndroid Build Coastguard Worker    return _default_load_endian
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerdef set_default_load_endianness(endianness):
106*da0073e9SAndroid Build Coastguard Worker    '''
107*da0073e9SAndroid Build Coastguard Worker    Set fallback byte order for loading files
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    If byteorder mark is not present in saved checkpoint,
110*da0073e9SAndroid Build Coastguard Worker    this byte order is used as fallback.
111*da0073e9SAndroid Build Coastguard Worker    By default, it's "native" byte order.
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    Args:
114*da0073e9SAndroid Build Coastguard Worker        endianness: the new fallback byte order
115*da0073e9SAndroid Build Coastguard Worker    '''
116*da0073e9SAndroid Build Coastguard Worker    global _default_load_endian
117*da0073e9SAndroid Build Coastguard Worker    if not isinstance(endianness, LoadEndianness) and endianness is not None:
118*da0073e9SAndroid Build Coastguard Worker        raise TypeError("Invalid argument type in function set_default_load_endianness")
119*da0073e9SAndroid Build Coastguard Worker    _default_load_endian = endianness
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker_default_mmap_options: int = MAP_PRIVATE
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Workerdef get_default_mmap_options() -> int:
124*da0073e9SAndroid Build Coastguard Worker    '''
125*da0073e9SAndroid Build Coastguard Worker    Get default mmap options for :func:`torch.load` with ``mmap=True``.
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    Defaults to ``mmap.MAP_PRIVATE``.
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    Returns:
131*da0073e9SAndroid Build Coastguard Worker        default_mmap_options: int
132*da0073e9SAndroid Build Coastguard Worker    '''
133*da0073e9SAndroid Build Coastguard Worker    return _default_mmap_options
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Workerdef set_default_mmap_options(flags: int):
136*da0073e9SAndroid Build Coastguard Worker    '''
137*da0073e9SAndroid Build Coastguard Worker    Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker    For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
140*da0073e9SAndroid Build Coastguard Worker    Please open an issue if you need any other option to be added here.
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    .. note::
143*da0073e9SAndroid Build Coastguard Worker        This feature is currently not supported for Windows.
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    Args:
146*da0073e9SAndroid Build Coastguard Worker        flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
147*da0073e9SAndroid Build Coastguard Worker    '''
148*da0073e9SAndroid Build Coastguard Worker    global _default_mmap_options
149*da0073e9SAndroid Build Coastguard Worker    if IS_WINDOWS:
150*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
151*da0073e9SAndroid Build Coastguard Worker    if (flags != MAP_PRIVATE and flags != MAP_SHARED):
152*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Invalid argument in function set_default_mmap_options, "
153*da0073e9SAndroid Build Coastguard Worker                         f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
154*da0073e9SAndroid Build Coastguard Worker    _default_mmap_options = flags
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerdef clear_safe_globals() -> None:
157*da0073e9SAndroid Build Coastguard Worker    '''
158*da0073e9SAndroid Build Coastguard Worker    Clears the list of globals that are safe for ``weights_only`` load.
159*da0073e9SAndroid Build Coastguard Worker    '''
160*da0073e9SAndroid Build Coastguard Worker    _weights_only_unpickler._clear_safe_globals()
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerdef get_safe_globals() -> List[Any]:
163*da0073e9SAndroid Build Coastguard Worker    '''
164*da0073e9SAndroid Build Coastguard Worker    Returns the list of user-added globals that are safe for ``weights_only`` load.
165*da0073e9SAndroid Build Coastguard Worker    '''
166*da0073e9SAndroid Build Coastguard Worker    return _weights_only_unpickler._get_safe_globals()
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerdef add_safe_globals(safe_globals: List[Any]) -> None:
169*da0073e9SAndroid Build Coastguard Worker    '''
170*da0073e9SAndroid Build Coastguard Worker    Marks the given globals as safe for ``weights_only`` load. For example, functions
171*da0073e9SAndroid Build Coastguard Worker    added to this list can be called during unpickling, classes could be instantiated
172*da0073e9SAndroid Build Coastguard Worker    and have state set.
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    Args:
175*da0073e9SAndroid Build Coastguard Worker        safe_globals (List[Any]): list of globals to mark as safe
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    Example:
178*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
179*da0073e9SAndroid Build Coastguard Worker        >>> import tempfile
180*da0073e9SAndroid Build Coastguard Worker        >>> class MyTensor(torch.Tensor):
181*da0073e9SAndroid Build Coastguard Worker        ...     pass
182*da0073e9SAndroid Build Coastguard Worker        >>> t = MyTensor(torch.randn(2, 3))
183*da0073e9SAndroid Build Coastguard Worker        >>> with tempfile.NamedTemporaryFile() as f:
184*da0073e9SAndroid Build Coastguard Worker        ...     torch.save(t, f.name)
185*da0073e9SAndroid Build Coastguard Worker        # Running `torch.load(f.name, weights_only=True)` will fail with
186*da0073e9SAndroid Build Coastguard Worker        # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
187*da0073e9SAndroid Build Coastguard Worker        # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
188*da0073e9SAndroid Build Coastguard Worker        ...     torch.serialization.add_safe_globals([MyTensor])
189*da0073e9SAndroid Build Coastguard Worker        ...     torch.load(f.name, weights_only=True)
190*da0073e9SAndroid Build Coastguard Worker        # MyTensor([[-0.5024, -1.8152, -0.5455],
191*da0073e9SAndroid Build Coastguard Worker        #          [-0.8234,  2.0500, -0.3657]])
192*da0073e9SAndroid Build Coastguard Worker    '''
193*da0073e9SAndroid Build Coastguard Worker    _weights_only_unpickler._add_safe_globals(safe_globals)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerdef _is_zipfile(f) -> bool:
196*da0073e9SAndroid Build Coastguard Worker    # This is a stricter implementation than zipfile.is_zipfile().
197*da0073e9SAndroid Build Coastguard Worker    # zipfile.is_zipfile() is True if the magic number appears anywhere in the
198*da0073e9SAndroid Build Coastguard Worker    # binary. Since we expect the files here to be generated by torch.save or
199*da0073e9SAndroid Build Coastguard Worker    # torch.jit.save, it's safe to only check the start bytes and avoid
200*da0073e9SAndroid Build Coastguard Worker    # collisions and assume the zip has only 1 file.
201*da0073e9SAndroid Build Coastguard Worker    # See bugs.python.org/issue28494.
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    start = f.tell()
204*da0073e9SAndroid Build Coastguard Worker    # Read the first few bytes and match against the ZIP file signature
205*da0073e9SAndroid Build Coastguard Worker    local_header_magic_number = b'PK\x03\x04'
206*da0073e9SAndroid Build Coastguard Worker    read_bytes = f.read(len(local_header_magic_number))
207*da0073e9SAndroid Build Coastguard Worker    f.seek(start)
208*da0073e9SAndroid Build Coastguard Worker    return read_bytes == local_header_magic_number
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Workerdef register_package(
212*da0073e9SAndroid Build Coastguard Worker    priority: int,
213*da0073e9SAndroid Build Coastguard Worker    tagger: Callable[[STORAGE], Optional[str]],
214*da0073e9SAndroid Build Coastguard Worker    deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
215*da0073e9SAndroid Build Coastguard Worker):
216*da0073e9SAndroid Build Coastguard Worker    '''
217*da0073e9SAndroid Build Coastguard Worker    Registers callables for tagging and deserializing storage objects with an associated priority.
218*da0073e9SAndroid Build Coastguard Worker    Tagging associates a device with a storage object at save time while deserializing moves a
219*da0073e9SAndroid Build Coastguard Worker    storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
220*da0073e9SAndroid Build Coastguard Worker    are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
221*da0073e9SAndroid Build Coastguard Worker    value that is not `None`.
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker    To override the deserialization behavior for a device in the global registry, one can register a
224*da0073e9SAndroid Build Coastguard Worker    tagger with a higher priority than the existing tagger.
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    This function can also be used to register a tagger and deserializer for new devices.
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    Args:
229*da0073e9SAndroid Build Coastguard Worker        priority: Indicates the priority associated with the tagger and deserializer, where a lower
230*da0073e9SAndroid Build Coastguard Worker            value indicates higher priority.
231*da0073e9SAndroid Build Coastguard Worker        tagger: Callable that takes in a storage object and returns its tagged device as a string
232*da0073e9SAndroid Build Coastguard Worker            or None.
233*da0073e9SAndroid Build Coastguard Worker        deserializer: Callable that takes in storage object and a device string and returns a storage
234*da0073e9SAndroid Build Coastguard Worker            object on the appropriate device or None.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    Returns:
237*da0073e9SAndroid Build Coastguard Worker        `None`
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    Example:
240*da0073e9SAndroid Build Coastguard Worker        >>> def ipu_tag(obj):
241*da0073e9SAndroid Build Coastguard Worker        >>>     if obj.device.type == 'ipu':
242*da0073e9SAndroid Build Coastguard Worker        >>>         return 'ipu'
243*da0073e9SAndroid Build Coastguard Worker        >>> def ipu_deserialize(obj, location):
244*da0073e9SAndroid Build Coastguard Worker        >>>     if location.startswith('ipu'):
245*da0073e9SAndroid Build Coastguard Worker        >>>         ipu = getattr(torch, "ipu", None)
246*da0073e9SAndroid Build Coastguard Worker        >>>         assert ipu is not None, "IPU device module is not loaded"
247*da0073e9SAndroid Build Coastguard Worker        >>>         assert torch.ipu.is_available(), "ipu is not available"
248*da0073e9SAndroid Build Coastguard Worker        >>>         return obj.ipu(location)
249*da0073e9SAndroid Build Coastguard Worker        >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
250*da0073e9SAndroid Build Coastguard Worker    '''
251*da0073e9SAndroid Build Coastguard Worker    queue_elem = (priority, tagger, deserializer)
252*da0073e9SAndroid Build Coastguard Worker    _package_registry.append(queue_elem)
253*da0073e9SAndroid Build Coastguard Worker    _package_registry.sort()
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Workerdef check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
257*da0073e9SAndroid Build Coastguard Worker    '''
258*da0073e9SAndroid Build Coastguard Worker    Check if a module's version satisfies requirements
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    Usually, a module's version string will be like 'x.y.z', which would be represented
261*da0073e9SAndroid Build Coastguard Worker    as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
262*da0073e9SAndroid Build Coastguard Worker    string does not match the given tuple's format up to the length of the tuple, then
263*da0073e9SAndroid Build Coastguard Worker    error and exit or emit a warning.
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    Args:
266*da0073e9SAndroid Build Coastguard Worker        module: the module to check the version of
267*da0073e9SAndroid Build Coastguard Worker        req_version_tuple: tuple (usually of ints) representing the required version
268*da0073e9SAndroid Build Coastguard Worker        error_if_malformed: whether we should exit if module version string is malformed
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    Returns:
271*da0073e9SAndroid Build Coastguard Worker        requirement_is_met: bool
272*da0073e9SAndroid Build Coastguard Worker    '''
273*da0073e9SAndroid Build Coastguard Worker    try:
274*da0073e9SAndroid Build Coastguard Worker        version_strs = module.__version__.split('.')
275*da0073e9SAndroid Build Coastguard Worker        # Cast module version fields to match the types of the required version
276*da0073e9SAndroid Build Coastguard Worker        module_version = tuple(
277*da0073e9SAndroid Build Coastguard Worker            type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
278*da0073e9SAndroid Build Coastguard Worker        )
279*da0073e9SAndroid Build Coastguard Worker        requirement_is_met = module_version >= req_version_tuple
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
282*da0073e9SAndroid Build Coastguard Worker        message = (
283*da0073e9SAndroid Build Coastguard Worker            f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
284*da0073e9SAndroid Build Coastguard Worker            f" with tuple {str(req_version_tuple)}"
285*da0073e9SAndroid Build Coastguard Worker        )
286*da0073e9SAndroid Build Coastguard Worker        if error_if_malformed:
287*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(message) from e
288*da0073e9SAndroid Build Coastguard Worker        else:
289*da0073e9SAndroid Build Coastguard Worker            warnings.warn(message + ', but continuing assuming that requirement is met')
290*da0073e9SAndroid Build Coastguard Worker            requirement_is_met = True
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    return requirement_is_met
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Workerdef _cpu_tag(obj):
296*da0073e9SAndroid Build Coastguard Worker    if obj.device.type == 'cpu':
297*da0073e9SAndroid Build Coastguard Worker        return 'cpu'
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Workerdef _mps_tag(obj):
301*da0073e9SAndroid Build Coastguard Worker    if obj.device.type == 'mps':
302*da0073e9SAndroid Build Coastguard Worker        return 'mps'
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Workerdef _meta_tag(obj):
306*da0073e9SAndroid Build Coastguard Worker    if obj.device.type == 'meta':
307*da0073e9SAndroid Build Coastguard Worker        return 'meta'
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Workerdef _backend_tag(backend_name, obj):
311*da0073e9SAndroid Build Coastguard Worker    if backend_name == 'privateuse1':
312*da0073e9SAndroid Build Coastguard Worker        backend_name = torch._C._get_privateuse1_backend_name()
313*da0073e9SAndroid Build Coastguard Worker    if obj.device.type == backend_name:
314*da0073e9SAndroid Build Coastguard Worker        if obj.device.index is None:
315*da0073e9SAndroid Build Coastguard Worker            return backend_name
316*da0073e9SAndroid Build Coastguard Worker        else:
317*da0073e9SAndroid Build Coastguard Worker            return backend_name + ':' + str(obj.device.index)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerdef _cpu_deserialize(obj, location):
321*da0073e9SAndroid Build Coastguard Worker    if location == 'cpu':
322*da0073e9SAndroid Build Coastguard Worker        return obj
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Workerdef _mps_deserialize(obj, location):
326*da0073e9SAndroid Build Coastguard Worker    if location.startswith('mps'):
327*da0073e9SAndroid Build Coastguard Worker        return obj.mps()
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Workerdef _meta_deserialize(obj, location):
331*da0073e9SAndroid Build Coastguard Worker    if location == 'meta':
332*da0073e9SAndroid Build Coastguard Worker        return torch.UntypedStorage(obj.nbytes(), device='meta')
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Workerdef _validate_device(location, backend_name):
336*da0073e9SAndroid Build Coastguard Worker    '''
337*da0073e9SAndroid Build Coastguard Worker    Check whether the device index of specified backend is valid
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker    In case of privateuse1 backend, your must first register a device_module for
340*da0073e9SAndroid Build Coastguard Worker    privateuse1 using torch._register_device_module. Implement the following
341*da0073e9SAndroid Build Coastguard Worker    methods in device_module like cuda: device_module._utils._get_device_index(location, True),
342*da0073e9SAndroid Build Coastguard Worker    device_module.device_count().
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    Args:
345*da0073e9SAndroid Build Coastguard Worker        location: string of device
346*da0073e9SAndroid Build Coastguard Worker        backend_name: the backend name or the name of privateuse1, which can be renamed
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    Returns:
349*da0073e9SAndroid Build Coastguard Worker        device_index: int
350*da0073e9SAndroid Build Coastguard Worker    '''
351*da0073e9SAndroid Build Coastguard Worker    if not hasattr(torch, backend_name):
352*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
353*da0073e9SAndroid Build Coastguard Worker                           'If you are running on a CPU-only machine, '
354*da0073e9SAndroid Build Coastguard Worker                           'please use torch.load with map_location=torch.device(\'cpu\') '
355*da0073e9SAndroid Build Coastguard Worker                           'to map your storages to the CPU.')
356*da0073e9SAndroid Build Coastguard Worker    device_module = getattr(torch, backend_name)
357*da0073e9SAndroid Build Coastguard Worker    if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
358*da0073e9SAndroid Build Coastguard Worker        device_index = device_module._utils._get_device_index(location, True)
359*da0073e9SAndroid Build Coastguard Worker        device = torch.device(backend_name, device_index)
360*da0073e9SAndroid Build Coastguard Worker    else:
361*da0073e9SAndroid Build Coastguard Worker        device = torch.device(location)
362*da0073e9SAndroid Build Coastguard Worker        device_index = device.index if device.index else 0
363*da0073e9SAndroid Build Coastguard Worker    if hasattr(device_module, 'is_available') and not device_module.is_available():
364*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
365*da0073e9SAndroid Build Coastguard Worker                           f'device but torch.{backend_name}.is_available() is False. '
366*da0073e9SAndroid Build Coastguard Worker                           'If you are running on a CPU-only machine, '
367*da0073e9SAndroid Build Coastguard Worker                           'please use torch.load with map_location=torch.device(\'cpu\') '
368*da0073e9SAndroid Build Coastguard Worker                           'to map your storages to the CPU.')
369*da0073e9SAndroid Build Coastguard Worker    if hasattr(device_module, 'device_count'):
370*da0073e9SAndroid Build Coastguard Worker        device_count = device_module.device_count()
371*da0073e9SAndroid Build Coastguard Worker        if device_index >= device_count:
372*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
373*da0073e9SAndroid Build Coastguard Worker                               f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
374*da0073e9SAndroid Build Coastguard Worker                               'Please use torch.load with map_location to map your storages '
375*da0073e9SAndroid Build Coastguard Worker                               'to an existing device.')
376*da0073e9SAndroid Build Coastguard Worker    return device
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Workerdef validate_cuda_device(location):
380*da0073e9SAndroid Build Coastguard Worker    return _validate_device(location, 'cuda').index
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Workerdef validate_hpu_device(location):
384*da0073e9SAndroid Build Coastguard Worker    return _validate_device(location, 'hpu').index
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Workerdef _deserialize(backend_name, obj, location):
388*da0073e9SAndroid Build Coastguard Worker    if backend_name == 'privateuse1':
389*da0073e9SAndroid Build Coastguard Worker        backend_name = torch._C._get_privateuse1_backend_name()
390*da0073e9SAndroid Build Coastguard Worker    if location.startswith(backend_name):
391*da0073e9SAndroid Build Coastguard Worker        device = _validate_device(location, backend_name)
392*da0073e9SAndroid Build Coastguard Worker        return obj.to(device=device)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Workerregister_package(10, _cpu_tag, _cpu_deserialize)
396*da0073e9SAndroid Build Coastguard Workerregister_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
397*da0073e9SAndroid Build Coastguard Workerregister_package(21, _mps_tag, _mps_deserialize)
398*da0073e9SAndroid Build Coastguard Workerregister_package(22, _meta_tag, _meta_deserialize)
399*da0073e9SAndroid Build Coastguard Workerregister_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
400*da0073e9SAndroid Build Coastguard Workerregister_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
401*da0073e9SAndroid Build Coastguard Workerregister_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Workerdef location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
404*da0073e9SAndroid Build Coastguard Worker    for _, tagger, _ in _package_registry:
405*da0073e9SAndroid Build Coastguard Worker        location = tagger(storage)
406*da0073e9SAndroid Build Coastguard Worker        if location:
407*da0073e9SAndroid Build Coastguard Worker            return location
408*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("don't know how to determine data location of "
409*da0073e9SAndroid Build Coastguard Worker                       + torch.typename(storage))
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Workerdef default_restore_location(storage, location):
413*da0073e9SAndroid Build Coastguard Worker    """
414*da0073e9SAndroid Build Coastguard Worker    Restores `storage` using a deserializer function registered for the `location`.
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    This function looks in the registry for deserializer functions that match the `location`.
417*da0073e9SAndroid Build Coastguard Worker    If found, it attempts to use them, in priority order, to restore `storage` until one
418*da0073e9SAndroid Build Coastguard Worker    returns a not `None` result. If no deserializer can be found in the registry, or all found fail
419*da0073e9SAndroid Build Coastguard Worker    to bear a result, it raises a `RuntimeError`.
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    Args:
422*da0073e9SAndroid Build Coastguard Worker        storage (STORAGE): the storage object to restore
423*da0073e9SAndroid Build Coastguard Worker        location (str): the location tag associated with the storage object
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    Returns:
426*da0073e9SAndroid Build Coastguard Worker        storage: Optional[STORAGE]
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    Raises:
429*da0073e9SAndroid Build Coastguard Worker        RuntimeError: If no deserializer matching `location` is found in the registry or if
430*da0073e9SAndroid Build Coastguard Worker           all matching ones return `None`.
431*da0073e9SAndroid Build Coastguard Worker    """
432*da0073e9SAndroid Build Coastguard Worker    for _, _, fn in _package_registry:
433*da0073e9SAndroid Build Coastguard Worker        result = fn(storage, location)
434*da0073e9SAndroid Build Coastguard Worker        if result is not None:
435*da0073e9SAndroid Build Coastguard Worker            return result
436*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("don't know how to restore data location of "
437*da0073e9SAndroid Build Coastguard Worker                       + torch.typename(storage) + " (tagged with "
438*da0073e9SAndroid Build Coastguard Worker                       + location + ")")
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Workerdef normalize_storage_type(storage_type):
442*da0073e9SAndroid Build Coastguard Worker    return getattr(torch, storage_type.__name__)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Workerdef storage_to_tensor_type(storage):
446*da0073e9SAndroid Build Coastguard Worker    storage_type = type(storage)
447*da0073e9SAndroid Build Coastguard Worker    module = _import_dotted_name(storage_type.__module__)
448*da0073e9SAndroid Build Coastguard Worker    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Workerdef _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
452*da0073e9SAndroid Build Coastguard Worker    return isinstance(name_or_buffer, (str, os.PathLike))
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Workerclass _opener:
456*da0073e9SAndroid Build Coastguard Worker    def __init__(self, file_like):
457*da0073e9SAndroid Build Coastguard Worker        self.file_like = file_like
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
460*da0073e9SAndroid Build Coastguard Worker        return self.file_like
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args):
463*da0073e9SAndroid Build Coastguard Worker        pass
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Workerclass _open_file(_opener):
467*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name, mode):
468*da0073e9SAndroid Build Coastguard Worker        super().__init__(open(name, mode))
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args):
471*da0073e9SAndroid Build Coastguard Worker        self.file_like.close()
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Workerclass _open_buffer_reader(_opener):
475*da0073e9SAndroid Build Coastguard Worker    def __init__(self, buffer):
476*da0073e9SAndroid Build Coastguard Worker        super().__init__(buffer)
477*da0073e9SAndroid Build Coastguard Worker        _check_seekable(buffer)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Workerclass _open_buffer_writer(_opener):
481*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args):
482*da0073e9SAndroid Build Coastguard Worker        self.file_like.flush()
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Workerdef _open_file_like(name_or_buffer, mode):
486*da0073e9SAndroid Build Coastguard Worker    if _is_path(name_or_buffer):
487*da0073e9SAndroid Build Coastguard Worker        return _open_file(name_or_buffer, mode)
488*da0073e9SAndroid Build Coastguard Worker    else:
489*da0073e9SAndroid Build Coastguard Worker        if 'w' in mode:
490*da0073e9SAndroid Build Coastguard Worker            return _open_buffer_writer(name_or_buffer)
491*da0073e9SAndroid Build Coastguard Worker        elif 'r' in mode:
492*da0073e9SAndroid Build Coastguard Worker            return _open_buffer_reader(name_or_buffer)
493*da0073e9SAndroid Build Coastguard Worker        else:
494*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_reader(_opener):
498*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name_or_buffer) -> None:
499*da0073e9SAndroid Build Coastguard Worker        super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_writer_file(_opener):
503*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name) -> None:
504*da0073e9SAndroid Build Coastguard Worker        self.file_stream = None
505*da0073e9SAndroid Build Coastguard Worker        self.name = str(name)
506*da0073e9SAndroid Build Coastguard Worker        try:
507*da0073e9SAndroid Build Coastguard Worker            self.name.encode('ascii')
508*da0073e9SAndroid Build Coastguard Worker        except UnicodeEncodeError:
509*da0073e9SAndroid Build Coastguard Worker            # PyTorchFileWriter only supports ascii filename.
510*da0073e9SAndroid Build Coastguard Worker            # For filenames with non-ascii characters, we rely on Python
511*da0073e9SAndroid Build Coastguard Worker            # for writing out the file.
512*da0073e9SAndroid Build Coastguard Worker            self.file_stream = io.FileIO(self.name, mode='w')
513*da0073e9SAndroid Build Coastguard Worker            super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
514*da0073e9SAndroid Build Coastguard Worker        else:
515*da0073e9SAndroid Build Coastguard Worker            super().__init__(torch._C.PyTorchFileWriter(self.name))
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args) -> None:
518*da0073e9SAndroid Build Coastguard Worker        self.file_like.write_end_of_file()
519*da0073e9SAndroid Build Coastguard Worker        if self.file_stream is not None:
520*da0073e9SAndroid Build Coastguard Worker            self.file_stream.close()
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_writer_buffer(_opener):
524*da0073e9SAndroid Build Coastguard Worker    def __init__(self, buffer) -> None:
525*da0073e9SAndroid Build Coastguard Worker        if not callable(getattr(buffer, "write", None)):
526*da0073e9SAndroid Build Coastguard Worker            msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
527*da0073e9SAndroid Build Coastguard Worker            if not hasattr(buffer, "write"):
528*da0073e9SAndroid Build Coastguard Worker                raise AttributeError(msg)
529*da0073e9SAndroid Build Coastguard Worker            raise TypeError(msg)
530*da0073e9SAndroid Build Coastguard Worker        self.buffer = buffer
531*da0073e9SAndroid Build Coastguard Worker        super().__init__(torch._C.PyTorchFileWriter(buffer))
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args) -> None:
534*da0073e9SAndroid Build Coastguard Worker        self.file_like.write_end_of_file()
535*da0073e9SAndroid Build Coastguard Worker        self.buffer.flush()
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Workerdef _open_zipfile_writer(name_or_buffer):
539*da0073e9SAndroid Build Coastguard Worker    container: Type[_opener]
540*da0073e9SAndroid Build Coastguard Worker    if _is_path(name_or_buffer):
541*da0073e9SAndroid Build Coastguard Worker        container = _open_zipfile_writer_file
542*da0073e9SAndroid Build Coastguard Worker    else:
543*da0073e9SAndroid Build Coastguard Worker        container = _open_zipfile_writer_buffer
544*da0073e9SAndroid Build Coastguard Worker    return container(name_or_buffer)
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Workerdef _is_compressed_file(f) -> bool:
548*da0073e9SAndroid Build Coastguard Worker    compress_modules = ['gzip']
549*da0073e9SAndroid Build Coastguard Worker    try:
550*da0073e9SAndroid Build Coastguard Worker        return f.__module__ in compress_modules
551*da0073e9SAndroid Build Coastguard Worker    except AttributeError:
552*da0073e9SAndroid Build Coastguard Worker        return False
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Workerdef _should_read_directly(f):
556*da0073e9SAndroid Build Coastguard Worker    """
557*da0073e9SAndroid Build Coastguard Worker    Checks if f is a file that should be read directly. It should be read
558*da0073e9SAndroid Build Coastguard Worker    directly if it is backed by a real file (has a fileno) and is not a
559*da0073e9SAndroid Build Coastguard Worker    a compressed file (e.g. gzip)
560*da0073e9SAndroid Build Coastguard Worker    """
561*da0073e9SAndroid Build Coastguard Worker    if _is_compressed_file(f):
562*da0073e9SAndroid Build Coastguard Worker        return False
563*da0073e9SAndroid Build Coastguard Worker    try:
564*da0073e9SAndroid Build Coastguard Worker        return f.fileno() >= 0
565*da0073e9SAndroid Build Coastguard Worker    except io.UnsupportedOperation:
566*da0073e9SAndroid Build Coastguard Worker        return False
567*da0073e9SAndroid Build Coastguard Worker    except AttributeError:
568*da0073e9SAndroid Build Coastguard Worker        return False
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Workerdef _check_seekable(f) -> bool:
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    def raise_err_msg(patterns, e):
574*da0073e9SAndroid Build Coastguard Worker        for p in patterns:
575*da0073e9SAndroid Build Coastguard Worker            if p in str(e):
576*da0073e9SAndroid Build Coastguard Worker                msg = (str(e) + ". You can only torch.load from a file that is seekable."
577*da0073e9SAndroid Build Coastguard Worker                                + " Please pre-load the data into a buffer like io.BytesIO and"
578*da0073e9SAndroid Build Coastguard Worker                                + " try to load from it instead.")
579*da0073e9SAndroid Build Coastguard Worker                raise type(e)(msg)
580*da0073e9SAndroid Build Coastguard Worker        raise e
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    try:
583*da0073e9SAndroid Build Coastguard Worker        f.seek(f.tell())
584*da0073e9SAndroid Build Coastguard Worker        return True
585*da0073e9SAndroid Build Coastguard Worker    except (io.UnsupportedOperation, AttributeError) as e:
586*da0073e9SAndroid Build Coastguard Worker        raise_err_msg(["seek", "tell"], e)
587*da0073e9SAndroid Build Coastguard Worker    return False
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Workerdef _check_dill_version(pickle_module) -> None:
591*da0073e9SAndroid Build Coastguard Worker    '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
592*da0073e9SAndroid Build Coastguard Worker    If dill version is lower than 0.3.1, a ValueError is raised.
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    Args:
595*da0073e9SAndroid Build Coastguard Worker        pickle_module: module used for pickling metadata and objects
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    '''
598*da0073e9SAndroid Build Coastguard Worker    if pickle_module is not None and pickle_module.__name__ == 'dill':
599*da0073e9SAndroid Build Coastguard Worker        required_dill_version = (0, 3, 1)
600*da0073e9SAndroid Build Coastguard Worker        if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
601*da0073e9SAndroid Build Coastguard Worker            raise ValueError((
602*da0073e9SAndroid Build Coastguard Worker                "'torch' supports dill >= {}, but you have dill {}."
603*da0073e9SAndroid Build Coastguard Worker                " Please upgrade dill or switch to 'pickle'"
604*da0073e9SAndroid Build Coastguard Worker            ).format(
605*da0073e9SAndroid Build Coastguard Worker                '.'.join([str(num) for num in required_dill_version]),
606*da0073e9SAndroid Build Coastguard Worker                pickle_module.__version__
607*da0073e9SAndroid Build Coastguard Worker            ))
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Workerdef _check_save_filelike(f):
611*da0073e9SAndroid Build Coastguard Worker    if not _is_path(f) and not hasattr(f, 'write'):
612*da0073e9SAndroid Build Coastguard Worker        raise AttributeError(
613*da0073e9SAndroid Build Coastguard Worker            "expected 'f' to be string, path, or a file-like object with "
614*da0073e9SAndroid Build Coastguard Worker            "a 'write' attribute")
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Workerdef save(
618*da0073e9SAndroid Build Coastguard Worker    obj: object,
619*da0073e9SAndroid Build Coastguard Worker    f: FILE_LIKE,
620*da0073e9SAndroid Build Coastguard Worker    pickle_module: Any = pickle,
621*da0073e9SAndroid Build Coastguard Worker    pickle_protocol: int = DEFAULT_PROTOCOL,
622*da0073e9SAndroid Build Coastguard Worker    _use_new_zipfile_serialization: bool = True,
623*da0073e9SAndroid Build Coastguard Worker    _disable_byteorder_record: bool = False
624*da0073e9SAndroid Build Coastguard Worker) -> None:
625*da0073e9SAndroid Build Coastguard Worker    # Reference: https://github.com/pytorch/pytorch/issues/54354
626*da0073e9SAndroid Build Coastguard Worker    # The first line of this docstring overrides the one Sphinx generates for the
627*da0073e9SAndroid Build Coastguard Worker    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
628*da0073e9SAndroid Build Coastguard Worker    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker    Saves an object to a disk file.
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker    See also: :ref:`saving-loading-tensors`
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    Args:
637*da0073e9SAndroid Build Coastguard Worker        obj: saved object
638*da0073e9SAndroid Build Coastguard Worker        f: a file-like object (has to implement write and flush) or a string or
639*da0073e9SAndroid Build Coastguard Worker           os.PathLike object containing a file name
640*da0073e9SAndroid Build Coastguard Worker        pickle_module: module used for pickling metadata and objects
641*da0073e9SAndroid Build Coastguard Worker        pickle_protocol: can be specified to override the default protocol
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker    .. note::
644*da0073e9SAndroid Build Coastguard Worker        A common PyTorch convention is to save tensors using .pt file extension.
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    .. note::
647*da0073e9SAndroid Build Coastguard Worker        PyTorch preserves storage sharing across serialization. See
648*da0073e9SAndroid Build Coastguard Worker        :ref:`preserve-storage-sharing` for more details.
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    .. note::
651*da0073e9SAndroid Build Coastguard Worker        The 1.6 release of PyTorch switched ``torch.save`` to use a new
652*da0073e9SAndroid Build Coastguard Worker        zipfile-based file format. ``torch.load`` still retains the ability to
653*da0073e9SAndroid Build Coastguard Worker        load files in the old format. If for any reason you want ``torch.save``
654*da0073e9SAndroid Build Coastguard Worker        to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker    Example:
657*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("makes cwd dirty")
658*da0073e9SAndroid Build Coastguard Worker        >>> # Save to file
659*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor([0, 1, 2, 3, 4])
660*da0073e9SAndroid Build Coastguard Worker        >>> torch.save(x, 'tensor.pt')
661*da0073e9SAndroid Build Coastguard Worker        >>> # Save to io.BytesIO buffer
662*da0073e9SAndroid Build Coastguard Worker        >>> buffer = io.BytesIO()
663*da0073e9SAndroid Build Coastguard Worker        >>> torch.save(x, buffer)
664*da0073e9SAndroid Build Coastguard Worker    """
665*da0073e9SAndroid Build Coastguard Worker    torch._C._log_api_usage_once("torch.save")
666*da0073e9SAndroid Build Coastguard Worker    _check_dill_version(pickle_module)
667*da0073e9SAndroid Build Coastguard Worker    _check_save_filelike(f)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    if _use_new_zipfile_serialization:
670*da0073e9SAndroid Build Coastguard Worker        with _open_zipfile_writer(f) as opened_zipfile:
671*da0073e9SAndroid Build Coastguard Worker            _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
672*da0073e9SAndroid Build Coastguard Worker            return
673*da0073e9SAndroid Build Coastguard Worker    else:
674*da0073e9SAndroid Build Coastguard Worker        with _open_file_like(f, 'wb') as opened_file:
675*da0073e9SAndroid Build Coastguard Worker            _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Workerdef _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
679*da0073e9SAndroid Build Coastguard Worker    import torch.nn as nn
680*da0073e9SAndroid Build Coastguard Worker    serialized_container_types = {}
681*da0073e9SAndroid Build Coastguard Worker    serialized_storages = {}
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker    # Since loading storages that view the same data with different dtypes is
684*da0073e9SAndroid Build Coastguard Worker    # not supported, we need to keep track of the dtype associated with each
685*da0073e9SAndroid Build Coastguard Worker    # storage data_ptr and throw an error if the dtype is ever different.
686*da0073e9SAndroid Build Coastguard Worker    # TODO: This feature could be added in the future
687*da0073e9SAndroid Build Coastguard Worker    storage_dtypes: Dict[int, torch.dtype] = {}
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    def persistent_id(obj: Any) -> Optional[Tuple]:
690*da0073e9SAndroid Build Coastguard Worker        # FIXME: the docs say that persistent_id should only return a string
691*da0073e9SAndroid Build Coastguard Worker        # but torch store returns tuples. This works only in the binary protocol
692*da0073e9SAndroid Build Coastguard Worker        # see
693*da0073e9SAndroid Build Coastguard Worker        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
694*da0073e9SAndroid Build Coastguard Worker        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
695*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, type) and issubclass(obj, nn.Module):
696*da0073e9SAndroid Build Coastguard Worker            if obj in serialized_container_types:
697*da0073e9SAndroid Build Coastguard Worker                return None
698*da0073e9SAndroid Build Coastguard Worker            serialized_container_types[obj] = True
699*da0073e9SAndroid Build Coastguard Worker            source_file = source = None
700*da0073e9SAndroid Build Coastguard Worker            try:
701*da0073e9SAndroid Build Coastguard Worker                source_lines, _, source_file = get_source_lines_and_file(obj)
702*da0073e9SAndroid Build Coastguard Worker                source = ''.join(source_lines)
703*da0073e9SAndroid Build Coastguard Worker            except Exception:  # saving the source is optional, so we can ignore any errors
704*da0073e9SAndroid Build Coastguard Worker                warnings.warn("Couldn't retrieve source code for container of "
705*da0073e9SAndroid Build Coastguard Worker                              "type " + obj.__name__ + ". It won't be checked "
706*da0073e9SAndroid Build Coastguard Worker                              "for correctness upon loading.")
707*da0073e9SAndroid Build Coastguard Worker            return ('module', obj, source_file, source)
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
710*da0073e9SAndroid Build Coastguard Worker            storage: torch.UntypedStorage
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            if isinstance(obj, torch.storage.TypedStorage):
713*da0073e9SAndroid Build Coastguard Worker                # TODO: Once we decide to break serialization FC, this case
714*da0073e9SAndroid Build Coastguard Worker                # can be deleted
715*da0073e9SAndroid Build Coastguard Worker                storage = obj._untyped_storage
716*da0073e9SAndroid Build Coastguard Worker                storage_dtype = obj.dtype
717*da0073e9SAndroid Build Coastguard Worker                storage_type_str = obj._pickle_storage_type()
718*da0073e9SAndroid Build Coastguard Worker                storage_type = getattr(torch, storage_type_str)
719*da0073e9SAndroid Build Coastguard Worker                dtype = obj.dtype
720*da0073e9SAndroid Build Coastguard Worker                storage_numel = obj._size()
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker            elif isinstance(obj, torch.UntypedStorage):
723*da0073e9SAndroid Build Coastguard Worker                storage = obj
724*da0073e9SAndroid Build Coastguard Worker                storage_dtype = torch.uint8
725*da0073e9SAndroid Build Coastguard Worker                storage_type = normalize_storage_type(type(obj))
726*da0073e9SAndroid Build Coastguard Worker                dtype = torch.uint8
727*da0073e9SAndroid Build Coastguard Worker                storage_numel = storage.nbytes()
728*da0073e9SAndroid Build Coastguard Worker            else:
729*da0073e9SAndroid Build Coastguard Worker                raise TypeError(f'type not recognized: {type(obj)}')
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker            # If storage is allocated, ensure that any other saved storages
732*da0073e9SAndroid Build Coastguard Worker            # pointing to the same data all have the same dtype. If storage is
733*da0073e9SAndroid Build Coastguard Worker            # not allocated, don't perform this check
734*da0073e9SAndroid Build Coastguard Worker            if storage.data_ptr() != 0:
735*da0073e9SAndroid Build Coastguard Worker                if storage.data_ptr() in storage_dtypes:
736*da0073e9SAndroid Build Coastguard Worker                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
737*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(
738*da0073e9SAndroid Build Coastguard Worker                            'Cannot save multiple tensors or storages that '
739*da0073e9SAndroid Build Coastguard Worker                            'view the same data as different types')
740*da0073e9SAndroid Build Coastguard Worker                else:
741*da0073e9SAndroid Build Coastguard Worker                    storage_dtypes[storage.data_ptr()] = storage_dtype
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker            view_metadata: Optional[Tuple[str, int, int]]
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker            # Offset is always 0, but we keep it for backwards compatibility
746*da0073e9SAndroid Build Coastguard Worker            # with the old serialization format (which supported storage views)
747*da0073e9SAndroid Build Coastguard Worker            offset = 0
748*da0073e9SAndroid Build Coastguard Worker            storage_key = str(storage._cdata)
749*da0073e9SAndroid Build Coastguard Worker            location = location_tag(storage)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker            # TODO: There's an issue here with FC. It might be impossible to
752*da0073e9SAndroid Build Coastguard Worker            # solve, but it's worth noting. Imagine we save a list `[storage,
753*da0073e9SAndroid Build Coastguard Worker            # tensor]`, where `tensor.storage()` is the same as `storage`, and
754*da0073e9SAndroid Build Coastguard Worker            # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
755*da0073e9SAndroid Build Coastguard Worker            # torch.float`.  The storage will be serialized with element size
756*da0073e9SAndroid Build Coastguard Worker            # of 1, since we're choosing to serialize the first occurance of
757*da0073e9SAndroid Build Coastguard Worker            # a duplicate storage. Since this legacy serialization format saves
758*da0073e9SAndroid Build Coastguard Worker            # the numel of the storage, rather than nbytes directly, we'll be
759*da0073e9SAndroid Build Coastguard Worker            # effectively saving nbytes in this case.  We'll be able to load it
760*da0073e9SAndroid Build Coastguard Worker            # and the tensor back up with no problems in _this_ and future
761*da0073e9SAndroid Build Coastguard Worker            # versions of pytorch, but in older versions, here's the problem:
762*da0073e9SAndroid Build Coastguard Worker            # the storage will be loaded up as a UntypedStorage, and then the
763*da0073e9SAndroid Build Coastguard Worker            # FloatTensor will loaded and the UntypedStorage will be assigned to
764*da0073e9SAndroid Build Coastguard Worker            # it. Since the storage dtype does not match the tensor dtype, this
765*da0073e9SAndroid Build Coastguard Worker            # will cause an error.  If we reverse the list, like `[tensor,
766*da0073e9SAndroid Build Coastguard Worker            # storage]`, then we will save the `tensor.storage()` as a faked
767*da0073e9SAndroid Build Coastguard Worker            # `FloatStorage`, and the saved size will be the correct
768*da0073e9SAndroid Build Coastguard Worker            # dtype-specific numel count that old versions expect. `tensor`
769*da0073e9SAndroid Build Coastguard Worker            # will be able to load up properly in old versions, pointing to
770*da0073e9SAndroid Build Coastguard Worker            # a FloatStorage. However, `storage` is still being translated to
771*da0073e9SAndroid Build Coastguard Worker            # a UntypedStorage, and it will try to resolve to the same
772*da0073e9SAndroid Build Coastguard Worker            # FloatStorage that `tensor` contains. This will also cause an
773*da0073e9SAndroid Build Coastguard Worker            # error. It doesn't seem like there's any way around this.
774*da0073e9SAndroid Build Coastguard Worker            # Probably, we just cannot maintain FC for the legacy format if the
775*da0073e9SAndroid Build Coastguard Worker            # saved list contains both a tensor and a storage that point to the
776*da0073e9SAndroid Build Coastguard Worker            # same data.  We should still be able to maintain FC for lists of
777*da0073e9SAndroid Build Coastguard Worker            # just tensors, as long as all views share the same dtype as the
778*da0073e9SAndroid Build Coastguard Worker            # tensor they are viewing.
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker            if storage_key not in serialized_storages:
781*da0073e9SAndroid Build Coastguard Worker                serialized_storages[storage_key] = (storage, dtype)
782*da0073e9SAndroid Build Coastguard Worker            is_view = storage._cdata != storage._cdata
783*da0073e9SAndroid Build Coastguard Worker            if is_view:
784*da0073e9SAndroid Build Coastguard Worker                view_metadata = (str(storage._cdata), offset, storage.nbytes())
785*da0073e9SAndroid Build Coastguard Worker            else:
786*da0073e9SAndroid Build Coastguard Worker                view_metadata = None
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker            res = ('storage',
789*da0073e9SAndroid Build Coastguard Worker                   storage_type,
790*da0073e9SAndroid Build Coastguard Worker                   storage_key,
791*da0073e9SAndroid Build Coastguard Worker                   location,
792*da0073e9SAndroid Build Coastguard Worker                   storage_numel,
793*da0073e9SAndroid Build Coastguard Worker                   view_metadata)
794*da0073e9SAndroid Build Coastguard Worker            return res
795*da0073e9SAndroid Build Coastguard Worker        return None
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker    sys_info = dict(
798*da0073e9SAndroid Build Coastguard Worker        protocol_version=PROTOCOL_VERSION,
799*da0073e9SAndroid Build Coastguard Worker        little_endian=sys.byteorder == 'little',
800*da0073e9SAndroid Build Coastguard Worker        type_sizes=dict(
801*da0073e9SAndroid Build Coastguard Worker            short=SHORT_SIZE,
802*da0073e9SAndroid Build Coastguard Worker            int=INT_SIZE,
803*da0073e9SAndroid Build Coastguard Worker            long=LONG_SIZE,
804*da0073e9SAndroid Build Coastguard Worker        ),
805*da0073e9SAndroid Build Coastguard Worker    )
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
808*da0073e9SAndroid Build Coastguard Worker    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
809*da0073e9SAndroid Build Coastguard Worker    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
810*da0073e9SAndroid Build Coastguard Worker    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
811*da0073e9SAndroid Build Coastguard Worker    pickler.persistent_id = persistent_id
812*da0073e9SAndroid Build Coastguard Worker    pickler.dump(obj)
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker    serialized_storage_keys = sorted(serialized_storages.keys())
815*da0073e9SAndroid Build Coastguard Worker    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
816*da0073e9SAndroid Build Coastguard Worker    f.flush()
817*da0073e9SAndroid Build Coastguard Worker    for key in serialized_storage_keys:
818*da0073e9SAndroid Build Coastguard Worker        storage, dtype = serialized_storages[key]
819*da0073e9SAndroid Build Coastguard Worker        storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Workerdef _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
823*da0073e9SAndroid Build Coastguard Worker    serialized_storages = {}
824*da0073e9SAndroid Build Coastguard Worker    id_map: Dict[int, str] = {}
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker    # Since loading storages that view the same data with different dtypes is
827*da0073e9SAndroid Build Coastguard Worker    # not supported, we need to keep track of the dtype associated with each
828*da0073e9SAndroid Build Coastguard Worker    # storage data_ptr and throw an error if the dtype is ever different.
829*da0073e9SAndroid Build Coastguard Worker    # TODO: This feature could be added in the future
830*da0073e9SAndroid Build Coastguard Worker    storage_dtypes: Dict[int, torch.dtype] = {}
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker    def persistent_id(obj):
833*da0073e9SAndroid Build Coastguard Worker        # FIXME: the docs say that persistent_id should only return a string
834*da0073e9SAndroid Build Coastguard Worker        # but torch store returns tuples. This works only in the binary protocol
835*da0073e9SAndroid Build Coastguard Worker        # see
836*da0073e9SAndroid Build Coastguard Worker        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
837*da0073e9SAndroid Build Coastguard Worker        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
838*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker            if isinstance(obj, torch.storage.TypedStorage):
841*da0073e9SAndroid Build Coastguard Worker                # TODO: Once we decide to break serialization FC, this case
842*da0073e9SAndroid Build Coastguard Worker                # can be deleted
843*da0073e9SAndroid Build Coastguard Worker                storage = obj._untyped_storage
844*da0073e9SAndroid Build Coastguard Worker                storage_dtype = obj.dtype
845*da0073e9SAndroid Build Coastguard Worker                storage_type_str = obj._pickle_storage_type()
846*da0073e9SAndroid Build Coastguard Worker                storage_type = getattr(torch, storage_type_str)
847*da0073e9SAndroid Build Coastguard Worker                storage_numel = obj._size()
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker            else:
850*da0073e9SAndroid Build Coastguard Worker                storage = obj
851*da0073e9SAndroid Build Coastguard Worker                storage_dtype = torch.uint8
852*da0073e9SAndroid Build Coastguard Worker                storage_type = normalize_storage_type(type(obj))
853*da0073e9SAndroid Build Coastguard Worker                storage_numel = storage.nbytes()
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker            # If storage is allocated, ensure that any other saved storages
856*da0073e9SAndroid Build Coastguard Worker            # pointing to the same data all have the same dtype. If storage is
857*da0073e9SAndroid Build Coastguard Worker            # not allocated, don't perform this check
858*da0073e9SAndroid Build Coastguard Worker            if storage.data_ptr() != 0:
859*da0073e9SAndroid Build Coastguard Worker                if storage.data_ptr() in storage_dtypes:
860*da0073e9SAndroid Build Coastguard Worker                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
861*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(
862*da0073e9SAndroid Build Coastguard Worker                            'Cannot save multiple tensors or storages that '
863*da0073e9SAndroid Build Coastguard Worker                            'view the same data as different types')
864*da0073e9SAndroid Build Coastguard Worker                else:
865*da0073e9SAndroid Build Coastguard Worker                    storage_dtypes[storage.data_ptr()] = storage_dtype
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
868*da0073e9SAndroid Build Coastguard Worker            location = location_tag(storage)
869*da0073e9SAndroid Build Coastguard Worker            serialized_storages[storage_key] = storage
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker            return ('storage',
872*da0073e9SAndroid Build Coastguard Worker                    storage_type,
873*da0073e9SAndroid Build Coastguard Worker                    storage_key,
874*da0073e9SAndroid Build Coastguard Worker                    location,
875*da0073e9SAndroid Build Coastguard Worker                    storage_numel)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker        return None
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker    # Write the pickle data for `obj`
880*da0073e9SAndroid Build Coastguard Worker    data_buf = io.BytesIO()
881*da0073e9SAndroid Build Coastguard Worker    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
882*da0073e9SAndroid Build Coastguard Worker    pickler.persistent_id = persistent_id
883*da0073e9SAndroid Build Coastguard Worker    pickler.dump(obj)
884*da0073e9SAndroid Build Coastguard Worker    data_value = data_buf.getvalue()
885*da0073e9SAndroid Build Coastguard Worker    zip_file.write_record('data.pkl', data_value, len(data_value))
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker    # Write byte order marker
888*da0073e9SAndroid Build Coastguard Worker    if not _disable_byteorder_record:
889*da0073e9SAndroid Build Coastguard Worker        if sys.byteorder not in ['little', 'big']:
890*da0073e9SAndroid Build Coastguard Worker            raise ValueError('Unknown endianness type: ' + sys.byteorder)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
895*da0073e9SAndroid Build Coastguard Worker    for key in sorted(serialized_storages.keys()):
896*da0073e9SAndroid Build Coastguard Worker        name = f'data/{key}'
897*da0073e9SAndroid Build Coastguard Worker        storage = serialized_storages[key]
898*da0073e9SAndroid Build Coastguard Worker        # given that we copy things around anyway, we might use storage.cpu()
899*da0073e9SAndroid Build Coastguard Worker        # this means to that to get tensors serialized, you need to implement
900*da0073e9SAndroid Build Coastguard Worker        # .cpu() on the underlying Storage
901*da0073e9SAndroid Build Coastguard Worker        if storage.device.type != 'cpu':
902*da0073e9SAndroid Build Coastguard Worker            storage = storage.cpu()
903*da0073e9SAndroid Build Coastguard Worker        # Now that it is on the CPU we can directly copy it into the zip file
904*da0073e9SAndroid Build Coastguard Worker        num_bytes = storage.nbytes()
905*da0073e9SAndroid Build Coastguard Worker        zip_file.write_record(name, storage, num_bytes)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Workerdef load(
909*da0073e9SAndroid Build Coastguard Worker    f: FILE_LIKE,
910*da0073e9SAndroid Build Coastguard Worker    map_location: MAP_LOCATION = None,
911*da0073e9SAndroid Build Coastguard Worker    pickle_module: Any = None,
912*da0073e9SAndroid Build Coastguard Worker    *,
913*da0073e9SAndroid Build Coastguard Worker    weights_only: Optional[bool] = None,
914*da0073e9SAndroid Build Coastguard Worker    mmap: Optional[bool] = None,
915*da0073e9SAndroid Build Coastguard Worker    **pickle_load_args: Any
916*da0073e9SAndroid Build Coastguard Worker) -> Any:
917*da0073e9SAndroid Build Coastguard Worker    # Reference: https://github.com/pytorch/pytorch/issues/54354
918*da0073e9SAndroid Build Coastguard Worker    # The first line of this docstring overrides the one Sphinx generates for the
919*da0073e9SAndroid Build Coastguard Worker    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
920*da0073e9SAndroid Build Coastguard Worker    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker    """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    Loads an object saved with :func:`torch.save` from a file.
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker    :func:`torch.load` uses Python's unpickling facilities but treats storages,
927*da0073e9SAndroid Build Coastguard Worker    which underlie tensors, specially. They are first deserialized on the
928*da0073e9SAndroid Build Coastguard Worker    CPU and are then moved to the device they were saved from. If this fails
929*da0073e9SAndroid Build Coastguard Worker    (e.g. because the run time system doesn't have certain devices), an exception
930*da0073e9SAndroid Build Coastguard Worker    is raised. However, storages can be dynamically remapped to an alternative
931*da0073e9SAndroid Build Coastguard Worker    set of devices using the :attr:`map_location` argument.
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker    If :attr:`map_location` is a callable, it will be called once for each serialized
934*da0073e9SAndroid Build Coastguard Worker    storage with two arguments: storage and location. The storage argument
935*da0073e9SAndroid Build Coastguard Worker    will be the initial deserialization of the storage, residing on the CPU.
936*da0073e9SAndroid Build Coastguard Worker    Each serialized storage has a location tag associated with it which
937*da0073e9SAndroid Build Coastguard Worker    identifies the device it was saved from, and this tag is the second
938*da0073e9SAndroid Build Coastguard Worker    argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
939*da0073e9SAndroid Build Coastguard Worker    for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
940*da0073e9SAndroid Build Coastguard Worker    :attr:`map_location` should return either ``None`` or a storage. If
941*da0073e9SAndroid Build Coastguard Worker    :attr:`map_location` returns a storage, it will be used as the final deserialized
942*da0073e9SAndroid Build Coastguard Worker    object, already moved to the right device. Otherwise, :func:`torch.load` will
943*da0073e9SAndroid Build Coastguard Worker    fall back to the default behavior, as if :attr:`map_location` wasn't specified.
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker    If :attr:`map_location` is a :class:`torch.device` object or a string containing
946*da0073e9SAndroid Build Coastguard Worker    a device tag, it indicates the location where all tensors should be loaded.
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
949*da0073e9SAndroid Build Coastguard Worker    appearing in the file (keys), to ones that specify where to put the
950*da0073e9SAndroid Build Coastguard Worker    storages (values).
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    User extensions can register their own location tags and tagging and
953*da0073e9SAndroid Build Coastguard Worker    deserialization methods using :func:`torch.serialization.register_package`.
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    Args:
956*da0073e9SAndroid Build Coastguard Worker        f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
957*da0073e9SAndroid Build Coastguard Worker            or a string or os.PathLike object containing a file name
958*da0073e9SAndroid Build Coastguard Worker        map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
959*da0073e9SAndroid Build Coastguard Worker            locations
960*da0073e9SAndroid Build Coastguard Worker        pickle_module: module used for unpickling metadata and objects (has to
961*da0073e9SAndroid Build Coastguard Worker            match the :attr:`pickle_module` used to serialize file)
962*da0073e9SAndroid Build Coastguard Worker        weights_only: Indicates whether unpickler should be restricted to
963*da0073e9SAndroid Build Coastguard Worker            loading only tensors, primitive types, dictionaries
964*da0073e9SAndroid Build Coastguard Worker            and any types added via :func:`torch.serialization.add_safe_globals`.
965*da0073e9SAndroid Build Coastguard Worker        mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
966*da0073e9SAndroid Build Coastguard Worker            Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
967*da0073e9SAndroid Build Coastguard Worker            are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
968*da0073e9SAndroid Build Coastguard Worker            second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
969*da0073e9SAndroid Build Coastguard Worker            tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
970*da0073e9SAndroid Build Coastguard Worker        pickle_load_args: (Python 3 only) optional keyword arguments passed over to
971*da0073e9SAndroid Build Coastguard Worker            :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
972*da0073e9SAndroid Build Coastguard Worker            :attr:`errors=...`.
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker    .. warning::
975*da0073e9SAndroid Build Coastguard Worker        :func:`torch.load()` unless `weights_only` parameter is set to `True`,
976*da0073e9SAndroid Build Coastguard Worker        uses ``pickle`` module implicitly, which is known to be insecure.
977*da0073e9SAndroid Build Coastguard Worker        It is possible to construct malicious pickle data which will execute arbitrary code
978*da0073e9SAndroid Build Coastguard Worker        during unpickling. Never load data that could have come from an untrusted
979*da0073e9SAndroid Build Coastguard Worker        source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    .. note::
982*da0073e9SAndroid Build Coastguard Worker        When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
983*da0073e9SAndroid Build Coastguard Worker        will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
984*da0073e9SAndroid Build Coastguard Worker        and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    .. note::
987*da0073e9SAndroid Build Coastguard Worker        By default, we decode byte strings as ``utf-8``.  This is to avoid a common error
988*da0073e9SAndroid Build Coastguard Worker        case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
989*da0073e9SAndroid Build Coastguard Worker        when loading files saved by Python 2 in Python 3.  If this default
990*da0073e9SAndroid Build Coastguard Worker        is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
991*da0073e9SAndroid Build Coastguard Worker        these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
992*da0073e9SAndroid Build Coastguard Worker        to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
993*da0073e9SAndroid Build Coastguard Worker        as byte arrays which can be decoded later with ``byte_array.decode(...)``.
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker    Example:
996*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("undefined filepaths")
997*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('tensors.pt', weights_only=True)
998*da0073e9SAndroid Build Coastguard Worker        # Load all tensors onto the CPU
999*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
1000*da0073e9SAndroid Build Coastguard Worker        # Load all tensors onto the CPU, using a function
1001*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
1002*da0073e9SAndroid Build Coastguard Worker        # Load all tensors onto GPU 1
1003*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
1004*da0073e9SAndroid Build Coastguard Worker        # Map tensors from GPU 1 to GPU 0
1005*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
1006*da0073e9SAndroid Build Coastguard Worker        # Load tensor from io.BytesIO object
1007*da0073e9SAndroid Build Coastguard Worker        # Loading from a buffer setting weights_only=False, warning this can be unsafe
1008*da0073e9SAndroid Build Coastguard Worker        >>> with open('tensor.pt', 'rb') as f:
1009*da0073e9SAndroid Build Coastguard Worker        ...     buffer = io.BytesIO(f.read())
1010*da0073e9SAndroid Build Coastguard Worker        >>> torch.load(buffer, weights_only=False)
1011*da0073e9SAndroid Build Coastguard Worker        # Load a module with 'ascii' encoding for unpickling
1012*da0073e9SAndroid Build Coastguard Worker        # Loading from a module setting weights_only=False, warning this can be unsafe
1013*da0073e9SAndroid Build Coastguard Worker        >>> torch.load('module.pt', encoding='ascii', weights_only=False)
1014*da0073e9SAndroid Build Coastguard Worker    """
1015*da0073e9SAndroid Build Coastguard Worker    torch._C._log_api_usage_once("torch.load")
1016*da0073e9SAndroid Build Coastguard Worker    UNSAFE_MESSAGE = (
1017*da0073e9SAndroid Build Coastguard Worker        "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1018*da0073e9SAndroid Build Coastguard Worker        "but it can result in arbitrary code execution. Do it only if you got the file from a "
1019*da0073e9SAndroid Build Coastguard Worker        "trusted source."
1020*da0073e9SAndroid Build Coastguard Worker    )
1021*da0073e9SAndroid Build Coastguard Worker    DOCS_MESSAGE = (
1022*da0073e9SAndroid Build Coastguard Worker        "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
1023*da0073e9SAndroid Build Coastguard Worker        "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
1024*da0073e9SAndroid Build Coastguard Worker    )
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    def _get_wo_message(message: str) -> str:
1027*da0073e9SAndroid Build Coastguard Worker        pattern = r"GLOBAL (\S+) was not an allowed global by default."
1028*da0073e9SAndroid Build Coastguard Worker        has_unsafe_global = re.search(pattern, message) is not None
1029*da0073e9SAndroid Build Coastguard Worker        if has_unsafe_global:
1030*da0073e9SAndroid Build Coastguard Worker            updated_message = (
1031*da0073e9SAndroid Build Coastguard Worker                "Weights only load failed. This file can still be loaded, to do so you have two options "
1032*da0073e9SAndroid Build Coastguard Worker                f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
1033*da0073e9SAndroid Build Coastguard Worker                "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
1034*da0073e9SAndroid Build Coastguard Worker                + message
1035*da0073e9SAndroid Build Coastguard Worker            )
1036*da0073e9SAndroid Build Coastguard Worker        else:
1037*da0073e9SAndroid Build Coastguard Worker            updated_message = (
1038*da0073e9SAndroid Build Coastguard Worker                f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
1039*da0073e9SAndroid Build Coastguard Worker                "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
1040*da0073e9SAndroid Build Coastguard Worker                "error: " + message
1041*da0073e9SAndroid Build Coastguard Worker            )
1042*da0073e9SAndroid Build Coastguard Worker        return updated_message + DOCS_MESSAGE
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker    if weights_only is None:
1045*da0073e9SAndroid Build Coastguard Worker        weights_only, warn_weights_only = False, True
1046*da0073e9SAndroid Build Coastguard Worker    else:
1047*da0073e9SAndroid Build Coastguard Worker        warn_weights_only = False
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker    # Add ability to force safe only weight loads via environment variable
1050*da0073e9SAndroid Build Coastguard Worker    if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
1051*da0073e9SAndroid Build Coastguard Worker        weights_only = True
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    if weights_only:
1054*da0073e9SAndroid Build Coastguard Worker        if pickle_module is not None:
1055*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
1056*da0073e9SAndroid Build Coastguard Worker    else:
1057*da0073e9SAndroid Build Coastguard Worker        if pickle_module is None:
1058*da0073e9SAndroid Build Coastguard Worker            if warn_weights_only:
1059*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
1060*da0073e9SAndroid Build Coastguard Worker                    "You are using `torch.load` with `weights_only=False` (the current default value), which uses "
1061*da0073e9SAndroid Build Coastguard Worker                    "the default pickle module implicitly. It is possible to construct malicious pickle data "
1062*da0073e9SAndroid Build Coastguard Worker                    "which will execute arbitrary code during unpickling (See "
1063*da0073e9SAndroid Build Coastguard Worker                    "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
1064*da0073e9SAndroid Build Coastguard Worker                    "In a future release, the default value for `weights_only` will be flipped to `True`. This "
1065*da0073e9SAndroid Build Coastguard Worker                    "limits the functions that could be executed during unpickling. Arbitrary objects will no "
1066*da0073e9SAndroid Build Coastguard Worker                    "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
1067*da0073e9SAndroid Build Coastguard Worker                    "user via `torch.serialization.add_safe_globals`. We recommend you start setting "
1068*da0073e9SAndroid Build Coastguard Worker                    "`weights_only=True` for any use case where you don't have full control of the loaded file. "
1069*da0073e9SAndroid Build Coastguard Worker                    "Please open an issue on GitHub for any issues related to this experimental feature.",
1070*da0073e9SAndroid Build Coastguard Worker                    FutureWarning,
1071*da0073e9SAndroid Build Coastguard Worker                    stacklevel=2,
1072*da0073e9SAndroid Build Coastguard Worker                )
1073*da0073e9SAndroid Build Coastguard Worker            pickle_module = pickle
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    # make flipping default BC-compatible
1076*da0073e9SAndroid Build Coastguard Worker    if mmap is None:
1077*da0073e9SAndroid Build Coastguard Worker        mmap = False
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker    _check_dill_version(pickle_module)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker    if 'encoding' not in pickle_load_args.keys():
1082*da0073e9SAndroid Build Coastguard Worker        pickle_load_args['encoding'] = 'utf-8'
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker    with _open_file_like(f, 'rb') as opened_file:
1085*da0073e9SAndroid Build Coastguard Worker        if _is_zipfile(opened_file):
1086*da0073e9SAndroid Build Coastguard Worker            # The zipfile reader is going to advance the current file position.
1087*da0073e9SAndroid Build Coastguard Worker            # If we want to actually tail call to torch.jit.load, we need to
1088*da0073e9SAndroid Build Coastguard Worker            # reset back to the original position.
1089*da0073e9SAndroid Build Coastguard Worker            orig_position = opened_file.tell()
1090*da0073e9SAndroid Build Coastguard Worker            overall_storage = None
1091*da0073e9SAndroid Build Coastguard Worker            with _open_zipfile_reader(opened_file) as opened_zipfile:
1092*da0073e9SAndroid Build Coastguard Worker                if _is_torchscript_zip(opened_zipfile):
1093*da0073e9SAndroid Build Coastguard Worker                    warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
1094*da0073e9SAndroid Build Coastguard Worker                                  " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
1095*da0073e9SAndroid Build Coastguard Worker                                  " silence this warning)", UserWarning)
1096*da0073e9SAndroid Build Coastguard Worker                    opened_file.seek(orig_position)
1097*da0073e9SAndroid Build Coastguard Worker                    return torch.jit.load(opened_file, map_location=map_location)
1098*da0073e9SAndroid Build Coastguard Worker                if mmap:
1099*da0073e9SAndroid Build Coastguard Worker                    if not _is_path(f):
1100*da0073e9SAndroid Build Coastguard Worker                        raise ValueError("f must be a file path in order to use the mmap argument")
1101*da0073e9SAndroid Build Coastguard Worker                    size = os.path.getsize(f)
1102*da0073e9SAndroid Build Coastguard Worker                    if not IS_WINDOWS:
1103*da0073e9SAndroid Build Coastguard Worker                        shared = get_default_mmap_options() == MAP_SHARED
1104*da0073e9SAndroid Build Coastguard Worker                    else:
1105*da0073e9SAndroid Build Coastguard Worker                        shared = False
1106*da0073e9SAndroid Build Coastguard Worker                    overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
1107*da0073e9SAndroid Build Coastguard Worker                if weights_only:
1108*da0073e9SAndroid Build Coastguard Worker                    try:
1109*da0073e9SAndroid Build Coastguard Worker                        return _load(opened_zipfile,
1110*da0073e9SAndroid Build Coastguard Worker                                     map_location,
1111*da0073e9SAndroid Build Coastguard Worker                                     _weights_only_unpickler,
1112*da0073e9SAndroid Build Coastguard Worker                                     overall_storage=overall_storage,
1113*da0073e9SAndroid Build Coastguard Worker                                     **pickle_load_args)
1114*da0073e9SAndroid Build Coastguard Worker                    except RuntimeError as e:
1115*da0073e9SAndroid Build Coastguard Worker                        raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1116*da0073e9SAndroid Build Coastguard Worker                return _load(
1117*da0073e9SAndroid Build Coastguard Worker                    opened_zipfile,
1118*da0073e9SAndroid Build Coastguard Worker                    map_location,
1119*da0073e9SAndroid Build Coastguard Worker                    pickle_module,
1120*da0073e9SAndroid Build Coastguard Worker                    overall_storage=overall_storage,
1121*da0073e9SAndroid Build Coastguard Worker                    **pickle_load_args,
1122*da0073e9SAndroid Build Coastguard Worker                )
1123*da0073e9SAndroid Build Coastguard Worker        if mmap:
1124*da0073e9SAndroid Build Coastguard Worker            f_name = "" if not isinstance(f, str) else f"{f}, "
1125*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("mmap can only be used with files saved with "
1126*da0073e9SAndroid Build Coastguard Worker                               f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
1127*da0073e9SAndroid Build Coastguard Worker                               "please torch.save your checkpoint with this option in order to use mmap.")
1128*da0073e9SAndroid Build Coastguard Worker        if weights_only:
1129*da0073e9SAndroid Build Coastguard Worker            try:
1130*da0073e9SAndroid Build Coastguard Worker                return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
1131*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
1132*da0073e9SAndroid Build Coastguard Worker                raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1133*da0073e9SAndroid Build Coastguard Worker        return _legacy_load(
1134*da0073e9SAndroid Build Coastguard Worker            opened_file, map_location, pickle_module, **pickle_load_args
1135*da0073e9SAndroid Build Coastguard Worker        )
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker# Register pickling support for layout instances such as
1139*da0073e9SAndroid Build Coastguard Worker# torch.sparse_coo, etc
1140*da0073e9SAndroid Build Coastguard Workerdef _get_layout(name):
1141*da0073e9SAndroid Build Coastguard Worker    """Get layout extension object from its string representation.
1142*da0073e9SAndroid Build Coastguard Worker    """
1143*da0073e9SAndroid Build Coastguard Worker    cache = _get_layout.cache   # type: ignore[attr-defined]
1144*da0073e9SAndroid Build Coastguard Worker    if not cache:
1145*da0073e9SAndroid Build Coastguard Worker        for v in torch.__dict__.values():
1146*da0073e9SAndroid Build Coastguard Worker            if isinstance(v, torch.layout):
1147*da0073e9SAndroid Build Coastguard Worker                cache[str(v)] = v
1148*da0073e9SAndroid Build Coastguard Worker    return cache[name]
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
1151*da0073e9SAndroid Build Coastguard Worker_get_layout.cache = {}   # type: ignore[attr-defined]
1152*da0073e9SAndroid Build Coastguard Workercopyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Workerdef _legacy_load(f, map_location, pickle_module, **pickle_load_args):
1156*da0073e9SAndroid Build Coastguard Worker    deserialized_objects: Dict[int, Any] = {}
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker    restore_location = _get_restore_location(map_location)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        def find_class(self, mod_name, name):
1163*da0073e9SAndroid Build Coastguard Worker            if type(name) is str and 'Storage' in name:
1164*da0073e9SAndroid Build Coastguard Worker                try:
1165*da0073e9SAndroid Build Coastguard Worker                    return StorageType(name)
1166*da0073e9SAndroid Build Coastguard Worker                except KeyError:
1167*da0073e9SAndroid Build Coastguard Worker                    pass
1168*da0073e9SAndroid Build Coastguard Worker            return super().find_class(mod_name, name)
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    def _check_container_source(container_type, source_file, original_source):
1171*da0073e9SAndroid Build Coastguard Worker        try:
1172*da0073e9SAndroid Build Coastguard Worker            current_source = ''.join(get_source_lines_and_file(container_type)[0])
1173*da0073e9SAndroid Build Coastguard Worker        except Exception:  # saving the source is optional, so we can ignore any errors
1174*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Couldn't retrieve source code for container of "
1175*da0073e9SAndroid Build Coastguard Worker                          "type " + container_type.__name__ + ". It won't be checked "
1176*da0073e9SAndroid Build Coastguard Worker                          "for correctness upon loading.")
1177*da0073e9SAndroid Build Coastguard Worker            return
1178*da0073e9SAndroid Build Coastguard Worker        if original_source != current_source:
1179*da0073e9SAndroid Build Coastguard Worker            if container_type.dump_patches:
1180*da0073e9SAndroid Build Coastguard Worker                file_name = container_type.__name__ + '.patch'
1181*da0073e9SAndroid Build Coastguard Worker                diff = difflib.unified_diff(current_source.split('\n'),
1182*da0073e9SAndroid Build Coastguard Worker                                            original_source.split('\n'),
1183*da0073e9SAndroid Build Coastguard Worker                                            source_file,
1184*da0073e9SAndroid Build Coastguard Worker                                            source_file, lineterm="")
1185*da0073e9SAndroid Build Coastguard Worker                lines = '\n'.join(diff)
1186*da0073e9SAndroid Build Coastguard Worker                try:
1187*da0073e9SAndroid Build Coastguard Worker                    with open(file_name, 'a+') as f:
1188*da0073e9SAndroid Build Coastguard Worker                        file_size = f.seek(0, 2)
1189*da0073e9SAndroid Build Coastguard Worker                        f.seek(0)
1190*da0073e9SAndroid Build Coastguard Worker                        if file_size == 0:
1191*da0073e9SAndroid Build Coastguard Worker                            f.write(lines)
1192*da0073e9SAndroid Build Coastguard Worker                        elif file_size != len(lines) or f.read() != lines:
1193*da0073e9SAndroid Build Coastguard Worker                            raise OSError
1194*da0073e9SAndroid Build Coastguard Worker                    msg = ("Saved a reverse patch to " + file_name + ". "
1195*da0073e9SAndroid Build Coastguard Worker                           "Run `patch -p0 < " + file_name + "` to revert your "
1196*da0073e9SAndroid Build Coastguard Worker                           "changes.")
1197*da0073e9SAndroid Build Coastguard Worker                except OSError:
1198*da0073e9SAndroid Build Coastguard Worker                    msg = ("Tried to save a patch, but couldn't create a "
1199*da0073e9SAndroid Build Coastguard Worker                           "writable file " + file_name + ". Make sure it "
1200*da0073e9SAndroid Build Coastguard Worker                           "doesn't exist and your working directory is "
1201*da0073e9SAndroid Build Coastguard Worker                           "writable.")
1202*da0073e9SAndroid Build Coastguard Worker            else:
1203*da0073e9SAndroid Build Coastguard Worker                msg = ("you can retrieve the original source code by "
1204*da0073e9SAndroid Build Coastguard Worker                       "accessing the object's source attribute or set "
1205*da0073e9SAndroid Build Coastguard Worker                       "`torch.nn.Module.dump_patches = True` and use the "
1206*da0073e9SAndroid Build Coastguard Worker                       "patch tool to revert the changes.")
1207*da0073e9SAndroid Build Coastguard Worker            msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
1208*da0073e9SAndroid Build Coastguard Worker            warnings.warn(msg, SourceChangeWarning)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker    def legacy_load(f):
1211*da0073e9SAndroid Build Coastguard Worker        deserialized_objects: Dict[int, Any] = {}
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker        def persistent_load(saved_id):
1214*da0073e9SAndroid Build Coastguard Worker            if isinstance(saved_id, tuple):
1215*da0073e9SAndroid Build Coastguard Worker                # Ignore containers that don't have any sources saved
1216*da0073e9SAndroid Build Coastguard Worker                if all(saved_id[1:]):
1217*da0073e9SAndroid Build Coastguard Worker                    _check_container_source(*saved_id)
1218*da0073e9SAndroid Build Coastguard Worker                return saved_id[0]
1219*da0073e9SAndroid Build Coastguard Worker            return deserialized_objects[int(saved_id)]
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
1222*da0073e9SAndroid Build Coastguard Worker                mkdtemp() as tmpdir:
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker            tar.extract('storages', path=tmpdir)
1225*da0073e9SAndroid Build Coastguard Worker            with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
1226*da0073e9SAndroid Build Coastguard Worker                num_storages = pickle_module.load(f, **pickle_load_args)
1227*da0073e9SAndroid Build Coastguard Worker                for i in range(num_storages):
1228*da0073e9SAndroid Build Coastguard Worker                    args = pickle_module.load(f, **pickle_load_args)
1229*da0073e9SAndroid Build Coastguard Worker                    key, location, storage_type = args
1230*da0073e9SAndroid Build Coastguard Worker                    dtype = storage_type._dtype
1231*da0073e9SAndroid Build Coastguard Worker                    obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
1232*da0073e9SAndroid Build Coastguard Worker                    obj = restore_location(obj, location)
1233*da0073e9SAndroid Build Coastguard Worker                    # TODO: Once we decide to break serialization FC, we can
1234*da0073e9SAndroid Build Coastguard Worker                    # stop wrapping with TypedStorage
1235*da0073e9SAndroid Build Coastguard Worker                    deserialized_objects[key] = torch.storage.TypedStorage(
1236*da0073e9SAndroid Build Coastguard Worker                        wrap_storage=obj,
1237*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
1238*da0073e9SAndroid Build Coastguard Worker                        _internal=True)
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker                storage_views = pickle_module.load(f, **pickle_load_args)
1241*da0073e9SAndroid Build Coastguard Worker                for target_cdata, root_cdata, offset, numel in storage_views:
1242*da0073e9SAndroid Build Coastguard Worker                    root = deserialized_objects[root_cdata]
1243*da0073e9SAndroid Build Coastguard Worker                    element_size = torch._utils._element_size(root.dtype)
1244*da0073e9SAndroid Build Coastguard Worker                    offset_bytes = offset * element_size
1245*da0073e9SAndroid Build Coastguard Worker                    # TODO: Once we decide to break serialization FC, we can
1246*da0073e9SAndroid Build Coastguard Worker                    # stop wrapping with TypedStorage
1247*da0073e9SAndroid Build Coastguard Worker                    deserialized_objects[target_cdata] = torch.storage.TypedStorage(
1248*da0073e9SAndroid Build Coastguard Worker                        wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
1249*da0073e9SAndroid Build Coastguard Worker                        dtype=root.dtype,
1250*da0073e9SAndroid Build Coastguard Worker                        _internal=True)
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker            tar.extract('tensors', path=tmpdir)
1253*da0073e9SAndroid Build Coastguard Worker            with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
1254*da0073e9SAndroid Build Coastguard Worker                num_tensors = pickle_module.load(f, **pickle_load_args)
1255*da0073e9SAndroid Build Coastguard Worker                for _ in range(num_tensors):
1256*da0073e9SAndroid Build Coastguard Worker                    args = pickle_module.load(f, **pickle_load_args)
1257*da0073e9SAndroid Build Coastguard Worker                    key, storage_id, original_tensor_type = args
1258*da0073e9SAndroid Build Coastguard Worker                    storage = deserialized_objects[storage_id]
1259*da0073e9SAndroid Build Coastguard Worker                    ndim, = struct.unpack('<i', f.read(4))
1260*da0073e9SAndroid Build Coastguard Worker                    # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
1261*da0073e9SAndroid Build Coastguard Worker                    f.read(4)
1262*da0073e9SAndroid Build Coastguard Worker                    numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1263*da0073e9SAndroid Build Coastguard Worker                    stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1264*da0073e9SAndroid Build Coastguard Worker                    storage_offset, = struct.unpack('<q', f.read(8))
1265*da0073e9SAndroid Build Coastguard Worker                    tensor = torch.empty((0,), dtype=storage.dtype).set_(
1266*da0073e9SAndroid Build Coastguard Worker                        storage._untyped_storage, storage_offset, numel, stride)
1267*da0073e9SAndroid Build Coastguard Worker                    deserialized_objects[key] = tensor
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker            pickle_file = tar.extractfile('pickle')
1270*da0073e9SAndroid Build Coastguard Worker            unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
1271*da0073e9SAndroid Build Coastguard Worker            unpickler.persistent_load = persistent_load
1272*da0073e9SAndroid Build Coastguard Worker            result = unpickler.load()
1273*da0073e9SAndroid Build Coastguard Worker            return result
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker    deserialized_objects = {}
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker    def persistent_load(saved_id):
1278*da0073e9SAndroid Build Coastguard Worker        assert isinstance(saved_id, tuple)
1279*da0073e9SAndroid Build Coastguard Worker        typename = _maybe_decode_ascii(saved_id[0])
1280*da0073e9SAndroid Build Coastguard Worker        data = saved_id[1:]
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker        if typename == 'module':
1283*da0073e9SAndroid Build Coastguard Worker            # Ignore containers that don't have any sources saved
1284*da0073e9SAndroid Build Coastguard Worker            if all(data[1:]):
1285*da0073e9SAndroid Build Coastguard Worker                _check_container_source(*data)
1286*da0073e9SAndroid Build Coastguard Worker            return data[0]
1287*da0073e9SAndroid Build Coastguard Worker        elif typename == 'storage':
1288*da0073e9SAndroid Build Coastguard Worker            storage_type, root_key, location, numel, view_metadata = data
1289*da0073e9SAndroid Build Coastguard Worker            location = _maybe_decode_ascii(location)
1290*da0073e9SAndroid Build Coastguard Worker            dtype = storage_type.dtype
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker            nbytes = numel * torch._utils._element_size(dtype)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker            if root_key not in deserialized_objects:
1295*da0073e9SAndroid Build Coastguard Worker                if torch._guards.active_fake_mode() is not None:
1296*da0073e9SAndroid Build Coastguard Worker                    obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
1297*da0073e9SAndroid Build Coastguard Worker                else:
1298*da0073e9SAndroid Build Coastguard Worker                    obj = cast(Storage, torch.UntypedStorage(nbytes))
1299*da0073e9SAndroid Build Coastguard Worker                    obj._torch_load_uninitialized = True
1300*da0073e9SAndroid Build Coastguard Worker                    obj = restore_location(obj, location)
1301*da0073e9SAndroid Build Coastguard Worker                # TODO: Once we decide to break serialization FC, we can
1302*da0073e9SAndroid Build Coastguard Worker                # stop wrapping with TypedStorage
1303*da0073e9SAndroid Build Coastguard Worker                typed_storage = torch.storage.TypedStorage(
1304*da0073e9SAndroid Build Coastguard Worker                    wrap_storage=obj,
1305*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
1306*da0073e9SAndroid Build Coastguard Worker                    _internal=True)
1307*da0073e9SAndroid Build Coastguard Worker                deserialized_objects[root_key] = typed_storage
1308*da0073e9SAndroid Build Coastguard Worker            else:
1309*da0073e9SAndroid Build Coastguard Worker                typed_storage = deserialized_objects[root_key]
1310*da0073e9SAndroid Build Coastguard Worker                if typed_storage._data_ptr() == 0:
1311*da0073e9SAndroid Build Coastguard Worker                    typed_storage = torch.storage.TypedStorage(
1312*da0073e9SAndroid Build Coastguard Worker                        device=typed_storage._untyped_storage.device,
1313*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
1314*da0073e9SAndroid Build Coastguard Worker                        _internal=True)
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker            if view_metadata is not None:
1317*da0073e9SAndroid Build Coastguard Worker                view_key, offset, view_size = view_metadata
1318*da0073e9SAndroid Build Coastguard Worker                offset_bytes = offset * torch._utils._element_size(dtype)
1319*da0073e9SAndroid Build Coastguard Worker                view_size_bytes = view_size * torch._utils._element_size(dtype)
1320*da0073e9SAndroid Build Coastguard Worker                if view_key not in deserialized_objects:
1321*da0073e9SAndroid Build Coastguard Worker                    # TODO: Once we decide to break serialization FC, we can
1322*da0073e9SAndroid Build Coastguard Worker                    # stop wrapping with TypedStorage
1323*da0073e9SAndroid Build Coastguard Worker                    deserialized_objects[view_key] = torch.storage.TypedStorage(
1324*da0073e9SAndroid Build Coastguard Worker                        wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
1325*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
1326*da0073e9SAndroid Build Coastguard Worker                        _internal=True)
1327*da0073e9SAndroid Build Coastguard Worker                res = deserialized_objects[view_key]
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker            else:
1330*da0073e9SAndroid Build Coastguard Worker                res = typed_storage
1331*da0073e9SAndroid Build Coastguard Worker            return res
1332*da0073e9SAndroid Build Coastguard Worker        else:
1333*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker    _check_seekable(f)
1336*da0073e9SAndroid Build Coastguard Worker    f_should_read_directly = _should_read_directly(f)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    if f_should_read_directly and f.tell() == 0:
1339*da0073e9SAndroid Build Coastguard Worker        # legacy_load requires that f has fileno()
1340*da0073e9SAndroid Build Coastguard Worker        # only if offset is zero we can attempt the legacy tar file loader
1341*da0073e9SAndroid Build Coastguard Worker        try:
1342*da0073e9SAndroid Build Coastguard Worker            return legacy_load(f)
1343*da0073e9SAndroid Build Coastguard Worker        except tarfile.TarError:
1344*da0073e9SAndroid Build Coastguard Worker            if _is_zipfile(f):
1345*da0073e9SAndroid Build Coastguard Worker                # .zip is used for torch.jit.save and will throw an un-pickling error here
1346*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1347*da0073e9SAndroid Build Coastguard Worker                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
1348*da0073e9SAndroid Build Coastguard Worker            # if not a tarfile, reset file offset and proceed
1349*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker    if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
1352*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1353*da0073e9SAndroid Build Coastguard Worker            "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
1354*da0073e9SAndroid Build Coastguard Worker            f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
1355*da0073e9SAndroid Build Coastguard Worker            "functionality.")
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker    magic_number = pickle_module.load(f, **pickle_load_args)
1358*da0073e9SAndroid Build Coastguard Worker    if magic_number != MAGIC_NUMBER:
1359*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Invalid magic number; corrupt file?")
1360*da0073e9SAndroid Build Coastguard Worker    protocol_version = pickle_module.load(f, **pickle_load_args)
1361*da0073e9SAndroid Build Coastguard Worker    if protocol_version != PROTOCOL_VERSION:
1362*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Invalid protocol version: {protocol_version}")
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker    _sys_info = pickle_module.load(f, **pickle_load_args)
1365*da0073e9SAndroid Build Coastguard Worker    unpickler = UnpicklerWrapper(f, **pickle_load_args)
1366*da0073e9SAndroid Build Coastguard Worker    unpickler.persistent_load = persistent_load
1367*da0073e9SAndroid Build Coastguard Worker    result = unpickler.load()
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker    deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker    if torch._guards.active_fake_mode() is None:
1372*da0073e9SAndroid Build Coastguard Worker        offset = f.tell() if f_should_read_directly else None
1373*da0073e9SAndroid Build Coastguard Worker        for key in deserialized_storage_keys:
1374*da0073e9SAndroid Build Coastguard Worker            assert key in deserialized_objects
1375*da0073e9SAndroid Build Coastguard Worker            typed_storage = deserialized_objects[key]
1376*da0073e9SAndroid Build Coastguard Worker            typed_storage._untyped_storage._set_from_file(
1377*da0073e9SAndroid Build Coastguard Worker                f, offset, f_should_read_directly,
1378*da0073e9SAndroid Build Coastguard Worker                torch._utils._element_size(typed_storage.dtype))
1379*da0073e9SAndroid Build Coastguard Worker            if offset is not None:
1380*da0073e9SAndroid Build Coastguard Worker                offset = f.tell()
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker    torch._utils._validate_loaded_sparse_tensors()
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker    return result
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Workerdef _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
1388*da0073e9SAndroid Build Coastguard Worker    # When using encoding='bytes' in Py3, some **internal** keys stored as
1389*da0073e9SAndroid Build Coastguard Worker    # strings in Py2 are loaded as bytes. This function decodes them with
1390*da0073e9SAndroid Build Coastguard Worker    # ascii encoding, one that Py3 uses by default.
1391*da0073e9SAndroid Build Coastguard Worker    #
1392*da0073e9SAndroid Build Coastguard Worker    # NOTE: This should only be used on internal keys (e.g., `typename` and
1393*da0073e9SAndroid Build Coastguard Worker    #       `location` in `persistent_load` below!
1394*da0073e9SAndroid Build Coastguard Worker    if isinstance(bytes_str, bytes):
1395*da0073e9SAndroid Build Coastguard Worker        return bytes_str.decode('ascii')
1396*da0073e9SAndroid Build Coastguard Worker    return bytes_str
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Workerdef _get_restore_location(map_location):
1400*da0073e9SAndroid Build Coastguard Worker    if map_location is None:
1401*da0073e9SAndroid Build Coastguard Worker        restore_location = default_restore_location
1402*da0073e9SAndroid Build Coastguard Worker    elif isinstance(map_location, dict):
1403*da0073e9SAndroid Build Coastguard Worker        def restore_location(storage, location):
1404*da0073e9SAndroid Build Coastguard Worker            location = map_location.get(location, location)
1405*da0073e9SAndroid Build Coastguard Worker            return default_restore_location(storage, location)
1406*da0073e9SAndroid Build Coastguard Worker    elif isinstance(map_location, (str, bytes)):
1407*da0073e9SAndroid Build Coastguard Worker        def restore_location(storage, location):
1408*da0073e9SAndroid Build Coastguard Worker            return default_restore_location(storage, map_location)
1409*da0073e9SAndroid Build Coastguard Worker    elif isinstance(map_location, torch.device):
1410*da0073e9SAndroid Build Coastguard Worker        def restore_location(storage, location):
1411*da0073e9SAndroid Build Coastguard Worker            return default_restore_location(storage, str(map_location))
1412*da0073e9SAndroid Build Coastguard Worker    else:
1413*da0073e9SAndroid Build Coastguard Worker        def restore_location(storage, location):
1414*da0073e9SAndroid Build Coastguard Worker            result = map_location(storage, location)
1415*da0073e9SAndroid Build Coastguard Worker            if result is None:
1416*da0073e9SAndroid Build Coastguard Worker                result = default_restore_location(storage, location)
1417*da0073e9SAndroid Build Coastguard Worker            return result
1418*da0073e9SAndroid Build Coastguard Worker    return restore_location
1419*da0073e9SAndroid Build Coastguard Worker
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Workerclass StorageType:
1422*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
1423*da0073e9SAndroid Build Coastguard Worker        self._dtype = _get_dtype_from_pickle_storage_type(name)
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker    @property
1426*da0073e9SAndroid Build Coastguard Worker    def dtype(self):
1427*da0073e9SAndroid Build Coastguard Worker        return self._dtype
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
1430*da0073e9SAndroid Build Coastguard Worker        return f'StorageType(dtype={self.dtype})'
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Workerdef _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
1434*da0073e9SAndroid Build Coastguard Worker    restore_location = _get_restore_location(map_location)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    loaded_storages = {}
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker    # check if byteswapping is needed
1439*da0073e9SAndroid Build Coastguard Worker    byteordername = 'byteorder'
1440*da0073e9SAndroid Build Coastguard Worker    byteorderdata = None
1441*da0073e9SAndroid Build Coastguard Worker    if zip_file.has_record(byteordername):
1442*da0073e9SAndroid Build Coastguard Worker        byteorderdata = zip_file.get_record(byteordername)
1443*da0073e9SAndroid Build Coastguard Worker        if byteorderdata not in [b'little', b'big']:
1444*da0073e9SAndroid Build Coastguard Worker            raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
1445*da0073e9SAndroid Build Coastguard Worker    elif get_default_load_endianness() == LoadEndianness.LITTLE or \
1446*da0073e9SAndroid Build Coastguard Worker            get_default_load_endianness() is None:
1447*da0073e9SAndroid Build Coastguard Worker        byteorderdata = b'little'
1448*da0073e9SAndroid Build Coastguard Worker    elif get_default_load_endianness() == LoadEndianness.BIG:
1449*da0073e9SAndroid Build Coastguard Worker        byteorderdata = b'big'
1450*da0073e9SAndroid Build Coastguard Worker    elif get_default_load_endianness() == LoadEndianness.NATIVE:
1451*da0073e9SAndroid Build Coastguard Worker        pass
1452*da0073e9SAndroid Build Coastguard Worker    else:
1453*da0073e9SAndroid Build Coastguard Worker        raise ValueError('Invalid load endianness type')
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    if not zip_file.has_record(byteordername) and \
1456*da0073e9SAndroid Build Coastguard Worker            get_default_load_endianness() is None and \
1457*da0073e9SAndroid Build Coastguard Worker            sys.byteorder == 'big':
1458*da0073e9SAndroid Build Coastguard Worker        # Default behaviour was changed
1459*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/101688
1460*da0073e9SAndroid Build Coastguard Worker        warnings.warn("The default load endianness for checkpoints without a byteorder mark "
1461*da0073e9SAndroid Build Coastguard Worker                      "on big endian machines was changed from 'native' to 'little' endian, "
1462*da0073e9SAndroid Build Coastguard Worker                      "to avoid this behavior please use "
1463*da0073e9SAndroid Build Coastguard Worker                      "torch.serialization.set_default_load_endianness to set "
1464*da0073e9SAndroid Build Coastguard Worker                      "the desired default load endianness",
1465*da0073e9SAndroid Build Coastguard Worker                      UserWarning)
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker    def load_tensor(dtype, numel, key, location):
1468*da0073e9SAndroid Build Coastguard Worker        name = f'data/{key}'
1469*da0073e9SAndroid Build Coastguard Worker        if torch._guards.detect_fake_mode(None) is not None:
1470*da0073e9SAndroid Build Coastguard Worker            nbytes = numel * torch._utils._element_size(dtype)
1471*da0073e9SAndroid Build Coastguard Worker            storage = torch.UntypedStorage(nbytes, device='meta')
1472*da0073e9SAndroid Build Coastguard Worker        elif overall_storage is not None:
1473*da0073e9SAndroid Build Coastguard Worker            storage_offset = zip_file.get_record_offset(name)
1474*da0073e9SAndroid Build Coastguard Worker            storage = overall_storage[storage_offset:storage_offset + numel]
1475*da0073e9SAndroid Build Coastguard Worker        else:
1476*da0073e9SAndroid Build Coastguard Worker            storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
1477*da0073e9SAndroid Build Coastguard Worker        # swap here if byteswapping is needed
1478*da0073e9SAndroid Build Coastguard Worker        if byteorderdata is not None:
1479*da0073e9SAndroid Build Coastguard Worker            if byteorderdata.decode() != sys.byteorder:
1480*da0073e9SAndroid Build Coastguard Worker                storage.byteswap(dtype)
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        # TODO: Once we decide to break serialization FC, we can
1483*da0073e9SAndroid Build Coastguard Worker        # stop wrapping with TypedStorage
1484*da0073e9SAndroid Build Coastguard Worker        typed_storage = torch.storage.TypedStorage(
1485*da0073e9SAndroid Build Coastguard Worker            wrap_storage=restore_location(storage, location),
1486*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1487*da0073e9SAndroid Build Coastguard Worker            _internal=True)
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        if typed_storage._data_ptr() != 0:
1490*da0073e9SAndroid Build Coastguard Worker            loaded_storages[key] = typed_storage
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker        return typed_storage
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker    def persistent_load(saved_id):
1495*da0073e9SAndroid Build Coastguard Worker        assert isinstance(saved_id, tuple)
1496*da0073e9SAndroid Build Coastguard Worker        typename = _maybe_decode_ascii(saved_id[0])
1497*da0073e9SAndroid Build Coastguard Worker        data = saved_id[1:]
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        assert typename == 'storage', \
1500*da0073e9SAndroid Build Coastguard Worker            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
1501*da0073e9SAndroid Build Coastguard Worker        storage_type, key, location, numel = data
1502*da0073e9SAndroid Build Coastguard Worker        if storage_type is torch.UntypedStorage:
1503*da0073e9SAndroid Build Coastguard Worker            dtype = torch.uint8
1504*da0073e9SAndroid Build Coastguard Worker        else:
1505*da0073e9SAndroid Build Coastguard Worker            dtype = storage_type.dtype
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker        if key in loaded_storages:
1508*da0073e9SAndroid Build Coastguard Worker            typed_storage = loaded_storages[key]
1509*da0073e9SAndroid Build Coastguard Worker        else:
1510*da0073e9SAndroid Build Coastguard Worker            nbytes = numel * torch._utils._element_size(dtype)
1511*da0073e9SAndroid Build Coastguard Worker            typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker        return typed_storage
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker    load_module_mapping: Dict[str, str] = {
1516*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/pull/51633
1517*da0073e9SAndroid Build Coastguard Worker        'torch.tensor': 'torch._tensor'
1518*da0073e9SAndroid Build Coastguard Worker    }
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker    # Need to subclass Unpickler instead of directly monkey-patching the find_class method
1521*da0073e9SAndroid Build Coastguard Worker    # because it's marked readonly in pickle.
1522*da0073e9SAndroid Build Coastguard Worker    # The type: ignore is because mypy can't statically determine the type of this class.
1523*da0073e9SAndroid Build Coastguard Worker    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1524*da0073e9SAndroid Build Coastguard Worker        # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
1525*da0073e9SAndroid Build Coastguard Worker        # Lets us override the imports that pickle uses when unpickling an object.
1526*da0073e9SAndroid Build Coastguard Worker        # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
1527*da0073e9SAndroid Build Coastguard Worker        def find_class(self, mod_name, name):
1528*da0073e9SAndroid Build Coastguard Worker            if type(name) is str and 'Storage' in name:
1529*da0073e9SAndroid Build Coastguard Worker                try:
1530*da0073e9SAndroid Build Coastguard Worker                    return StorageType(name)
1531*da0073e9SAndroid Build Coastguard Worker                except KeyError:
1532*da0073e9SAndroid Build Coastguard Worker                    pass
1533*da0073e9SAndroid Build Coastguard Worker            mod_name = load_module_mapping.get(mod_name, mod_name)
1534*da0073e9SAndroid Build Coastguard Worker            return super().find_class(mod_name, name)
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker    # Load the data (which may in turn use `persistent_load` to load tensors)
1537*da0073e9SAndroid Build Coastguard Worker    data_file = io.BytesIO(zip_file.get_record(pickle_file))
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1540*da0073e9SAndroid Build Coastguard Worker    unpickler.persistent_load = persistent_load
1541*da0073e9SAndroid Build Coastguard Worker    # Needed for tensors where storage device and rebuild tensor device are
1542*da0073e9SAndroid Build Coastguard Worker    # not connected (wrapper subclasses and tensors rebuilt using numpy)
1543*da0073e9SAndroid Build Coastguard Worker    torch._utils._thread_local_state.map_location = map_location
1544*da0073e9SAndroid Build Coastguard Worker    result = unpickler.load()
1545*da0073e9SAndroid Build Coastguard Worker    del torch._utils._thread_local_state.map_location
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker    torch._utils._validate_loaded_sparse_tensors()
1548*da0073e9SAndroid Build Coastguard Worker    torch._C._log_api_usage_metadata(
1549*da0073e9SAndroid Build Coastguard Worker        "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
1550*da0073e9SAndroid Build Coastguard Worker    )
1551*da0073e9SAndroid Build Coastguard Worker    return result
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Workerdef _is_torchscript_zip(zip_file):
1555*da0073e9SAndroid Build Coastguard Worker    return 'constants.pkl' in zip_file.get_all_records()
1556