1# mypy: allow-untyped-defs 2# Unpickler restricted to loading only state dicts 3# Restrict constructing types to a list defined in _get_allowed_globals() 4# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only 5# Restrict APPEND/APPENDS to `list` 6# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary 7# defined by `_get_allowed_globals()` method, that contains: 8# - torch types (Storage, dtypes, Tensor, `torch.Size`), 9# - `torch._utils._rebuild` functions. 10# - `torch.nn.Parameter` 11# - `collections.Counter` 12# - `collections.OrderedDict` 13# Additionally, users can use an allowlist for adding classes they have deemed as safe using 14# `_add_safe_globals()` (`torch.serialization.add_safe_globals`) 15# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) 16# `_get_safe_globals()` (`torch.serialization.get_safe_globals`) 17 18# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py 19# Expected to be useful for loading PyTorch model weights 20# For example: 21# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read() 22# buf = io.BytesIO(data) 23# weights = torch.load(buf, weights_only = True) 24 25import functools as _functools 26import warnings 27from collections import Counter, OrderedDict 28from pickle import ( 29 APPEND, 30 APPENDS, 31 BINFLOAT, 32 BINGET, 33 BININT, 34 BININT1, 35 BININT2, 36 BINPERSID, 37 BINPUT, 38 BINUNICODE, 39 BUILD, 40 bytes_types, 41 decode_long, 42 EMPTY_DICT, 43 EMPTY_LIST, 44 EMPTY_SET, 45 EMPTY_TUPLE, 46 GLOBAL, 47 LONG1, 48 LONG_BINGET, 49 LONG_BINPUT, 50 MARK, 51 NEWFALSE, 52 NEWOBJ, 53 NEWTRUE, 54 NONE, 55 PROTO, 56 REDUCE, 57 SETITEM, 58 SETITEMS, 59 SHORT_BINSTRING, 60 STOP, 61 TUPLE, 62 TUPLE1, 63 TUPLE2, 64 TUPLE3, 65 UnpicklingError, 66) 67from struct import unpack 68from sys import maxsize 69from typing import Any, Dict, List 70 71import torch 72from torch._utils import IMPORT_MAPPING, NAME_MAPPING 73 74 75_marked_safe_globals_list: List[Any] = [] 76 77 78def _add_safe_globals(safe_globals: List[Any]): 79 global _marked_safe_globals_list 80 _marked_safe_globals_list += safe_globals 81 82 83def _get_safe_globals() -> List[Any]: 84 global _marked_safe_globals_list 85 return _marked_safe_globals_list 86 87 88def _clear_safe_globals(): 89 global _marked_safe_globals_list 90 _marked_safe_globals_list = [] 91 92 93# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals 94# For example if user had a script like 95# torch.load(file_a) 96# torch.serialization._add_safe_globals([torch.foo]) 97# torch.load(file_b) 98# the dynamic additions to safe_globals would not be picked up by 99# _get_allowed_globals due to the lru_cache 100def _get_user_allowed_globals(): 101 rc: Dict[str, Any] = {} 102 for f in _marked_safe_globals_list: 103 module, name = f.__module__, f.__name__ 104 rc[f"{module}.{name}"] = f 105 return rc 106 107 108def _tensor_rebuild_functions(): 109 return { 110 torch._utils._rebuild_parameter, 111 torch._utils._rebuild_parameter_with_state, 112 torch._utils._rebuild_qtensor, 113 torch._utils._rebuild_tensor, 114 torch._utils._rebuild_tensor_v2, 115 torch._utils._rebuild_tensor_v3, 116 torch._utils._rebuild_sparse_tensor, 117 torch._utils._rebuild_meta_tensor_no_storage, 118 torch._utils._rebuild_nested_tensor, 119 torch._utils._rebuild_wrapper_subclass, 120 } 121 122 123# Unpickling machinery 124@_functools.lru_cache(maxsize=1) 125def _get_allowed_globals(): 126 rc: Dict[str, Any] = { 127 "collections.OrderedDict": OrderedDict, 128 "collections.Counter": Counter, 129 "torch.nn.parameter.Parameter": torch.nn.Parameter, 130 "torch.serialization._get_layout": torch.serialization._get_layout, 131 "torch.Size": torch.Size, 132 "torch.Tensor": torch.Tensor, 133 "torch.device": torch.device, 134 } 135 # dtype 136 for t in torch.storage._dtype_to_storage_type_map().keys(): 137 rc[str(t)] = t 138 for t in torch.storage._new_dtypes(): 139 rc[str(t)] = t 140 # Tensor classes 141 for tt in torch._tensor_classes: 142 rc[f"{tt.__module__}.{tt.__name__}"] = tt 143 # Storage classes 144 for ts in torch._storage_classes: 145 if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage): 146 # Wrap legacy storage types in a dummy class 147 rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType( 148 ts.__name__ 149 ) 150 else: 151 rc[f"{ts.__module__}.{ts.__name__}"] = ts 152 # Quantization specific 153 for qt in [ 154 torch.per_tensor_affine, 155 torch.per_tensor_symmetric, 156 torch.per_channel_affine, 157 torch.per_channel_symmetric, 158 torch.per_channel_affine_float_qparams, 159 ]: 160 rc[str(qt)] = qt 161 # Rebuild functions 162 for f in _tensor_rebuild_functions(): 163 rc[f"torch._utils.{f.__name__}"] = f 164 165 # Handles Tensor Subclasses, Tensor's with attributes. 166 # NOTE: It calls into above rebuild functions for regular Tensor types. 167 rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 168 return rc 169 170 171class Unpickler: 172 def __init__(self, file, *, encoding: str = "bytes"): 173 self.encoding = encoding 174 self.readline = file.readline 175 self.read = file.read 176 self.memo: Dict[int, Any] = {} 177 self.proto: int = -1 178 179 def load(self): 180 """Read a pickled object representation from the open file. 181 182 Return the reconstituted object hierarchy specified in the file. 183 """ 184 self.metastack = [] 185 self.stack: List[Any] = [] 186 self.append = self.stack.append 187 read = self.read 188 readline = self.readline 189 while True: 190 key = read(1) 191 if not key: 192 raise EOFError 193 assert isinstance(key, bytes_types) 194 # Risky operators 195 if key[0] == GLOBAL[0]: 196 module = readline()[:-1].decode("utf-8") 197 name = readline()[:-1].decode("utf-8") 198 # Patch since torch.save default protocol is 2 199 # users will be running this code in python > 3 200 if self.proto == 2: 201 if (module, name) in NAME_MAPPING: 202 module, name = NAME_MAPPING[(module, name)] 203 elif module in IMPORT_MAPPING: 204 module = IMPORT_MAPPING[module] 205 full_path = f"{module}.{name}" 206 if full_path in _get_allowed_globals(): 207 self.append(_get_allowed_globals()[full_path]) 208 elif full_path in _get_user_allowed_globals(): 209 self.append(_get_user_allowed_globals()[full_path]) 210 else: 211 raise RuntimeError( 212 f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " 213 f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist " 214 "this global if you trust this class/function." 215 ) 216 elif key[0] == NEWOBJ[0]: 217 args = self.stack.pop() 218 cls = self.stack.pop() 219 if cls is torch.nn.Parameter: 220 self.append(torch.nn.Parameter(*args)) 221 elif cls in _get_user_allowed_globals().values(): 222 self.append(cls.__new__(cls, *args)) 223 else: 224 raise RuntimeError(f"Trying to instantiate unsupported class {cls}") 225 elif key[0] == REDUCE[0]: 226 args = self.stack.pop() 227 func = self.stack[-1] 228 if ( 229 func not in _get_allowed_globals().values() 230 and func not in _get_user_allowed_globals().values() 231 ): 232 raise RuntimeError( 233 f"Trying to call reduce for unrecognized function {func}" 234 ) 235 self.stack[-1] = func(*args) 236 elif key[0] == BUILD[0]: 237 state = self.stack.pop() 238 inst = self.stack[-1] 239 if type(inst) is torch.Tensor: 240 # Legacy unpickling 241 inst.set_(*state) 242 elif type(inst) is torch.nn.Parameter: 243 inst.__setstate__(state) 244 elif type(inst) is OrderedDict: 245 inst.__dict__.update(state) 246 elif type(inst) in _get_user_allowed_globals().values(): 247 if hasattr(inst, "__setstate__"): 248 inst.__setstate__(state) 249 else: 250 inst.__dict__.update(state) 251 else: 252 raise RuntimeError( 253 f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" 254 ) 255 # Stack manipulation 256 elif key[0] == APPEND[0]: 257 item = self.stack.pop() 258 list_obj = self.stack[-1] 259 if type(list_obj) is not list: 260 raise RuntimeError( 261 f"Can only append to lists, but got {type(list_obj)}" 262 ) 263 list_obj.append(item) 264 elif key[0] == APPENDS[0]: 265 items = self.pop_mark() 266 list_obj = self.stack[-1] 267 if type(list_obj) is not list: 268 raise RuntimeError( 269 f"Can only extend lists, but got {type(list_obj)}" 270 ) 271 list_obj.extend(items) 272 elif key[0] == SETITEM[0]: 273 (v, k) = (self.stack.pop(), self.stack.pop()) 274 self.stack[-1][k] = v 275 elif key[0] == SETITEMS[0]: 276 items = self.pop_mark() 277 for i in range(0, len(items), 2): 278 self.stack[-1][items[i]] = items[i + 1] 279 elif key[0] == MARK[0]: 280 self.metastack.append(self.stack) 281 self.stack = [] 282 self.append = self.stack.append 283 elif key[0] == TUPLE[0]: 284 items = self.pop_mark() 285 self.append(tuple(items)) 286 elif key[0] == TUPLE1[0]: 287 self.stack[-1] = (self.stack[-1],) 288 elif key[0] == TUPLE2[0]: 289 self.stack[-2:] = [(self.stack[-2], self.stack[-1])] 290 elif key[0] == TUPLE3[0]: 291 self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] 292 # Basic types construction 293 elif key[0] == NONE[0]: 294 self.append(None) 295 elif key[0] == NEWFALSE[0]: 296 self.append(False) 297 elif key[0] == NEWTRUE[0]: 298 self.append(True) 299 elif key[0] == EMPTY_TUPLE[0]: 300 self.append(()) 301 elif key[0] == EMPTY_LIST[0]: 302 self.append([]) 303 elif key[0] == EMPTY_DICT[0]: 304 self.append({}) 305 elif key[0] == EMPTY_SET[0]: 306 self.append(set()) 307 elif key[0] == BININT[0]: 308 self.append(unpack("<i", read(4))[0]) 309 elif key[0] == BININT1[0]: 310 self.append(self.read(1)[0]) 311 elif key[0] == BININT2[0]: 312 self.append(unpack("<H", read(2))[0]) 313 elif key[0] == BINFLOAT[0]: 314 self.append(unpack(">d", self.read(8))[0]) 315 elif key[0] == BINUNICODE[0]: 316 strlen = unpack("<I", read(4))[0] 317 if strlen > maxsize: 318 raise RuntimeError("String is too long") 319 strval = str(read(strlen), "utf-8", "surrogatepass") 320 self.append(strval) 321 elif key[0] == SHORT_BINSTRING[0]: 322 strlen = read(1)[0] 323 strdata = read(strlen) 324 if self.encoding != "bytes": 325 strdata = strdata.decode(self.encoding, "strict") 326 self.append(strdata) 327 elif key[0] == BINPERSID[0]: 328 pid = self.stack.pop() 329 # Only allow persistent load of storage 330 if type(pid) is not tuple and not type(pid) is not int: 331 raise RuntimeError( 332 f"persistent_load id must be tuple or int, but got {type(pid)}" 333 ) 334 if ( 335 type(pid) is tuple 336 and len(pid) > 0 337 and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" 338 ): 339 raise RuntimeError( 340 f"Only persistent_load of storage is allowed, but got {pid[0]}" 341 ) 342 self.append(self.persistent_load(pid)) 343 elif key[0] in [BINGET[0], LONG_BINGET[0]]: 344 idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0] 345 self.append(self.memo[idx]) 346 elif key[0] in [BINPUT[0], LONG_BINPUT[0]]: 347 i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0] 348 if i < 0: 349 raise ValueError("negative argument") 350 self.memo[i] = self.stack[-1] 351 elif key[0] == LONG1[0]: 352 n = read(1)[0] 353 data = read(n) 354 self.append(decode_long(data)) 355 # First and last deserializer ops 356 elif key[0] == PROTO[0]: 357 self.proto = read(1)[0] 358 if self.proto != 2: 359 warnings.warn( 360 f"Detected pickle protocol {self.proto} in the checkpoint, which was " 361 "not the default pickle protocol used by `torch.load` (2). The weights_only " 362 "Unpickler might not support all instructions implemented by this protocol, " 363 "please file an issue for adding support if you encounter this." 364 ) 365 elif key[0] == STOP[0]: 366 rc = self.stack.pop() 367 return rc 368 else: 369 raise RuntimeError(f"Unsupported operand {key[0]}") 370 371 # Return a list of items pushed in the stack after last MARK instruction. 372 def pop_mark(self): 373 items = self.stack 374 self.stack = self.metastack.pop() 375 self.append = self.stack.append 376 return items 377 378 def persistent_load(self, pid): 379 raise UnpicklingError("unsupported persistent id encountered") 380 381 382def load(file, *, encoding: str = "ASCII"): 383 return Unpickler(file, encoding=encoding).load() 384