1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport copyreg 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport logging 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport traceback 7*da0073e9SAndroid Build Coastguard Workerimport warnings 8*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, DefaultDict, Generic, List, Optional 10*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import ParamSpec 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport torch 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef _type(self, dtype=None, non_blocking=False, **kwargs): 16*da0073e9SAndroid Build Coastguard Worker """Returns the type if `dtype` is not provided, else casts this object to 17*da0073e9SAndroid Build Coastguard Worker the specified type. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker If this is already of the correct type, no copy is performed and the 20*da0073e9SAndroid Build Coastguard Worker original object is returned. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Args: 23*da0073e9SAndroid Build Coastguard Worker dtype (type or string): The desired type 24*da0073e9SAndroid Build Coastguard Worker non_blocking (bool): If ``True``, and the source is in pinned memory 25*da0073e9SAndroid Build Coastguard Worker and destination is on the GPU or vice versa, the copy is performed 26*da0073e9SAndroid Build Coastguard Worker asynchronously with respect to the host. Otherwise, the argument 27*da0073e9SAndroid Build Coastguard Worker has no effect. 28*da0073e9SAndroid Build Coastguard Worker **kwargs: For compatibility, may contain the key ``async`` in place of 29*da0073e9SAndroid Build Coastguard Worker the ``non_blocking`` argument. The ``async`` arg is deprecated. 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) 32*da0073e9SAndroid Build Coastguard Worker if dtype is None: 33*da0073e9SAndroid Build Coastguard Worker return self.__module__ + "." + self.__class__.__name__ 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker if isinstance(dtype, str): 36*da0073e9SAndroid Build Coastguard Worker dtype = _import_dotted_name(dtype) 37*da0073e9SAndroid Build Coastguard Worker if dtype == type(self): 38*da0073e9SAndroid Build Coastguard Worker return self 39*da0073e9SAndroid Build Coastguard Worker if self.is_sparse: 40*da0073e9SAndroid Build Coastguard Worker if not dtype.is_sparse: 41*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Cannot cast sparse tensor to dense tensor") 42*da0073e9SAndroid Build Coastguard Worker new_module_name = dtype.__module__.replace(".sparse", "") 43*da0073e9SAndroid Build Coastguard Worker new_values_type_name = new_module_name + "." + dtype.__name__ 44*da0073e9SAndroid Build Coastguard Worker new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) 45*da0073e9SAndroid Build Coastguard Worker new_indices_type_name = new_module_name + ".LongTensor" 46*da0073e9SAndroid Build Coastguard Worker new_indices = torch.Tensor._indices(self).type( 47*da0073e9SAndroid Build Coastguard Worker new_indices_type_name, non_blocking 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker return dtype(new_indices, new_values, self.size()) 50*da0073e9SAndroid Build Coastguard Worker if dtype.is_sparse: 51*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Cannot cast dense tensor to sparse tensor") 52*da0073e9SAndroid Build Coastguard Worker return dtype(self.size()).copy_(self, non_blocking) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef _to(self, device, non_blocking=False): 56*da0073e9SAndroid Build Coastguard Worker """Returns a copy of this object in device memory. 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker If this object is already on the correct device, then no copy is performed 59*da0073e9SAndroid Build Coastguard Worker and the original object is returned. 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker Args: 62*da0073e9SAndroid Build Coastguard Worker device (int): The destination device. 63*da0073e9SAndroid Build Coastguard Worker non_blocking (bool): If ``True`` and the source is in pinned memory, 64*da0073e9SAndroid Build Coastguard Worker the copy will be asynchronous with respect to the host. Otherwise, 65*da0073e9SAndroid Build Coastguard Worker the argument has no effect. 66*da0073e9SAndroid Build Coastguard Worker """ 67*da0073e9SAndroid Build Coastguard Worker if self.device == device: 68*da0073e9SAndroid Build Coastguard Worker return self 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, device.type, None) 71*da0073e9SAndroid Build Coastguard Worker assert ( 72*da0073e9SAndroid Build Coastguard Worker device_module is not None 73*da0073e9SAndroid Build Coastguard Worker ), f"{device.type.upper()} device module is not loaded" 74*da0073e9SAndroid Build Coastguard Worker with device_module.device(device): 75*da0073e9SAndroid Build Coastguard Worker if self.is_sparse and hasattr(device_module, "sparse"): 76*da0073e9SAndroid Build Coastguard Worker new_type = getattr(device_module.sparse, self.__class__.__name__) 77*da0073e9SAndroid Build Coastguard Worker indices = getattr(torch.Tensor._indices(self), device.type)( 78*da0073e9SAndroid Build Coastguard Worker device, non_blocking 79*da0073e9SAndroid Build Coastguard Worker ) 80*da0073e9SAndroid Build Coastguard Worker values = getattr(torch.Tensor._values(self), device.type)( 81*da0073e9SAndroid Build Coastguard Worker device, non_blocking 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker return new_type(indices, values, self.size()) 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker assert ( 86*da0073e9SAndroid Build Coastguard Worker not self.is_sparse 87*da0073e9SAndroid Build Coastguard Worker ), f"sparse storage is not supported for {device.type.upper()} tensors" 88*da0073e9SAndroid Build Coastguard Worker untyped_storage = torch.UntypedStorage(self.size(), device=device) 89*da0073e9SAndroid Build Coastguard Worker untyped_storage.copy_(self, non_blocking) 90*da0073e9SAndroid Build Coastguard Worker return untyped_storage 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerdef _get_async_or_non_blocking(function_name, non_blocking, kwargs): 94*da0073e9SAndroid Build Coastguard Worker """Return the non-blocking flag given the function name and kwargs. 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker Args: 97*da0073e9SAndroid Build Coastguard Worker function_name (str): the name of the function being used. 98*da0073e9SAndroid Build Coastguard Worker non_blocking (bool): the default value. 99*da0073e9SAndroid Build Coastguard Worker **kwargs (dict): the kwargs passed to the function. 100*da0073e9SAndroid Build Coastguard Worker """ 101*da0073e9SAndroid Build Coastguard Worker if not kwargs: 102*da0073e9SAndroid Build Coastguard Worker return non_blocking 103*da0073e9SAndroid Build Coastguard Worker if len(kwargs) != 1 or "async" not in kwargs: 104*da0073e9SAndroid Build Coastguard Worker message = "{}() got an unexpected keyword argument '{}'" 105*da0073e9SAndroid Build Coastguard Worker argument = list(kwargs.keys()).pop() 106*da0073e9SAndroid Build Coastguard Worker raise TypeError(message.format(function_name, argument)) 107*da0073e9SAndroid Build Coastguard Worker warnings.warn("'async' is deprecated; use 'non_blocking'") 108*da0073e9SAndroid Build Coastguard Worker return kwargs["async"] 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Workerdef _get_restore_location(device): 112*da0073e9SAndroid Build Coastguard Worker """Return the map_location location. 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker Used for rebuild functions where the tensor device is distinct from the storage 115*da0073e9SAndroid Build Coastguard Worker """ 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker map_location = torch.serialization._serialization_tls.map_location 118*da0073e9SAndroid Build Coastguard Worker if map_location is None: 119*da0073e9SAndroid Build Coastguard Worker return device 120*da0073e9SAndroid Build Coastguard Worker else: 121*da0073e9SAndroid Build Coastguard Worker if isinstance(map_location, dict): 122*da0073e9SAndroid Build Coastguard Worker return map_location.get(device, device) 123*da0073e9SAndroid Build Coastguard Worker elif isinstance(map_location, (str, torch.device)): 124*da0073e9SAndroid Build Coastguard Worker return map_location 125*da0073e9SAndroid Build Coastguard Worker else: 126*da0073e9SAndroid Build Coastguard Worker assert callable(map_location) 127*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 128*da0073e9SAndroid Build Coastguard Worker "Callable map_location not supported with _rebuild_wrapper_subclass " 129*da0073e9SAndroid Build Coastguard Worker "or _rebuild_device_tensor_from_numpy" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker# Note [Don't serialize hooks] 134*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 135*da0073e9SAndroid Build Coastguard Worker# Since time immemorial, we have serialized the backward hooks associated with 136*da0073e9SAndroid Build Coastguard Worker# variables. This kind of half-worked--Python can pickle global functions 137*da0073e9SAndroid Build Coastguard Worker# (but not closures!)--but there were problems. 138*da0073e9SAndroid Build Coastguard Worker# 139*da0073e9SAndroid Build Coastguard Worker# - It's fragile. If you serialize a backward hook into a saved 140*da0073e9SAndroid Build Coastguard Worker# model, and then you rename the function associated with the hook, 141*da0073e9SAndroid Build Coastguard Worker# now your saved model is broken and you can't load it anymore. 142*da0073e9SAndroid Build Coastguard Worker# 143*da0073e9SAndroid Build Coastguard Worker# - It's not actually used. The standard recommendation is to 144*da0073e9SAndroid Build Coastguard Worker# serialize the *state_dict* of a model, not the model itself 145*da0073e9SAndroid Build Coastguard Worker# (since this is more stable to code changes affecting the model 146*da0073e9SAndroid Build Coastguard Worker# serialization), and the state dict saves "data" only, thus 147*da0073e9SAndroid Build Coastguard Worker# stripping the backward hooks. In some cases, hooks are 148*da0073e9SAndroid Build Coastguard Worker# essential to the well-functioning of a model (e.g., DDP), 149*da0073e9SAndroid Build Coastguard Worker# but DDP already manages readding the hooks! 150*da0073e9SAndroid Build Coastguard Worker# 151*da0073e9SAndroid Build Coastguard Worker# - We didn't serialize them in many cases. Prior to #10220, we 152*da0073e9SAndroid Build Coastguard Worker# were dropping backward hooks in ForkingPickler. We "fixed" this 153*da0073e9SAndroid Build Coastguard Worker# to be convenient with other serialization sites, but lack of 154*da0073e9SAndroid Build Coastguard Worker# serializing backward hooks wasn't actually the root cause of 155*da0073e9SAndroid Build Coastguard Worker# the bug. 156*da0073e9SAndroid Build Coastguard Worker# 157*da0073e9SAndroid Build Coastguard Worker# With these cases in mind, we have decided that a better strategy 158*da0073e9SAndroid Build Coastguard Worker# is to just NOT serialize hooks at all. 159*da0073e9SAndroid Build Coastguard Worker# 160*da0073e9SAndroid Build Coastguard Worker# Since this is a BC-breaking change, we should warn when we previously 161*da0073e9SAndroid Build Coastguard Worker# serialized a hook, but no longer do so. This will be done by adding a special 162*da0073e9SAndroid Build Coastguard Worker# sentinel property to hooks will be used to suppress this warning. If a hook 163*da0073e9SAndroid Build Coastguard Worker# has the property _torch_serialize_ignore, we will not emit a warning if we 164*da0073e9SAndroid Build Coastguard Worker# attempt to serialize a Tensor with this hook attached to it. 165*da0073e9SAndroid Build Coastguard Worker# 166*da0073e9SAndroid Build Coastguard Worker# By the way, when _backward_hooks is skipped, we must give an EMPTY 167*da0073e9SAndroid Build Coastguard Worker# OrderedDict(), if you pass a None you'll run afoul #12219. 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker# TODO: Once we decide to break serialization FC, `storage` no longer needs to 171*da0073e9SAndroid Build Coastguard Worker# be a TypedStorage 172*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor(storage, storage_offset, size, stride): 173*da0073e9SAndroid Build Coastguard Worker # first construct a tensor with the correct dtype/device 174*da0073e9SAndroid Build Coastguard Worker t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device) 175*da0073e9SAndroid Build Coastguard Worker return t.set_(storage._untyped_storage, storage_offset, size, stride) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Workerdef get_tensor_metadata(tensor): 179*da0073e9SAndroid Build Coastguard Worker # Tensor's Metadata for serializing. 180*da0073e9SAndroid Build Coastguard Worker # Currently, this only returns a dict[string, bool] specifing whether 181*da0073e9SAndroid Build Coastguard Worker # `conj` or `neg` bit is set. 182*da0073e9SAndroid Build Coastguard Worker assert isinstance(tensor, torch.Tensor) 183*da0073e9SAndroid Build Coastguard Worker return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef set_tensor_metadata(tensor, metadata): 187*da0073e9SAndroid Build Coastguard Worker # See `get_tensor_metadata` above 188*da0073e9SAndroid Build Coastguard Worker assert isinstance(metadata, dict) 189*da0073e9SAndroid Build Coastguard Worker assert isinstance(tensor, torch.Tensor) 190*da0073e9SAndroid Build Coastguard Worker torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor_v2( 194*da0073e9SAndroid Build Coastguard Worker storage, 195*da0073e9SAndroid Build Coastguard Worker storage_offset, 196*da0073e9SAndroid Build Coastguard Worker size, 197*da0073e9SAndroid Build Coastguard Worker stride, 198*da0073e9SAndroid Build Coastguard Worker requires_grad, 199*da0073e9SAndroid Build Coastguard Worker backward_hooks, 200*da0073e9SAndroid Build Coastguard Worker metadata=None, 201*da0073e9SAndroid Build Coastguard Worker): 202*da0073e9SAndroid Build Coastguard Worker tensor = _rebuild_tensor(storage, storage_offset, size, stride) 203*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad = requires_grad 204*da0073e9SAndroid Build Coastguard Worker if metadata: 205*da0073e9SAndroid Build Coastguard Worker set_tensor_metadata(tensor, metadata) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker # NB: This line exists only for backwards compatibility; the 208*da0073e9SAndroid Build Coastguard Worker # general expectation is that backward_hooks is an empty 209*da0073e9SAndroid Build Coastguard Worker # OrderedDict. See Note [Don't serialize hooks] 210*da0073e9SAndroid Build Coastguard Worker tensor._backward_hooks = backward_hooks 211*da0073e9SAndroid Build Coastguard Worker return tensor 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Workerdef _rebuild_tensor_v3( 215*da0073e9SAndroid Build Coastguard Worker storage, 216*da0073e9SAndroid Build Coastguard Worker storage_offset, 217*da0073e9SAndroid Build Coastguard Worker size, 218*da0073e9SAndroid Build Coastguard Worker stride, 219*da0073e9SAndroid Build Coastguard Worker requires_grad, 220*da0073e9SAndroid Build Coastguard Worker backward_hooks, 221*da0073e9SAndroid Build Coastguard Worker dtype, 222*da0073e9SAndroid Build Coastguard Worker metadata=None, 223*da0073e9SAndroid Build Coastguard Worker): 224*da0073e9SAndroid Build Coastguard Worker t = torch.empty( 225*da0073e9SAndroid Build Coastguard Worker (0,), 226*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 227*da0073e9SAndroid Build Coastguard Worker device=storage._untyped_storage.device, 228*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 229*da0073e9SAndroid Build Coastguard Worker ) 230*da0073e9SAndroid Build Coastguard Worker t.set_(storage._untyped_storage, storage_offset, size, stride) 231*da0073e9SAndroid Build Coastguard Worker if metadata: 232*da0073e9SAndroid Build Coastguard Worker set_tensor_metadata(t, metadata) 233*da0073e9SAndroid Build Coastguard Worker t._backward_hooks = backward_hooks 234*da0073e9SAndroid Build Coastguard Worker return t 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker_sparse_tensors_to_validate: List["torch.Tensor"] = [] 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker# In _legacy_load() in serialization.py we unpickle storages after the sparse 241*da0073e9SAndroid Build Coastguard Worker# tensors have been already unpickled. Those storages contain data necessary for 242*da0073e9SAndroid Build Coastguard Worker# validating sparse tensors: indices and values. That's why sparse tensors are 243*da0073e9SAndroid Build Coastguard Worker# first unpickled without any validation, and then this function is called just 244*da0073e9SAndroid Build Coastguard Worker# before _legacy_load() returns, so that all the sparse tensors can be validated 245*da0073e9SAndroid Build Coastguard Worker# in bulk. 246*da0073e9SAndroid Build Coastguard Worker# 247*da0073e9SAndroid Build Coastguard Worker# The same procedure must be followed by _load() in serialization.py because due 248*da0073e9SAndroid Build Coastguard Worker# to Pickler semantics, we have to use the same (non-validating) function for 249*da0073e9SAndroid Build Coastguard Worker# unpickling sparse tensors, regardless of the caller. 250*da0073e9SAndroid Build Coastguard Workerdef _validate_loaded_sparse_tensors(): 251*da0073e9SAndroid Build Coastguard Worker try: 252*da0073e9SAndroid Build Coastguard Worker for t in _sparse_tensors_to_validate: 253*da0073e9SAndroid Build Coastguard Worker if t.layout is torch.sparse_coo: 254*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_coo_tensor_args( 255*da0073e9SAndroid Build Coastguard Worker t._indices(), t._values(), t.size(), t.is_coalesced() 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker elif t.layout in { 258*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 259*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 260*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 261*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 262*da0073e9SAndroid Build Coastguard Worker }: 263*da0073e9SAndroid Build Coastguard Worker # TODO: Validation currently involves an expensive traversal 264*da0073e9SAndroid Build Coastguard Worker # on CPU, which may include a device transfer. 265*da0073e9SAndroid Build Coastguard Worker if t.layout in {torch.sparse_csr, torch.sparse_bsr}: 266*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = ( 267*da0073e9SAndroid Build Coastguard Worker t.crow_indices(), 268*da0073e9SAndroid Build Coastguard Worker t.col_indices(), 269*da0073e9SAndroid Build Coastguard Worker ) 270*da0073e9SAndroid Build Coastguard Worker else: 271*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = ( 272*da0073e9SAndroid Build Coastguard Worker t.ccol_indices(), 273*da0073e9SAndroid Build Coastguard Worker t.row_indices(), 274*da0073e9SAndroid Build Coastguard Worker ) 275*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_compressed_tensor_args( 276*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices, t.values(), t.size(), t.layout 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker else: 279*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 280*da0073e9SAndroid Build Coastguard Worker f"_validate_loaded_sparse_tensors for layout `{t.layout}`" 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker finally: 284*da0073e9SAndroid Build Coastguard Worker _sparse_tensors_to_validate.clear() 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Workerdef _rebuild_sparse_tensor(layout, data): 288*da0073e9SAndroid Build Coastguard Worker """ 289*da0073e9SAndroid Build Coastguard Worker Rebuilds a sparse tensor from its sparse storage representation. 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker Args: 292*da0073e9SAndroid Build Coastguard Worker layout (str): The sparse storage layout of the tensor. 293*da0073e9SAndroid Build Coastguard Worker data (tuple): The tensor's sparse storage representation. 294*da0073e9SAndroid Build Coastguard Worker """ 295*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_coo: 296*da0073e9SAndroid Build Coastguard Worker if len(data) == 3: 297*da0073e9SAndroid Build Coastguard Worker # For BC: 298*da0073e9SAndroid Build Coastguard Worker indices, values, size = data 299*da0073e9SAndroid Build Coastguard Worker is_coalesced = None 300*da0073e9SAndroid Build Coastguard Worker else: 301*da0073e9SAndroid Build Coastguard Worker indices, values, size, is_coalesced = data 302*da0073e9SAndroid Build Coastguard Worker result = torch.sparse_coo_tensor( 303*da0073e9SAndroid Build Coastguard Worker indices, values, size, check_invariants=False, is_coalesced=is_coalesced 304*da0073e9SAndroid Build Coastguard Worker ) 305*da0073e9SAndroid Build Coastguard Worker _sparse_tensors_to_validate.append(result) 306*da0073e9SAndroid Build Coastguard Worker return result 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker elif layout in { 309*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 310*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 311*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 312*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 313*da0073e9SAndroid Build Coastguard Worker }: 314*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices, values, size = data 315*da0073e9SAndroid Build Coastguard Worker result = torch.sparse_compressed_tensor( 316*da0073e9SAndroid Build Coastguard Worker compressed_indices, 317*da0073e9SAndroid Build Coastguard Worker plain_indices, 318*da0073e9SAndroid Build Coastguard Worker values, 319*da0073e9SAndroid Build Coastguard Worker size, 320*da0073e9SAndroid Build Coastguard Worker layout=layout, 321*da0073e9SAndroid Build Coastguard Worker check_invariants=False, 322*da0073e9SAndroid Build Coastguard Worker ) 323*da0073e9SAndroid Build Coastguard Worker _sparse_tensors_to_validate.append(result) 324*da0073e9SAndroid Build Coastguard Worker return result 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}") 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Workerdef _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): 330*da0073e9SAndroid Build Coastguard Worker return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Workerdef _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): 334*da0073e9SAndroid Build Coastguard Worker device = _get_restore_location(device) 335*da0073e9SAndroid Build Coastguard Worker tensor = torch.from_numpy(data).to(dtype=dtype, device=device) 336*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad = requires_grad 337*da0073e9SAndroid Build Coastguard Worker return tensor 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch 341*da0073e9SAndroid Build Coastguard Worker_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Workerdef _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): 345*da0073e9SAndroid Build Coastguard Worker return torch.empty_strided( 346*da0073e9SAndroid Build Coastguard Worker size, stride, dtype=dtype, device="meta", requires_grad=requires_grad 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Workerdef _rebuild_wrapper_subclass( 351*da0073e9SAndroid Build Coastguard Worker cls, 352*da0073e9SAndroid Build Coastguard Worker dtype, 353*da0073e9SAndroid Build Coastguard Worker size, 354*da0073e9SAndroid Build Coastguard Worker stride, 355*da0073e9SAndroid Build Coastguard Worker storage_offset, 356*da0073e9SAndroid Build Coastguard Worker layout, 357*da0073e9SAndroid Build Coastguard Worker device, 358*da0073e9SAndroid Build Coastguard Worker requires_grad, 359*da0073e9SAndroid Build Coastguard Worker): 360*da0073e9SAndroid Build Coastguard Worker device = _get_restore_location(device) 361*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 362*da0073e9SAndroid Build Coastguard Worker cls, 363*da0073e9SAndroid Build Coastguard Worker size, 364*da0073e9SAndroid Build Coastguard Worker strides=stride, 365*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 366*da0073e9SAndroid Build Coastguard Worker storage_offset=storage_offset, 367*da0073e9SAndroid Build Coastguard Worker layout=layout, 368*da0073e9SAndroid Build Coastguard Worker device=device, 369*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker# TODO: Once we decide to break serialization FC, `storage` no longer needs to 374*da0073e9SAndroid Build Coastguard Worker# be a TypedStorage 375*da0073e9SAndroid Build Coastguard Workerdef _rebuild_qtensor( 376*da0073e9SAndroid Build Coastguard Worker storage, 377*da0073e9SAndroid Build Coastguard Worker storage_offset, 378*da0073e9SAndroid Build Coastguard Worker size, 379*da0073e9SAndroid Build Coastguard Worker stride, 380*da0073e9SAndroid Build Coastguard Worker quantizer_params, 381*da0073e9SAndroid Build Coastguard Worker requires_grad, 382*da0073e9SAndroid Build Coastguard Worker backward_hooks, 383*da0073e9SAndroid Build Coastguard Worker): 384*da0073e9SAndroid Build Coastguard Worker qscheme = quantizer_params[0] 385*da0073e9SAndroid Build Coastguard Worker if qscheme == torch.per_tensor_affine: 386*da0073e9SAndroid Build Coastguard Worker _, scale, zero_point = quantizer_params 387*da0073e9SAndroid Build Coastguard Worker tensor = torch._empty_affine_quantized( 388*da0073e9SAndroid Build Coastguard Worker size, 389*da0073e9SAndroid Build Coastguard Worker scale=scale, 390*da0073e9SAndroid Build Coastguard Worker zero_point=zero_point, 391*da0073e9SAndroid Build Coastguard Worker dtype=storage.dtype, 392*da0073e9SAndroid Build Coastguard Worker device=storage.device, 393*da0073e9SAndroid Build Coastguard Worker ) 394*da0073e9SAndroid Build Coastguard Worker elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): 395*da0073e9SAndroid Build Coastguard Worker _, scales, zero_points, axis = quantizer_params 396*da0073e9SAndroid Build Coastguard Worker if type(scales) is list and type(zero_points) is list: 397*da0073e9SAndroid Build Coastguard Worker if qscheme == torch.per_channel_affine: 398*da0073e9SAndroid Build Coastguard Worker scales = torch.tensor(scales, dtype=torch.double, device=storage.device) 399*da0073e9SAndroid Build Coastguard Worker zero_points = torch.tensor( 400*da0073e9SAndroid Build Coastguard Worker zero_points, dtype=torch.long, device=storage.device 401*da0073e9SAndroid Build Coastguard Worker ) 402*da0073e9SAndroid Build Coastguard Worker else: 403*da0073e9SAndroid Build Coastguard Worker scales = torch.tensor(scales, dtype=torch.float, device=storage.device) 404*da0073e9SAndroid Build Coastguard Worker zero_points = torch.tensor( 405*da0073e9SAndroid Build Coastguard Worker zero_points, dtype=torch.float, device=storage.device 406*da0073e9SAndroid Build Coastguard Worker ) 407*da0073e9SAndroid Build Coastguard Worker tensor = torch._empty_per_channel_affine_quantized( 408*da0073e9SAndroid Build Coastguard Worker size, 409*da0073e9SAndroid Build Coastguard Worker scales=scales, 410*da0073e9SAndroid Build Coastguard Worker zero_points=zero_points, 411*da0073e9SAndroid Build Coastguard Worker axis=axis, 412*da0073e9SAndroid Build Coastguard Worker dtype=storage.dtype, 413*da0073e9SAndroid Build Coastguard Worker device=storage.device, 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker else: 416*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}") 417*da0073e9SAndroid Build Coastguard Worker tensor.set_(storage, storage_offset, size, stride) 418*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad = requires_grad 419*da0073e9SAndroid Build Coastguard Worker # NB: This line exists only for backwards compatibility; the 420*da0073e9SAndroid Build Coastguard Worker # general expectation is that backward_hooks is an empty 421*da0073e9SAndroid Build Coastguard Worker # OrderedDict. See Note [Don't serialize hooks] 422*da0073e9SAndroid Build Coastguard Worker tensor._backward_hooks = backward_hooks 423*da0073e9SAndroid Build Coastguard Worker return tensor 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Workerdef _rebuild_parameter(data, requires_grad, backward_hooks): 427*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter(data, requires_grad) 428*da0073e9SAndroid Build Coastguard Worker # NB: This line exists only for backwards compatibility; the 429*da0073e9SAndroid Build Coastguard Worker # general expectation is that backward_hooks is an empty 430*da0073e9SAndroid Build Coastguard Worker # OrderedDict. See Note [Don't serialize hooks] 431*da0073e9SAndroid Build Coastguard Worker param._backward_hooks = backward_hooks 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker return param 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Workerdef _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): 437*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter(data, requires_grad) 438*da0073e9SAndroid Build Coastguard Worker # NB: This line exists only for backwards compatibility; the 439*da0073e9SAndroid Build Coastguard Worker # general expectation is that backward_hooks is an empty 440*da0073e9SAndroid Build Coastguard Worker # OrderedDict. See Note [Don't serialize hooks] 441*da0073e9SAndroid Build Coastguard Worker param._backward_hooks = backward_hooks 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker # Restore state on Parameter like python attr. 444*da0073e9SAndroid Build Coastguard Worker param = _set_obj_state(param, state) 445*da0073e9SAndroid Build Coastguard Worker return param 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Workerdef _get_obj_state(obj): 449*da0073e9SAndroid Build Coastguard Worker # Get the state of the python subclass 450*da0073e9SAndroid Build Coastguard Worker # This loosely mimicks the function on the object class but since Tensor do not inherit 451*da0073e9SAndroid Build Coastguard Worker # from it, we cannot call that function directly 452*da0073e9SAndroid Build Coastguard Worker # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 453*da0073e9SAndroid Build Coastguard Worker # Note that starting with Python 3.11, this `__getstate__` is always defined and thus 454*da0073e9SAndroid Build Coastguard Worker # the else branch will never be taken. 455*da0073e9SAndroid Build Coastguard Worker getstate_fn = getattr(obj, "__getstate__", None) 456*da0073e9SAndroid Build Coastguard Worker if getstate_fn: 457*da0073e9SAndroid Build Coastguard Worker state = getstate_fn() 458*da0073e9SAndroid Build Coastguard Worker else: 459*da0073e9SAndroid Build Coastguard Worker slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] 460*da0073e9SAndroid Build Coastguard Worker if slots_to_save: 461*da0073e9SAndroid Build Coastguard Worker state = ( 462*da0073e9SAndroid Build Coastguard Worker obj.__dict__, 463*da0073e9SAndroid Build Coastguard Worker { 464*da0073e9SAndroid Build Coastguard Worker name: getattr(obj, name) 465*da0073e9SAndroid Build Coastguard Worker for name in slots_to_save 466*da0073e9SAndroid Build Coastguard Worker if hasattr(obj, name) 467*da0073e9SAndroid Build Coastguard Worker }, 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker else: 470*da0073e9SAndroid Build Coastguard Worker state = obj.__dict__ 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker return state 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Workerdef _set_obj_state(obj, state): 476*da0073e9SAndroid Build Coastguard Worker if isinstance(state, tuple): 477*da0073e9SAndroid Build Coastguard Worker if not len(state) == 2: 478*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid serialized state: {state}") 479*da0073e9SAndroid Build Coastguard Worker dict_state = state[0] 480*da0073e9SAndroid Build Coastguard Worker slots_state = state[1] 481*da0073e9SAndroid Build Coastguard Worker else: 482*da0073e9SAndroid Build Coastguard Worker dict_state = state 483*da0073e9SAndroid Build Coastguard Worker slots_state = None 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker # Starting with Python 3.11, the __dict__ attribute is lazily created 486*da0073e9SAndroid Build Coastguard Worker # and is serialized as None when not needed. 487*da0073e9SAndroid Build Coastguard Worker if dict_state: 488*da0073e9SAndroid Build Coastguard Worker for k, v in dict_state.items(): 489*da0073e9SAndroid Build Coastguard Worker setattr(obj, k, v) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker if slots_state: 492*da0073e9SAndroid Build Coastguard Worker for k, v in slots_state.items(): 493*da0073e9SAndroid Build Coastguard Worker setattr(obj, k, v) 494*da0073e9SAndroid Build Coastguard Worker return obj 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Workerdef _import_dotted_name(name): 498*da0073e9SAndroid Build Coastguard Worker components = name.split(".") 499*da0073e9SAndroid Build Coastguard Worker obj = __import__(components[0]) 500*da0073e9SAndroid Build Coastguard Worker for component in components[1:]: 501*da0073e9SAndroid Build Coastguard Worker obj = getattr(obj, component) 502*da0073e9SAndroid Build Coastguard Worker return obj 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Workerdef _flatten_dense_tensors(tensors): 506*da0073e9SAndroid Build Coastguard Worker """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 507*da0073e9SAndroid Build Coastguard Worker same dense type. 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker Since inputs are dense, the resulting tensor will be a concatenated 1D 510*da0073e9SAndroid Build Coastguard Worker buffer. Element-wise operation on this buffer will be equivalent to 511*da0073e9SAndroid Build Coastguard Worker operating individually. 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker Args: 514*da0073e9SAndroid Build Coastguard Worker tensors (Iterable[Tensor]): dense tensors to flatten. 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker Returns: 517*da0073e9SAndroid Build Coastguard Worker A contiguous 1D buffer containing input tensors. 518*da0073e9SAndroid Build Coastguard Worker """ 519*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.flatten_dense_tensors(tensors) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Workerdef _flatten_sparse_tensors(tensors): 523*da0073e9SAndroid Build Coastguard Worker """Flatten sparse tensors into two contiguous 1D buffers, one of indices and 524*da0073e9SAndroid Build Coastguard Worker one of values. Assume tensors are of same sparse type. 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker Args: 527*da0073e9SAndroid Build Coastguard Worker tensors (Iterable[Tensor]): sparse tensors to flatten. 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker Returns: 530*da0073e9SAndroid Build Coastguard Worker A tuple of two contiguous 1D buffers, one containing input tensors' 531*da0073e9SAndroid Build Coastguard Worker indices and the other containing the values. 532*da0073e9SAndroid Build Coastguard Worker """ 533*da0073e9SAndroid Build Coastguard Worker flat_indices = torch._C._nn.flatten_dense_tensors( 534*da0073e9SAndroid Build Coastguard Worker [torch.Tensor._indices(t) for t in tensors] 535*da0073e9SAndroid Build Coastguard Worker ) 536*da0073e9SAndroid Build Coastguard Worker flat_values = torch._C._nn.flatten_dense_tensors( 537*da0073e9SAndroid Build Coastguard Worker [torch.Tensor._values(t) for t in tensors] 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker return flat_indices, flat_values 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Workerdef _unflatten_dense_tensors(flat, tensors): 543*da0073e9SAndroid Build Coastguard Worker """View a flat buffer using the sizes of tensors. Assume that tensors are of 544*da0073e9SAndroid Build Coastguard Worker same dense type, and that flat is given by _flatten_dense_tensors. 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker Args: 547*da0073e9SAndroid Build Coastguard Worker flat (Tensor): flattened dense tensors to unflatten. 548*da0073e9SAndroid Build Coastguard Worker tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 549*da0073e9SAndroid Build Coastguard Worker unflatten flat. 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker Returns: 552*da0073e9SAndroid Build Coastguard Worker Unflattened dense tensors with sizes same as tensors and values from 553*da0073e9SAndroid Build Coastguard Worker flat. 554*da0073e9SAndroid Build Coastguard Worker """ 555*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.unflatten_dense_tensors(flat, tensors) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Workerdef _unflatten_sparse_tensors(flat, tensors): 559*da0073e9SAndroid Build Coastguard Worker """View flat buffer (containing indices and values) using the sizes of 560*da0073e9SAndroid Build Coastguard Worker tensors. Assume that tensors are of same sparse type, and that flat is given 561*da0073e9SAndroid Build Coastguard Worker by _flatten_sparse_tensors. 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker Args: 564*da0073e9SAndroid Build Coastguard Worker flat (tuple(Tensor, Tensor)): flattened indices and values of sparse 565*da0073e9SAndroid Build Coastguard Worker tensors to unflatten. 566*da0073e9SAndroid Build Coastguard Worker tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to 567*da0073e9SAndroid Build Coastguard Worker unflatten flat. 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker Returns: 570*da0073e9SAndroid Build Coastguard Worker Unflattened sparse tensors with sizes same as tensors and values from 571*da0073e9SAndroid Build Coastguard Worker flat. 572*da0073e9SAndroid Build Coastguard Worker """ 573*da0073e9SAndroid Build Coastguard Worker flat_indices, flat_values = flat 574*da0073e9SAndroid Build Coastguard Worker indices = torch._C._nn.unflatten_dense_tensors( 575*da0073e9SAndroid Build Coastguard Worker flat_indices, [torch.Tensor._indices(t) for t in tensors] 576*da0073e9SAndroid Build Coastguard Worker ) 577*da0073e9SAndroid Build Coastguard Worker values = torch._C._nn.unflatten_dense_tensors( 578*da0073e9SAndroid Build Coastguard Worker flat_values, [torch.Tensor._values(t) for t in tensors] 579*da0073e9SAndroid Build Coastguard Worker ) 580*da0073e9SAndroid Build Coastguard Worker outputs = [] 581*da0073e9SAndroid Build Coastguard Worker for t, i, v in zip(tensors, indices, values): 582*da0073e9SAndroid Build Coastguard Worker outputs.append(t.new(i, v, t.size())) 583*da0073e9SAndroid Build Coastguard Worker return tuple(outputs) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Workerdef _reorder_tensors_as(tensors, ordered_tensors): 587*da0073e9SAndroid Build Coastguard Worker """Assume that tensors are of same order as ordered_tensors within their 588*da0073e9SAndroid Build Coastguard Worker types, e.g., from _take_tensors. Reorder them to be of same order as 589*da0073e9SAndroid Build Coastguard Worker ordered_tensors. 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker Args: 592*da0073e9SAndroid Build Coastguard Worker tensors (Iterable[Tensor]): tensors to be reordered. They should be of 593*da0073e9SAndroid Build Coastguard Worker the same order as ordered_tensors within their own types. 594*da0073e9SAndroid Build Coastguard Worker ordered_tensors (Iterable[Tensor]): tensors whose order will be the 595*da0073e9SAndroid Build Coastguard Worker reference. 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker Returns: 598*da0073e9SAndroid Build Coastguard Worker Ordered tuple of tensors with contents from tensors and order of 599*da0073e9SAndroid Build Coastguard Worker ordered_tensors. 600*da0073e9SAndroid Build Coastguard Worker """ 601*da0073e9SAndroid Build Coastguard Worker type_dict = defaultdict(list) 602*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 603*da0073e9SAndroid Build Coastguard Worker type_dict[tensor.type()].append(tensor) 604*da0073e9SAndroid Build Coastguard Worker type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} 605*da0073e9SAndroid Build Coastguard Worker return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Workerdef _take_tensors(tensors, size_limit): 609*da0073e9SAndroid Build Coastguard Worker """Group tensors into chunks. This generator yields a chunk at each time, 610*da0073e9SAndroid Build Coastguard Worker each containing tensors of same type up to certain byte limit in total size. 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker Args: 613*da0073e9SAndroid Build Coastguard Worker tensors (Sequence): A sequence of tensors to be separated into chunks. 614*da0073e9SAndroid Build Coastguard Worker size_limit (int): The limit of each chunk in bytes. 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker Yields: 617*da0073e9SAndroid Build Coastguard Worker Blocks of tensors of same type and within size_limit. The yielded 618*da0073e9SAndroid Build Coastguard Worker tensors are only ordered as the original sequence within its types. 619*da0073e9SAndroid Build Coastguard Worker """ 620*da0073e9SAndroid Build Coastguard Worker buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) 621*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 622*da0073e9SAndroid Build Coastguard Worker t = tensor.type() 623*da0073e9SAndroid Build Coastguard Worker if tensor.is_sparse: 624*da0073e9SAndroid Build Coastguard Worker indices = torch.Tensor._indices(tensor) 625*da0073e9SAndroid Build Coastguard Worker values = torch.Tensor._values(tensor) 626*da0073e9SAndroid Build Coastguard Worker size = ( 627*da0073e9SAndroid Build Coastguard Worker indices.numel() * indices.element_size() 628*da0073e9SAndroid Build Coastguard Worker + values.numel() * values.element_size() 629*da0073e9SAndroid Build Coastguard Worker ) 630*da0073e9SAndroid Build Coastguard Worker else: 631*da0073e9SAndroid Build Coastguard Worker size = tensor.numel() * tensor.element_size() 632*da0073e9SAndroid Build Coastguard Worker buf_and_size = buf_dict[t] 633*da0073e9SAndroid Build Coastguard Worker if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: 634*da0073e9SAndroid Build Coastguard Worker yield buf_and_size[0] 635*da0073e9SAndroid Build Coastguard Worker buf_and_size = buf_dict[t] = [[], 0] 636*da0073e9SAndroid Build Coastguard Worker buf_and_size[0].append(tensor) 637*da0073e9SAndroid Build Coastguard Worker buf_and_size[1] += size 638*da0073e9SAndroid Build Coastguard Worker for buf, _ in buf_dict.values(): 639*da0073e9SAndroid Build Coastguard Worker if len(buf) > 0: 640*da0073e9SAndroid Build Coastguard Worker yield buf 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker# annotation decorator to get annotations in a way that is compatible 644*da0073e9SAndroid Build Coastguard Worker# with both Python 2 and 3 645*da0073e9SAndroid Build Coastguard Workerdef annotate(ret, **kwargs): 646*da0073e9SAndroid Build Coastguard Worker def dec(fun): 647*da0073e9SAndroid Build Coastguard Worker fun.__annotations__ = dict(kwargs) 648*da0073e9SAndroid Build Coastguard Worker fun.__annotations__["return"] = ret 649*da0073e9SAndroid Build Coastguard Worker return fun 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker return dec 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Workerdef render_call(fn, args, kwargs): 655*da0073e9SAndroid Build Coastguard Worker str_fn = torch.overrides.resolve_name(fn) 656*da0073e9SAndroid Build Coastguard Worker if str_fn is None: 657*da0073e9SAndroid Build Coastguard Worker str_fn = str(fn) 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker str_args: List[str] = [] 660*da0073e9SAndroid Build Coastguard Worker with torch._tensor_str.printoptions(threshold=0, edgeitems=0): 661*da0073e9SAndroid Build Coastguard Worker str_args.extend(repr(a) for a in args) 662*da0073e9SAndroid Build Coastguard Worker str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items()) 663*da0073e9SAndroid Build Coastguard Worker r = f"{str_fn}({', '.join(str_args)})" 664*da0073e9SAndroid Build Coastguard Worker return r 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker# NOTE [ Python Traceback Reference Cycle Problem ] 668*da0073e9SAndroid Build Coastguard Worker# 669*da0073e9SAndroid Build Coastguard Worker# When using sys.exc_info(), it is important to **not** store the exc_info[2], 670*da0073e9SAndroid Build Coastguard Worker# which is the traceback, because otherwise you will run into the traceback 671*da0073e9SAndroid Build Coastguard Worker# reference cycle problem, i.e., the traceback holding reference to the frame, 672*da0073e9SAndroid Build Coastguard Worker# and the frame (which holds reference to all the object in its temporary scope) 673*da0073e9SAndroid Build Coastguard Worker# holding reference the traceback. 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Workerclass KeyErrorMessage(str): 677*da0073e9SAndroid Build Coastguard Worker r"""str subclass that returns itself in repr""" 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 680*da0073e9SAndroid Build Coastguard Worker return self 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Workerclass ExceptionWrapper: 684*da0073e9SAndroid Build Coastguard Worker r"""Wraps an exception plus traceback to communicate across threads""" 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker def __init__(self, exc_info=None, where="in background"): 687*da0073e9SAndroid Build Coastguard Worker # It is important that we don't store exc_info, see 688*da0073e9SAndroid Build Coastguard Worker # NOTE [ Python Traceback Reference Cycle Problem ] 689*da0073e9SAndroid Build Coastguard Worker if exc_info is None: 690*da0073e9SAndroid Build Coastguard Worker exc_info = sys.exc_info() 691*da0073e9SAndroid Build Coastguard Worker self.exc_type = exc_info[0] 692*da0073e9SAndroid Build Coastguard Worker self.exc_msg = "".join(traceback.format_exception(*exc_info)) 693*da0073e9SAndroid Build Coastguard Worker self.where = where 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker def reraise(self): 696*da0073e9SAndroid Build Coastguard Worker r"""Reraises the wrapped exception in the current thread""" 697*da0073e9SAndroid Build Coastguard Worker # Format a message such as: "Caught ValueError in DataLoader worker 698*da0073e9SAndroid Build Coastguard Worker # process 2. Original Traceback:", followed by the traceback. 699*da0073e9SAndroid Build Coastguard Worker msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" 700*da0073e9SAndroid Build Coastguard Worker if self.exc_type == KeyError: 701*da0073e9SAndroid Build Coastguard Worker # KeyError calls repr() on its argument (usually a dict key). This 702*da0073e9SAndroid Build Coastguard Worker # makes stack traces unreadable. It will not be changed in Python 703*da0073e9SAndroid Build Coastguard Worker # (https://bugs.python.org/issue2651), so we work around it. 704*da0073e9SAndroid Build Coastguard Worker msg = KeyErrorMessage(msg) 705*da0073e9SAndroid Build Coastguard Worker elif getattr(self.exc_type, "message", None): 706*da0073e9SAndroid Build Coastguard Worker # Some exceptions have first argument as non-str but explicitly 707*da0073e9SAndroid Build Coastguard Worker # have message field 708*da0073e9SAndroid Build Coastguard Worker raise self.exc_type(message=msg) 709*da0073e9SAndroid Build Coastguard Worker try: 710*da0073e9SAndroid Build Coastguard Worker exception = self.exc_type(msg) 711*da0073e9SAndroid Build Coastguard Worker except TypeError: 712*da0073e9SAndroid Build Coastguard Worker # If the exception takes multiple arguments, don't try to 713*da0073e9SAndroid Build Coastguard Worker # instantiate since we don't know how to 714*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(msg) from None 715*da0073e9SAndroid Build Coastguard Worker raise exception 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Workerdef _get_available_device_type(): 719*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 720*da0073e9SAndroid Build Coastguard Worker return "cuda" 721*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] 722*da0073e9SAndroid Build Coastguard Worker return "xpu" 723*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, "mtia") and torch.mtia.is_available(): 724*da0073e9SAndroid Build Coastguard Worker return "mtia" 725*da0073e9SAndroid Build Coastguard Worker custom_backend_name = torch._C._get_privateuse1_backend_name() 726*da0073e9SAndroid Build Coastguard Worker custom_device_mod = getattr(torch, custom_backend_name, None) 727*da0073e9SAndroid Build Coastguard Worker if custom_device_mod and custom_device_mod.is_available(): 728*da0073e9SAndroid Build Coastguard Worker return custom_backend_name 729*da0073e9SAndroid Build Coastguard Worker # add more available device types here 730*da0073e9SAndroid Build Coastguard Worker return None 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Workerdef _get_device_attr(get_member): 734*da0073e9SAndroid Build Coastguard Worker device_type = _get_available_device_type() 735*da0073e9SAndroid Build Coastguard Worker if device_type and device_type.lower() == "cuda": 736*da0073e9SAndroid Build Coastguard Worker return get_member(torch.cuda) 737*da0073e9SAndroid Build Coastguard Worker if device_type and device_type.lower() == "xpu": 738*da0073e9SAndroid Build Coastguard Worker return get_member(torch.xpu) # type: ignore[attr-defined] 739*da0073e9SAndroid Build Coastguard Worker if device_type and device_type.lower() == "mtia": 740*da0073e9SAndroid Build Coastguard Worker return get_member(torch.mtia) 741*da0073e9SAndroid Build Coastguard Worker if device_type == torch._C._get_privateuse1_backend_name(): 742*da0073e9SAndroid Build Coastguard Worker return get_member(getattr(torch, device_type)) 743*da0073e9SAndroid Build Coastguard Worker # add more available device types here 744*da0073e9SAndroid Build Coastguard Worker return None 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Workerdef _get_current_device_index(): 748*da0073e9SAndroid Build Coastguard Worker # current device index 749*da0073e9SAndroid Build Coastguard Worker return _get_device_attr(lambda m: m.current_device()) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Workerdef _get_all_device_indices(): 753*da0073e9SAndroid Build Coastguard Worker # all device index 754*da0073e9SAndroid Build Coastguard Worker return _get_device_attr(lambda m: list(range(m.device_count()))) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Workerdef _get_devices_properties(device_ids): 758*da0073e9SAndroid Build Coastguard Worker # all device properties 759*da0073e9SAndroid Build Coastguard Worker return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Workerdef get_current_device_index() -> int: 763*da0073e9SAndroid Build Coastguard Worker r"""Checks if there are CUDA devices available and 764*da0073e9SAndroid Build Coastguard Worker returns the device index of the current default CUDA device. 765*da0073e9SAndroid Build Coastguard Worker Returns -1 in case there are no CUDA devices available. 766*da0073e9SAndroid Build Coastguard Worker Arguments: ``None`` 767*da0073e9SAndroid Build Coastguard Worker """ 768*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 0: 769*da0073e9SAndroid Build Coastguard Worker return torch.cuda.current_device() 770*da0073e9SAndroid Build Coastguard Worker return -1 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Workerdef _get_device_index( 774*da0073e9SAndroid Build Coastguard Worker device: Any, 775*da0073e9SAndroid Build Coastguard Worker optional: bool = False, 776*da0073e9SAndroid Build Coastguard Worker allow_cpu: bool = False, 777*da0073e9SAndroid Build Coastguard Worker) -> int: 778*da0073e9SAndroid Build Coastguard Worker r"""Gets the device index from :attr:`device`, which can be a torch.device 779*da0073e9SAndroid Build Coastguard Worker object, a Python integer, or ``None``. 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker If :attr:`device` is a torch.device object, returns the device index if it 782*da0073e9SAndroid Build Coastguard Worker has index. Note that for a device without a specified index, 783*da0073e9SAndroid Build Coastguard Worker i.e., ``torch.device('xxx')``, this will return the current default 784*da0073e9SAndroid Build Coastguard Worker device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, 785*da0073e9SAndroid Build Coastguard Worker CPU devices will be accepted and ``-1`` will be returned in this case. 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker If :attr:`device` is a Python integer, it is returned as is. 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker If :attr:`device` is ``None``, this will return the current default 790*da0073e9SAndroid Build Coastguard Worker device of the supported runtime platform if :attr:`optional` is ``True``. 791*da0073e9SAndroid Build Coastguard Worker i.e., the current default CUDA device will be returned if CUDA runtime is supported. 792*da0073e9SAndroid Build Coastguard Worker """ 793*da0073e9SAndroid Build Coastguard Worker if isinstance(device, str): 794*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 795*da0073e9SAndroid Build Coastguard Worker device_idx: Optional[int] = None 796*da0073e9SAndroid Build Coastguard Worker if isinstance(device, torch.device): 797*da0073e9SAndroid Build Coastguard Worker if not allow_cpu and device.type == "cpu": 798*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Expected a non cpu device, but got: {device}") 799*da0073e9SAndroid Build Coastguard Worker device_idx = -1 if device.type == "cpu" else device.index 800*da0073e9SAndroid Build Coastguard Worker if isinstance(device, int): 801*da0073e9SAndroid Build Coastguard Worker device_idx = device 802*da0073e9SAndroid Build Coastguard Worker if device_idx is None: 803*da0073e9SAndroid Build Coastguard Worker if optional: 804*da0073e9SAndroid Build Coastguard Worker # The eager API _get_current_device_index uses `lambda` functions which are 805*da0073e9SAndroid Build Coastguard Worker # not supported in JIT and hence not scriptable. The JIT equivalent API to get 806*da0073e9SAndroid Build Coastguard Worker # the current device index is `get_current_device_index()` which can 807*da0073e9SAndroid Build Coastguard Worker # be scripted. We use is_scripting to check the mode we are in and call the 808*da0073e9SAndroid Build Coastguard Worker # appropriate API. 809*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 810*da0073e9SAndroid Build Coastguard Worker device_idx = get_current_device_index() 811*da0073e9SAndroid Build Coastguard Worker else: 812*da0073e9SAndroid Build Coastguard Worker device_idx = _get_current_device_index() 813*da0073e9SAndroid Build Coastguard Worker else: 814*da0073e9SAndroid Build Coastguard Worker raise ValueError( 815*da0073e9SAndroid Build Coastguard Worker f"Expected a torch.device with a specified index or an integer, but got:{device}" 816*da0073e9SAndroid Build Coastguard Worker ) 817*da0073e9SAndroid Build Coastguard Worker return device_idx 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Workerdef _handle_complex(tensor): 821*da0073e9SAndroid Build Coastguard Worker """ 822*da0073e9SAndroid Build Coastguard Worker Returns a real view of a tensor if complex dtype else just the tensor 823*da0073e9SAndroid Build Coastguard Worker need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule 824*da0073e9SAndroid Build Coastguard Worker """ 825*da0073e9SAndroid Build Coastguard Worker return ( 826*da0073e9SAndroid Build Coastguard Worker torch.view_as_real(tensor) 827*da0073e9SAndroid Build Coastguard Worker if not isinstance(tensor, torch.nn.UninitializedParameter) 828*da0073e9SAndroid Build Coastguard Worker and tensor.is_complex() 829*da0073e9SAndroid Build Coastguard Worker else tensor 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Workerdef _element_size(dtype): 834*da0073e9SAndroid Build Coastguard Worker """ 835*da0073e9SAndroid Build Coastguard Worker Returns the element size for a dtype, in bytes 836*da0073e9SAndroid Build Coastguard Worker """ 837*da0073e9SAndroid Build Coastguard Worker if not isinstance(dtype, torch.dtype): 838*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 841*da0073e9SAndroid Build Coastguard Worker return torch.finfo(dtype).bits >> 2 842*da0073e9SAndroid Build Coastguard Worker elif dtype.is_floating_point: 843*da0073e9SAndroid Build Coastguard Worker return torch.finfo(dtype).bits >> 3 844*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bool: 845*da0073e9SAndroid Build Coastguard Worker # NOTE: torch.bool is not supported in torch.iinfo() 846*da0073e9SAndroid Build Coastguard Worker return 1 847*da0073e9SAndroid Build Coastguard Worker else: 848*da0073e9SAndroid Build Coastguard Worker return torch.iinfo(dtype).bits >> 3 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Workerclass _ClassPropertyDescriptor: 852*da0073e9SAndroid Build Coastguard Worker def __init__(self, fget, fset=None): 853*da0073e9SAndroid Build Coastguard Worker self.fget = fget 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker def __get__(self, instance, owner=None): 856*da0073e9SAndroid Build Coastguard Worker if owner is None: 857*da0073e9SAndroid Build Coastguard Worker owner = type(instance) 858*da0073e9SAndroid Build Coastguard Worker return self.fget.__get__(instance, owner)() 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Workerdef classproperty(func): 862*da0073e9SAndroid Build Coastguard Worker if not isinstance(func, (classmethod, staticmethod)): 863*da0073e9SAndroid Build Coastguard Worker func = classmethod(func) 864*da0073e9SAndroid Build Coastguard Worker return _ClassPropertyDescriptor(func) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Workerdef is_compiling() -> bool: 868*da0073e9SAndroid Build Coastguard Worker """ 869*da0073e9SAndroid Build Coastguard Worker Indicates whether we are tracing/compiling with torch.compile() or torch.export(). 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling(). 872*da0073e9SAndroid Build Coastguard Worker """ 873*da0073e9SAndroid Build Coastguard Worker return torch.compiler.is_compiling() 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Workerdef _functionalize_sync(t): 877*da0073e9SAndroid Build Coastguard Worker # This code lives in python instead of C++ since conditioning on a certain python subclass 878*da0073e9SAndroid Build Coastguard Worker # is much more of a pain in C++. 879*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.functional_tensor import FunctionalTensor 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker if isinstance(t, FunctionalTensor): 882*da0073e9SAndroid Build Coastguard Worker # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called 883*da0073e9SAndroid Build Coastguard Worker # when we sync our inner tensor. 884*da0073e9SAndroid Build Coastguard Worker # Why? 885*da0073e9SAndroid Build Coastguard Worker # (1) If there are input mutations in the graph, then they will be re-applied during 886*da0073e9SAndroid Build Coastguard Worker # AOTAutograd when we call _sync() from inside of our functionalization kernels. 887*da0073e9SAndroid Build Coastguard Worker # (2) _sync() causes us to regenerate our updated the tensor from the updated base, 888*da0073e9SAndroid Build Coastguard Worker # which dispatches to a bunch of view ops 889*da0073e9SAndroid Build Coastguard Worker # (3) The input to these view ops is our inner FunctionalTensorWrapper 890*da0073e9SAndroid Build Coastguard Worker # (since the sync was called from C++), not the python FunctionalTensor 891*da0073e9SAndroid Build Coastguard Worker # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts 892*da0073e9SAndroid Build Coastguard Worker # the view op, since it will see an input that is a C++ FunctionalTensorWrapper 893*da0073e9SAndroid Build Coastguard Worker # (aka a normal torch.Tensor) instead of a python `FunctionalTensor). 894*da0073e9SAndroid Build Coastguard Worker maybe_functional_mode = torch._C._unset_dispatch_mode( 895*da0073e9SAndroid Build Coastguard Worker torch._C._TorchDispatchModeKey.FUNCTIONAL 896*da0073e9SAndroid Build Coastguard Worker ) 897*da0073e9SAndroid Build Coastguard Worker try: 898*da0073e9SAndroid Build Coastguard Worker torch._functionalize_sync(t.elem) # type: ignore[attr-defined] 899*da0073e9SAndroid Build Coastguard Worker finally: 900*da0073e9SAndroid Build Coastguard Worker if maybe_functional_mode is not None: 901*da0073e9SAndroid Build Coastguard Worker torch._C._set_dispatch_mode(maybe_functional_mode) 902*da0073e9SAndroid Build Coastguard Worker else: 903*da0073e9SAndroid Build Coastguard Worker torch._functionalize_sync(t) # type: ignore[attr-defined] 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(2) 907*da0073e9SAndroid Build Coastguard Workerdef _get_device_module(device_type: str): 908*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, device_type, None) 909*da0073e9SAndroid Build Coastguard Worker if device_module is None: 910*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 911*da0073e9SAndroid Build Coastguard Worker f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker return device_module 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Workerdef _dummy_type(name: str) -> type: 917*da0073e9SAndroid Build Coastguard Worker def get_err_fn(is_init: bool): 918*da0073e9SAndroid Build Coastguard Worker def err_fn(obj, *args, **kwargs): 919*da0073e9SAndroid Build Coastguard Worker if is_init: 920*da0073e9SAndroid Build Coastguard Worker class_name = obj.__class__.__name__ 921*da0073e9SAndroid Build Coastguard Worker else: 922*da0073e9SAndroid Build Coastguard Worker class_name = obj.__name__ 923*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker return err_fn 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker return type( 928*da0073e9SAndroid Build Coastguard Worker name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} 929*da0073e9SAndroid Build Coastguard Worker ) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Workerclass _LazySeedTracker: 933*da0073e9SAndroid Build Coastguard Worker # Since seeding is memory-less, only track the latest seed. 934*da0073e9SAndroid Build Coastguard Worker # Note: `manual_seed_all` followed by `manual_seed` overwrites 935*da0073e9SAndroid Build Coastguard Worker # the seed on current device. We track the order of **latest** 936*da0073e9SAndroid Build Coastguard Worker # calls between these two API. 937*da0073e9SAndroid Build Coastguard Worker def __init__(self): 938*da0073e9SAndroid Build Coastguard Worker self.manual_seed_all_cb = None 939*da0073e9SAndroid Build Coastguard Worker self.manual_seed_cb = None 940*da0073e9SAndroid Build Coastguard Worker self.call_order = [] 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker def queue_seed_all(self, cb, traceback): 943*da0073e9SAndroid Build Coastguard Worker self.manual_seed_all_cb = (cb, traceback) 944*da0073e9SAndroid Build Coastguard Worker # update seed_all to be latest 945*da0073e9SAndroid Build Coastguard Worker self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker def queue_seed(self, cb, traceback): 948*da0073e9SAndroid Build Coastguard Worker self.manual_seed_cb = (cb, traceback) 949*da0073e9SAndroid Build Coastguard Worker # update seed to be latest 950*da0073e9SAndroid Build Coastguard Worker self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker def get_calls(self) -> List: 953*da0073e9SAndroid Build Coastguard Worker return self.call_order 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 957*da0073e9SAndroid Build Coastguard WorkerP = ParamSpec("P") 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Workerclass CallbackRegistry(Generic[P]): 961*da0073e9SAndroid Build Coastguard Worker def __init__(self, name: str): 962*da0073e9SAndroid Build Coastguard Worker self.name = name 963*da0073e9SAndroid Build Coastguard Worker self.callback_list: List[Callable[P, None]] = [] 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker def add_callback(self, cb: Callable[P, None]) -> None: 966*da0073e9SAndroid Build Coastguard Worker self.callback_list.append(cb) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: 969*da0073e9SAndroid Build Coastguard Worker for cb in self.callback_list: 970*da0073e9SAndroid Build Coastguard Worker try: 971*da0073e9SAndroid Build Coastguard Worker cb(*args, **kwargs) 972*da0073e9SAndroid Build Coastguard Worker except Exception as e: 973*da0073e9SAndroid Build Coastguard Worker logger.exception( 974*da0073e9SAndroid Build Coastguard Worker "Exception in callback for %s registered with gpu trace", self.name 975*da0073e9SAndroid Build Coastguard Worker ) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py 979*da0073e9SAndroid Build Coastguard Worker# for use in the weights_only Unpickler. 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard WorkerIMPORT_MAPPING = { 982*da0073e9SAndroid Build Coastguard Worker "__builtin__": "builtins", 983*da0073e9SAndroid Build Coastguard Worker "copy_reg": "copyreg", 984*da0073e9SAndroid Build Coastguard Worker "Queue": "queue", 985*da0073e9SAndroid Build Coastguard Worker "repr": "reprlib", 986*da0073e9SAndroid Build Coastguard Worker "_abcoll": "collections.abc", 987*da0073e9SAndroid Build Coastguard Worker # Non-mutual mappings. 988*da0073e9SAndroid Build Coastguard Worker "UserDict": "collections", 989*da0073e9SAndroid Build Coastguard Worker "UserList": "collections", 990*da0073e9SAndroid Build Coastguard Worker "UserString": "collections", 991*da0073e9SAndroid Build Coastguard Worker "whichdb": "dbm", 992*da0073e9SAndroid Build Coastguard Worker "StringIO": "io", 993*da0073e9SAndroid Build Coastguard Worker "cStringIO": "io", 994*da0073e9SAndroid Build Coastguard Worker} 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker# This contains rename rules that are easy to handle. We ignore the more 998*da0073e9SAndroid Build Coastguard Worker# complex stuff (e.g. mapping the names in the urllib and types modules). 999*da0073e9SAndroid Build Coastguard Worker# These rules should be run before import names are fixed. 1000*da0073e9SAndroid Build Coastguard WorkerNAME_MAPPING = { 1001*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "xrange"): ("builtins", "range"), 1002*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "reduce"): ("functools", "reduce"), 1003*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "intern"): ("sys", "intern"), 1004*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "unichr"): ("builtins", "chr"), 1005*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "unicode"): ("builtins", "str"), 1006*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "long"): ("builtins", "int"), 1007*da0073e9SAndroid Build Coastguard Worker ("itertools", "izip"): ("builtins", "zip"), 1008*da0073e9SAndroid Build Coastguard Worker ("itertools", "imap"): ("builtins", "map"), 1009*da0073e9SAndroid Build Coastguard Worker ("itertools", "ifilter"): ("builtins", "filter"), 1010*da0073e9SAndroid Build Coastguard Worker ("itertools", "ifilterfalse"): ("itertools", "filterfalse"), 1011*da0073e9SAndroid Build Coastguard Worker ("itertools", "izip_longest"): ("itertools", "zip_longest"), 1012*da0073e9SAndroid Build Coastguard Worker ("UserDict", "IterableUserDict"): ("collections", "UserDict"), 1013*da0073e9SAndroid Build Coastguard Worker ("UserList", "UserList"): ("collections", "UserList"), 1014*da0073e9SAndroid Build Coastguard Worker ("UserString", "UserString"): ("collections", "UserString"), 1015*da0073e9SAndroid Build Coastguard Worker # Non-mutual mappings. 1016*da0073e9SAndroid Build Coastguard Worker ("__builtin__", "basestring"): ("builtins", "str"), 1017*da0073e9SAndroid Build Coastguard Worker ("exceptions", "StandardError"): ("builtins", "Exception"), 1018*da0073e9SAndroid Build Coastguard Worker ("UserDict", "UserDict"): ("collections", "UserDict"), 1019*da0073e9SAndroid Build Coastguard Worker} 1020