xref: /aosp_15_r20/external/pytorch/torch/_weights_only_unpickler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Unpickler restricted to loading only state dicts
3# Restrict constructing types to a list defined in _get_allowed_globals()
4# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
5# Restrict APPEND/APPENDS to `list`
6# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
7# defined by `_get_allowed_globals()` method, that contains:
8# - torch types (Storage, dtypes, Tensor, `torch.Size`),
9# - `torch._utils._rebuild` functions.
10# - `torch.nn.Parameter`
11# - `collections.Counter`
12# - `collections.OrderedDict`
13# Additionally, users can use an allowlist for adding classes they have deemed as safe using
14# `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
15# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
16# `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
17
18# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
19# Expected to be useful for loading PyTorch model weights
20# For example:
21# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
22# buf = io.BytesIO(data)
23# weights = torch.load(buf, weights_only = True)
24
25import functools as _functools
26import warnings
27from collections import Counter, OrderedDict
28from pickle import (
29    APPEND,
30    APPENDS,
31    BINFLOAT,
32    BINGET,
33    BININT,
34    BININT1,
35    BININT2,
36    BINPERSID,
37    BINPUT,
38    BINUNICODE,
39    BUILD,
40    bytes_types,
41    decode_long,
42    EMPTY_DICT,
43    EMPTY_LIST,
44    EMPTY_SET,
45    EMPTY_TUPLE,
46    GLOBAL,
47    LONG1,
48    LONG_BINGET,
49    LONG_BINPUT,
50    MARK,
51    NEWFALSE,
52    NEWOBJ,
53    NEWTRUE,
54    NONE,
55    PROTO,
56    REDUCE,
57    SETITEM,
58    SETITEMS,
59    SHORT_BINSTRING,
60    STOP,
61    TUPLE,
62    TUPLE1,
63    TUPLE2,
64    TUPLE3,
65    UnpicklingError,
66)
67from struct import unpack
68from sys import maxsize
69from typing import Any, Dict, List
70
71import torch
72from torch._utils import IMPORT_MAPPING, NAME_MAPPING
73
74
75_marked_safe_globals_list: List[Any] = []
76
77
78def _add_safe_globals(safe_globals: List[Any]):
79    global _marked_safe_globals_list
80    _marked_safe_globals_list += safe_globals
81
82
83def _get_safe_globals() -> List[Any]:
84    global _marked_safe_globals_list
85    return _marked_safe_globals_list
86
87
88def _clear_safe_globals():
89    global _marked_safe_globals_list
90    _marked_safe_globals_list = []
91
92
93# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
94# For example if user had a script like
95#   torch.load(file_a)
96#   torch.serialization._add_safe_globals([torch.foo])
97#   torch.load(file_b)
98# the dynamic additions to safe_globals would not be picked up by
99# _get_allowed_globals due to the lru_cache
100def _get_user_allowed_globals():
101    rc: Dict[str, Any] = {}
102    for f in _marked_safe_globals_list:
103        module, name = f.__module__, f.__name__
104        rc[f"{module}.{name}"] = f
105    return rc
106
107
108def _tensor_rebuild_functions():
109    return {
110        torch._utils._rebuild_parameter,
111        torch._utils._rebuild_parameter_with_state,
112        torch._utils._rebuild_qtensor,
113        torch._utils._rebuild_tensor,
114        torch._utils._rebuild_tensor_v2,
115        torch._utils._rebuild_tensor_v3,
116        torch._utils._rebuild_sparse_tensor,
117        torch._utils._rebuild_meta_tensor_no_storage,
118        torch._utils._rebuild_nested_tensor,
119        torch._utils._rebuild_wrapper_subclass,
120    }
121
122
123# Unpickling machinery
124@_functools.lru_cache(maxsize=1)
125def _get_allowed_globals():
126    rc: Dict[str, Any] = {
127        "collections.OrderedDict": OrderedDict,
128        "collections.Counter": Counter,
129        "torch.nn.parameter.Parameter": torch.nn.Parameter,
130        "torch.serialization._get_layout": torch.serialization._get_layout,
131        "torch.Size": torch.Size,
132        "torch.Tensor": torch.Tensor,
133        "torch.device": torch.device,
134    }
135    # dtype
136    for t in torch.storage._dtype_to_storage_type_map().keys():
137        rc[str(t)] = t
138    for t in torch.storage._new_dtypes():
139        rc[str(t)] = t
140    # Tensor classes
141    for tt in torch._tensor_classes:
142        rc[f"{tt.__module__}.{tt.__name__}"] = tt
143    # Storage classes
144    for ts in torch._storage_classes:
145        if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
146            # Wrap legacy storage types in a dummy class
147            rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
148                ts.__name__
149            )
150        else:
151            rc[f"{ts.__module__}.{ts.__name__}"] = ts
152    # Quantization specific
153    for qt in [
154        torch.per_tensor_affine,
155        torch.per_tensor_symmetric,
156        torch.per_channel_affine,
157        torch.per_channel_symmetric,
158        torch.per_channel_affine_float_qparams,
159    ]:
160        rc[str(qt)] = qt
161    # Rebuild functions
162    for f in _tensor_rebuild_functions():
163        rc[f"torch._utils.{f.__name__}"] = f
164
165    # Handles Tensor Subclasses, Tensor's with attributes.
166    # NOTE: It calls into above rebuild functions for regular Tensor types.
167    rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
168    return rc
169
170
171class Unpickler:
172    def __init__(self, file, *, encoding: str = "bytes"):
173        self.encoding = encoding
174        self.readline = file.readline
175        self.read = file.read
176        self.memo: Dict[int, Any] = {}
177        self.proto: int = -1
178
179    def load(self):
180        """Read a pickled object representation from the open file.
181
182        Return the reconstituted object hierarchy specified in the file.
183        """
184        self.metastack = []
185        self.stack: List[Any] = []
186        self.append = self.stack.append
187        read = self.read
188        readline = self.readline
189        while True:
190            key = read(1)
191            if not key:
192                raise EOFError
193            assert isinstance(key, bytes_types)
194            # Risky operators
195            if key[0] == GLOBAL[0]:
196                module = readline()[:-1].decode("utf-8")
197                name = readline()[:-1].decode("utf-8")
198                # Patch since torch.save default protocol is 2
199                # users will be running this code in python > 3
200                if self.proto == 2:
201                    if (module, name) in NAME_MAPPING:
202                        module, name = NAME_MAPPING[(module, name)]
203                    elif module in IMPORT_MAPPING:
204                        module = IMPORT_MAPPING[module]
205                full_path = f"{module}.{name}"
206                if full_path in _get_allowed_globals():
207                    self.append(_get_allowed_globals()[full_path])
208                elif full_path in _get_user_allowed_globals():
209                    self.append(_get_user_allowed_globals()[full_path])
210                else:
211                    raise RuntimeError(
212                        f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
213                        f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
214                        "this global if you trust this class/function."
215                    )
216            elif key[0] == NEWOBJ[0]:
217                args = self.stack.pop()
218                cls = self.stack.pop()
219                if cls is torch.nn.Parameter:
220                    self.append(torch.nn.Parameter(*args))
221                elif cls in _get_user_allowed_globals().values():
222                    self.append(cls.__new__(cls, *args))
223                else:
224                    raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
225            elif key[0] == REDUCE[0]:
226                args = self.stack.pop()
227                func = self.stack[-1]
228                if (
229                    func not in _get_allowed_globals().values()
230                    and func not in _get_user_allowed_globals().values()
231                ):
232                    raise RuntimeError(
233                        f"Trying to call reduce for unrecognized function {func}"
234                    )
235                self.stack[-1] = func(*args)
236            elif key[0] == BUILD[0]:
237                state = self.stack.pop()
238                inst = self.stack[-1]
239                if type(inst) is torch.Tensor:
240                    # Legacy unpickling
241                    inst.set_(*state)
242                elif type(inst) is torch.nn.Parameter:
243                    inst.__setstate__(state)
244                elif type(inst) is OrderedDict:
245                    inst.__dict__.update(state)
246                elif type(inst) in _get_user_allowed_globals().values():
247                    if hasattr(inst, "__setstate__"):
248                        inst.__setstate__(state)
249                    else:
250                        inst.__dict__.update(state)
251                else:
252                    raise RuntimeError(
253                        f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
254                    )
255            # Stack manipulation
256            elif key[0] == APPEND[0]:
257                item = self.stack.pop()
258                list_obj = self.stack[-1]
259                if type(list_obj) is not list:
260                    raise RuntimeError(
261                        f"Can only append to lists, but got {type(list_obj)}"
262                    )
263                list_obj.append(item)
264            elif key[0] == APPENDS[0]:
265                items = self.pop_mark()
266                list_obj = self.stack[-1]
267                if type(list_obj) is not list:
268                    raise RuntimeError(
269                        f"Can only extend lists, but got {type(list_obj)}"
270                    )
271                list_obj.extend(items)
272            elif key[0] == SETITEM[0]:
273                (v, k) = (self.stack.pop(), self.stack.pop())
274                self.stack[-1][k] = v
275            elif key[0] == SETITEMS[0]:
276                items = self.pop_mark()
277                for i in range(0, len(items), 2):
278                    self.stack[-1][items[i]] = items[i + 1]
279            elif key[0] == MARK[0]:
280                self.metastack.append(self.stack)
281                self.stack = []
282                self.append = self.stack.append
283            elif key[0] == TUPLE[0]:
284                items = self.pop_mark()
285                self.append(tuple(items))
286            elif key[0] == TUPLE1[0]:
287                self.stack[-1] = (self.stack[-1],)
288            elif key[0] == TUPLE2[0]:
289                self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
290            elif key[0] == TUPLE3[0]:
291                self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
292            # Basic types construction
293            elif key[0] == NONE[0]:
294                self.append(None)
295            elif key[0] == NEWFALSE[0]:
296                self.append(False)
297            elif key[0] == NEWTRUE[0]:
298                self.append(True)
299            elif key[0] == EMPTY_TUPLE[0]:
300                self.append(())
301            elif key[0] == EMPTY_LIST[0]:
302                self.append([])
303            elif key[0] == EMPTY_DICT[0]:
304                self.append({})
305            elif key[0] == EMPTY_SET[0]:
306                self.append(set())
307            elif key[0] == BININT[0]:
308                self.append(unpack("<i", read(4))[0])
309            elif key[0] == BININT1[0]:
310                self.append(self.read(1)[0])
311            elif key[0] == BININT2[0]:
312                self.append(unpack("<H", read(2))[0])
313            elif key[0] == BINFLOAT[0]:
314                self.append(unpack(">d", self.read(8))[0])
315            elif key[0] == BINUNICODE[0]:
316                strlen = unpack("<I", read(4))[0]
317                if strlen > maxsize:
318                    raise RuntimeError("String is too long")
319                strval = str(read(strlen), "utf-8", "surrogatepass")
320                self.append(strval)
321            elif key[0] == SHORT_BINSTRING[0]:
322                strlen = read(1)[0]
323                strdata = read(strlen)
324                if self.encoding != "bytes":
325                    strdata = strdata.decode(self.encoding, "strict")
326                self.append(strdata)
327            elif key[0] == BINPERSID[0]:
328                pid = self.stack.pop()
329                # Only allow persistent load of storage
330                if type(pid) is not tuple and not type(pid) is not int:
331                    raise RuntimeError(
332                        f"persistent_load id must be tuple or int, but got {type(pid)}"
333                    )
334                if (
335                    type(pid) is tuple
336                    and len(pid) > 0
337                    and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
338                ):
339                    raise RuntimeError(
340                        f"Only persistent_load of storage is allowed, but got {pid[0]}"
341                    )
342                self.append(self.persistent_load(pid))
343            elif key[0] in [BINGET[0], LONG_BINGET[0]]:
344                idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
345                self.append(self.memo[idx])
346            elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
347                i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
348                if i < 0:
349                    raise ValueError("negative argument")
350                self.memo[i] = self.stack[-1]
351            elif key[0] == LONG1[0]:
352                n = read(1)[0]
353                data = read(n)
354                self.append(decode_long(data))
355            # First and last deserializer ops
356            elif key[0] == PROTO[0]:
357                self.proto = read(1)[0]
358                if self.proto != 2:
359                    warnings.warn(
360                        f"Detected pickle protocol {self.proto} in the checkpoint, which was "
361                        "not the default pickle protocol used by `torch.load` (2). The weights_only "
362                        "Unpickler might not support all instructions implemented by this protocol, "
363                        "please file an issue for adding support if you encounter this."
364                    )
365            elif key[0] == STOP[0]:
366                rc = self.stack.pop()
367                return rc
368            else:
369                raise RuntimeError(f"Unsupported operand {key[0]}")
370
371    # Return a list of items pushed in the stack after last MARK instruction.
372    def pop_mark(self):
373        items = self.stack
374        self.stack = self.metastack.pop()
375        self.append = self.stack.append
376        return items
377
378    def persistent_load(self, pid):
379        raise UnpicklingError("unsupported persistent id encountered")
380
381
382def load(file, *, encoding: str = "ASCII"):
383    return Unpickler(file, encoding=encoding).load()
384