xref: /aosp_15_r20/external/pytorch/torch/onnx/operators.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""This file provides a location for operators that help exporting models via onnx.
3
4E.g. `shape_as_tensor` and `reshape_from_tensor_shape`
5are to make all dynamic sizes operations traceable.
6
7NOTE: at one point these functions were implemented differently.
8Since then we have implemented these directly in ATen, so this
9file is kept purely for backward-compatibility.
10"""
11
12import torch
13import torch.onnx
14
15
16def shape_as_tensor(x):
17    """Get the shape of a tensor as a tensor.
18
19    Args:
20        x (Tensor): The input tensor.
21
22    Returns:
23        Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x.
24
25    Example:
26        >>> x = torch.randn(2, 3)
27        >>> shape_as_tensor(x)
28        tensor([2, 3])
29
30    """
31    return torch._shape_as_tensor(x)
32
33
34def reshape_from_tensor_shape(x, shape):
35    """Reshape a tensor to the given shape.
36
37    This function is used to make dynamic size operations traceable when exporting models via ONNX.
38    This function is kept for backward-compatibility. It is implemented directly in ATen.
39
40    Parameters:
41        x (Tensor): the tensor to be reshaped.
42        shape (Tensor): the target shape.
43
44    Returns:
45        Tensor: the reshaped tensor.
46    """
47    return torch._reshape_from_tensor(x, shape)
48