1# mypy: allow-untyped-defs 2r"""This file provides a location for operators that help exporting models via onnx. 3 4E.g. `shape_as_tensor` and `reshape_from_tensor_shape` 5are to make all dynamic sizes operations traceable. 6 7NOTE: at one point these functions were implemented differently. 8Since then we have implemented these directly in ATen, so this 9file is kept purely for backward-compatibility. 10""" 11 12import torch 13import torch.onnx 14 15 16def shape_as_tensor(x): 17 """Get the shape of a tensor as a tensor. 18 19 Args: 20 x (Tensor): The input tensor. 21 22 Returns: 23 Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x. 24 25 Example: 26 >>> x = torch.randn(2, 3) 27 >>> shape_as_tensor(x) 28 tensor([2, 3]) 29 30 """ 31 return torch._shape_as_tensor(x) 32 33 34def reshape_from_tensor_shape(x, shape): 35 """Reshape a tensor to the given shape. 36 37 This function is used to make dynamic size operations traceable when exporting models via ONNX. 38 This function is kept for backward-compatibility. It is implemented directly in ATen. 39 40 Parameters: 41 x (Tensor): the tensor to be reshaped. 42 shape (Tensor): the target shape. 43 44 Returns: 45 Tensor: the reshaped tensor. 46 """ 47 return torch._reshape_from_tensor(x, shape) 48