xref: /aosp_15_r20/external/pytorch/torch/testing/_creation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This module contains tensor creation utilities.
3"""
4
5import collections.abc
6import math
7import warnings
8from typing import cast, List, Optional, Tuple, Union
9
10import torch
11
12
13_INTEGRAL_TYPES = [
14    torch.uint8,
15    torch.int8,
16    torch.int16,
17    torch.int32,
18    torch.int64,
19    torch.uint16,
20    torch.uint32,
21    torch.uint64,
22]
23_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
24_FLOATING_8BIT_TYPES = [
25    torch.float8_e4m3fn,
26    torch.float8_e5m2,
27    torch.float8_e4m3fnuz,
28    torch.float8_e5m2fnuz,
29]
30_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
31_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
32_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
33
34
35def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
36    # uniform_ requires to-from <= std::numeric_limits<scalar_t>::max()
37    # Work around this by scaling the range before and after the PRNG
38    if high - low >= torch.finfo(t.dtype).max:
39        return t.uniform_(low / 2, high / 2).mul_(2)
40    else:
41        return t.uniform_(low, high)
42
43
44def make_tensor(
45    *shape: Union[int, torch.Size, List[int], Tuple[int, ...]],
46    dtype: torch.dtype,
47    device: Union[str, torch.device],
48    low: Optional[float] = None,
49    high: Optional[float] = None,
50    requires_grad: bool = False,
51    noncontiguous: bool = False,
52    exclude_zero: bool = False,
53    memory_format: Optional[torch.memory_format] = None,
54) -> torch.Tensor:
55    r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
56    values uniformly drawn from ``[low, high)``.
57
58    If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
59    finite values then they are clamped to the lowest or highest representable finite value, respectively.
60    If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
61    which depend on :attr:`dtype`.
62
63    +---------------------------+------------+----------+
64    | ``dtype``                 | ``low``    | ``high`` |
65    +===========================+============+==========+
66    | boolean type              | ``0``      | ``2``    |
67    +---------------------------+------------+----------+
68    | unsigned integral type    | ``0``      | ``10``   |
69    +---------------------------+------------+----------+
70    | signed integral types     | ``-9``     | ``10``   |
71    +---------------------------+------------+----------+
72    | floating types            | ``-9``     | ``9``    |
73    +---------------------------+------------+----------+
74    | complex types             | ``-9``     | ``9``    |
75    +---------------------------+------------+----------+
76
77    Args:
78        shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
79        dtype (:class:`torch.dtype`): The data type of the returned tensor.
80        device (Union[str, torch.device]): The device of the returned tensor.
81        low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
82            clamped to the least representable finite value of the given dtype. When ``None`` (default),
83            this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
84        high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is
85            clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value
86            is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
87
88            .. deprecated:: 2.1
89
90                Passing ``low==high`` to :func:`~torch.testing.make_tensor` for floating or complex types is deprecated
91                since 2.1 and will be removed in 2.3. Use :func:`torch.full` instead.
92
93        requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
94        noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
95            ignored if the constructed tensor has fewer than two elements. Mutually exclusive with ``memory_format``.
96        exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
97            depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
98            point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
99            :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
100            whose real and imaginary parts are both the smallest positive normal number representable by the complex
101            type. Default ``False``.
102        memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Mutually exclusive
103            with ``noncontiguous``.
104
105    Raises:
106        ValueError: If ``requires_grad=True`` is passed for integral `dtype`
107        ValueError: If ``low >= high``.
108        ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
109        ValueError: If both :attr:`noncontiguous` and :attr:`memory_format` are passed.
110        TypeError: If :attr:`dtype` isn't supported by this function.
111
112    Examples:
113        >>> # xdoctest: +SKIP
114        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
115        >>> from torch.testing import make_tensor
116        >>> # Creates a float tensor with values in [-1, 1)
117        >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
118        >>> # xdoctest: +SKIP
119        tensor([ 0.1205, 0.2282, -0.6380])
120        >>> # Creates a bool tensor on CUDA
121        >>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
122        tensor([[False, False],
123                [False, True]], device='cuda:0')
124    """
125
126    def modify_low_high(
127        low: Optional[float],
128        high: Optional[float],
129        *,
130        lowest_inclusive: float,
131        highest_exclusive: float,
132        default_low: float,
133        default_high: float,
134    ) -> Tuple[float, float]:
135        """
136        Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high)
137        if required.
138        """
139
140        def clamp(a: float, l: float, h: float) -> float:
141            return min(max(a, l), h)
142
143        low = low if low is not None else default_low
144        high = high if high is not None else default_high
145
146        if any(isinstance(value, float) and math.isnan(value) for value in [low, high]):
147            raise ValueError(
148                f"`low` and `high` cannot be NaN, but got {low=} and {high=}"
149            )
150        elif low == high and dtype in _FLOATING_OR_COMPLEX_TYPES:
151            warnings.warn(
152                "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types "
153                "is deprecated since 2.1 and will be removed in 2.3. "
154                "Use `torch.full(...)` instead.",
155                FutureWarning,
156                stacklevel=3,
157            )
158        elif low >= high:
159            raise ValueError(f"`low` must be less than `high`, but got {low} >= {high}")
160        elif high < lowest_inclusive or low >= highest_exclusive:
161            raise ValueError(
162                f"The value interval specified by `low` and `high` is [{low}, {high}), "
163                f"but {dtype} only supports [{lowest_inclusive}, {highest_exclusive})"
164            )
165
166        low = clamp(low, lowest_inclusive, highest_exclusive)
167        high = clamp(high, lowest_inclusive, highest_exclusive)
168
169        if dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
170            # 1. `low` is ceiled to avoid creating values smaller than `low` and thus outside the specified interval
171            # 2. Following the same reasoning as for 1., `high` should be floored. However, the higher bound of
172            #    `torch.randint` is exclusive, and thus we need to ceil here as well.
173            return math.ceil(low), math.ceil(high)
174
175        return low, high
176
177    if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence):
178        shape = shape[0]  # type: ignore[assignment]
179    shape = cast(Tuple[int, ...], tuple(shape))
180
181    if noncontiguous and memory_format is not None:
182        raise ValueError(
183            f"The parameters `noncontiguous` and `memory_format` are mutually exclusive, "
184            f"but got {noncontiguous=} and {memory_format=}"
185        )
186
187    if requires_grad and dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
188        raise ValueError(
189            f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}"
190        )
191
192    if dtype is torch.bool:
193        low, high = cast(
194            Tuple[int, int],
195            modify_low_high(
196                low,
197                high,
198                lowest_inclusive=0,
199                highest_exclusive=2,
200                default_low=0,
201                default_high=2,
202            ),
203        )
204        result = torch.randint(low, high, shape, device=device, dtype=dtype)
205    elif dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
206        low, high = cast(
207            Tuple[int, int],
208            modify_low_high(
209                low,
210                high,
211                lowest_inclusive=torch.iinfo(dtype).min,
212                highest_exclusive=torch.iinfo(dtype).max
213                # In theory, `highest_exclusive` should always be the maximum value + 1. However, `torch.randint`
214                # internally converts the bounds to an int64 and would overflow. In other words: `torch.randint` cannot
215                # sample 2**63 - 1, i.e. the maximum value of `torch.int64` and we need to account for that here.
216                + (1 if dtype is not torch.int64 else 0),
217                # This is incorrect for `torch.uint8`, but since we clamp to `lowest`, i.e. 0 for `torch.uint8`,
218                # _after_ we use the default value, we don't need to special case it here
219                default_low=-9,
220                default_high=10,
221            ),
222        )
223        result = torch.randint(low, high, shape, device=device, dtype=dtype)
224    elif dtype in _FLOATING_OR_COMPLEX_TYPES:
225        low, high = modify_low_high(
226            low,
227            high,
228            lowest_inclusive=torch.finfo(dtype).min,
229            highest_exclusive=torch.finfo(dtype).max,
230            default_low=-9,
231            default_high=9,
232        )
233        result = torch.empty(shape, device=device, dtype=dtype)
234        _uniform_random_(
235            torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
236        )
237    elif dtype in _FLOATING_8BIT_TYPES:
238        low, high = modify_low_high(
239            low,
240            high,
241            lowest_inclusive=torch.finfo(dtype).min,
242            highest_exclusive=torch.finfo(dtype).max,
243            default_low=-9,
244            default_high=9,
245        )
246        result = torch.empty(shape, device=device, dtype=torch.float32)
247        _uniform_random_(result, low, high)
248        result = result.to(dtype)
249    else:
250        raise TypeError(
251            f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
252            " To request support, file an issue at: https://github.com/pytorch/pytorch/issues"
253        )
254
255    if noncontiguous and result.numel() > 1:
256        result = torch.repeat_interleave(result, 2, dim=-1)
257        result = result[..., ::2]
258    elif memory_format is not None:
259        result = result.clone(memory_format=memory_format)
260
261    if exclude_zero:
262        result[result == 0] = (
263            1 if dtype in _BOOLEAN_OR_INTEGRAL_TYPES else torch.finfo(dtype).tiny
264        )
265
266    if dtype in _FLOATING_OR_COMPLEX_TYPES:
267        result.requires_grad = requires_grad
268
269    return result
270