xref: /aosp_15_r20/external/pytorch/torch/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport copyreg
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport logging
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport traceback
7*da0073e9SAndroid Build Coastguard Workerimport warnings
8*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, DefaultDict, Generic, List, Optional
10*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import ParamSpec
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef _type(self, dtype=None, non_blocking=False, **kwargs):
16*da0073e9SAndroid Build Coastguard Worker    """Returns the type if `dtype` is not provided, else casts this object to
17*da0073e9SAndroid Build Coastguard Worker    the specified type.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    If this is already of the correct type, no copy is performed and the
20*da0073e9SAndroid Build Coastguard Worker    original object is returned.
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    Args:
23*da0073e9SAndroid Build Coastguard Worker        dtype (type or string): The desired type
24*da0073e9SAndroid Build Coastguard Worker        non_blocking (bool): If ``True``, and the source is in pinned memory
25*da0073e9SAndroid Build Coastguard Worker            and destination is on the GPU or vice versa, the copy is performed
26*da0073e9SAndroid Build Coastguard Worker            asynchronously with respect to the host. Otherwise, the argument
27*da0073e9SAndroid Build Coastguard Worker            has no effect.
28*da0073e9SAndroid Build Coastguard Worker        **kwargs: For compatibility, may contain the key ``async`` in place of
29*da0073e9SAndroid Build Coastguard Worker            the ``non_blocking`` argument. The ``async`` arg is deprecated.
30*da0073e9SAndroid Build Coastguard Worker    """
31*da0073e9SAndroid Build Coastguard Worker    non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
32*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
33*da0073e9SAndroid Build Coastguard Worker        return self.__module__ + "." + self.__class__.__name__
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    if isinstance(dtype, str):
36*da0073e9SAndroid Build Coastguard Worker        dtype = _import_dotted_name(dtype)
37*da0073e9SAndroid Build Coastguard Worker    if dtype == type(self):
38*da0073e9SAndroid Build Coastguard Worker        return self
39*da0073e9SAndroid Build Coastguard Worker    if self.is_sparse:
40*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_sparse:
41*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Cannot cast sparse tensor to dense tensor")
42*da0073e9SAndroid Build Coastguard Worker        new_module_name = dtype.__module__.replace(".sparse", "")
43*da0073e9SAndroid Build Coastguard Worker        new_values_type_name = new_module_name + "." + dtype.__name__
44*da0073e9SAndroid Build Coastguard Worker        new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
45*da0073e9SAndroid Build Coastguard Worker        new_indices_type_name = new_module_name + ".LongTensor"
46*da0073e9SAndroid Build Coastguard Worker        new_indices = torch.Tensor._indices(self).type(
47*da0073e9SAndroid Build Coastguard Worker            new_indices_type_name, non_blocking
48*da0073e9SAndroid Build Coastguard Worker        )
49*da0073e9SAndroid Build Coastguard Worker        return dtype(new_indices, new_values, self.size())
50*da0073e9SAndroid Build Coastguard Worker    if dtype.is_sparse:
51*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Cannot cast dense tensor to sparse tensor")
52*da0073e9SAndroid Build Coastguard Worker    return dtype(self.size()).copy_(self, non_blocking)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerdef _to(self, device, non_blocking=False):
56*da0073e9SAndroid Build Coastguard Worker    """Returns a copy of this object in device memory.
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    If this object is already on the correct device, then no copy is performed
59*da0073e9SAndroid Build Coastguard Worker    and the original object is returned.
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    Args:
62*da0073e9SAndroid Build Coastguard Worker        device (int): The destination device.
63*da0073e9SAndroid Build Coastguard Worker        non_blocking (bool): If ``True`` and the source is in pinned memory,
64*da0073e9SAndroid Build Coastguard Worker            the copy will be asynchronous with respect to the host. Otherwise,
65*da0073e9SAndroid Build Coastguard Worker            the argument has no effect.
66*da0073e9SAndroid Build Coastguard Worker    """
67*da0073e9SAndroid Build Coastguard Worker    if self.device == device:
68*da0073e9SAndroid Build Coastguard Worker        return self
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    device_module = getattr(torch, device.type, None)
71*da0073e9SAndroid Build Coastguard Worker    assert (
72*da0073e9SAndroid Build Coastguard Worker        device_module is not None
73*da0073e9SAndroid Build Coastguard Worker    ), f"{device.type.upper()} device module is not loaded"
74*da0073e9SAndroid Build Coastguard Worker    with device_module.device(device):
75*da0073e9SAndroid Build Coastguard Worker        if self.is_sparse and hasattr(device_module, "sparse"):
76*da0073e9SAndroid Build Coastguard Worker            new_type = getattr(device_module.sparse, self.__class__.__name__)
77*da0073e9SAndroid Build Coastguard Worker            indices = getattr(torch.Tensor._indices(self), device.type)(
78*da0073e9SAndroid Build Coastguard Worker                device, non_blocking
79*da0073e9SAndroid Build Coastguard Worker            )
80*da0073e9SAndroid Build Coastguard Worker            values = getattr(torch.Tensor._values(self), device.type)(
81*da0073e9SAndroid Build Coastguard Worker                device, non_blocking
82*da0073e9SAndroid Build Coastguard Worker            )
83*da0073e9SAndroid Build Coastguard Worker            return new_type(indices, values, self.size())
84*da0073e9SAndroid Build Coastguard Worker        else:
85*da0073e9SAndroid Build Coastguard Worker            assert (
86*da0073e9SAndroid Build Coastguard Worker                not self.is_sparse
87*da0073e9SAndroid Build Coastguard Worker            ), f"sparse storage is not supported for {device.type.upper()} tensors"
88*da0073e9SAndroid Build Coastguard Worker            untyped_storage = torch.UntypedStorage(self.size(), device=device)
89*da0073e9SAndroid Build Coastguard Worker            untyped_storage.copy_(self, non_blocking)
90*da0073e9SAndroid Build Coastguard Worker            return untyped_storage
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerdef _get_async_or_non_blocking(function_name, non_blocking, kwargs):
94*da0073e9SAndroid Build Coastguard Worker    """Return the non-blocking flag given the function name and kwargs.
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    Args:
97*da0073e9SAndroid Build Coastguard Worker        function_name (str): the name of the function being used.
98*da0073e9SAndroid Build Coastguard Worker        non_blocking (bool): the default value.
99*da0073e9SAndroid Build Coastguard Worker        **kwargs (dict): the kwargs passed to the function.
100*da0073e9SAndroid Build Coastguard Worker    """
101*da0073e9SAndroid Build Coastguard Worker    if not kwargs:
102*da0073e9SAndroid Build Coastguard Worker        return non_blocking
103*da0073e9SAndroid Build Coastguard Worker    if len(kwargs) != 1 or "async" not in kwargs:
104*da0073e9SAndroid Build Coastguard Worker        message = "{}() got an unexpected keyword argument '{}'"
105*da0073e9SAndroid Build Coastguard Worker        argument = list(kwargs.keys()).pop()
106*da0073e9SAndroid Build Coastguard Worker        raise TypeError(message.format(function_name, argument))
107*da0073e9SAndroid Build Coastguard Worker    warnings.warn("'async' is deprecated; use 'non_blocking'")
108*da0073e9SAndroid Build Coastguard Worker    return kwargs["async"]
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Workerdef _get_restore_location(device):
112*da0073e9SAndroid Build Coastguard Worker    """Return the map_location location.
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    Used for rebuild functions where the tensor device is distinct from the storage
115*da0073e9SAndroid Build Coastguard Worker    """
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    map_location = torch.serialization._serialization_tls.map_location
118*da0073e9SAndroid Build Coastguard Worker    if map_location is None:
119*da0073e9SAndroid Build Coastguard Worker        return device
120*da0073e9SAndroid Build Coastguard Worker    else:
121*da0073e9SAndroid Build Coastguard Worker        if isinstance(map_location, dict):
122*da0073e9SAndroid Build Coastguard Worker            return map_location.get(device, device)
123*da0073e9SAndroid Build Coastguard Worker        elif isinstance(map_location, (str, torch.device)):
124*da0073e9SAndroid Build Coastguard Worker            return map_location
125*da0073e9SAndroid Build Coastguard Worker        else:
126*da0073e9SAndroid Build Coastguard Worker            assert callable(map_location)
127*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
128*da0073e9SAndroid Build Coastguard Worker                "Callable map_location not supported with _rebuild_wrapper_subclass "
129*da0073e9SAndroid Build Coastguard Worker                "or _rebuild_device_tensor_from_numpy"
130*da0073e9SAndroid Build Coastguard Worker            )
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker# Note [Don't serialize hooks]
134*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135*da0073e9SAndroid Build Coastguard Worker# Since time immemorial, we have serialized the backward hooks associated with
136*da0073e9SAndroid Build Coastguard Worker# variables.  This kind of half-worked--Python can pickle global functions
137*da0073e9SAndroid Build Coastguard Worker# (but not closures!)--but there were problems.
138*da0073e9SAndroid Build Coastguard Worker#
139*da0073e9SAndroid Build Coastguard Worker#   - It's fragile.  If you serialize a backward hook into a saved
140*da0073e9SAndroid Build Coastguard Worker#     model, and then you rename the function associated with the hook,
141*da0073e9SAndroid Build Coastguard Worker#     now your saved model is broken and you can't load it anymore.
142*da0073e9SAndroid Build Coastguard Worker#
143*da0073e9SAndroid Build Coastguard Worker#   - It's not actually used.  The standard recommendation is to
144*da0073e9SAndroid Build Coastguard Worker#     serialize the *state_dict* of a model, not the model itself
145*da0073e9SAndroid Build Coastguard Worker#     (since this is more stable to code changes affecting the model
146*da0073e9SAndroid Build Coastguard Worker#     serialization), and the state dict saves "data" only, thus
147*da0073e9SAndroid Build Coastguard Worker#     stripping the backward hooks.  In some cases, hooks are
148*da0073e9SAndroid Build Coastguard Worker#     essential to the well-functioning of a model (e.g., DDP),
149*da0073e9SAndroid Build Coastguard Worker#     but DDP already manages readding the hooks!
150*da0073e9SAndroid Build Coastguard Worker#
151*da0073e9SAndroid Build Coastguard Worker#   - We didn't serialize them in many cases.  Prior to #10220, we
152*da0073e9SAndroid Build Coastguard Worker#     were dropping backward hooks in ForkingPickler.  We "fixed" this
153*da0073e9SAndroid Build Coastguard Worker#     to be convenient with other serialization sites, but lack of
154*da0073e9SAndroid Build Coastguard Worker#     serializing backward hooks wasn't actually the root cause of
155*da0073e9SAndroid Build Coastguard Worker#     the bug.
156*da0073e9SAndroid Build Coastguard Worker#
157*da0073e9SAndroid Build Coastguard Worker# With these cases in mind, we have decided that a better strategy
158*da0073e9SAndroid Build Coastguard Worker# is to just NOT serialize hooks at all.
159*da0073e9SAndroid Build Coastguard Worker#
160*da0073e9SAndroid Build Coastguard Worker# Since this is a BC-breaking change, we should warn when we previously
161*da0073e9SAndroid Build Coastguard Worker# serialized a hook, but no longer do so. This will be done by adding a special
162*da0073e9SAndroid Build Coastguard Worker# sentinel property to hooks will be used to suppress this warning. If a hook
163*da0073e9SAndroid Build Coastguard Worker# has the property _torch_serialize_ignore, we will not emit a warning if we
164*da0073e9SAndroid Build Coastguard Worker# attempt to serialize a Tensor with this hook attached to it.
165*da0073e9SAndroid Build Coastguard Worker#
166*da0073e9SAndroid Build Coastguard Worker# By the way, when _backward_hooks is skipped, we must give an EMPTY
167*da0073e9SAndroid Build Coastguard Worker# OrderedDict(), if you pass a None you'll run afoul #12219.
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker# TODO: Once we decide to break serialization FC, `storage` no longer needs to
171*da0073e9SAndroid Build Coastguard Worker# be a TypedStorage
172*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor(storage, storage_offset, size, stride):
173*da0073e9SAndroid Build Coastguard Worker    # first construct a tensor with the correct dtype/device
174*da0073e9SAndroid Build Coastguard Worker    t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
175*da0073e9SAndroid Build Coastguard Worker    return t.set_(storage._untyped_storage, storage_offset, size, stride)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Workerdef get_tensor_metadata(tensor):
179*da0073e9SAndroid Build Coastguard Worker    # Tensor's Metadata for serializing.
180*da0073e9SAndroid Build Coastguard Worker    # Currently, this only returns a dict[string, bool] specifing whether
181*da0073e9SAndroid Build Coastguard Worker    # `conj` or `neg` bit is set.
182*da0073e9SAndroid Build Coastguard Worker    assert isinstance(tensor, torch.Tensor)
183*da0073e9SAndroid Build Coastguard Worker    return torch._C._get_tensor_metadata(tensor)  # type: ignore[attr-defined]
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef set_tensor_metadata(tensor, metadata):
187*da0073e9SAndroid Build Coastguard Worker    # See `get_tensor_metadata` above
188*da0073e9SAndroid Build Coastguard Worker    assert isinstance(metadata, dict)
189*da0073e9SAndroid Build Coastguard Worker    assert isinstance(tensor, torch.Tensor)
190*da0073e9SAndroid Build Coastguard Worker    torch._C._set_tensor_metadata(tensor, metadata)  # type: ignore[attr-defined]
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor_v2(
194*da0073e9SAndroid Build Coastguard Worker    storage,
195*da0073e9SAndroid Build Coastguard Worker    storage_offset,
196*da0073e9SAndroid Build Coastguard Worker    size,
197*da0073e9SAndroid Build Coastguard Worker    stride,
198*da0073e9SAndroid Build Coastguard Worker    requires_grad,
199*da0073e9SAndroid Build Coastguard Worker    backward_hooks,
200*da0073e9SAndroid Build Coastguard Worker    metadata=None,
201*da0073e9SAndroid Build Coastguard Worker):
202*da0073e9SAndroid Build Coastguard Worker    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203*da0073e9SAndroid Build Coastguard Worker    tensor.requires_grad = requires_grad
204*da0073e9SAndroid Build Coastguard Worker    if metadata:
205*da0073e9SAndroid Build Coastguard Worker        set_tensor_metadata(tensor, metadata)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    # NB: This line exists only for backwards compatibility; the
208*da0073e9SAndroid Build Coastguard Worker    # general expectation is that backward_hooks is an empty
209*da0073e9SAndroid Build Coastguard Worker    # OrderedDict.  See Note [Don't serialize hooks]
210*da0073e9SAndroid Build Coastguard Worker    tensor._backward_hooks = backward_hooks
211*da0073e9SAndroid Build Coastguard Worker    return tensor
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor_v3(
215*da0073e9SAndroid Build Coastguard Worker    storage,
216*da0073e9SAndroid Build Coastguard Worker    storage_offset,
217*da0073e9SAndroid Build Coastguard Worker    size,
218*da0073e9SAndroid Build Coastguard Worker    stride,
219*da0073e9SAndroid Build Coastguard Worker    requires_grad,
220*da0073e9SAndroid Build Coastguard Worker    backward_hooks,
221*da0073e9SAndroid Build Coastguard Worker    dtype,
222*da0073e9SAndroid Build Coastguard Worker    metadata=None,
223*da0073e9SAndroid Build Coastguard Worker):
224*da0073e9SAndroid Build Coastguard Worker    t = torch.empty(
225*da0073e9SAndroid Build Coastguard Worker        (0,),
226*da0073e9SAndroid Build Coastguard Worker        dtype=dtype,
227*da0073e9SAndroid Build Coastguard Worker        device=storage._untyped_storage.device,
228*da0073e9SAndroid Build Coastguard Worker        requires_grad=requires_grad,
229*da0073e9SAndroid Build Coastguard Worker    )
230*da0073e9SAndroid Build Coastguard Worker    t.set_(storage._untyped_storage, storage_offset, size, stride)
231*da0073e9SAndroid Build Coastguard Worker    if metadata:
232*da0073e9SAndroid Build Coastguard Worker        set_tensor_metadata(t, metadata)
233*da0073e9SAndroid Build Coastguard Worker    t._backward_hooks = backward_hooks
234*da0073e9SAndroid Build Coastguard Worker    return t
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker_sparse_tensors_to_validate: List["torch.Tensor"] = []
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker# In _legacy_load() in serialization.py we unpickle storages after the sparse
241*da0073e9SAndroid Build Coastguard Worker# tensors have been already unpickled. Those storages contain data necessary for
242*da0073e9SAndroid Build Coastguard Worker# validating sparse tensors: indices and values. That's why sparse tensors are
243*da0073e9SAndroid Build Coastguard Worker# first unpickled without any validation, and then this function is called just
244*da0073e9SAndroid Build Coastguard Worker# before _legacy_load() returns, so that all the sparse tensors can be validated
245*da0073e9SAndroid Build Coastguard Worker# in bulk.
246*da0073e9SAndroid Build Coastguard Worker#
247*da0073e9SAndroid Build Coastguard Worker# The same procedure must be followed by _load() in serialization.py because due
248*da0073e9SAndroid Build Coastguard Worker# to Pickler semantics, we have to use the same (non-validating) function for
249*da0073e9SAndroid Build Coastguard Worker# unpickling sparse tensors, regardless of the caller.
250*da0073e9SAndroid Build Coastguard Workerdef _validate_loaded_sparse_tensors():
251*da0073e9SAndroid Build Coastguard Worker    try:
252*da0073e9SAndroid Build Coastguard Worker        for t in _sparse_tensors_to_validate:
253*da0073e9SAndroid Build Coastguard Worker            if t.layout is torch.sparse_coo:
254*da0073e9SAndroid Build Coastguard Worker                torch._validate_sparse_coo_tensor_args(
255*da0073e9SAndroid Build Coastguard Worker                    t._indices(), t._values(), t.size(), t.is_coalesced()
256*da0073e9SAndroid Build Coastguard Worker                )
257*da0073e9SAndroid Build Coastguard Worker            elif t.layout in {
258*da0073e9SAndroid Build Coastguard Worker                torch.sparse_csr,
259*da0073e9SAndroid Build Coastguard Worker                torch.sparse_csc,
260*da0073e9SAndroid Build Coastguard Worker                torch.sparse_bsr,
261*da0073e9SAndroid Build Coastguard Worker                torch.sparse_bsc,
262*da0073e9SAndroid Build Coastguard Worker            }:
263*da0073e9SAndroid Build Coastguard Worker                # TODO: Validation currently involves an expensive traversal
264*da0073e9SAndroid Build Coastguard Worker                # on CPU, which may include a device transfer.
265*da0073e9SAndroid Build Coastguard Worker                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266*da0073e9SAndroid Build Coastguard Worker                    compressed_indices, plain_indices = (
267*da0073e9SAndroid Build Coastguard Worker                        t.crow_indices(),
268*da0073e9SAndroid Build Coastguard Worker                        t.col_indices(),
269*da0073e9SAndroid Build Coastguard Worker                    )
270*da0073e9SAndroid Build Coastguard Worker                else:
271*da0073e9SAndroid Build Coastguard Worker                    compressed_indices, plain_indices = (
272*da0073e9SAndroid Build Coastguard Worker                        t.ccol_indices(),
273*da0073e9SAndroid Build Coastguard Worker                        t.row_indices(),
274*da0073e9SAndroid Build Coastguard Worker                    )
275*da0073e9SAndroid Build Coastguard Worker                torch._validate_sparse_compressed_tensor_args(
276*da0073e9SAndroid Build Coastguard Worker                    compressed_indices, plain_indices, t.values(), t.size(), t.layout
277*da0073e9SAndroid Build Coastguard Worker                )
278*da0073e9SAndroid Build Coastguard Worker            else:
279*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(
280*da0073e9SAndroid Build Coastguard Worker                    f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
281*da0073e9SAndroid Build Coastguard Worker                )
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    finally:
284*da0073e9SAndroid Build Coastguard Worker        _sparse_tensors_to_validate.clear()
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Workerdef _rebuild_sparse_tensor(layout, data):
288*da0073e9SAndroid Build Coastguard Worker    """
289*da0073e9SAndroid Build Coastguard Worker    Rebuilds a sparse tensor from its sparse storage representation.
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    Args:
292*da0073e9SAndroid Build Coastguard Worker        layout (str): The sparse storage layout of the tensor.
293*da0073e9SAndroid Build Coastguard Worker        data (tuple): The tensor's sparse storage representation.
294*da0073e9SAndroid Build Coastguard Worker    """
295*da0073e9SAndroid Build Coastguard Worker    if layout == torch.sparse_coo:
296*da0073e9SAndroid Build Coastguard Worker        if len(data) == 3:
297*da0073e9SAndroid Build Coastguard Worker            # For BC:
298*da0073e9SAndroid Build Coastguard Worker            indices, values, size = data
299*da0073e9SAndroid Build Coastguard Worker            is_coalesced = None
300*da0073e9SAndroid Build Coastguard Worker        else:
301*da0073e9SAndroid Build Coastguard Worker            indices, values, size, is_coalesced = data
302*da0073e9SAndroid Build Coastguard Worker        result = torch.sparse_coo_tensor(
303*da0073e9SAndroid Build Coastguard Worker            indices, values, size, check_invariants=False, is_coalesced=is_coalesced
304*da0073e9SAndroid Build Coastguard Worker        )
305*da0073e9SAndroid Build Coastguard Worker        _sparse_tensors_to_validate.append(result)
306*da0073e9SAndroid Build Coastguard Worker        return result
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    elif layout in {
309*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csr,
310*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csc,
311*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsr,
312*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsc,
313*da0073e9SAndroid Build Coastguard Worker    }:
314*da0073e9SAndroid Build Coastguard Worker        compressed_indices, plain_indices, values, size = data
315*da0073e9SAndroid Build Coastguard Worker        result = torch.sparse_compressed_tensor(
316*da0073e9SAndroid Build Coastguard Worker            compressed_indices,
317*da0073e9SAndroid Build Coastguard Worker            plain_indices,
318*da0073e9SAndroid Build Coastguard Worker            values,
319*da0073e9SAndroid Build Coastguard Worker            size,
320*da0073e9SAndroid Build Coastguard Worker            layout=layout,
321*da0073e9SAndroid Build Coastguard Worker            check_invariants=False,
322*da0073e9SAndroid Build Coastguard Worker        )
323*da0073e9SAndroid Build Coastguard Worker        _sparse_tensors_to_validate.append(result)
324*da0073e9SAndroid Build Coastguard Worker        return result
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker    raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Workerdef _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330*da0073e9SAndroid Build Coastguard Worker    return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Workerdef _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334*da0073e9SAndroid Build Coastguard Worker    device = _get_restore_location(device)
335*da0073e9SAndroid Build Coastguard Worker    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
336*da0073e9SAndroid Build Coastguard Worker    tensor.requires_grad = requires_grad
337*da0073e9SAndroid Build Coastguard Worker    return tensor
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
341*da0073e9SAndroid Build Coastguard Worker_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Workerdef _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
345*da0073e9SAndroid Build Coastguard Worker    return torch.empty_strided(
346*da0073e9SAndroid Build Coastguard Worker        size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
347*da0073e9SAndroid Build Coastguard Worker    )
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Workerdef _rebuild_wrapper_subclass(
351*da0073e9SAndroid Build Coastguard Worker    cls,
352*da0073e9SAndroid Build Coastguard Worker    dtype,
353*da0073e9SAndroid Build Coastguard Worker    size,
354*da0073e9SAndroid Build Coastguard Worker    stride,
355*da0073e9SAndroid Build Coastguard Worker    storage_offset,
356*da0073e9SAndroid Build Coastguard Worker    layout,
357*da0073e9SAndroid Build Coastguard Worker    device,
358*da0073e9SAndroid Build Coastguard Worker    requires_grad,
359*da0073e9SAndroid Build Coastguard Worker):
360*da0073e9SAndroid Build Coastguard Worker    device = _get_restore_location(device)
361*da0073e9SAndroid Build Coastguard Worker    return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
362*da0073e9SAndroid Build Coastguard Worker        cls,
363*da0073e9SAndroid Build Coastguard Worker        size,
364*da0073e9SAndroid Build Coastguard Worker        strides=stride,
365*da0073e9SAndroid Build Coastguard Worker        dtype=dtype,
366*da0073e9SAndroid Build Coastguard Worker        storage_offset=storage_offset,
367*da0073e9SAndroid Build Coastguard Worker        layout=layout,
368*da0073e9SAndroid Build Coastguard Worker        device=device,
369*da0073e9SAndroid Build Coastguard Worker        requires_grad=requires_grad,
370*da0073e9SAndroid Build Coastguard Worker    )
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker# TODO: Once we decide to break serialization FC, `storage` no longer needs to
374*da0073e9SAndroid Build Coastguard Worker# be a TypedStorage
375*da0073e9SAndroid Build Coastguard Workerdef _rebuild_qtensor(
376*da0073e9SAndroid Build Coastguard Worker    storage,
377*da0073e9SAndroid Build Coastguard Worker    storage_offset,
378*da0073e9SAndroid Build Coastguard Worker    size,
379*da0073e9SAndroid Build Coastguard Worker    stride,
380*da0073e9SAndroid Build Coastguard Worker    quantizer_params,
381*da0073e9SAndroid Build Coastguard Worker    requires_grad,
382*da0073e9SAndroid Build Coastguard Worker    backward_hooks,
383*da0073e9SAndroid Build Coastguard Worker):
384*da0073e9SAndroid Build Coastguard Worker    qscheme = quantizer_params[0]
385*da0073e9SAndroid Build Coastguard Worker    if qscheme == torch.per_tensor_affine:
386*da0073e9SAndroid Build Coastguard Worker        _, scale, zero_point = quantizer_params
387*da0073e9SAndroid Build Coastguard Worker        tensor = torch._empty_affine_quantized(
388*da0073e9SAndroid Build Coastguard Worker            size,
389*da0073e9SAndroid Build Coastguard Worker            scale=scale,
390*da0073e9SAndroid Build Coastguard Worker            zero_point=zero_point,
391*da0073e9SAndroid Build Coastguard Worker            dtype=storage.dtype,
392*da0073e9SAndroid Build Coastguard Worker            device=storage.device,
393*da0073e9SAndroid Build Coastguard Worker        )
394*da0073e9SAndroid Build Coastguard Worker    elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
395*da0073e9SAndroid Build Coastguard Worker        _, scales, zero_points, axis = quantizer_params
396*da0073e9SAndroid Build Coastguard Worker        if type(scales) is list and type(zero_points) is list:
397*da0073e9SAndroid Build Coastguard Worker            if qscheme == torch.per_channel_affine:
398*da0073e9SAndroid Build Coastguard Worker                scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
399*da0073e9SAndroid Build Coastguard Worker                zero_points = torch.tensor(
400*da0073e9SAndroid Build Coastguard Worker                    zero_points, dtype=torch.long, device=storage.device
401*da0073e9SAndroid Build Coastguard Worker                )
402*da0073e9SAndroid Build Coastguard Worker            else:
403*da0073e9SAndroid Build Coastguard Worker                scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
404*da0073e9SAndroid Build Coastguard Worker                zero_points = torch.tensor(
405*da0073e9SAndroid Build Coastguard Worker                    zero_points, dtype=torch.float, device=storage.device
406*da0073e9SAndroid Build Coastguard Worker                )
407*da0073e9SAndroid Build Coastguard Worker        tensor = torch._empty_per_channel_affine_quantized(
408*da0073e9SAndroid Build Coastguard Worker            size,
409*da0073e9SAndroid Build Coastguard Worker            scales=scales,
410*da0073e9SAndroid Build Coastguard Worker            zero_points=zero_points,
411*da0073e9SAndroid Build Coastguard Worker            axis=axis,
412*da0073e9SAndroid Build Coastguard Worker            dtype=storage.dtype,
413*da0073e9SAndroid Build Coastguard Worker            device=storage.device,
414*da0073e9SAndroid Build Coastguard Worker        )
415*da0073e9SAndroid Build Coastguard Worker    else:
416*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
417*da0073e9SAndroid Build Coastguard Worker    tensor.set_(storage, storage_offset, size, stride)
418*da0073e9SAndroid Build Coastguard Worker    tensor.requires_grad = requires_grad
419*da0073e9SAndroid Build Coastguard Worker    # NB: This line exists only for backwards compatibility; the
420*da0073e9SAndroid Build Coastguard Worker    # general expectation is that backward_hooks is an empty
421*da0073e9SAndroid Build Coastguard Worker    # OrderedDict.  See Note [Don't serialize hooks]
422*da0073e9SAndroid Build Coastguard Worker    tensor._backward_hooks = backward_hooks
423*da0073e9SAndroid Build Coastguard Worker    return tensor
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Workerdef _rebuild_parameter(data, requires_grad, backward_hooks):
427*da0073e9SAndroid Build Coastguard Worker    param = torch.nn.Parameter(data, requires_grad)
428*da0073e9SAndroid Build Coastguard Worker    # NB: This line exists only for backwards compatibility; the
429*da0073e9SAndroid Build Coastguard Worker    # general expectation is that backward_hooks is an empty
430*da0073e9SAndroid Build Coastguard Worker    # OrderedDict.  See Note [Don't serialize hooks]
431*da0073e9SAndroid Build Coastguard Worker    param._backward_hooks = backward_hooks
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    return param
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Workerdef _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
437*da0073e9SAndroid Build Coastguard Worker    param = torch.nn.Parameter(data, requires_grad)
438*da0073e9SAndroid Build Coastguard Worker    # NB: This line exists only for backwards compatibility; the
439*da0073e9SAndroid Build Coastguard Worker    # general expectation is that backward_hooks is an empty
440*da0073e9SAndroid Build Coastguard Worker    # OrderedDict.  See Note [Don't serialize hooks]
441*da0073e9SAndroid Build Coastguard Worker    param._backward_hooks = backward_hooks
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    # Restore state on Parameter like python attr.
444*da0073e9SAndroid Build Coastguard Worker    param = _set_obj_state(param, state)
445*da0073e9SAndroid Build Coastguard Worker    return param
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Workerdef _get_obj_state(obj):
449*da0073e9SAndroid Build Coastguard Worker    # Get the state of the python subclass
450*da0073e9SAndroid Build Coastguard Worker    # This loosely mimicks the function on the object class but since Tensor do not inherit
451*da0073e9SAndroid Build Coastguard Worker    # from it, we cannot call that function directly
452*da0073e9SAndroid Build Coastguard Worker    # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
453*da0073e9SAndroid Build Coastguard Worker    # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
454*da0073e9SAndroid Build Coastguard Worker    # the else branch will never be taken.
455*da0073e9SAndroid Build Coastguard Worker    getstate_fn = getattr(obj, "__getstate__", None)
456*da0073e9SAndroid Build Coastguard Worker    if getstate_fn:
457*da0073e9SAndroid Build Coastguard Worker        state = getstate_fn()
458*da0073e9SAndroid Build Coastguard Worker    else:
459*da0073e9SAndroid Build Coastguard Worker        slots_to_save = copyreg._slotnames(obj.__class__)  # type: ignore[attr-defined]
460*da0073e9SAndroid Build Coastguard Worker        if slots_to_save:
461*da0073e9SAndroid Build Coastguard Worker            state = (
462*da0073e9SAndroid Build Coastguard Worker                obj.__dict__,
463*da0073e9SAndroid Build Coastguard Worker                {
464*da0073e9SAndroid Build Coastguard Worker                    name: getattr(obj, name)
465*da0073e9SAndroid Build Coastguard Worker                    for name in slots_to_save
466*da0073e9SAndroid Build Coastguard Worker                    if hasattr(obj, name)
467*da0073e9SAndroid Build Coastguard Worker                },
468*da0073e9SAndroid Build Coastguard Worker            )
469*da0073e9SAndroid Build Coastguard Worker        else:
470*da0073e9SAndroid Build Coastguard Worker            state = obj.__dict__
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    return state
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Workerdef _set_obj_state(obj, state):
476*da0073e9SAndroid Build Coastguard Worker    if isinstance(state, tuple):
477*da0073e9SAndroid Build Coastguard Worker        if not len(state) == 2:
478*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Invalid serialized state: {state}")
479*da0073e9SAndroid Build Coastguard Worker        dict_state = state[0]
480*da0073e9SAndroid Build Coastguard Worker        slots_state = state[1]
481*da0073e9SAndroid Build Coastguard Worker    else:
482*da0073e9SAndroid Build Coastguard Worker        dict_state = state
483*da0073e9SAndroid Build Coastguard Worker        slots_state = None
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker    # Starting with Python 3.11, the __dict__ attribute is lazily created
486*da0073e9SAndroid Build Coastguard Worker    # and is serialized as None when not needed.
487*da0073e9SAndroid Build Coastguard Worker    if dict_state:
488*da0073e9SAndroid Build Coastguard Worker        for k, v in dict_state.items():
489*da0073e9SAndroid Build Coastguard Worker            setattr(obj, k, v)
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    if slots_state:
492*da0073e9SAndroid Build Coastguard Worker        for k, v in slots_state.items():
493*da0073e9SAndroid Build Coastguard Worker            setattr(obj, k, v)
494*da0073e9SAndroid Build Coastguard Worker    return obj
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Workerdef _import_dotted_name(name):
498*da0073e9SAndroid Build Coastguard Worker    components = name.split(".")
499*da0073e9SAndroid Build Coastguard Worker    obj = __import__(components[0])
500*da0073e9SAndroid Build Coastguard Worker    for component in components[1:]:
501*da0073e9SAndroid Build Coastguard Worker        obj = getattr(obj, component)
502*da0073e9SAndroid Build Coastguard Worker    return obj
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Workerdef _flatten_dense_tensors(tensors):
506*da0073e9SAndroid Build Coastguard Worker    """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
507*da0073e9SAndroid Build Coastguard Worker    same dense type.
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker    Since inputs are dense, the resulting tensor will be a concatenated 1D
510*da0073e9SAndroid Build Coastguard Worker    buffer. Element-wise operation on this buffer will be equivalent to
511*da0073e9SAndroid Build Coastguard Worker    operating individually.
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker    Args:
514*da0073e9SAndroid Build Coastguard Worker        tensors (Iterable[Tensor]): dense tensors to flatten.
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker    Returns:
517*da0073e9SAndroid Build Coastguard Worker        A contiguous 1D buffer containing input tensors.
518*da0073e9SAndroid Build Coastguard Worker    """
519*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.flatten_dense_tensors(tensors)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Workerdef _flatten_sparse_tensors(tensors):
523*da0073e9SAndroid Build Coastguard Worker    """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
524*da0073e9SAndroid Build Coastguard Worker    one of values. Assume tensors are of same sparse type.
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    Args:
527*da0073e9SAndroid Build Coastguard Worker        tensors (Iterable[Tensor]): sparse tensors to flatten.
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    Returns:
530*da0073e9SAndroid Build Coastguard Worker        A tuple of two contiguous 1D buffers, one containing input tensors'
531*da0073e9SAndroid Build Coastguard Worker        indices and the other containing the values.
532*da0073e9SAndroid Build Coastguard Worker    """
533*da0073e9SAndroid Build Coastguard Worker    flat_indices = torch._C._nn.flatten_dense_tensors(
534*da0073e9SAndroid Build Coastguard Worker        [torch.Tensor._indices(t) for t in tensors]
535*da0073e9SAndroid Build Coastguard Worker    )
536*da0073e9SAndroid Build Coastguard Worker    flat_values = torch._C._nn.flatten_dense_tensors(
537*da0073e9SAndroid Build Coastguard Worker        [torch.Tensor._values(t) for t in tensors]
538*da0073e9SAndroid Build Coastguard Worker    )
539*da0073e9SAndroid Build Coastguard Worker    return flat_indices, flat_values
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Workerdef _unflatten_dense_tensors(flat, tensors):
543*da0073e9SAndroid Build Coastguard Worker    """View a flat buffer using the sizes of tensors. Assume that tensors are of
544*da0073e9SAndroid Build Coastguard Worker    same dense type, and that flat is given by _flatten_dense_tensors.
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    Args:
547*da0073e9SAndroid Build Coastguard Worker        flat (Tensor): flattened dense tensors to unflatten.
548*da0073e9SAndroid Build Coastguard Worker        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
549*da0073e9SAndroid Build Coastguard Worker          unflatten flat.
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker    Returns:
552*da0073e9SAndroid Build Coastguard Worker        Unflattened dense tensors with sizes same as tensors and values from
553*da0073e9SAndroid Build Coastguard Worker        flat.
554*da0073e9SAndroid Build Coastguard Worker    """
555*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Workerdef _unflatten_sparse_tensors(flat, tensors):
559*da0073e9SAndroid Build Coastguard Worker    """View flat buffer (containing indices and values) using the sizes of
560*da0073e9SAndroid Build Coastguard Worker    tensors. Assume that tensors are of same sparse type, and that flat is given
561*da0073e9SAndroid Build Coastguard Worker    by _flatten_sparse_tensors.
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    Args:
564*da0073e9SAndroid Build Coastguard Worker        flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
565*da0073e9SAndroid Build Coastguard Worker          tensors to unflatten.
566*da0073e9SAndroid Build Coastguard Worker        tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
567*da0073e9SAndroid Build Coastguard Worker          unflatten flat.
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    Returns:
570*da0073e9SAndroid Build Coastguard Worker        Unflattened sparse tensors with sizes same as tensors and values from
571*da0073e9SAndroid Build Coastguard Worker        flat.
572*da0073e9SAndroid Build Coastguard Worker    """
573*da0073e9SAndroid Build Coastguard Worker    flat_indices, flat_values = flat
574*da0073e9SAndroid Build Coastguard Worker    indices = torch._C._nn.unflatten_dense_tensors(
575*da0073e9SAndroid Build Coastguard Worker        flat_indices, [torch.Tensor._indices(t) for t in tensors]
576*da0073e9SAndroid Build Coastguard Worker    )
577*da0073e9SAndroid Build Coastguard Worker    values = torch._C._nn.unflatten_dense_tensors(
578*da0073e9SAndroid Build Coastguard Worker        flat_values, [torch.Tensor._values(t) for t in tensors]
579*da0073e9SAndroid Build Coastguard Worker    )
580*da0073e9SAndroid Build Coastguard Worker    outputs = []
581*da0073e9SAndroid Build Coastguard Worker    for t, i, v in zip(tensors, indices, values):
582*da0073e9SAndroid Build Coastguard Worker        outputs.append(t.new(i, v, t.size()))
583*da0073e9SAndroid Build Coastguard Worker    return tuple(outputs)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Workerdef _reorder_tensors_as(tensors, ordered_tensors):
587*da0073e9SAndroid Build Coastguard Worker    """Assume that tensors are of same order as ordered_tensors within their
588*da0073e9SAndroid Build Coastguard Worker    types, e.g., from _take_tensors. Reorder them to be of same order as
589*da0073e9SAndroid Build Coastguard Worker    ordered_tensors.
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    Args:
592*da0073e9SAndroid Build Coastguard Worker        tensors (Iterable[Tensor]): tensors to be reordered. They should be of
593*da0073e9SAndroid Build Coastguard Worker          the same order as ordered_tensors within their own types.
594*da0073e9SAndroid Build Coastguard Worker        ordered_tensors (Iterable[Tensor]): tensors whose order will be the
595*da0073e9SAndroid Build Coastguard Worker          reference.
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    Returns:
598*da0073e9SAndroid Build Coastguard Worker        Ordered tuple of tensors with contents from tensors and order of
599*da0073e9SAndroid Build Coastguard Worker        ordered_tensors.
600*da0073e9SAndroid Build Coastguard Worker    """
601*da0073e9SAndroid Build Coastguard Worker    type_dict = defaultdict(list)
602*da0073e9SAndroid Build Coastguard Worker    for tensor in tensors:
603*da0073e9SAndroid Build Coastguard Worker        type_dict[tensor.type()].append(tensor)
604*da0073e9SAndroid Build Coastguard Worker    type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
605*da0073e9SAndroid Build Coastguard Worker    return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Workerdef _take_tensors(tensors, size_limit):
609*da0073e9SAndroid Build Coastguard Worker    """Group tensors into chunks. This generator yields a chunk at each time,
610*da0073e9SAndroid Build Coastguard Worker    each containing tensors of same type up to certain byte limit in total size.
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker    Args:
613*da0073e9SAndroid Build Coastguard Worker        tensors (Sequence): A sequence of tensors to be separated into chunks.
614*da0073e9SAndroid Build Coastguard Worker        size_limit (int): The limit of each chunk in bytes.
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker    Yields:
617*da0073e9SAndroid Build Coastguard Worker        Blocks of tensors of same type and within size_limit. The yielded
618*da0073e9SAndroid Build Coastguard Worker        tensors are only ordered as the original sequence within its types.
619*da0073e9SAndroid Build Coastguard Worker    """
620*da0073e9SAndroid Build Coastguard Worker    buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
621*da0073e9SAndroid Build Coastguard Worker    for tensor in tensors:
622*da0073e9SAndroid Build Coastguard Worker        t = tensor.type()
623*da0073e9SAndroid Build Coastguard Worker        if tensor.is_sparse:
624*da0073e9SAndroid Build Coastguard Worker            indices = torch.Tensor._indices(tensor)
625*da0073e9SAndroid Build Coastguard Worker            values = torch.Tensor._values(tensor)
626*da0073e9SAndroid Build Coastguard Worker            size = (
627*da0073e9SAndroid Build Coastguard Worker                indices.numel() * indices.element_size()
628*da0073e9SAndroid Build Coastguard Worker                + values.numel() * values.element_size()
629*da0073e9SAndroid Build Coastguard Worker            )
630*da0073e9SAndroid Build Coastguard Worker        else:
631*da0073e9SAndroid Build Coastguard Worker            size = tensor.numel() * tensor.element_size()
632*da0073e9SAndroid Build Coastguard Worker        buf_and_size = buf_dict[t]
633*da0073e9SAndroid Build Coastguard Worker        if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
634*da0073e9SAndroid Build Coastguard Worker            yield buf_and_size[0]
635*da0073e9SAndroid Build Coastguard Worker            buf_and_size = buf_dict[t] = [[], 0]
636*da0073e9SAndroid Build Coastguard Worker        buf_and_size[0].append(tensor)
637*da0073e9SAndroid Build Coastguard Worker        buf_and_size[1] += size
638*da0073e9SAndroid Build Coastguard Worker    for buf, _ in buf_dict.values():
639*da0073e9SAndroid Build Coastguard Worker        if len(buf) > 0:
640*da0073e9SAndroid Build Coastguard Worker            yield buf
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker# annotation decorator to get annotations in a way that is compatible
644*da0073e9SAndroid Build Coastguard Worker# with both Python 2 and 3
645*da0073e9SAndroid Build Coastguard Workerdef annotate(ret, **kwargs):
646*da0073e9SAndroid Build Coastguard Worker    def dec(fun):
647*da0073e9SAndroid Build Coastguard Worker        fun.__annotations__ = dict(kwargs)
648*da0073e9SAndroid Build Coastguard Worker        fun.__annotations__["return"] = ret
649*da0073e9SAndroid Build Coastguard Worker        return fun
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    return dec
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Workerdef render_call(fn, args, kwargs):
655*da0073e9SAndroid Build Coastguard Worker    str_fn = torch.overrides.resolve_name(fn)
656*da0073e9SAndroid Build Coastguard Worker    if str_fn is None:
657*da0073e9SAndroid Build Coastguard Worker        str_fn = str(fn)
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker    str_args: List[str] = []
660*da0073e9SAndroid Build Coastguard Worker    with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
661*da0073e9SAndroid Build Coastguard Worker        str_args.extend(repr(a) for a in args)
662*da0073e9SAndroid Build Coastguard Worker        str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
663*da0073e9SAndroid Build Coastguard Worker        r = f"{str_fn}({', '.join(str_args)})"
664*da0073e9SAndroid Build Coastguard Worker    return r
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker# NOTE [ Python Traceback Reference Cycle Problem ]
668*da0073e9SAndroid Build Coastguard Worker#
669*da0073e9SAndroid Build Coastguard Worker# When using sys.exc_info(), it is important to **not** store the exc_info[2],
670*da0073e9SAndroid Build Coastguard Worker# which is the traceback, because otherwise you will run into the traceback
671*da0073e9SAndroid Build Coastguard Worker# reference cycle problem, i.e., the traceback holding reference to the frame,
672*da0073e9SAndroid Build Coastguard Worker# and the frame (which holds reference to all the object in its temporary scope)
673*da0073e9SAndroid Build Coastguard Worker# holding reference the traceback.
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Workerclass KeyErrorMessage(str):
677*da0073e9SAndroid Build Coastguard Worker    r"""str subclass that returns itself in repr"""
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
680*da0073e9SAndroid Build Coastguard Worker        return self
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Workerclass ExceptionWrapper:
684*da0073e9SAndroid Build Coastguard Worker    r"""Wraps an exception plus traceback to communicate across threads"""
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker    def __init__(self, exc_info=None, where="in background"):
687*da0073e9SAndroid Build Coastguard Worker        # It is important that we don't store exc_info, see
688*da0073e9SAndroid Build Coastguard Worker        # NOTE [ Python Traceback Reference Cycle Problem ]
689*da0073e9SAndroid Build Coastguard Worker        if exc_info is None:
690*da0073e9SAndroid Build Coastguard Worker            exc_info = sys.exc_info()
691*da0073e9SAndroid Build Coastguard Worker        self.exc_type = exc_info[0]
692*da0073e9SAndroid Build Coastguard Worker        self.exc_msg = "".join(traceback.format_exception(*exc_info))
693*da0073e9SAndroid Build Coastguard Worker        self.where = where
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    def reraise(self):
696*da0073e9SAndroid Build Coastguard Worker        r"""Reraises the wrapped exception in the current thread"""
697*da0073e9SAndroid Build Coastguard Worker        # Format a message such as: "Caught ValueError in DataLoader worker
698*da0073e9SAndroid Build Coastguard Worker        # process 2. Original Traceback:", followed by the traceback.
699*da0073e9SAndroid Build Coastguard Worker        msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
700*da0073e9SAndroid Build Coastguard Worker        if self.exc_type == KeyError:
701*da0073e9SAndroid Build Coastguard Worker            # KeyError calls repr() on its argument (usually a dict key). This
702*da0073e9SAndroid Build Coastguard Worker            # makes stack traces unreadable. It will not be changed in Python
703*da0073e9SAndroid Build Coastguard Worker            # (https://bugs.python.org/issue2651), so we work around it.
704*da0073e9SAndroid Build Coastguard Worker            msg = KeyErrorMessage(msg)
705*da0073e9SAndroid Build Coastguard Worker        elif getattr(self.exc_type, "message", None):
706*da0073e9SAndroid Build Coastguard Worker            # Some exceptions have first argument as non-str but explicitly
707*da0073e9SAndroid Build Coastguard Worker            # have message field
708*da0073e9SAndroid Build Coastguard Worker            raise self.exc_type(message=msg)
709*da0073e9SAndroid Build Coastguard Worker        try:
710*da0073e9SAndroid Build Coastguard Worker            exception = self.exc_type(msg)
711*da0073e9SAndroid Build Coastguard Worker        except TypeError:
712*da0073e9SAndroid Build Coastguard Worker            # If the exception takes multiple arguments, don't try to
713*da0073e9SAndroid Build Coastguard Worker            # instantiate since we don't know how to
714*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(msg) from None
715*da0073e9SAndroid Build Coastguard Worker        raise exception
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Workerdef _get_available_device_type():
719*da0073e9SAndroid Build Coastguard Worker    if torch.cuda.is_available():
720*da0073e9SAndroid Build Coastguard Worker        return "cuda"
721*da0073e9SAndroid Build Coastguard Worker    if hasattr(torch, "xpu") and torch.xpu.is_available():  # type: ignore[attr-defined]
722*da0073e9SAndroid Build Coastguard Worker        return "xpu"
723*da0073e9SAndroid Build Coastguard Worker    if hasattr(torch, "mtia") and torch.mtia.is_available():
724*da0073e9SAndroid Build Coastguard Worker        return "mtia"
725*da0073e9SAndroid Build Coastguard Worker    custom_backend_name = torch._C._get_privateuse1_backend_name()
726*da0073e9SAndroid Build Coastguard Worker    custom_device_mod = getattr(torch, custom_backend_name, None)
727*da0073e9SAndroid Build Coastguard Worker    if custom_device_mod and custom_device_mod.is_available():
728*da0073e9SAndroid Build Coastguard Worker        return custom_backend_name
729*da0073e9SAndroid Build Coastguard Worker    # add more available device types here
730*da0073e9SAndroid Build Coastguard Worker    return None
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Workerdef _get_device_attr(get_member):
734*da0073e9SAndroid Build Coastguard Worker    device_type = _get_available_device_type()
735*da0073e9SAndroid Build Coastguard Worker    if device_type and device_type.lower() == "cuda":
736*da0073e9SAndroid Build Coastguard Worker        return get_member(torch.cuda)
737*da0073e9SAndroid Build Coastguard Worker    if device_type and device_type.lower() == "xpu":
738*da0073e9SAndroid Build Coastguard Worker        return get_member(torch.xpu)  # type: ignore[attr-defined]
739*da0073e9SAndroid Build Coastguard Worker    if device_type and device_type.lower() == "mtia":
740*da0073e9SAndroid Build Coastguard Worker        return get_member(torch.mtia)
741*da0073e9SAndroid Build Coastguard Worker    if device_type == torch._C._get_privateuse1_backend_name():
742*da0073e9SAndroid Build Coastguard Worker        return get_member(getattr(torch, device_type))
743*da0073e9SAndroid Build Coastguard Worker    # add more available device types here
744*da0073e9SAndroid Build Coastguard Worker    return None
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Workerdef _get_current_device_index():
748*da0073e9SAndroid Build Coastguard Worker    # current device index
749*da0073e9SAndroid Build Coastguard Worker    return _get_device_attr(lambda m: m.current_device())
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Workerdef _get_all_device_indices():
753*da0073e9SAndroid Build Coastguard Worker    # all device index
754*da0073e9SAndroid Build Coastguard Worker    return _get_device_attr(lambda m: list(range(m.device_count())))
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Workerdef _get_devices_properties(device_ids):
758*da0073e9SAndroid Build Coastguard Worker    # all device properties
759*da0073e9SAndroid Build Coastguard Worker    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Workerdef get_current_device_index() -> int:
763*da0073e9SAndroid Build Coastguard Worker    r"""Checks if there are CUDA devices available and
764*da0073e9SAndroid Build Coastguard Worker    returns the device index of the current default CUDA device.
765*da0073e9SAndroid Build Coastguard Worker    Returns -1 in case there are no CUDA devices available.
766*da0073e9SAndroid Build Coastguard Worker    Arguments: ``None``
767*da0073e9SAndroid Build Coastguard Worker    """
768*da0073e9SAndroid Build Coastguard Worker    if torch.cuda.device_count() > 0:
769*da0073e9SAndroid Build Coastguard Worker        return torch.cuda.current_device()
770*da0073e9SAndroid Build Coastguard Worker    return -1
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Workerdef _get_device_index(
774*da0073e9SAndroid Build Coastguard Worker    device: Any,
775*da0073e9SAndroid Build Coastguard Worker    optional: bool = False,
776*da0073e9SAndroid Build Coastguard Worker    allow_cpu: bool = False,
777*da0073e9SAndroid Build Coastguard Worker) -> int:
778*da0073e9SAndroid Build Coastguard Worker    r"""Gets the device index from :attr:`device`, which can be a torch.device
779*da0073e9SAndroid Build Coastguard Worker    object, a Python integer, or ``None``.
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker    If :attr:`device` is a torch.device object, returns the device index if it
782*da0073e9SAndroid Build Coastguard Worker    has index. Note that for a device without a specified index,
783*da0073e9SAndroid Build Coastguard Worker    i.e., ``torch.device('xxx')``, this will return the current default
784*da0073e9SAndroid Build Coastguard Worker    device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
785*da0073e9SAndroid Build Coastguard Worker    CPU devices will be accepted and ``-1`` will be returned in this case.
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker    If :attr:`device` is a Python integer, it is returned as is.
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker    If :attr:`device` is ``None``, this will return the current default
790*da0073e9SAndroid Build Coastguard Worker    device of the supported runtime platform if :attr:`optional` is ``True``.
791*da0073e9SAndroid Build Coastguard Worker    i.e., the current default CUDA device will be returned if CUDA runtime is supported.
792*da0073e9SAndroid Build Coastguard Worker    """
793*da0073e9SAndroid Build Coastguard Worker    if isinstance(device, str):
794*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
795*da0073e9SAndroid Build Coastguard Worker    device_idx: Optional[int] = None
796*da0073e9SAndroid Build Coastguard Worker    if isinstance(device, torch.device):
797*da0073e9SAndroid Build Coastguard Worker        if not allow_cpu and device.type == "cpu":
798*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Expected a non cpu device, but got: {device}")
799*da0073e9SAndroid Build Coastguard Worker        device_idx = -1 if device.type == "cpu" else device.index
800*da0073e9SAndroid Build Coastguard Worker    if isinstance(device, int):
801*da0073e9SAndroid Build Coastguard Worker        device_idx = device
802*da0073e9SAndroid Build Coastguard Worker    if device_idx is None:
803*da0073e9SAndroid Build Coastguard Worker        if optional:
804*da0073e9SAndroid Build Coastguard Worker            # The eager API _get_current_device_index uses `lambda` functions which are
805*da0073e9SAndroid Build Coastguard Worker            # not supported in JIT and hence not scriptable. The JIT equivalent API to get
806*da0073e9SAndroid Build Coastguard Worker            # the current device index is `get_current_device_index()` which can
807*da0073e9SAndroid Build Coastguard Worker            # be scripted. We use is_scripting to check the mode we are in and call the
808*da0073e9SAndroid Build Coastguard Worker            # appropriate API.
809*da0073e9SAndroid Build Coastguard Worker            if torch.jit.is_scripting():
810*da0073e9SAndroid Build Coastguard Worker                device_idx = get_current_device_index()
811*da0073e9SAndroid Build Coastguard Worker            else:
812*da0073e9SAndroid Build Coastguard Worker                device_idx = _get_current_device_index()
813*da0073e9SAndroid Build Coastguard Worker        else:
814*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
815*da0073e9SAndroid Build Coastguard Worker                f"Expected a torch.device with a specified index or an integer, but got:{device}"
816*da0073e9SAndroid Build Coastguard Worker            )
817*da0073e9SAndroid Build Coastguard Worker    return device_idx
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Workerdef _handle_complex(tensor):
821*da0073e9SAndroid Build Coastguard Worker    """
822*da0073e9SAndroid Build Coastguard Worker    Returns a real view of a tensor if complex dtype else just the tensor
823*da0073e9SAndroid Build Coastguard Worker    need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
824*da0073e9SAndroid Build Coastguard Worker    """
825*da0073e9SAndroid Build Coastguard Worker    return (
826*da0073e9SAndroid Build Coastguard Worker        torch.view_as_real(tensor)
827*da0073e9SAndroid Build Coastguard Worker        if not isinstance(tensor, torch.nn.UninitializedParameter)
828*da0073e9SAndroid Build Coastguard Worker        and tensor.is_complex()
829*da0073e9SAndroid Build Coastguard Worker        else tensor
830*da0073e9SAndroid Build Coastguard Worker    )
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Workerdef _element_size(dtype):
834*da0073e9SAndroid Build Coastguard Worker    """
835*da0073e9SAndroid Build Coastguard Worker    Returns the element size for a dtype, in bytes
836*da0073e9SAndroid Build Coastguard Worker    """
837*da0073e9SAndroid Build Coastguard Worker    if not isinstance(dtype, torch.dtype):
838*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker    if dtype.is_complex:
841*da0073e9SAndroid Build Coastguard Worker        return torch.finfo(dtype).bits >> 2
842*da0073e9SAndroid Build Coastguard Worker    elif dtype.is_floating_point:
843*da0073e9SAndroid Build Coastguard Worker        return torch.finfo(dtype).bits >> 3
844*da0073e9SAndroid Build Coastguard Worker    elif dtype == torch.bool:
845*da0073e9SAndroid Build Coastguard Worker        # NOTE: torch.bool is not supported in torch.iinfo()
846*da0073e9SAndroid Build Coastguard Worker        return 1
847*da0073e9SAndroid Build Coastguard Worker    else:
848*da0073e9SAndroid Build Coastguard Worker        return torch.iinfo(dtype).bits >> 3
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Workerclass _ClassPropertyDescriptor:
852*da0073e9SAndroid Build Coastguard Worker    def __init__(self, fget, fset=None):
853*da0073e9SAndroid Build Coastguard Worker        self.fget = fget
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker    def __get__(self, instance, owner=None):
856*da0073e9SAndroid Build Coastguard Worker        if owner is None:
857*da0073e9SAndroid Build Coastguard Worker            owner = type(instance)
858*da0073e9SAndroid Build Coastguard Worker        return self.fget.__get__(instance, owner)()
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Workerdef classproperty(func):
862*da0073e9SAndroid Build Coastguard Worker    if not isinstance(func, (classmethod, staticmethod)):
863*da0073e9SAndroid Build Coastguard Worker        func = classmethod(func)
864*da0073e9SAndroid Build Coastguard Worker    return _ClassPropertyDescriptor(func)
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Workerdef is_compiling() -> bool:
868*da0073e9SAndroid Build Coastguard Worker    """
869*da0073e9SAndroid Build Coastguard Worker    Indicates whether we are tracing/compiling with torch.compile() or torch.export().
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
872*da0073e9SAndroid Build Coastguard Worker    """
873*da0073e9SAndroid Build Coastguard Worker    return torch.compiler.is_compiling()
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Workerdef _functionalize_sync(t):
877*da0073e9SAndroid Build Coastguard Worker    # This code lives in python instead of C++ since conditioning on a certain python subclass
878*da0073e9SAndroid Build Coastguard Worker    # is much more of a pain in C++.
879*da0073e9SAndroid Build Coastguard Worker    from torch._subclasses.functional_tensor import FunctionalTensor
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, FunctionalTensor):
882*da0073e9SAndroid Build Coastguard Worker        # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
883*da0073e9SAndroid Build Coastguard Worker        # when we sync our inner tensor.
884*da0073e9SAndroid Build Coastguard Worker        # Why?
885*da0073e9SAndroid Build Coastguard Worker        # (1) If there are input mutations in the graph, then they will be re-applied during
886*da0073e9SAndroid Build Coastguard Worker        #     AOTAutograd when we call _sync() from inside of our functionalization kernels.
887*da0073e9SAndroid Build Coastguard Worker        # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
888*da0073e9SAndroid Build Coastguard Worker        #     which dispatches to a bunch of view ops
889*da0073e9SAndroid Build Coastguard Worker        # (3) The input to these view ops is our inner FunctionalTensorWrapper
890*da0073e9SAndroid Build Coastguard Worker        #     (since the sync was called from C++), not the python FunctionalTensor
891*da0073e9SAndroid Build Coastguard Worker        # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
892*da0073e9SAndroid Build Coastguard Worker        #     the view op, since it will see an input that is a C++ FunctionalTensorWrapper
893*da0073e9SAndroid Build Coastguard Worker        #     (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
894*da0073e9SAndroid Build Coastguard Worker        maybe_functional_mode = torch._C._unset_dispatch_mode(
895*da0073e9SAndroid Build Coastguard Worker            torch._C._TorchDispatchModeKey.FUNCTIONAL
896*da0073e9SAndroid Build Coastguard Worker        )
897*da0073e9SAndroid Build Coastguard Worker        try:
898*da0073e9SAndroid Build Coastguard Worker            torch._functionalize_sync(t.elem)  # type: ignore[attr-defined]
899*da0073e9SAndroid Build Coastguard Worker        finally:
900*da0073e9SAndroid Build Coastguard Worker            if maybe_functional_mode is not None:
901*da0073e9SAndroid Build Coastguard Worker                torch._C._set_dispatch_mode(maybe_functional_mode)
902*da0073e9SAndroid Build Coastguard Worker    else:
903*da0073e9SAndroid Build Coastguard Worker        torch._functionalize_sync(t)  # type: ignore[attr-defined]
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(2)
907*da0073e9SAndroid Build Coastguard Workerdef _get_device_module(device_type: str):
908*da0073e9SAndroid Build Coastguard Worker    device_module = getattr(torch, device_type, None)
909*da0073e9SAndroid Build Coastguard Worker    if device_module is None:
910*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
911*da0073e9SAndroid Build Coastguard Worker            f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
912*da0073e9SAndroid Build Coastguard Worker        )
913*da0073e9SAndroid Build Coastguard Worker    return device_module
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Workerdef _dummy_type(name: str) -> type:
917*da0073e9SAndroid Build Coastguard Worker    def get_err_fn(is_init: bool):
918*da0073e9SAndroid Build Coastguard Worker        def err_fn(obj, *args, **kwargs):
919*da0073e9SAndroid Build Coastguard Worker            if is_init:
920*da0073e9SAndroid Build Coastguard Worker                class_name = obj.__class__.__name__
921*da0073e9SAndroid Build Coastguard Worker            else:
922*da0073e9SAndroid Build Coastguard Worker                class_name = obj.__name__
923*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        return err_fn
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    return type(
928*da0073e9SAndroid Build Coastguard Worker        name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
929*da0073e9SAndroid Build Coastguard Worker    )
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Workerclass _LazySeedTracker:
933*da0073e9SAndroid Build Coastguard Worker    # Since seeding is memory-less, only track the latest seed.
934*da0073e9SAndroid Build Coastguard Worker    # Note: `manual_seed_all` followed by `manual_seed` overwrites
935*da0073e9SAndroid Build Coastguard Worker    # the seed on current device. We track the order of **latest**
936*da0073e9SAndroid Build Coastguard Worker    # calls between these two API.
937*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
938*da0073e9SAndroid Build Coastguard Worker        self.manual_seed_all_cb = None
939*da0073e9SAndroid Build Coastguard Worker        self.manual_seed_cb = None
940*da0073e9SAndroid Build Coastguard Worker        self.call_order = []
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker    def queue_seed_all(self, cb, traceback):
943*da0073e9SAndroid Build Coastguard Worker        self.manual_seed_all_cb = (cb, traceback)
944*da0073e9SAndroid Build Coastguard Worker        # update seed_all to be latest
945*da0073e9SAndroid Build Coastguard Worker        self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker    def queue_seed(self, cb, traceback):
948*da0073e9SAndroid Build Coastguard Worker        self.manual_seed_cb = (cb, traceback)
949*da0073e9SAndroid Build Coastguard Worker        # update seed to be latest
950*da0073e9SAndroid Build Coastguard Worker        self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    def get_calls(self) -> List:
953*da0073e9SAndroid Build Coastguard Worker        return self.call_order
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
957*da0073e9SAndroid Build Coastguard WorkerP = ParamSpec("P")
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Workerclass CallbackRegistry(Generic[P]):
961*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name: str):
962*da0073e9SAndroid Build Coastguard Worker        self.name = name
963*da0073e9SAndroid Build Coastguard Worker        self.callback_list: List[Callable[P, None]] = []
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker    def add_callback(self, cb: Callable[P, None]) -> None:
966*da0073e9SAndroid Build Coastguard Worker        self.callback_list.append(cb)
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker    def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
969*da0073e9SAndroid Build Coastguard Worker        for cb in self.callback_list:
970*da0073e9SAndroid Build Coastguard Worker            try:
971*da0073e9SAndroid Build Coastguard Worker                cb(*args, **kwargs)
972*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
973*da0073e9SAndroid Build Coastguard Worker                logger.exception(
974*da0073e9SAndroid Build Coastguard Worker                    "Exception in callback for %s registered with gpu trace", self.name
975*da0073e9SAndroid Build Coastguard Worker                )
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
979*da0073e9SAndroid Build Coastguard Worker# for use in the weights_only Unpickler.
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard WorkerIMPORT_MAPPING = {
982*da0073e9SAndroid Build Coastguard Worker    "__builtin__": "builtins",
983*da0073e9SAndroid Build Coastguard Worker    "copy_reg": "copyreg",
984*da0073e9SAndroid Build Coastguard Worker    "Queue": "queue",
985*da0073e9SAndroid Build Coastguard Worker    "repr": "reprlib",
986*da0073e9SAndroid Build Coastguard Worker    "_abcoll": "collections.abc",
987*da0073e9SAndroid Build Coastguard Worker    # Non-mutual mappings.
988*da0073e9SAndroid Build Coastguard Worker    "UserDict": "collections",
989*da0073e9SAndroid Build Coastguard Worker    "UserList": "collections",
990*da0073e9SAndroid Build Coastguard Worker    "UserString": "collections",
991*da0073e9SAndroid Build Coastguard Worker    "whichdb": "dbm",
992*da0073e9SAndroid Build Coastguard Worker    "StringIO": "io",
993*da0073e9SAndroid Build Coastguard Worker    "cStringIO": "io",
994*da0073e9SAndroid Build Coastguard Worker}
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker# This contains rename rules that are easy to handle.  We ignore the more
998*da0073e9SAndroid Build Coastguard Worker# complex stuff (e.g. mapping the names in the urllib and types modules).
999*da0073e9SAndroid Build Coastguard Worker# These rules should be run before import names are fixed.
1000*da0073e9SAndroid Build Coastguard WorkerNAME_MAPPING = {
1001*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "xrange"): ("builtins", "range"),
1002*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "reduce"): ("functools", "reduce"),
1003*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "intern"): ("sys", "intern"),
1004*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "unichr"): ("builtins", "chr"),
1005*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "unicode"): ("builtins", "str"),
1006*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "long"): ("builtins", "int"),
1007*da0073e9SAndroid Build Coastguard Worker    ("itertools", "izip"): ("builtins", "zip"),
1008*da0073e9SAndroid Build Coastguard Worker    ("itertools", "imap"): ("builtins", "map"),
1009*da0073e9SAndroid Build Coastguard Worker    ("itertools", "ifilter"): ("builtins", "filter"),
1010*da0073e9SAndroid Build Coastguard Worker    ("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1011*da0073e9SAndroid Build Coastguard Worker    ("itertools", "izip_longest"): ("itertools", "zip_longest"),
1012*da0073e9SAndroid Build Coastguard Worker    ("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1013*da0073e9SAndroid Build Coastguard Worker    ("UserList", "UserList"): ("collections", "UserList"),
1014*da0073e9SAndroid Build Coastguard Worker    ("UserString", "UserString"): ("collections", "UserString"),
1015*da0073e9SAndroid Build Coastguard Worker    # Non-mutual mappings.
1016*da0073e9SAndroid Build Coastguard Worker    ("__builtin__", "basestring"): ("builtins", "str"),
1017*da0073e9SAndroid Build Coastguard Worker    ("exceptions", "StandardError"): ("builtins", "Exception"),
1018*da0073e9SAndroid Build Coastguard Worker    ("UserDict", "UserDict"): ("collections", "UserDict"),
1019*da0073e9SAndroid Build Coastguard Worker}
1020