1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport difflib 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport io 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport shutil 8*da0073e9SAndroid Build Coastguard Workerimport struct 9*da0073e9SAndroid Build Coastguard Workerimport sys 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport tarfile 12*da0073e9SAndroid Build Coastguard Workerimport tempfile 13*da0073e9SAndroid Build Coastguard Workerimport warnings 14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import closing, contextmanager 15*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 16*da0073e9SAndroid Build Coastguard Workerfrom ._utils import _import_dotted_name 17*da0073e9SAndroid Build Coastguard Workerfrom torch._sources import get_source_lines_and_file 18*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Storage 19*da0073e9SAndroid Build Coastguard Workerfrom torch.storage import _get_dtype_from_pickle_storage_type 20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List 21*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias, TypeGuard # Python 3.10+ 22*da0073e9SAndroid Build Coastguard Workerimport copyreg 23*da0073e9SAndroid Build Coastguard Workerimport pickle 24*da0073e9SAndroid Build Coastguard Workerimport torch._weights_only_unpickler as _weights_only_unpickler 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerDEFAULT_PROTOCOL = 2 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard WorkerLONG_SIZE = struct.Struct('=l').size 29*da0073e9SAndroid Build Coastguard WorkerINT_SIZE = struct.Struct('=i').size 30*da0073e9SAndroid Build Coastguard WorkerSHORT_SIZE = struct.Struct('=h').size 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard WorkerMAGIC_NUMBER = 0x1950a86a20f9469cfc6c 33*da0073e9SAndroid Build Coastguard WorkerPROTOCOL_VERSION = 1001 34*da0073e9SAndroid Build Coastguard WorkerSTORAGE_KEY_SEPARATOR = ',' 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard WorkerFILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] 37*da0073e9SAndroid Build Coastguard WorkerMAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]] 38*da0073e9SAndroid Build Coastguard WorkerSTORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerIS_WINDOWS = sys.platform == "win32" 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerif not IS_WINDOWS: 43*da0073e9SAndroid Build Coastguard Worker from mmap import MAP_SHARED, MAP_PRIVATE 44*da0073e9SAndroid Build Coastguard Workerelse: 45*da0073e9SAndroid Build Coastguard Worker MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker__all__ = [ 48*da0073e9SAndroid Build Coastguard Worker 'SourceChangeWarning', 49*da0073e9SAndroid Build Coastguard Worker 'mkdtemp', 50*da0073e9SAndroid Build Coastguard Worker 'register_package', 51*da0073e9SAndroid Build Coastguard Worker 'check_module_version_greater_or_equal', 52*da0073e9SAndroid Build Coastguard Worker 'validate_cuda_device', 53*da0073e9SAndroid Build Coastguard Worker 'validate_hpu_device', 54*da0073e9SAndroid Build Coastguard Worker 'location_tag', 55*da0073e9SAndroid Build Coastguard Worker 'default_restore_location', 56*da0073e9SAndroid Build Coastguard Worker 'normalize_storage_type', 57*da0073e9SAndroid Build Coastguard Worker 'storage_to_tensor_type', 58*da0073e9SAndroid Build Coastguard Worker 'save', 59*da0073e9SAndroid Build Coastguard Worker 'load', 60*da0073e9SAndroid Build Coastguard Worker 'StorageType', 61*da0073e9SAndroid Build Coastguard Worker 'LoadEndianness', 62*da0073e9SAndroid Build Coastguard Worker 'get_default_load_endianness', 63*da0073e9SAndroid Build Coastguard Worker 'set_default_load_endianness', 64*da0073e9SAndroid Build Coastguard Worker 'clear_safe_globals', 65*da0073e9SAndroid Build Coastguard Worker 'get_safe_globals', 66*da0073e9SAndroid Build Coastguard Worker 'add_safe_globals', 67*da0073e9SAndroid Build Coastguard Worker] 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerclass SourceChangeWarning(Warning): 71*da0073e9SAndroid Build Coastguard Worker pass 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker@contextmanager 75*da0073e9SAndroid Build Coastguard Workerdef mkdtemp(): 76*da0073e9SAndroid Build Coastguard Worker path = tempfile.mkdtemp() 77*da0073e9SAndroid Build Coastguard Worker try: 78*da0073e9SAndroid Build Coastguard Worker yield path 79*da0073e9SAndroid Build Coastguard Worker finally: 80*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(path) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = [] 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerclass LoadEndianness(Enum): 86*da0073e9SAndroid Build Coastguard Worker NATIVE = 1 87*da0073e9SAndroid Build Coastguard Worker LITTLE = 2 88*da0073e9SAndroid Build Coastguard Worker BIG = 3 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker_default_load_endian: Optional[LoadEndianness] = None 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerdef get_default_load_endianness() -> Optional[LoadEndianness]: 93*da0073e9SAndroid Build Coastguard Worker ''' 94*da0073e9SAndroid Build Coastguard Worker Get fallback byte order for loading files 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker If byteorder mark is not present in saved checkpoint, 97*da0073e9SAndroid Build Coastguard Worker this byte order is used as fallback. 98*da0073e9SAndroid Build Coastguard Worker By default, it's "native" byte order. 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker Returns: 101*da0073e9SAndroid Build Coastguard Worker default_load_endian: Optional[LoadEndianness] 102*da0073e9SAndroid Build Coastguard Worker ''' 103*da0073e9SAndroid Build Coastguard Worker return _default_load_endian 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerdef set_default_load_endianness(endianness): 106*da0073e9SAndroid Build Coastguard Worker ''' 107*da0073e9SAndroid Build Coastguard Worker Set fallback byte order for loading files 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker If byteorder mark is not present in saved checkpoint, 110*da0073e9SAndroid Build Coastguard Worker this byte order is used as fallback. 111*da0073e9SAndroid Build Coastguard Worker By default, it's "native" byte order. 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker Args: 114*da0073e9SAndroid Build Coastguard Worker endianness: the new fallback byte order 115*da0073e9SAndroid Build Coastguard Worker ''' 116*da0073e9SAndroid Build Coastguard Worker global _default_load_endian 117*da0073e9SAndroid Build Coastguard Worker if not isinstance(endianness, LoadEndianness) and endianness is not None: 118*da0073e9SAndroid Build Coastguard Worker raise TypeError("Invalid argument type in function set_default_load_endianness") 119*da0073e9SAndroid Build Coastguard Worker _default_load_endian = endianness 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker_default_mmap_options: int = MAP_PRIVATE 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Workerdef get_default_mmap_options() -> int: 124*da0073e9SAndroid Build Coastguard Worker ''' 125*da0073e9SAndroid Build Coastguard Worker Get default mmap options for :func:`torch.load` with ``mmap=True``. 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker Defaults to ``mmap.MAP_PRIVATE``. 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker Returns: 131*da0073e9SAndroid Build Coastguard Worker default_mmap_options: int 132*da0073e9SAndroid Build Coastguard Worker ''' 133*da0073e9SAndroid Build Coastguard Worker return _default_mmap_options 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Workerdef set_default_mmap_options(flags: int): 136*da0073e9SAndroid Build Coastguard Worker ''' 137*da0073e9SAndroid Build Coastguard Worker Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. 140*da0073e9SAndroid Build Coastguard Worker Please open an issue if you need any other option to be added here. 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker .. note:: 143*da0073e9SAndroid Build Coastguard Worker This feature is currently not supported for Windows. 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker Args: 146*da0073e9SAndroid Build Coastguard Worker flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` 147*da0073e9SAndroid Build Coastguard Worker ''' 148*da0073e9SAndroid Build Coastguard Worker global _default_mmap_options 149*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS: 150*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Changing the default mmap options is currently not supported for Windows") 151*da0073e9SAndroid Build Coastguard Worker if (flags != MAP_PRIVATE and flags != MAP_SHARED): 152*da0073e9SAndroid Build Coastguard Worker raise ValueError("Invalid argument in function set_default_mmap_options, " 153*da0073e9SAndroid Build Coastguard Worker f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}") 154*da0073e9SAndroid Build Coastguard Worker _default_mmap_options = flags 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workerdef clear_safe_globals() -> None: 157*da0073e9SAndroid Build Coastguard Worker ''' 158*da0073e9SAndroid Build Coastguard Worker Clears the list of globals that are safe for ``weights_only`` load. 159*da0073e9SAndroid Build Coastguard Worker ''' 160*da0073e9SAndroid Build Coastguard Worker _weights_only_unpickler._clear_safe_globals() 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Workerdef get_safe_globals() -> List[Any]: 163*da0073e9SAndroid Build Coastguard Worker ''' 164*da0073e9SAndroid Build Coastguard Worker Returns the list of user-added globals that are safe for ``weights_only`` load. 165*da0073e9SAndroid Build Coastguard Worker ''' 166*da0073e9SAndroid Build Coastguard Worker return _weights_only_unpickler._get_safe_globals() 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerdef add_safe_globals(safe_globals: List[Any]) -> None: 169*da0073e9SAndroid Build Coastguard Worker ''' 170*da0073e9SAndroid Build Coastguard Worker Marks the given globals as safe for ``weights_only`` load. For example, functions 171*da0073e9SAndroid Build Coastguard Worker added to this list can be called during unpickling, classes could be instantiated 172*da0073e9SAndroid Build Coastguard Worker and have state set. 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker Args: 175*da0073e9SAndroid Build Coastguard Worker safe_globals (List[Any]): list of globals to mark as safe 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker Example: 178*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") 179*da0073e9SAndroid Build Coastguard Worker >>> import tempfile 180*da0073e9SAndroid Build Coastguard Worker >>> class MyTensor(torch.Tensor): 181*da0073e9SAndroid Build Coastguard Worker ... pass 182*da0073e9SAndroid Build Coastguard Worker >>> t = MyTensor(torch.randn(2, 3)) 183*da0073e9SAndroid Build Coastguard Worker >>> with tempfile.NamedTemporaryFile() as f: 184*da0073e9SAndroid Build Coastguard Worker ... torch.save(t, f.name) 185*da0073e9SAndroid Build Coastguard Worker # Running `torch.load(f.name, weights_only=True)` will fail with 186*da0073e9SAndroid Build Coastguard Worker # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. 187*da0073e9SAndroid Build Coastguard Worker # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. 188*da0073e9SAndroid Build Coastguard Worker ... torch.serialization.add_safe_globals([MyTensor]) 189*da0073e9SAndroid Build Coastguard Worker ... torch.load(f.name, weights_only=True) 190*da0073e9SAndroid Build Coastguard Worker # MyTensor([[-0.5024, -1.8152, -0.5455], 191*da0073e9SAndroid Build Coastguard Worker # [-0.8234, 2.0500, -0.3657]]) 192*da0073e9SAndroid Build Coastguard Worker ''' 193*da0073e9SAndroid Build Coastguard Worker _weights_only_unpickler._add_safe_globals(safe_globals) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerdef _is_zipfile(f) -> bool: 196*da0073e9SAndroid Build Coastguard Worker # This is a stricter implementation than zipfile.is_zipfile(). 197*da0073e9SAndroid Build Coastguard Worker # zipfile.is_zipfile() is True if the magic number appears anywhere in the 198*da0073e9SAndroid Build Coastguard Worker # binary. Since we expect the files here to be generated by torch.save or 199*da0073e9SAndroid Build Coastguard Worker # torch.jit.save, it's safe to only check the start bytes and avoid 200*da0073e9SAndroid Build Coastguard Worker # collisions and assume the zip has only 1 file. 201*da0073e9SAndroid Build Coastguard Worker # See bugs.python.org/issue28494. 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker start = f.tell() 204*da0073e9SAndroid Build Coastguard Worker # Read the first few bytes and match against the ZIP file signature 205*da0073e9SAndroid Build Coastguard Worker local_header_magic_number = b'PK\x03\x04' 206*da0073e9SAndroid Build Coastguard Worker read_bytes = f.read(len(local_header_magic_number)) 207*da0073e9SAndroid Build Coastguard Worker f.seek(start) 208*da0073e9SAndroid Build Coastguard Worker return read_bytes == local_header_magic_number 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Workerdef register_package( 212*da0073e9SAndroid Build Coastguard Worker priority: int, 213*da0073e9SAndroid Build Coastguard Worker tagger: Callable[[STORAGE], Optional[str]], 214*da0073e9SAndroid Build Coastguard Worker deserializer: Callable[[STORAGE, str], Optional[STORAGE]] 215*da0073e9SAndroid Build Coastguard Worker): 216*da0073e9SAndroid Build Coastguard Worker ''' 217*da0073e9SAndroid Build Coastguard Worker Registers callables for tagging and deserializing storage objects with an associated priority. 218*da0073e9SAndroid Build Coastguard Worker Tagging associates a device with a storage object at save time while deserializing moves a 219*da0073e9SAndroid Build Coastguard Worker storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` 220*da0073e9SAndroid Build Coastguard Worker are run in the order given by their :attr:`priority` until a tagger/deserializer returns a 221*da0073e9SAndroid Build Coastguard Worker value that is not `None`. 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker To override the deserialization behavior for a device in the global registry, one can register a 224*da0073e9SAndroid Build Coastguard Worker tagger with a higher priority than the existing tagger. 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker This function can also be used to register a tagger and deserializer for new devices. 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker Args: 229*da0073e9SAndroid Build Coastguard Worker priority: Indicates the priority associated with the tagger and deserializer, where a lower 230*da0073e9SAndroid Build Coastguard Worker value indicates higher priority. 231*da0073e9SAndroid Build Coastguard Worker tagger: Callable that takes in a storage object and returns its tagged device as a string 232*da0073e9SAndroid Build Coastguard Worker or None. 233*da0073e9SAndroid Build Coastguard Worker deserializer: Callable that takes in storage object and a device string and returns a storage 234*da0073e9SAndroid Build Coastguard Worker object on the appropriate device or None. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker Returns: 237*da0073e9SAndroid Build Coastguard Worker `None` 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker Example: 240*da0073e9SAndroid Build Coastguard Worker >>> def ipu_tag(obj): 241*da0073e9SAndroid Build Coastguard Worker >>> if obj.device.type == 'ipu': 242*da0073e9SAndroid Build Coastguard Worker >>> return 'ipu' 243*da0073e9SAndroid Build Coastguard Worker >>> def ipu_deserialize(obj, location): 244*da0073e9SAndroid Build Coastguard Worker >>> if location.startswith('ipu'): 245*da0073e9SAndroid Build Coastguard Worker >>> ipu = getattr(torch, "ipu", None) 246*da0073e9SAndroid Build Coastguard Worker >>> assert ipu is not None, "IPU device module is not loaded" 247*da0073e9SAndroid Build Coastguard Worker >>> assert torch.ipu.is_available(), "ipu is not available" 248*da0073e9SAndroid Build Coastguard Worker >>> return obj.ipu(location) 249*da0073e9SAndroid Build Coastguard Worker >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) 250*da0073e9SAndroid Build Coastguard Worker ''' 251*da0073e9SAndroid Build Coastguard Worker queue_elem = (priority, tagger, deserializer) 252*da0073e9SAndroid Build Coastguard Worker _package_registry.append(queue_elem) 253*da0073e9SAndroid Build Coastguard Worker _package_registry.sort() 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Workerdef check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): 257*da0073e9SAndroid Build Coastguard Worker ''' 258*da0073e9SAndroid Build Coastguard Worker Check if a module's version satisfies requirements 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker Usually, a module's version string will be like 'x.y.z', which would be represented 261*da0073e9SAndroid Build Coastguard Worker as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version 262*da0073e9SAndroid Build Coastguard Worker string does not match the given tuple's format up to the length of the tuple, then 263*da0073e9SAndroid Build Coastguard Worker error and exit or emit a warning. 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker Args: 266*da0073e9SAndroid Build Coastguard Worker module: the module to check the version of 267*da0073e9SAndroid Build Coastguard Worker req_version_tuple: tuple (usually of ints) representing the required version 268*da0073e9SAndroid Build Coastguard Worker error_if_malformed: whether we should exit if module version string is malformed 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker Returns: 271*da0073e9SAndroid Build Coastguard Worker requirement_is_met: bool 272*da0073e9SAndroid Build Coastguard Worker ''' 273*da0073e9SAndroid Build Coastguard Worker try: 274*da0073e9SAndroid Build Coastguard Worker version_strs = module.__version__.split('.') 275*da0073e9SAndroid Build Coastguard Worker # Cast module version fields to match the types of the required version 276*da0073e9SAndroid Build Coastguard Worker module_version = tuple( 277*da0073e9SAndroid Build Coastguard Worker type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker requirement_is_met = module_version >= req_version_tuple 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker except Exception as e: 282*da0073e9SAndroid Build Coastguard Worker message = ( 283*da0073e9SAndroid Build Coastguard Worker f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" 284*da0073e9SAndroid Build Coastguard Worker f" with tuple {str(req_version_tuple)}" 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker if error_if_malformed: 287*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(message) from e 288*da0073e9SAndroid Build Coastguard Worker else: 289*da0073e9SAndroid Build Coastguard Worker warnings.warn(message + ', but continuing assuming that requirement is met') 290*da0073e9SAndroid Build Coastguard Worker requirement_is_met = True 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker return requirement_is_met 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Workerdef _cpu_tag(obj): 296*da0073e9SAndroid Build Coastguard Worker if obj.device.type == 'cpu': 297*da0073e9SAndroid Build Coastguard Worker return 'cpu' 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Workerdef _mps_tag(obj): 301*da0073e9SAndroid Build Coastguard Worker if obj.device.type == 'mps': 302*da0073e9SAndroid Build Coastguard Worker return 'mps' 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Workerdef _meta_tag(obj): 306*da0073e9SAndroid Build Coastguard Worker if obj.device.type == 'meta': 307*da0073e9SAndroid Build Coastguard Worker return 'meta' 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Workerdef _backend_tag(backend_name, obj): 311*da0073e9SAndroid Build Coastguard Worker if backend_name == 'privateuse1': 312*da0073e9SAndroid Build Coastguard Worker backend_name = torch._C._get_privateuse1_backend_name() 313*da0073e9SAndroid Build Coastguard Worker if obj.device.type == backend_name: 314*da0073e9SAndroid Build Coastguard Worker if obj.device.index is None: 315*da0073e9SAndroid Build Coastguard Worker return backend_name 316*da0073e9SAndroid Build Coastguard Worker else: 317*da0073e9SAndroid Build Coastguard Worker return backend_name + ':' + str(obj.device.index) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerdef _cpu_deserialize(obj, location): 321*da0073e9SAndroid Build Coastguard Worker if location == 'cpu': 322*da0073e9SAndroid Build Coastguard Worker return obj 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Workerdef _mps_deserialize(obj, location): 326*da0073e9SAndroid Build Coastguard Worker if location.startswith('mps'): 327*da0073e9SAndroid Build Coastguard Worker return obj.mps() 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Workerdef _meta_deserialize(obj, location): 331*da0073e9SAndroid Build Coastguard Worker if location == 'meta': 332*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage(obj.nbytes(), device='meta') 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Workerdef _validate_device(location, backend_name): 336*da0073e9SAndroid Build Coastguard Worker ''' 337*da0073e9SAndroid Build Coastguard Worker Check whether the device index of specified backend is valid 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker In case of privateuse1 backend, your must first register a device_module for 340*da0073e9SAndroid Build Coastguard Worker privateuse1 using torch._register_device_module. Implement the following 341*da0073e9SAndroid Build Coastguard Worker methods in device_module like cuda: device_module._utils._get_device_index(location, True), 342*da0073e9SAndroid Build Coastguard Worker device_module.device_count(). 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker Args: 345*da0073e9SAndroid Build Coastguard Worker location: string of device 346*da0073e9SAndroid Build Coastguard Worker backend_name: the backend name or the name of privateuse1, which can be renamed 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker Returns: 349*da0073e9SAndroid Build Coastguard Worker device_index: int 350*da0073e9SAndroid Build Coastguard Worker ''' 351*da0073e9SAndroid Build Coastguard Worker if not hasattr(torch, backend_name): 352*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'The {backend_name.upper()} device module is not registered. ' 353*da0073e9SAndroid Build Coastguard Worker 'If you are running on a CPU-only machine, ' 354*da0073e9SAndroid Build Coastguard Worker 'please use torch.load with map_location=torch.device(\'cpu\') ' 355*da0073e9SAndroid Build Coastguard Worker 'to map your storages to the CPU.') 356*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, backend_name) 357*da0073e9SAndroid Build Coastguard Worker if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): 358*da0073e9SAndroid Build Coastguard Worker device_index = device_module._utils._get_device_index(location, True) 359*da0073e9SAndroid Build Coastguard Worker device = torch.device(backend_name, device_index) 360*da0073e9SAndroid Build Coastguard Worker else: 361*da0073e9SAndroid Build Coastguard Worker device = torch.device(location) 362*da0073e9SAndroid Build Coastguard Worker device_index = device.index if device.index else 0 363*da0073e9SAndroid Build Coastguard Worker if hasattr(device_module, 'is_available') and not device_module.is_available(): 364*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} ' 365*da0073e9SAndroid Build Coastguard Worker f'device but torch.{backend_name}.is_available() is False. ' 366*da0073e9SAndroid Build Coastguard Worker 'If you are running on a CPU-only machine, ' 367*da0073e9SAndroid Build Coastguard Worker 'please use torch.load with map_location=torch.device(\'cpu\') ' 368*da0073e9SAndroid Build Coastguard Worker 'to map your storages to the CPU.') 369*da0073e9SAndroid Build Coastguard Worker if hasattr(device_module, 'device_count'): 370*da0073e9SAndroid Build Coastguard Worker device_count = device_module.device_count() 371*da0073e9SAndroid Build Coastguard Worker if device_index >= device_count: 372*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device ' 373*da0073e9SAndroid Build Coastguard Worker f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' 374*da0073e9SAndroid Build Coastguard Worker 'Please use torch.load with map_location to map your storages ' 375*da0073e9SAndroid Build Coastguard Worker 'to an existing device.') 376*da0073e9SAndroid Build Coastguard Worker return device 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Workerdef validate_cuda_device(location): 380*da0073e9SAndroid Build Coastguard Worker return _validate_device(location, 'cuda').index 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Workerdef validate_hpu_device(location): 384*da0073e9SAndroid Build Coastguard Worker return _validate_device(location, 'hpu').index 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Workerdef _deserialize(backend_name, obj, location): 388*da0073e9SAndroid Build Coastguard Worker if backend_name == 'privateuse1': 389*da0073e9SAndroid Build Coastguard Worker backend_name = torch._C._get_privateuse1_backend_name() 390*da0073e9SAndroid Build Coastguard Worker if location.startswith(backend_name): 391*da0073e9SAndroid Build Coastguard Worker device = _validate_device(location, backend_name) 392*da0073e9SAndroid Build Coastguard Worker return obj.to(device=device) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Workerregister_package(10, _cpu_tag, _cpu_deserialize) 396*da0073e9SAndroid Build Coastguard Workerregister_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda')) 397*da0073e9SAndroid Build Coastguard Workerregister_package(21, _mps_tag, _mps_deserialize) 398*da0073e9SAndroid Build Coastguard Workerregister_package(22, _meta_tag, _meta_deserialize) 399*da0073e9SAndroid Build Coastguard Workerregister_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1')) 400*da0073e9SAndroid Build Coastguard Workerregister_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu')) 401*da0073e9SAndroid Build Coastguard Workerregister_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu')) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Workerdef location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): 404*da0073e9SAndroid Build Coastguard Worker for _, tagger, _ in _package_registry: 405*da0073e9SAndroid Build Coastguard Worker location = tagger(storage) 406*da0073e9SAndroid Build Coastguard Worker if location: 407*da0073e9SAndroid Build Coastguard Worker return location 408*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("don't know how to determine data location of " 409*da0073e9SAndroid Build Coastguard Worker + torch.typename(storage)) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Workerdef default_restore_location(storage, location): 413*da0073e9SAndroid Build Coastguard Worker """ 414*da0073e9SAndroid Build Coastguard Worker Restores `storage` using a deserializer function registered for the `location`. 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker This function looks in the registry for deserializer functions that match the `location`. 417*da0073e9SAndroid Build Coastguard Worker If found, it attempts to use them, in priority order, to restore `storage` until one 418*da0073e9SAndroid Build Coastguard Worker returns a not `None` result. If no deserializer can be found in the registry, or all found fail 419*da0073e9SAndroid Build Coastguard Worker to bear a result, it raises a `RuntimeError`. 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker Args: 422*da0073e9SAndroid Build Coastguard Worker storage (STORAGE): the storage object to restore 423*da0073e9SAndroid Build Coastguard Worker location (str): the location tag associated with the storage object 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker Returns: 426*da0073e9SAndroid Build Coastguard Worker storage: Optional[STORAGE] 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker Raises: 429*da0073e9SAndroid Build Coastguard Worker RuntimeError: If no deserializer matching `location` is found in the registry or if 430*da0073e9SAndroid Build Coastguard Worker all matching ones return `None`. 431*da0073e9SAndroid Build Coastguard Worker """ 432*da0073e9SAndroid Build Coastguard Worker for _, _, fn in _package_registry: 433*da0073e9SAndroid Build Coastguard Worker result = fn(storage, location) 434*da0073e9SAndroid Build Coastguard Worker if result is not None: 435*da0073e9SAndroid Build Coastguard Worker return result 436*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("don't know how to restore data location of " 437*da0073e9SAndroid Build Coastguard Worker + torch.typename(storage) + " (tagged with " 438*da0073e9SAndroid Build Coastguard Worker + location + ")") 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Workerdef normalize_storage_type(storage_type): 442*da0073e9SAndroid Build Coastguard Worker return getattr(torch, storage_type.__name__) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Workerdef storage_to_tensor_type(storage): 446*da0073e9SAndroid Build Coastguard Worker storage_type = type(storage) 447*da0073e9SAndroid Build Coastguard Worker module = _import_dotted_name(storage_type.__module__) 448*da0073e9SAndroid Build Coastguard Worker return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Workerdef _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: 452*da0073e9SAndroid Build Coastguard Worker return isinstance(name_or_buffer, (str, os.PathLike)) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Workerclass _opener: 456*da0073e9SAndroid Build Coastguard Worker def __init__(self, file_like): 457*da0073e9SAndroid Build Coastguard Worker self.file_like = file_like 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 460*da0073e9SAndroid Build Coastguard Worker return self.file_like 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 463*da0073e9SAndroid Build Coastguard Worker pass 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Workerclass _open_file(_opener): 467*da0073e9SAndroid Build Coastguard Worker def __init__(self, name, mode): 468*da0073e9SAndroid Build Coastguard Worker super().__init__(open(name, mode)) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 471*da0073e9SAndroid Build Coastguard Worker self.file_like.close() 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Workerclass _open_buffer_reader(_opener): 475*da0073e9SAndroid Build Coastguard Worker def __init__(self, buffer): 476*da0073e9SAndroid Build Coastguard Worker super().__init__(buffer) 477*da0073e9SAndroid Build Coastguard Worker _check_seekable(buffer) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Workerclass _open_buffer_writer(_opener): 481*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 482*da0073e9SAndroid Build Coastguard Worker self.file_like.flush() 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Workerdef _open_file_like(name_or_buffer, mode): 486*da0073e9SAndroid Build Coastguard Worker if _is_path(name_or_buffer): 487*da0073e9SAndroid Build Coastguard Worker return _open_file(name_or_buffer, mode) 488*da0073e9SAndroid Build Coastguard Worker else: 489*da0073e9SAndroid Build Coastguard Worker if 'w' in mode: 490*da0073e9SAndroid Build Coastguard Worker return _open_buffer_writer(name_or_buffer) 491*da0073e9SAndroid Build Coastguard Worker elif 'r' in mode: 492*da0073e9SAndroid Build Coastguard Worker return _open_buffer_reader(name_or_buffer) 493*da0073e9SAndroid Build Coastguard Worker else: 494*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_reader(_opener): 498*da0073e9SAndroid Build Coastguard Worker def __init__(self, name_or_buffer) -> None: 499*da0073e9SAndroid Build Coastguard Worker super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_writer_file(_opener): 503*da0073e9SAndroid Build Coastguard Worker def __init__(self, name) -> None: 504*da0073e9SAndroid Build Coastguard Worker self.file_stream = None 505*da0073e9SAndroid Build Coastguard Worker self.name = str(name) 506*da0073e9SAndroid Build Coastguard Worker try: 507*da0073e9SAndroid Build Coastguard Worker self.name.encode('ascii') 508*da0073e9SAndroid Build Coastguard Worker except UnicodeEncodeError: 509*da0073e9SAndroid Build Coastguard Worker # PyTorchFileWriter only supports ascii filename. 510*da0073e9SAndroid Build Coastguard Worker # For filenames with non-ascii characters, we rely on Python 511*da0073e9SAndroid Build Coastguard Worker # for writing out the file. 512*da0073e9SAndroid Build Coastguard Worker self.file_stream = io.FileIO(self.name, mode='w') 513*da0073e9SAndroid Build Coastguard Worker super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) 514*da0073e9SAndroid Build Coastguard Worker else: 515*da0073e9SAndroid Build Coastguard Worker super().__init__(torch._C.PyTorchFileWriter(self.name)) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args) -> None: 518*da0073e9SAndroid Build Coastguard Worker self.file_like.write_end_of_file() 519*da0073e9SAndroid Build Coastguard Worker if self.file_stream is not None: 520*da0073e9SAndroid Build Coastguard Worker self.file_stream.close() 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Workerclass _open_zipfile_writer_buffer(_opener): 524*da0073e9SAndroid Build Coastguard Worker def __init__(self, buffer) -> None: 525*da0073e9SAndroid Build Coastguard Worker if not callable(getattr(buffer, "write", None)): 526*da0073e9SAndroid Build Coastguard Worker msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" 527*da0073e9SAndroid Build Coastguard Worker if not hasattr(buffer, "write"): 528*da0073e9SAndroid Build Coastguard Worker raise AttributeError(msg) 529*da0073e9SAndroid Build Coastguard Worker raise TypeError(msg) 530*da0073e9SAndroid Build Coastguard Worker self.buffer = buffer 531*da0073e9SAndroid Build Coastguard Worker super().__init__(torch._C.PyTorchFileWriter(buffer)) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args) -> None: 534*da0073e9SAndroid Build Coastguard Worker self.file_like.write_end_of_file() 535*da0073e9SAndroid Build Coastguard Worker self.buffer.flush() 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Workerdef _open_zipfile_writer(name_or_buffer): 539*da0073e9SAndroid Build Coastguard Worker container: Type[_opener] 540*da0073e9SAndroid Build Coastguard Worker if _is_path(name_or_buffer): 541*da0073e9SAndroid Build Coastguard Worker container = _open_zipfile_writer_file 542*da0073e9SAndroid Build Coastguard Worker else: 543*da0073e9SAndroid Build Coastguard Worker container = _open_zipfile_writer_buffer 544*da0073e9SAndroid Build Coastguard Worker return container(name_or_buffer) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Workerdef _is_compressed_file(f) -> bool: 548*da0073e9SAndroid Build Coastguard Worker compress_modules = ['gzip'] 549*da0073e9SAndroid Build Coastguard Worker try: 550*da0073e9SAndroid Build Coastguard Worker return f.__module__ in compress_modules 551*da0073e9SAndroid Build Coastguard Worker except AttributeError: 552*da0073e9SAndroid Build Coastguard Worker return False 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Workerdef _should_read_directly(f): 556*da0073e9SAndroid Build Coastguard Worker """ 557*da0073e9SAndroid Build Coastguard Worker Checks if f is a file that should be read directly. It should be read 558*da0073e9SAndroid Build Coastguard Worker directly if it is backed by a real file (has a fileno) and is not a 559*da0073e9SAndroid Build Coastguard Worker a compressed file (e.g. gzip) 560*da0073e9SAndroid Build Coastguard Worker """ 561*da0073e9SAndroid Build Coastguard Worker if _is_compressed_file(f): 562*da0073e9SAndroid Build Coastguard Worker return False 563*da0073e9SAndroid Build Coastguard Worker try: 564*da0073e9SAndroid Build Coastguard Worker return f.fileno() >= 0 565*da0073e9SAndroid Build Coastguard Worker except io.UnsupportedOperation: 566*da0073e9SAndroid Build Coastguard Worker return False 567*da0073e9SAndroid Build Coastguard Worker except AttributeError: 568*da0073e9SAndroid Build Coastguard Worker return False 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Workerdef _check_seekable(f) -> bool: 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def raise_err_msg(patterns, e): 574*da0073e9SAndroid Build Coastguard Worker for p in patterns: 575*da0073e9SAndroid Build Coastguard Worker if p in str(e): 576*da0073e9SAndroid Build Coastguard Worker msg = (str(e) + ". You can only torch.load from a file that is seekable." 577*da0073e9SAndroid Build Coastguard Worker + " Please pre-load the data into a buffer like io.BytesIO and" 578*da0073e9SAndroid Build Coastguard Worker + " try to load from it instead.") 579*da0073e9SAndroid Build Coastguard Worker raise type(e)(msg) 580*da0073e9SAndroid Build Coastguard Worker raise e 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker try: 583*da0073e9SAndroid Build Coastguard Worker f.seek(f.tell()) 584*da0073e9SAndroid Build Coastguard Worker return True 585*da0073e9SAndroid Build Coastguard Worker except (io.UnsupportedOperation, AttributeError) as e: 586*da0073e9SAndroid Build Coastguard Worker raise_err_msg(["seek", "tell"], e) 587*da0073e9SAndroid Build Coastguard Worker return False 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Workerdef _check_dill_version(pickle_module) -> None: 591*da0073e9SAndroid Build Coastguard Worker '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. 592*da0073e9SAndroid Build Coastguard Worker If dill version is lower than 0.3.1, a ValueError is raised. 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker Args: 595*da0073e9SAndroid Build Coastguard Worker pickle_module: module used for pickling metadata and objects 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker ''' 598*da0073e9SAndroid Build Coastguard Worker if pickle_module is not None and pickle_module.__name__ == 'dill': 599*da0073e9SAndroid Build Coastguard Worker required_dill_version = (0, 3, 1) 600*da0073e9SAndroid Build Coastguard Worker if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): 601*da0073e9SAndroid Build Coastguard Worker raise ValueError(( 602*da0073e9SAndroid Build Coastguard Worker "'torch' supports dill >= {}, but you have dill {}." 603*da0073e9SAndroid Build Coastguard Worker " Please upgrade dill or switch to 'pickle'" 604*da0073e9SAndroid Build Coastguard Worker ).format( 605*da0073e9SAndroid Build Coastguard Worker '.'.join([str(num) for num in required_dill_version]), 606*da0073e9SAndroid Build Coastguard Worker pickle_module.__version__ 607*da0073e9SAndroid Build Coastguard Worker )) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Workerdef _check_save_filelike(f): 611*da0073e9SAndroid Build Coastguard Worker if not _is_path(f) and not hasattr(f, 'write'): 612*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 613*da0073e9SAndroid Build Coastguard Worker "expected 'f' to be string, path, or a file-like object with " 614*da0073e9SAndroid Build Coastguard Worker "a 'write' attribute") 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Workerdef save( 618*da0073e9SAndroid Build Coastguard Worker obj: object, 619*da0073e9SAndroid Build Coastguard Worker f: FILE_LIKE, 620*da0073e9SAndroid Build Coastguard Worker pickle_module: Any = pickle, 621*da0073e9SAndroid Build Coastguard Worker pickle_protocol: int = DEFAULT_PROTOCOL, 622*da0073e9SAndroid Build Coastguard Worker _use_new_zipfile_serialization: bool = True, 623*da0073e9SAndroid Build Coastguard Worker _disable_byteorder_record: bool = False 624*da0073e9SAndroid Build Coastguard Worker) -> None: 625*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/54354 626*da0073e9SAndroid Build Coastguard Worker # The first line of this docstring overrides the one Sphinx generates for the 627*da0073e9SAndroid Build Coastguard Worker # documentation. We need it so that Sphinx doesn't leak `pickle`s path from 628*da0073e9SAndroid Build Coastguard Worker # the build environment (e.g. `<module 'pickle' from '/leaked/path'). 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker Saves an object to a disk file. 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker See also: :ref:`saving-loading-tensors` 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker Args: 637*da0073e9SAndroid Build Coastguard Worker obj: saved object 638*da0073e9SAndroid Build Coastguard Worker f: a file-like object (has to implement write and flush) or a string or 639*da0073e9SAndroid Build Coastguard Worker os.PathLike object containing a file name 640*da0073e9SAndroid Build Coastguard Worker pickle_module: module used for pickling metadata and objects 641*da0073e9SAndroid Build Coastguard Worker pickle_protocol: can be specified to override the default protocol 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker .. note:: 644*da0073e9SAndroid Build Coastguard Worker A common PyTorch convention is to save tensors using .pt file extension. 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker .. note:: 647*da0073e9SAndroid Build Coastguard Worker PyTorch preserves storage sharing across serialization. See 648*da0073e9SAndroid Build Coastguard Worker :ref:`preserve-storage-sharing` for more details. 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker .. note:: 651*da0073e9SAndroid Build Coastguard Worker The 1.6 release of PyTorch switched ``torch.save`` to use a new 652*da0073e9SAndroid Build Coastguard Worker zipfile-based file format. ``torch.load`` still retains the ability to 653*da0073e9SAndroid Build Coastguard Worker load files in the old format. If for any reason you want ``torch.save`` 654*da0073e9SAndroid Build Coastguard Worker to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker Example: 657*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("makes cwd dirty") 658*da0073e9SAndroid Build Coastguard Worker >>> # Save to file 659*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor([0, 1, 2, 3, 4]) 660*da0073e9SAndroid Build Coastguard Worker >>> torch.save(x, 'tensor.pt') 661*da0073e9SAndroid Build Coastguard Worker >>> # Save to io.BytesIO buffer 662*da0073e9SAndroid Build Coastguard Worker >>> buffer = io.BytesIO() 663*da0073e9SAndroid Build Coastguard Worker >>> torch.save(x, buffer) 664*da0073e9SAndroid Build Coastguard Worker """ 665*da0073e9SAndroid Build Coastguard Worker torch._C._log_api_usage_once("torch.save") 666*da0073e9SAndroid Build Coastguard Worker _check_dill_version(pickle_module) 667*da0073e9SAndroid Build Coastguard Worker _check_save_filelike(f) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker if _use_new_zipfile_serialization: 670*da0073e9SAndroid Build Coastguard Worker with _open_zipfile_writer(f) as opened_zipfile: 671*da0073e9SAndroid Build Coastguard Worker _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record) 672*da0073e9SAndroid Build Coastguard Worker return 673*da0073e9SAndroid Build Coastguard Worker else: 674*da0073e9SAndroid Build Coastguard Worker with _open_file_like(f, 'wb') as opened_file: 675*da0073e9SAndroid Build Coastguard Worker _legacy_save(obj, opened_file, pickle_module, pickle_protocol) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Workerdef _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: 679*da0073e9SAndroid Build Coastguard Worker import torch.nn as nn 680*da0073e9SAndroid Build Coastguard Worker serialized_container_types = {} 681*da0073e9SAndroid Build Coastguard Worker serialized_storages = {} 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker # Since loading storages that view the same data with different dtypes is 684*da0073e9SAndroid Build Coastguard Worker # not supported, we need to keep track of the dtype associated with each 685*da0073e9SAndroid Build Coastguard Worker # storage data_ptr and throw an error if the dtype is ever different. 686*da0073e9SAndroid Build Coastguard Worker # TODO: This feature could be added in the future 687*da0073e9SAndroid Build Coastguard Worker storage_dtypes: Dict[int, torch.dtype] = {} 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker def persistent_id(obj: Any) -> Optional[Tuple]: 690*da0073e9SAndroid Build Coastguard Worker # FIXME: the docs say that persistent_id should only return a string 691*da0073e9SAndroid Build Coastguard Worker # but torch store returns tuples. This works only in the binary protocol 692*da0073e9SAndroid Build Coastguard Worker # see 693*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 694*da0073e9SAndroid Build Coastguard Worker # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 695*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, type) and issubclass(obj, nn.Module): 696*da0073e9SAndroid Build Coastguard Worker if obj in serialized_container_types: 697*da0073e9SAndroid Build Coastguard Worker return None 698*da0073e9SAndroid Build Coastguard Worker serialized_container_types[obj] = True 699*da0073e9SAndroid Build Coastguard Worker source_file = source = None 700*da0073e9SAndroid Build Coastguard Worker try: 701*da0073e9SAndroid Build Coastguard Worker source_lines, _, source_file = get_source_lines_and_file(obj) 702*da0073e9SAndroid Build Coastguard Worker source = ''.join(source_lines) 703*da0073e9SAndroid Build Coastguard Worker except Exception: # saving the source is optional, so we can ignore any errors 704*da0073e9SAndroid Build Coastguard Worker warnings.warn("Couldn't retrieve source code for container of " 705*da0073e9SAndroid Build Coastguard Worker "type " + obj.__name__ + ". It won't be checked " 706*da0073e9SAndroid Build Coastguard Worker "for correctness upon loading.") 707*da0073e9SAndroid Build Coastguard Worker return ('module', obj, source_file, source) 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 710*da0073e9SAndroid Build Coastguard Worker storage: torch.UntypedStorage 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, torch.storage.TypedStorage): 713*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, this case 714*da0073e9SAndroid Build Coastguard Worker # can be deleted 715*da0073e9SAndroid Build Coastguard Worker storage = obj._untyped_storage 716*da0073e9SAndroid Build Coastguard Worker storage_dtype = obj.dtype 717*da0073e9SAndroid Build Coastguard Worker storage_type_str = obj._pickle_storage_type() 718*da0073e9SAndroid Build Coastguard Worker storage_type = getattr(torch, storage_type_str) 719*da0073e9SAndroid Build Coastguard Worker dtype = obj.dtype 720*da0073e9SAndroid Build Coastguard Worker storage_numel = obj._size() 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker elif isinstance(obj, torch.UntypedStorage): 723*da0073e9SAndroid Build Coastguard Worker storage = obj 724*da0073e9SAndroid Build Coastguard Worker storage_dtype = torch.uint8 725*da0073e9SAndroid Build Coastguard Worker storage_type = normalize_storage_type(type(obj)) 726*da0073e9SAndroid Build Coastguard Worker dtype = torch.uint8 727*da0073e9SAndroid Build Coastguard Worker storage_numel = storage.nbytes() 728*da0073e9SAndroid Build Coastguard Worker else: 729*da0073e9SAndroid Build Coastguard Worker raise TypeError(f'type not recognized: {type(obj)}') 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker # If storage is allocated, ensure that any other saved storages 732*da0073e9SAndroid Build Coastguard Worker # pointing to the same data all have the same dtype. If storage is 733*da0073e9SAndroid Build Coastguard Worker # not allocated, don't perform this check 734*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() != 0: 735*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() in storage_dtypes: 736*da0073e9SAndroid Build Coastguard Worker if storage_dtype != storage_dtypes[storage.data_ptr()]: 737*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 738*da0073e9SAndroid Build Coastguard Worker 'Cannot save multiple tensors or storages that ' 739*da0073e9SAndroid Build Coastguard Worker 'view the same data as different types') 740*da0073e9SAndroid Build Coastguard Worker else: 741*da0073e9SAndroid Build Coastguard Worker storage_dtypes[storage.data_ptr()] = storage_dtype 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker view_metadata: Optional[Tuple[str, int, int]] 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker # Offset is always 0, but we keep it for backwards compatibility 746*da0073e9SAndroid Build Coastguard Worker # with the old serialization format (which supported storage views) 747*da0073e9SAndroid Build Coastguard Worker offset = 0 748*da0073e9SAndroid Build Coastguard Worker storage_key = str(storage._cdata) 749*da0073e9SAndroid Build Coastguard Worker location = location_tag(storage) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker # TODO: There's an issue here with FC. It might be impossible to 752*da0073e9SAndroid Build Coastguard Worker # solve, but it's worth noting. Imagine we save a list `[storage, 753*da0073e9SAndroid Build Coastguard Worker # tensor]`, where `tensor.storage()` is the same as `storage`, and 754*da0073e9SAndroid Build Coastguard Worker # `tensor.element_size() > 1`. Let's say that `tensor.dtype == 755*da0073e9SAndroid Build Coastguard Worker # torch.float`. The storage will be serialized with element size 756*da0073e9SAndroid Build Coastguard Worker # of 1, since we're choosing to serialize the first occurance of 757*da0073e9SAndroid Build Coastguard Worker # a duplicate storage. Since this legacy serialization format saves 758*da0073e9SAndroid Build Coastguard Worker # the numel of the storage, rather than nbytes directly, we'll be 759*da0073e9SAndroid Build Coastguard Worker # effectively saving nbytes in this case. We'll be able to load it 760*da0073e9SAndroid Build Coastguard Worker # and the tensor back up with no problems in _this_ and future 761*da0073e9SAndroid Build Coastguard Worker # versions of pytorch, but in older versions, here's the problem: 762*da0073e9SAndroid Build Coastguard Worker # the storage will be loaded up as a UntypedStorage, and then the 763*da0073e9SAndroid Build Coastguard Worker # FloatTensor will loaded and the UntypedStorage will be assigned to 764*da0073e9SAndroid Build Coastguard Worker # it. Since the storage dtype does not match the tensor dtype, this 765*da0073e9SAndroid Build Coastguard Worker # will cause an error. If we reverse the list, like `[tensor, 766*da0073e9SAndroid Build Coastguard Worker # storage]`, then we will save the `tensor.storage()` as a faked 767*da0073e9SAndroid Build Coastguard Worker # `FloatStorage`, and the saved size will be the correct 768*da0073e9SAndroid Build Coastguard Worker # dtype-specific numel count that old versions expect. `tensor` 769*da0073e9SAndroid Build Coastguard Worker # will be able to load up properly in old versions, pointing to 770*da0073e9SAndroid Build Coastguard Worker # a FloatStorage. However, `storage` is still being translated to 771*da0073e9SAndroid Build Coastguard Worker # a UntypedStorage, and it will try to resolve to the same 772*da0073e9SAndroid Build Coastguard Worker # FloatStorage that `tensor` contains. This will also cause an 773*da0073e9SAndroid Build Coastguard Worker # error. It doesn't seem like there's any way around this. 774*da0073e9SAndroid Build Coastguard Worker # Probably, we just cannot maintain FC for the legacy format if the 775*da0073e9SAndroid Build Coastguard Worker # saved list contains both a tensor and a storage that point to the 776*da0073e9SAndroid Build Coastguard Worker # same data. We should still be able to maintain FC for lists of 777*da0073e9SAndroid Build Coastguard Worker # just tensors, as long as all views share the same dtype as the 778*da0073e9SAndroid Build Coastguard Worker # tensor they are viewing. 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker if storage_key not in serialized_storages: 781*da0073e9SAndroid Build Coastguard Worker serialized_storages[storage_key] = (storage, dtype) 782*da0073e9SAndroid Build Coastguard Worker is_view = storage._cdata != storage._cdata 783*da0073e9SAndroid Build Coastguard Worker if is_view: 784*da0073e9SAndroid Build Coastguard Worker view_metadata = (str(storage._cdata), offset, storage.nbytes()) 785*da0073e9SAndroid Build Coastguard Worker else: 786*da0073e9SAndroid Build Coastguard Worker view_metadata = None 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker res = ('storage', 789*da0073e9SAndroid Build Coastguard Worker storage_type, 790*da0073e9SAndroid Build Coastguard Worker storage_key, 791*da0073e9SAndroid Build Coastguard Worker location, 792*da0073e9SAndroid Build Coastguard Worker storage_numel, 793*da0073e9SAndroid Build Coastguard Worker view_metadata) 794*da0073e9SAndroid Build Coastguard Worker return res 795*da0073e9SAndroid Build Coastguard Worker return None 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker sys_info = dict( 798*da0073e9SAndroid Build Coastguard Worker protocol_version=PROTOCOL_VERSION, 799*da0073e9SAndroid Build Coastguard Worker little_endian=sys.byteorder == 'little', 800*da0073e9SAndroid Build Coastguard Worker type_sizes=dict( 801*da0073e9SAndroid Build Coastguard Worker short=SHORT_SIZE, 802*da0073e9SAndroid Build Coastguard Worker int=INT_SIZE, 803*da0073e9SAndroid Build Coastguard Worker long=LONG_SIZE, 804*da0073e9SAndroid Build Coastguard Worker ), 805*da0073e9SAndroid Build Coastguard Worker ) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) 808*da0073e9SAndroid Build Coastguard Worker pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) 809*da0073e9SAndroid Build Coastguard Worker pickle_module.dump(sys_info, f, protocol=pickle_protocol) 810*da0073e9SAndroid Build Coastguard Worker pickler = pickle_module.Pickler(f, protocol=pickle_protocol) 811*da0073e9SAndroid Build Coastguard Worker pickler.persistent_id = persistent_id 812*da0073e9SAndroid Build Coastguard Worker pickler.dump(obj) 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker serialized_storage_keys = sorted(serialized_storages.keys()) 815*da0073e9SAndroid Build Coastguard Worker pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) 816*da0073e9SAndroid Build Coastguard Worker f.flush() 817*da0073e9SAndroid Build Coastguard Worker for key in serialized_storage_keys: 818*da0073e9SAndroid Build Coastguard Worker storage, dtype = serialized_storages[key] 819*da0073e9SAndroid Build Coastguard Worker storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Workerdef _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): 823*da0073e9SAndroid Build Coastguard Worker serialized_storages = {} 824*da0073e9SAndroid Build Coastguard Worker id_map: Dict[int, str] = {} 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker # Since loading storages that view the same data with different dtypes is 827*da0073e9SAndroid Build Coastguard Worker # not supported, we need to keep track of the dtype associated with each 828*da0073e9SAndroid Build Coastguard Worker # storage data_ptr and throw an error if the dtype is ever different. 829*da0073e9SAndroid Build Coastguard Worker # TODO: This feature could be added in the future 830*da0073e9SAndroid Build Coastguard Worker storage_dtypes: Dict[int, torch.dtype] = {} 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker def persistent_id(obj): 833*da0073e9SAndroid Build Coastguard Worker # FIXME: the docs say that persistent_id should only return a string 834*da0073e9SAndroid Build Coastguard Worker # but torch store returns tuples. This works only in the binary protocol 835*da0073e9SAndroid Build Coastguard Worker # see 836*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 837*da0073e9SAndroid Build Coastguard Worker # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 838*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, torch.storage.TypedStorage): 841*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, this case 842*da0073e9SAndroid Build Coastguard Worker # can be deleted 843*da0073e9SAndroid Build Coastguard Worker storage = obj._untyped_storage 844*da0073e9SAndroid Build Coastguard Worker storage_dtype = obj.dtype 845*da0073e9SAndroid Build Coastguard Worker storage_type_str = obj._pickle_storage_type() 846*da0073e9SAndroid Build Coastguard Worker storage_type = getattr(torch, storage_type_str) 847*da0073e9SAndroid Build Coastguard Worker storage_numel = obj._size() 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker else: 850*da0073e9SAndroid Build Coastguard Worker storage = obj 851*da0073e9SAndroid Build Coastguard Worker storage_dtype = torch.uint8 852*da0073e9SAndroid Build Coastguard Worker storage_type = normalize_storage_type(type(obj)) 853*da0073e9SAndroid Build Coastguard Worker storage_numel = storage.nbytes() 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker # If storage is allocated, ensure that any other saved storages 856*da0073e9SAndroid Build Coastguard Worker # pointing to the same data all have the same dtype. If storage is 857*da0073e9SAndroid Build Coastguard Worker # not allocated, don't perform this check 858*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() != 0: 859*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() in storage_dtypes: 860*da0073e9SAndroid Build Coastguard Worker if storage_dtype != storage_dtypes[storage.data_ptr()]: 861*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 862*da0073e9SAndroid Build Coastguard Worker 'Cannot save multiple tensors or storages that ' 863*da0073e9SAndroid Build Coastguard Worker 'view the same data as different types') 864*da0073e9SAndroid Build Coastguard Worker else: 865*da0073e9SAndroid Build Coastguard Worker storage_dtypes[storage.data_ptr()] = storage_dtype 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) 868*da0073e9SAndroid Build Coastguard Worker location = location_tag(storage) 869*da0073e9SAndroid Build Coastguard Worker serialized_storages[storage_key] = storage 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker return ('storage', 872*da0073e9SAndroid Build Coastguard Worker storage_type, 873*da0073e9SAndroid Build Coastguard Worker storage_key, 874*da0073e9SAndroid Build Coastguard Worker location, 875*da0073e9SAndroid Build Coastguard Worker storage_numel) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker return None 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker # Write the pickle data for `obj` 880*da0073e9SAndroid Build Coastguard Worker data_buf = io.BytesIO() 881*da0073e9SAndroid Build Coastguard Worker pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) 882*da0073e9SAndroid Build Coastguard Worker pickler.persistent_id = persistent_id 883*da0073e9SAndroid Build Coastguard Worker pickler.dump(obj) 884*da0073e9SAndroid Build Coastguard Worker data_value = data_buf.getvalue() 885*da0073e9SAndroid Build Coastguard Worker zip_file.write_record('data.pkl', data_value, len(data_value)) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker # Write byte order marker 888*da0073e9SAndroid Build Coastguard Worker if not _disable_byteorder_record: 889*da0073e9SAndroid Build Coastguard Worker if sys.byteorder not in ['little', 'big']: 890*da0073e9SAndroid Build Coastguard Worker raise ValueError('Unknown endianness type: ' + sys.byteorder) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker # Write each tensor to a file named tensor/the_tensor_key in the zip archive 895*da0073e9SAndroid Build Coastguard Worker for key in sorted(serialized_storages.keys()): 896*da0073e9SAndroid Build Coastguard Worker name = f'data/{key}' 897*da0073e9SAndroid Build Coastguard Worker storage = serialized_storages[key] 898*da0073e9SAndroid Build Coastguard Worker # given that we copy things around anyway, we might use storage.cpu() 899*da0073e9SAndroid Build Coastguard Worker # this means to that to get tensors serialized, you need to implement 900*da0073e9SAndroid Build Coastguard Worker # .cpu() on the underlying Storage 901*da0073e9SAndroid Build Coastguard Worker if storage.device.type != 'cpu': 902*da0073e9SAndroid Build Coastguard Worker storage = storage.cpu() 903*da0073e9SAndroid Build Coastguard Worker # Now that it is on the CPU we can directly copy it into the zip file 904*da0073e9SAndroid Build Coastguard Worker num_bytes = storage.nbytes() 905*da0073e9SAndroid Build Coastguard Worker zip_file.write_record(name, storage, num_bytes) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Workerdef load( 909*da0073e9SAndroid Build Coastguard Worker f: FILE_LIKE, 910*da0073e9SAndroid Build Coastguard Worker map_location: MAP_LOCATION = None, 911*da0073e9SAndroid Build Coastguard Worker pickle_module: Any = None, 912*da0073e9SAndroid Build Coastguard Worker *, 913*da0073e9SAndroid Build Coastguard Worker weights_only: Optional[bool] = None, 914*da0073e9SAndroid Build Coastguard Worker mmap: Optional[bool] = None, 915*da0073e9SAndroid Build Coastguard Worker **pickle_load_args: Any 916*da0073e9SAndroid Build Coastguard Worker) -> Any: 917*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/54354 918*da0073e9SAndroid Build Coastguard Worker # The first line of this docstring overrides the one Sphinx generates for the 919*da0073e9SAndroid Build Coastguard Worker # documentation. We need it so that Sphinx doesn't leak `pickle`s path from 920*da0073e9SAndroid Build Coastguard Worker # the build environment (e.g. `<module 'pickle' from '/leaked/path'). 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker Loads an object saved with :func:`torch.save` from a file. 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker :func:`torch.load` uses Python's unpickling facilities but treats storages, 927*da0073e9SAndroid Build Coastguard Worker which underlie tensors, specially. They are first deserialized on the 928*da0073e9SAndroid Build Coastguard Worker CPU and are then moved to the device they were saved from. If this fails 929*da0073e9SAndroid Build Coastguard Worker (e.g. because the run time system doesn't have certain devices), an exception 930*da0073e9SAndroid Build Coastguard Worker is raised. However, storages can be dynamically remapped to an alternative 931*da0073e9SAndroid Build Coastguard Worker set of devices using the :attr:`map_location` argument. 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker If :attr:`map_location` is a callable, it will be called once for each serialized 934*da0073e9SAndroid Build Coastguard Worker storage with two arguments: storage and location. The storage argument 935*da0073e9SAndroid Build Coastguard Worker will be the initial deserialization of the storage, residing on the CPU. 936*da0073e9SAndroid Build Coastguard Worker Each serialized storage has a location tag associated with it which 937*da0073e9SAndroid Build Coastguard Worker identifies the device it was saved from, and this tag is the second 938*da0073e9SAndroid Build Coastguard Worker argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` 939*da0073e9SAndroid Build Coastguard Worker for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. 940*da0073e9SAndroid Build Coastguard Worker :attr:`map_location` should return either ``None`` or a storage. If 941*da0073e9SAndroid Build Coastguard Worker :attr:`map_location` returns a storage, it will be used as the final deserialized 942*da0073e9SAndroid Build Coastguard Worker object, already moved to the right device. Otherwise, :func:`torch.load` will 943*da0073e9SAndroid Build Coastguard Worker fall back to the default behavior, as if :attr:`map_location` wasn't specified. 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker If :attr:`map_location` is a :class:`torch.device` object or a string containing 946*da0073e9SAndroid Build Coastguard Worker a device tag, it indicates the location where all tensors should be loaded. 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags 949*da0073e9SAndroid Build Coastguard Worker appearing in the file (keys), to ones that specify where to put the 950*da0073e9SAndroid Build Coastguard Worker storages (values). 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker User extensions can register their own location tags and tagging and 953*da0073e9SAndroid Build Coastguard Worker deserialization methods using :func:`torch.serialization.register_package`. 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker Args: 956*da0073e9SAndroid Build Coastguard Worker f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), 957*da0073e9SAndroid Build Coastguard Worker or a string or os.PathLike object containing a file name 958*da0073e9SAndroid Build Coastguard Worker map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage 959*da0073e9SAndroid Build Coastguard Worker locations 960*da0073e9SAndroid Build Coastguard Worker pickle_module: module used for unpickling metadata and objects (has to 961*da0073e9SAndroid Build Coastguard Worker match the :attr:`pickle_module` used to serialize file) 962*da0073e9SAndroid Build Coastguard Worker weights_only: Indicates whether unpickler should be restricted to 963*da0073e9SAndroid Build Coastguard Worker loading only tensors, primitive types, dictionaries 964*da0073e9SAndroid Build Coastguard Worker and any types added via :func:`torch.serialization.add_safe_globals`. 965*da0073e9SAndroid Build Coastguard Worker mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. 966*da0073e9SAndroid Build Coastguard Worker Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they 967*da0073e9SAndroid Build Coastguard Worker are moved to the location that they were tagged with when saving, or specified by ``map_location``. This 968*da0073e9SAndroid Build Coastguard Worker second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the 969*da0073e9SAndroid Build Coastguard Worker tensor storages from disk to CPU memory in the first step, ``f`` is mmaped. 970*da0073e9SAndroid Build Coastguard Worker pickle_load_args: (Python 3 only) optional keyword arguments passed over to 971*da0073e9SAndroid Build Coastguard Worker :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., 972*da0073e9SAndroid Build Coastguard Worker :attr:`errors=...`. 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker .. warning:: 975*da0073e9SAndroid Build Coastguard Worker :func:`torch.load()` unless `weights_only` parameter is set to `True`, 976*da0073e9SAndroid Build Coastguard Worker uses ``pickle`` module implicitly, which is known to be insecure. 977*da0073e9SAndroid Build Coastguard Worker It is possible to construct malicious pickle data which will execute arbitrary code 978*da0073e9SAndroid Build Coastguard Worker during unpickling. Never load data that could have come from an untrusted 979*da0073e9SAndroid Build Coastguard Worker source in an unsafe mode, or that could have been tampered with. **Only load data you trust**. 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker .. note:: 982*da0073e9SAndroid Build Coastguard Worker When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors 983*da0073e9SAndroid Build Coastguard Worker will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` 984*da0073e9SAndroid Build Coastguard Worker and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker .. note:: 987*da0073e9SAndroid Build Coastguard Worker By default, we decode byte strings as ``utf-8``. This is to avoid a common error 988*da0073e9SAndroid Build Coastguard Worker case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` 989*da0073e9SAndroid Build Coastguard Worker when loading files saved by Python 2 in Python 3. If this default 990*da0073e9SAndroid Build Coastguard Worker is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how 991*da0073e9SAndroid Build Coastguard Worker these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them 992*da0073e9SAndroid Build Coastguard Worker to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them 993*da0073e9SAndroid Build Coastguard Worker as byte arrays which can be decoded later with ``byte_array.decode(...)``. 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker Example: 996*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("undefined filepaths") 997*da0073e9SAndroid Build Coastguard Worker >>> torch.load('tensors.pt', weights_only=True) 998*da0073e9SAndroid Build Coastguard Worker # Load all tensors onto the CPU 999*da0073e9SAndroid Build Coastguard Worker >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True) 1000*da0073e9SAndroid Build Coastguard Worker # Load all tensors onto the CPU, using a function 1001*da0073e9SAndroid Build Coastguard Worker >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True) 1002*da0073e9SAndroid Build Coastguard Worker # Load all tensors onto GPU 1 1003*da0073e9SAndroid Build Coastguard Worker >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True) 1004*da0073e9SAndroid Build Coastguard Worker # Map tensors from GPU 1 to GPU 0 1005*da0073e9SAndroid Build Coastguard Worker >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True) 1006*da0073e9SAndroid Build Coastguard Worker # Load tensor from io.BytesIO object 1007*da0073e9SAndroid Build Coastguard Worker # Loading from a buffer setting weights_only=False, warning this can be unsafe 1008*da0073e9SAndroid Build Coastguard Worker >>> with open('tensor.pt', 'rb') as f: 1009*da0073e9SAndroid Build Coastguard Worker ... buffer = io.BytesIO(f.read()) 1010*da0073e9SAndroid Build Coastguard Worker >>> torch.load(buffer, weights_only=False) 1011*da0073e9SAndroid Build Coastguard Worker # Load a module with 'ascii' encoding for unpickling 1012*da0073e9SAndroid Build Coastguard Worker # Loading from a module setting weights_only=False, warning this can be unsafe 1013*da0073e9SAndroid Build Coastguard Worker >>> torch.load('module.pt', encoding='ascii', weights_only=False) 1014*da0073e9SAndroid Build Coastguard Worker """ 1015*da0073e9SAndroid Build Coastguard Worker torch._C._log_api_usage_once("torch.load") 1016*da0073e9SAndroid Build Coastguard Worker UNSAFE_MESSAGE = ( 1017*da0073e9SAndroid Build Coastguard Worker "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " 1018*da0073e9SAndroid Build Coastguard Worker "but it can result in arbitrary code execution. Do it only if you got the file from a " 1019*da0073e9SAndroid Build Coastguard Worker "trusted source." 1020*da0073e9SAndroid Build Coastguard Worker ) 1021*da0073e9SAndroid Build Coastguard Worker DOCS_MESSAGE = ( 1022*da0073e9SAndroid Build Coastguard Worker "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " 1023*da0073e9SAndroid Build Coastguard Worker "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker def _get_wo_message(message: str) -> str: 1027*da0073e9SAndroid Build Coastguard Worker pattern = r"GLOBAL (\S+) was not an allowed global by default." 1028*da0073e9SAndroid Build Coastguard Worker has_unsafe_global = re.search(pattern, message) is not None 1029*da0073e9SAndroid Build Coastguard Worker if has_unsafe_global: 1030*da0073e9SAndroid Build Coastguard Worker updated_message = ( 1031*da0073e9SAndroid Build Coastguard Worker "Weights only load failed. This file can still be loaded, to do so you have two options " 1032*da0073e9SAndroid Build Coastguard Worker f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check " 1033*da0073e9SAndroid Build Coastguard Worker "the recommended steps in the following error message.\n\tWeightsUnpickler error: " 1034*da0073e9SAndroid Build Coastguard Worker + message 1035*da0073e9SAndroid Build Coastguard Worker ) 1036*da0073e9SAndroid Build Coastguard Worker else: 1037*da0073e9SAndroid Build Coastguard Worker updated_message = ( 1038*da0073e9SAndroid Build Coastguard Worker f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following " 1039*da0073e9SAndroid Build Coastguard Worker "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler " 1040*da0073e9SAndroid Build Coastguard Worker "error: " + message 1041*da0073e9SAndroid Build Coastguard Worker ) 1042*da0073e9SAndroid Build Coastguard Worker return updated_message + DOCS_MESSAGE 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker if weights_only is None: 1045*da0073e9SAndroid Build Coastguard Worker weights_only, warn_weights_only = False, True 1046*da0073e9SAndroid Build Coastguard Worker else: 1047*da0073e9SAndroid Build Coastguard Worker warn_weights_only = False 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker # Add ability to force safe only weight loads via environment variable 1050*da0073e9SAndroid Build Coastguard Worker if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: 1051*da0073e9SAndroid Build Coastguard Worker weights_only = True 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker if weights_only: 1054*da0073e9SAndroid Build Coastguard Worker if pickle_module is not None: 1055*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") 1056*da0073e9SAndroid Build Coastguard Worker else: 1057*da0073e9SAndroid Build Coastguard Worker if pickle_module is None: 1058*da0073e9SAndroid Build Coastguard Worker if warn_weights_only: 1059*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1060*da0073e9SAndroid Build Coastguard Worker "You are using `torch.load` with `weights_only=False` (the current default value), which uses " 1061*da0073e9SAndroid Build Coastguard Worker "the default pickle module implicitly. It is possible to construct malicious pickle data " 1062*da0073e9SAndroid Build Coastguard Worker "which will execute arbitrary code during unpickling (See " 1063*da0073e9SAndroid Build Coastguard Worker "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). " 1064*da0073e9SAndroid Build Coastguard Worker "In a future release, the default value for `weights_only` will be flipped to `True`. This " 1065*da0073e9SAndroid Build Coastguard Worker "limits the functions that could be executed during unpickling. Arbitrary objects will no " 1066*da0073e9SAndroid Build Coastguard Worker "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the " 1067*da0073e9SAndroid Build Coastguard Worker "user via `torch.serialization.add_safe_globals`. We recommend you start setting " 1068*da0073e9SAndroid Build Coastguard Worker "`weights_only=True` for any use case where you don't have full control of the loaded file. " 1069*da0073e9SAndroid Build Coastguard Worker "Please open an issue on GitHub for any issues related to this experimental feature.", 1070*da0073e9SAndroid Build Coastguard Worker FutureWarning, 1071*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 1072*da0073e9SAndroid Build Coastguard Worker ) 1073*da0073e9SAndroid Build Coastguard Worker pickle_module = pickle 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker # make flipping default BC-compatible 1076*da0073e9SAndroid Build Coastguard Worker if mmap is None: 1077*da0073e9SAndroid Build Coastguard Worker mmap = False 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker _check_dill_version(pickle_module) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker if 'encoding' not in pickle_load_args.keys(): 1082*da0073e9SAndroid Build Coastguard Worker pickle_load_args['encoding'] = 'utf-8' 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker with _open_file_like(f, 'rb') as opened_file: 1085*da0073e9SAndroid Build Coastguard Worker if _is_zipfile(opened_file): 1086*da0073e9SAndroid Build Coastguard Worker # The zipfile reader is going to advance the current file position. 1087*da0073e9SAndroid Build Coastguard Worker # If we want to actually tail call to torch.jit.load, we need to 1088*da0073e9SAndroid Build Coastguard Worker # reset back to the original position. 1089*da0073e9SAndroid Build Coastguard Worker orig_position = opened_file.tell() 1090*da0073e9SAndroid Build Coastguard Worker overall_storage = None 1091*da0073e9SAndroid Build Coastguard Worker with _open_zipfile_reader(opened_file) as opened_zipfile: 1092*da0073e9SAndroid Build Coastguard Worker if _is_torchscript_zip(opened_zipfile): 1093*da0073e9SAndroid Build Coastguard Worker warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" 1094*da0073e9SAndroid Build Coastguard Worker " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" 1095*da0073e9SAndroid Build Coastguard Worker " silence this warning)", UserWarning) 1096*da0073e9SAndroid Build Coastguard Worker opened_file.seek(orig_position) 1097*da0073e9SAndroid Build Coastguard Worker return torch.jit.load(opened_file, map_location=map_location) 1098*da0073e9SAndroid Build Coastguard Worker if mmap: 1099*da0073e9SAndroid Build Coastguard Worker if not _is_path(f): 1100*da0073e9SAndroid Build Coastguard Worker raise ValueError("f must be a file path in order to use the mmap argument") 1101*da0073e9SAndroid Build Coastguard Worker size = os.path.getsize(f) 1102*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 1103*da0073e9SAndroid Build Coastguard Worker shared = get_default_mmap_options() == MAP_SHARED 1104*da0073e9SAndroid Build Coastguard Worker else: 1105*da0073e9SAndroid Build Coastguard Worker shared = False 1106*da0073e9SAndroid Build Coastguard Worker overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size) 1107*da0073e9SAndroid Build Coastguard Worker if weights_only: 1108*da0073e9SAndroid Build Coastguard Worker try: 1109*da0073e9SAndroid Build Coastguard Worker return _load(opened_zipfile, 1110*da0073e9SAndroid Build Coastguard Worker map_location, 1111*da0073e9SAndroid Build Coastguard Worker _weights_only_unpickler, 1112*da0073e9SAndroid Build Coastguard Worker overall_storage=overall_storage, 1113*da0073e9SAndroid Build Coastguard Worker **pickle_load_args) 1114*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1115*da0073e9SAndroid Build Coastguard Worker raise pickle.UnpicklingError(_get_wo_message(str(e))) from None 1116*da0073e9SAndroid Build Coastguard Worker return _load( 1117*da0073e9SAndroid Build Coastguard Worker opened_zipfile, 1118*da0073e9SAndroid Build Coastguard Worker map_location, 1119*da0073e9SAndroid Build Coastguard Worker pickle_module, 1120*da0073e9SAndroid Build Coastguard Worker overall_storage=overall_storage, 1121*da0073e9SAndroid Build Coastguard Worker **pickle_load_args, 1122*da0073e9SAndroid Build Coastguard Worker ) 1123*da0073e9SAndroid Build Coastguard Worker if mmap: 1124*da0073e9SAndroid Build Coastguard Worker f_name = "" if not isinstance(f, str) else f"{f}, " 1125*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("mmap can only be used with files saved with " 1126*da0073e9SAndroid Build Coastguard Worker f"`torch.save({f_name}_use_new_zipfile_serialization=True), " 1127*da0073e9SAndroid Build Coastguard Worker "please torch.save your checkpoint with this option in order to use mmap.") 1128*da0073e9SAndroid Build Coastguard Worker if weights_only: 1129*da0073e9SAndroid Build Coastguard Worker try: 1130*da0073e9SAndroid Build Coastguard Worker return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) 1131*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1132*da0073e9SAndroid Build Coastguard Worker raise pickle.UnpicklingError(_get_wo_message(str(e))) from None 1133*da0073e9SAndroid Build Coastguard Worker return _legacy_load( 1134*da0073e9SAndroid Build Coastguard Worker opened_file, map_location, pickle_module, **pickle_load_args 1135*da0073e9SAndroid Build Coastguard Worker ) 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker# Register pickling support for layout instances such as 1139*da0073e9SAndroid Build Coastguard Worker# torch.sparse_coo, etc 1140*da0073e9SAndroid Build Coastguard Workerdef _get_layout(name): 1141*da0073e9SAndroid Build Coastguard Worker """Get layout extension object from its string representation. 1142*da0073e9SAndroid Build Coastguard Worker """ 1143*da0073e9SAndroid Build Coastguard Worker cache = _get_layout.cache # type: ignore[attr-defined] 1144*da0073e9SAndroid Build Coastguard Worker if not cache: 1145*da0073e9SAndroid Build Coastguard Worker for v in torch.__dict__.values(): 1146*da0073e9SAndroid Build Coastguard Worker if isinstance(v, torch.layout): 1147*da0073e9SAndroid Build Coastguard Worker cache[str(v)] = v 1148*da0073e9SAndroid Build Coastguard Worker return cache[name] 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 1151*da0073e9SAndroid Build Coastguard Worker_get_layout.cache = {} # type: ignore[attr-defined] 1152*da0073e9SAndroid Build Coastguard Workercopyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Workerdef _legacy_load(f, map_location, pickle_module, **pickle_load_args): 1156*da0073e9SAndroid Build Coastguard Worker deserialized_objects: Dict[int, Any] = {} 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker restore_location = _get_restore_location(map_location) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker def find_class(self, mod_name, name): 1163*da0073e9SAndroid Build Coastguard Worker if type(name) is str and 'Storage' in name: 1164*da0073e9SAndroid Build Coastguard Worker try: 1165*da0073e9SAndroid Build Coastguard Worker return StorageType(name) 1166*da0073e9SAndroid Build Coastguard Worker except KeyError: 1167*da0073e9SAndroid Build Coastguard Worker pass 1168*da0073e9SAndroid Build Coastguard Worker return super().find_class(mod_name, name) 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker def _check_container_source(container_type, source_file, original_source): 1171*da0073e9SAndroid Build Coastguard Worker try: 1172*da0073e9SAndroid Build Coastguard Worker current_source = ''.join(get_source_lines_and_file(container_type)[0]) 1173*da0073e9SAndroid Build Coastguard Worker except Exception: # saving the source is optional, so we can ignore any errors 1174*da0073e9SAndroid Build Coastguard Worker warnings.warn("Couldn't retrieve source code for container of " 1175*da0073e9SAndroid Build Coastguard Worker "type " + container_type.__name__ + ". It won't be checked " 1176*da0073e9SAndroid Build Coastguard Worker "for correctness upon loading.") 1177*da0073e9SAndroid Build Coastguard Worker return 1178*da0073e9SAndroid Build Coastguard Worker if original_source != current_source: 1179*da0073e9SAndroid Build Coastguard Worker if container_type.dump_patches: 1180*da0073e9SAndroid Build Coastguard Worker file_name = container_type.__name__ + '.patch' 1181*da0073e9SAndroid Build Coastguard Worker diff = difflib.unified_diff(current_source.split('\n'), 1182*da0073e9SAndroid Build Coastguard Worker original_source.split('\n'), 1183*da0073e9SAndroid Build Coastguard Worker source_file, 1184*da0073e9SAndroid Build Coastguard Worker source_file, lineterm="") 1185*da0073e9SAndroid Build Coastguard Worker lines = '\n'.join(diff) 1186*da0073e9SAndroid Build Coastguard Worker try: 1187*da0073e9SAndroid Build Coastguard Worker with open(file_name, 'a+') as f: 1188*da0073e9SAndroid Build Coastguard Worker file_size = f.seek(0, 2) 1189*da0073e9SAndroid Build Coastguard Worker f.seek(0) 1190*da0073e9SAndroid Build Coastguard Worker if file_size == 0: 1191*da0073e9SAndroid Build Coastguard Worker f.write(lines) 1192*da0073e9SAndroid Build Coastguard Worker elif file_size != len(lines) or f.read() != lines: 1193*da0073e9SAndroid Build Coastguard Worker raise OSError 1194*da0073e9SAndroid Build Coastguard Worker msg = ("Saved a reverse patch to " + file_name + ". " 1195*da0073e9SAndroid Build Coastguard Worker "Run `patch -p0 < " + file_name + "` to revert your " 1196*da0073e9SAndroid Build Coastguard Worker "changes.") 1197*da0073e9SAndroid Build Coastguard Worker except OSError: 1198*da0073e9SAndroid Build Coastguard Worker msg = ("Tried to save a patch, but couldn't create a " 1199*da0073e9SAndroid Build Coastguard Worker "writable file " + file_name + ". Make sure it " 1200*da0073e9SAndroid Build Coastguard Worker "doesn't exist and your working directory is " 1201*da0073e9SAndroid Build Coastguard Worker "writable.") 1202*da0073e9SAndroid Build Coastguard Worker else: 1203*da0073e9SAndroid Build Coastguard Worker msg = ("you can retrieve the original source code by " 1204*da0073e9SAndroid Build Coastguard Worker "accessing the object's source attribute or set " 1205*da0073e9SAndroid Build Coastguard Worker "`torch.nn.Module.dump_patches = True` and use the " 1206*da0073e9SAndroid Build Coastguard Worker "patch tool to revert the changes.") 1207*da0073e9SAndroid Build Coastguard Worker msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" 1208*da0073e9SAndroid Build Coastguard Worker warnings.warn(msg, SourceChangeWarning) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker def legacy_load(f): 1211*da0073e9SAndroid Build Coastguard Worker deserialized_objects: Dict[int, Any] = {} 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker def persistent_load(saved_id): 1214*da0073e9SAndroid Build Coastguard Worker if isinstance(saved_id, tuple): 1215*da0073e9SAndroid Build Coastguard Worker # Ignore containers that don't have any sources saved 1216*da0073e9SAndroid Build Coastguard Worker if all(saved_id[1:]): 1217*da0073e9SAndroid Build Coastguard Worker _check_container_source(*saved_id) 1218*da0073e9SAndroid Build Coastguard Worker return saved_id[0] 1219*da0073e9SAndroid Build Coastguard Worker return deserialized_objects[int(saved_id)] 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ 1222*da0073e9SAndroid Build Coastguard Worker mkdtemp() as tmpdir: 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker tar.extract('storages', path=tmpdir) 1225*da0073e9SAndroid Build Coastguard Worker with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: 1226*da0073e9SAndroid Build Coastguard Worker num_storages = pickle_module.load(f, **pickle_load_args) 1227*da0073e9SAndroid Build Coastguard Worker for i in range(num_storages): 1228*da0073e9SAndroid Build Coastguard Worker args = pickle_module.load(f, **pickle_load_args) 1229*da0073e9SAndroid Build Coastguard Worker key, location, storage_type = args 1230*da0073e9SAndroid Build Coastguard Worker dtype = storage_type._dtype 1231*da0073e9SAndroid Build Coastguard Worker obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) 1232*da0073e9SAndroid Build Coastguard Worker obj = restore_location(obj, location) 1233*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 1234*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 1235*da0073e9SAndroid Build Coastguard Worker deserialized_objects[key] = torch.storage.TypedStorage( 1236*da0073e9SAndroid Build Coastguard Worker wrap_storage=obj, 1237*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1238*da0073e9SAndroid Build Coastguard Worker _internal=True) 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker storage_views = pickle_module.load(f, **pickle_load_args) 1241*da0073e9SAndroid Build Coastguard Worker for target_cdata, root_cdata, offset, numel in storage_views: 1242*da0073e9SAndroid Build Coastguard Worker root = deserialized_objects[root_cdata] 1243*da0073e9SAndroid Build Coastguard Worker element_size = torch._utils._element_size(root.dtype) 1244*da0073e9SAndroid Build Coastguard Worker offset_bytes = offset * element_size 1245*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 1246*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 1247*da0073e9SAndroid Build Coastguard Worker deserialized_objects[target_cdata] = torch.storage.TypedStorage( 1248*da0073e9SAndroid Build Coastguard Worker wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], 1249*da0073e9SAndroid Build Coastguard Worker dtype=root.dtype, 1250*da0073e9SAndroid Build Coastguard Worker _internal=True) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker tar.extract('tensors', path=tmpdir) 1253*da0073e9SAndroid Build Coastguard Worker with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: 1254*da0073e9SAndroid Build Coastguard Worker num_tensors = pickle_module.load(f, **pickle_load_args) 1255*da0073e9SAndroid Build Coastguard Worker for _ in range(num_tensors): 1256*da0073e9SAndroid Build Coastguard Worker args = pickle_module.load(f, **pickle_load_args) 1257*da0073e9SAndroid Build Coastguard Worker key, storage_id, original_tensor_type = args 1258*da0073e9SAndroid Build Coastguard Worker storage = deserialized_objects[storage_id] 1259*da0073e9SAndroid Build Coastguard Worker ndim, = struct.unpack('<i', f.read(4)) 1260*da0073e9SAndroid Build Coastguard Worker # skip next 4 bytes; legacy encoding treated ndim as 8 bytes 1261*da0073e9SAndroid Build Coastguard Worker f.read(4) 1262*da0073e9SAndroid Build Coastguard Worker numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) 1263*da0073e9SAndroid Build Coastguard Worker stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) 1264*da0073e9SAndroid Build Coastguard Worker storage_offset, = struct.unpack('<q', f.read(8)) 1265*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty((0,), dtype=storage.dtype).set_( 1266*da0073e9SAndroid Build Coastguard Worker storage._untyped_storage, storage_offset, numel, stride) 1267*da0073e9SAndroid Build Coastguard Worker deserialized_objects[key] = tensor 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker pickle_file = tar.extractfile('pickle') 1270*da0073e9SAndroid Build Coastguard Worker unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) 1271*da0073e9SAndroid Build Coastguard Worker unpickler.persistent_load = persistent_load 1272*da0073e9SAndroid Build Coastguard Worker result = unpickler.load() 1273*da0073e9SAndroid Build Coastguard Worker return result 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker deserialized_objects = {} 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker def persistent_load(saved_id): 1278*da0073e9SAndroid Build Coastguard Worker assert isinstance(saved_id, tuple) 1279*da0073e9SAndroid Build Coastguard Worker typename = _maybe_decode_ascii(saved_id[0]) 1280*da0073e9SAndroid Build Coastguard Worker data = saved_id[1:] 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker if typename == 'module': 1283*da0073e9SAndroid Build Coastguard Worker # Ignore containers that don't have any sources saved 1284*da0073e9SAndroid Build Coastguard Worker if all(data[1:]): 1285*da0073e9SAndroid Build Coastguard Worker _check_container_source(*data) 1286*da0073e9SAndroid Build Coastguard Worker return data[0] 1287*da0073e9SAndroid Build Coastguard Worker elif typename == 'storage': 1288*da0073e9SAndroid Build Coastguard Worker storage_type, root_key, location, numel, view_metadata = data 1289*da0073e9SAndroid Build Coastguard Worker location = _maybe_decode_ascii(location) 1290*da0073e9SAndroid Build Coastguard Worker dtype = storage_type.dtype 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker nbytes = numel * torch._utils._element_size(dtype) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker if root_key not in deserialized_objects: 1295*da0073e9SAndroid Build Coastguard Worker if torch._guards.active_fake_mode() is not None: 1296*da0073e9SAndroid Build Coastguard Worker obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta')) 1297*da0073e9SAndroid Build Coastguard Worker else: 1298*da0073e9SAndroid Build Coastguard Worker obj = cast(Storage, torch.UntypedStorage(nbytes)) 1299*da0073e9SAndroid Build Coastguard Worker obj._torch_load_uninitialized = True 1300*da0073e9SAndroid Build Coastguard Worker obj = restore_location(obj, location) 1301*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 1302*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 1303*da0073e9SAndroid Build Coastguard Worker typed_storage = torch.storage.TypedStorage( 1304*da0073e9SAndroid Build Coastguard Worker wrap_storage=obj, 1305*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1306*da0073e9SAndroid Build Coastguard Worker _internal=True) 1307*da0073e9SAndroid Build Coastguard Worker deserialized_objects[root_key] = typed_storage 1308*da0073e9SAndroid Build Coastguard Worker else: 1309*da0073e9SAndroid Build Coastguard Worker typed_storage = deserialized_objects[root_key] 1310*da0073e9SAndroid Build Coastguard Worker if typed_storage._data_ptr() == 0: 1311*da0073e9SAndroid Build Coastguard Worker typed_storage = torch.storage.TypedStorage( 1312*da0073e9SAndroid Build Coastguard Worker device=typed_storage._untyped_storage.device, 1313*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1314*da0073e9SAndroid Build Coastguard Worker _internal=True) 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker if view_metadata is not None: 1317*da0073e9SAndroid Build Coastguard Worker view_key, offset, view_size = view_metadata 1318*da0073e9SAndroid Build Coastguard Worker offset_bytes = offset * torch._utils._element_size(dtype) 1319*da0073e9SAndroid Build Coastguard Worker view_size_bytes = view_size * torch._utils._element_size(dtype) 1320*da0073e9SAndroid Build Coastguard Worker if view_key not in deserialized_objects: 1321*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 1322*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 1323*da0073e9SAndroid Build Coastguard Worker deserialized_objects[view_key] = torch.storage.TypedStorage( 1324*da0073e9SAndroid Build Coastguard Worker wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes], 1325*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1326*da0073e9SAndroid Build Coastguard Worker _internal=True) 1327*da0073e9SAndroid Build Coastguard Worker res = deserialized_objects[view_key] 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker else: 1330*da0073e9SAndroid Build Coastguard Worker res = typed_storage 1331*da0073e9SAndroid Build Coastguard Worker return res 1332*da0073e9SAndroid Build Coastguard Worker else: 1333*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker _check_seekable(f) 1336*da0073e9SAndroid Build Coastguard Worker f_should_read_directly = _should_read_directly(f) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker if f_should_read_directly and f.tell() == 0: 1339*da0073e9SAndroid Build Coastguard Worker # legacy_load requires that f has fileno() 1340*da0073e9SAndroid Build Coastguard Worker # only if offset is zero we can attempt the legacy tar file loader 1341*da0073e9SAndroid Build Coastguard Worker try: 1342*da0073e9SAndroid Build Coastguard Worker return legacy_load(f) 1343*da0073e9SAndroid Build Coastguard Worker except tarfile.TarError: 1344*da0073e9SAndroid Build Coastguard Worker if _is_zipfile(f): 1345*da0073e9SAndroid Build Coastguard Worker # .zip is used for torch.jit.save and will throw an un-pickling error here 1346*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1347*da0073e9SAndroid Build Coastguard Worker f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None 1348*da0073e9SAndroid Build Coastguard Worker # if not a tarfile, reset file offset and proceed 1349*da0073e9SAndroid Build Coastguard Worker f.seek(0) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): 1352*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1353*da0073e9SAndroid Build Coastguard Worker "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " 1354*da0073e9SAndroid Build Coastguard Worker f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this ' 1355*da0073e9SAndroid Build Coastguard Worker "functionality.") 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker magic_number = pickle_module.load(f, **pickle_load_args) 1358*da0073e9SAndroid Build Coastguard Worker if magic_number != MAGIC_NUMBER: 1359*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Invalid magic number; corrupt file?") 1360*da0073e9SAndroid Build Coastguard Worker protocol_version = pickle_module.load(f, **pickle_load_args) 1361*da0073e9SAndroid Build Coastguard Worker if protocol_version != PROTOCOL_VERSION: 1362*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid protocol version: {protocol_version}") 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker _sys_info = pickle_module.load(f, **pickle_load_args) 1365*da0073e9SAndroid Build Coastguard Worker unpickler = UnpicklerWrapper(f, **pickle_load_args) 1366*da0073e9SAndroid Build Coastguard Worker unpickler.persistent_load = persistent_load 1367*da0073e9SAndroid Build Coastguard Worker result = unpickler.load() 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker if torch._guards.active_fake_mode() is None: 1372*da0073e9SAndroid Build Coastguard Worker offset = f.tell() if f_should_read_directly else None 1373*da0073e9SAndroid Build Coastguard Worker for key in deserialized_storage_keys: 1374*da0073e9SAndroid Build Coastguard Worker assert key in deserialized_objects 1375*da0073e9SAndroid Build Coastguard Worker typed_storage = deserialized_objects[key] 1376*da0073e9SAndroid Build Coastguard Worker typed_storage._untyped_storage._set_from_file( 1377*da0073e9SAndroid Build Coastguard Worker f, offset, f_should_read_directly, 1378*da0073e9SAndroid Build Coastguard Worker torch._utils._element_size(typed_storage.dtype)) 1379*da0073e9SAndroid Build Coastguard Worker if offset is not None: 1380*da0073e9SAndroid Build Coastguard Worker offset = f.tell() 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker torch._utils._validate_loaded_sparse_tensors() 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker return result 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Workerdef _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: 1388*da0073e9SAndroid Build Coastguard Worker # When using encoding='bytes' in Py3, some **internal** keys stored as 1389*da0073e9SAndroid Build Coastguard Worker # strings in Py2 are loaded as bytes. This function decodes them with 1390*da0073e9SAndroid Build Coastguard Worker # ascii encoding, one that Py3 uses by default. 1391*da0073e9SAndroid Build Coastguard Worker # 1392*da0073e9SAndroid Build Coastguard Worker # NOTE: This should only be used on internal keys (e.g., `typename` and 1393*da0073e9SAndroid Build Coastguard Worker # `location` in `persistent_load` below! 1394*da0073e9SAndroid Build Coastguard Worker if isinstance(bytes_str, bytes): 1395*da0073e9SAndroid Build Coastguard Worker return bytes_str.decode('ascii') 1396*da0073e9SAndroid Build Coastguard Worker return bytes_str 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Workerdef _get_restore_location(map_location): 1400*da0073e9SAndroid Build Coastguard Worker if map_location is None: 1401*da0073e9SAndroid Build Coastguard Worker restore_location = default_restore_location 1402*da0073e9SAndroid Build Coastguard Worker elif isinstance(map_location, dict): 1403*da0073e9SAndroid Build Coastguard Worker def restore_location(storage, location): 1404*da0073e9SAndroid Build Coastguard Worker location = map_location.get(location, location) 1405*da0073e9SAndroid Build Coastguard Worker return default_restore_location(storage, location) 1406*da0073e9SAndroid Build Coastguard Worker elif isinstance(map_location, (str, bytes)): 1407*da0073e9SAndroid Build Coastguard Worker def restore_location(storage, location): 1408*da0073e9SAndroid Build Coastguard Worker return default_restore_location(storage, map_location) 1409*da0073e9SAndroid Build Coastguard Worker elif isinstance(map_location, torch.device): 1410*da0073e9SAndroid Build Coastguard Worker def restore_location(storage, location): 1411*da0073e9SAndroid Build Coastguard Worker return default_restore_location(storage, str(map_location)) 1412*da0073e9SAndroid Build Coastguard Worker else: 1413*da0073e9SAndroid Build Coastguard Worker def restore_location(storage, location): 1414*da0073e9SAndroid Build Coastguard Worker result = map_location(storage, location) 1415*da0073e9SAndroid Build Coastguard Worker if result is None: 1416*da0073e9SAndroid Build Coastguard Worker result = default_restore_location(storage, location) 1417*da0073e9SAndroid Build Coastguard Worker return result 1418*da0073e9SAndroid Build Coastguard Worker return restore_location 1419*da0073e9SAndroid Build Coastguard Worker 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Workerclass StorageType: 1422*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 1423*da0073e9SAndroid Build Coastguard Worker self._dtype = _get_dtype_from_pickle_storage_type(name) 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker @property 1426*da0073e9SAndroid Build Coastguard Worker def dtype(self): 1427*da0073e9SAndroid Build Coastguard Worker return self._dtype 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker def __str__(self): 1430*da0073e9SAndroid Build Coastguard Worker return f'StorageType(dtype={self.dtype})' 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Workerdef _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): 1434*da0073e9SAndroid Build Coastguard Worker restore_location = _get_restore_location(map_location) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker loaded_storages = {} 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker # check if byteswapping is needed 1439*da0073e9SAndroid Build Coastguard Worker byteordername = 'byteorder' 1440*da0073e9SAndroid Build Coastguard Worker byteorderdata = None 1441*da0073e9SAndroid Build Coastguard Worker if zip_file.has_record(byteordername): 1442*da0073e9SAndroid Build Coastguard Worker byteorderdata = zip_file.get_record(byteordername) 1443*da0073e9SAndroid Build Coastguard Worker if byteorderdata not in [b'little', b'big']: 1444*da0073e9SAndroid Build Coastguard Worker raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) 1445*da0073e9SAndroid Build Coastguard Worker elif get_default_load_endianness() == LoadEndianness.LITTLE or \ 1446*da0073e9SAndroid Build Coastguard Worker get_default_load_endianness() is None: 1447*da0073e9SAndroid Build Coastguard Worker byteorderdata = b'little' 1448*da0073e9SAndroid Build Coastguard Worker elif get_default_load_endianness() == LoadEndianness.BIG: 1449*da0073e9SAndroid Build Coastguard Worker byteorderdata = b'big' 1450*da0073e9SAndroid Build Coastguard Worker elif get_default_load_endianness() == LoadEndianness.NATIVE: 1451*da0073e9SAndroid Build Coastguard Worker pass 1452*da0073e9SAndroid Build Coastguard Worker else: 1453*da0073e9SAndroid Build Coastguard Worker raise ValueError('Invalid load endianness type') 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker if not zip_file.has_record(byteordername) and \ 1456*da0073e9SAndroid Build Coastguard Worker get_default_load_endianness() is None and \ 1457*da0073e9SAndroid Build Coastguard Worker sys.byteorder == 'big': 1458*da0073e9SAndroid Build Coastguard Worker # Default behaviour was changed 1459*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/101688 1460*da0073e9SAndroid Build Coastguard Worker warnings.warn("The default load endianness for checkpoints without a byteorder mark " 1461*da0073e9SAndroid Build Coastguard Worker "on big endian machines was changed from 'native' to 'little' endian, " 1462*da0073e9SAndroid Build Coastguard Worker "to avoid this behavior please use " 1463*da0073e9SAndroid Build Coastguard Worker "torch.serialization.set_default_load_endianness to set " 1464*da0073e9SAndroid Build Coastguard Worker "the desired default load endianness", 1465*da0073e9SAndroid Build Coastguard Worker UserWarning) 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker def load_tensor(dtype, numel, key, location): 1468*da0073e9SAndroid Build Coastguard Worker name = f'data/{key}' 1469*da0073e9SAndroid Build Coastguard Worker if torch._guards.detect_fake_mode(None) is not None: 1470*da0073e9SAndroid Build Coastguard Worker nbytes = numel * torch._utils._element_size(dtype) 1471*da0073e9SAndroid Build Coastguard Worker storage = torch.UntypedStorage(nbytes, device='meta') 1472*da0073e9SAndroid Build Coastguard Worker elif overall_storage is not None: 1473*da0073e9SAndroid Build Coastguard Worker storage_offset = zip_file.get_record_offset(name) 1474*da0073e9SAndroid Build Coastguard Worker storage = overall_storage[storage_offset:storage_offset + numel] 1475*da0073e9SAndroid Build Coastguard Worker else: 1476*da0073e9SAndroid Build Coastguard Worker storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage 1477*da0073e9SAndroid Build Coastguard Worker # swap here if byteswapping is needed 1478*da0073e9SAndroid Build Coastguard Worker if byteorderdata is not None: 1479*da0073e9SAndroid Build Coastguard Worker if byteorderdata.decode() != sys.byteorder: 1480*da0073e9SAndroid Build Coastguard Worker storage.byteswap(dtype) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 1483*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 1484*da0073e9SAndroid Build Coastguard Worker typed_storage = torch.storage.TypedStorage( 1485*da0073e9SAndroid Build Coastguard Worker wrap_storage=restore_location(storage, location), 1486*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1487*da0073e9SAndroid Build Coastguard Worker _internal=True) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker if typed_storage._data_ptr() != 0: 1490*da0073e9SAndroid Build Coastguard Worker loaded_storages[key] = typed_storage 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker return typed_storage 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker def persistent_load(saved_id): 1495*da0073e9SAndroid Build Coastguard Worker assert isinstance(saved_id, tuple) 1496*da0073e9SAndroid Build Coastguard Worker typename = _maybe_decode_ascii(saved_id[0]) 1497*da0073e9SAndroid Build Coastguard Worker data = saved_id[1:] 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker assert typename == 'storage', \ 1500*da0073e9SAndroid Build Coastguard Worker f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" 1501*da0073e9SAndroid Build Coastguard Worker storage_type, key, location, numel = data 1502*da0073e9SAndroid Build Coastguard Worker if storage_type is torch.UntypedStorage: 1503*da0073e9SAndroid Build Coastguard Worker dtype = torch.uint8 1504*da0073e9SAndroid Build Coastguard Worker else: 1505*da0073e9SAndroid Build Coastguard Worker dtype = storage_type.dtype 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker if key in loaded_storages: 1508*da0073e9SAndroid Build Coastguard Worker typed_storage = loaded_storages[key] 1509*da0073e9SAndroid Build Coastguard Worker else: 1510*da0073e9SAndroid Build Coastguard Worker nbytes = numel * torch._utils._element_size(dtype) 1511*da0073e9SAndroid Build Coastguard Worker typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker return typed_storage 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker load_module_mapping: Dict[str, str] = { 1516*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/51633 1517*da0073e9SAndroid Build Coastguard Worker 'torch.tensor': 'torch._tensor' 1518*da0073e9SAndroid Build Coastguard Worker } 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker # Need to subclass Unpickler instead of directly monkey-patching the find_class method 1521*da0073e9SAndroid Build Coastguard Worker # because it's marked readonly in pickle. 1522*da0073e9SAndroid Build Coastguard Worker # The type: ignore is because mypy can't statically determine the type of this class. 1523*da0073e9SAndroid Build Coastguard Worker class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 1524*da0073e9SAndroid Build Coastguard Worker # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 1525*da0073e9SAndroid Build Coastguard Worker # Lets us override the imports that pickle uses when unpickling an object. 1526*da0073e9SAndroid Build Coastguard Worker # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. 1527*da0073e9SAndroid Build Coastguard Worker def find_class(self, mod_name, name): 1528*da0073e9SAndroid Build Coastguard Worker if type(name) is str and 'Storage' in name: 1529*da0073e9SAndroid Build Coastguard Worker try: 1530*da0073e9SAndroid Build Coastguard Worker return StorageType(name) 1531*da0073e9SAndroid Build Coastguard Worker except KeyError: 1532*da0073e9SAndroid Build Coastguard Worker pass 1533*da0073e9SAndroid Build Coastguard Worker mod_name = load_module_mapping.get(mod_name, mod_name) 1534*da0073e9SAndroid Build Coastguard Worker return super().find_class(mod_name, name) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker # Load the data (which may in turn use `persistent_load` to load tensors) 1537*da0073e9SAndroid Build Coastguard Worker data_file = io.BytesIO(zip_file.get_record(pickle_file)) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 1540*da0073e9SAndroid Build Coastguard Worker unpickler.persistent_load = persistent_load 1541*da0073e9SAndroid Build Coastguard Worker # Needed for tensors where storage device and rebuild tensor device are 1542*da0073e9SAndroid Build Coastguard Worker # not connected (wrapper subclasses and tensors rebuilt using numpy) 1543*da0073e9SAndroid Build Coastguard Worker torch._utils._thread_local_state.map_location = map_location 1544*da0073e9SAndroid Build Coastguard Worker result = unpickler.load() 1545*da0073e9SAndroid Build Coastguard Worker del torch._utils._thread_local_state.map_location 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker torch._utils._validate_loaded_sparse_tensors() 1548*da0073e9SAndroid Build Coastguard Worker torch._C._log_api_usage_metadata( 1549*da0073e9SAndroid Build Coastguard Worker "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} 1550*da0073e9SAndroid Build Coastguard Worker ) 1551*da0073e9SAndroid Build Coastguard Worker return result 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Workerdef _is_torchscript_zip(zip_file): 1555*da0073e9SAndroid Build Coastguard Worker return 'constants.pkl' in zip_file.get_all_records() 1556