xref: /aosp_15_r20/external/pytorch/torch/nn/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""Functional interface."""
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport importlib
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport warnings
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch import _VF, sym_int as _sym_int, Tensor
10*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _add_docstr, _infer_size
11*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import (
12*da0073e9SAndroid Build Coastguard Worker    _overload,
13*da0073e9SAndroid Build Coastguard Worker    boolean_dispatch,
14*da0073e9SAndroid Build Coastguard Worker    BroadcastingList1,
15*da0073e9SAndroid Build Coastguard Worker    BroadcastingList2,
16*da0073e9SAndroid Build Coastguard Worker    BroadcastingList3,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes
19*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import _reduction as _Reduction, grad  # noqa: F401
20*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.utils import _list_with_default, _pair, _single, _triple
21*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import (
22*da0073e9SAndroid Build Coastguard Worker    handle_torch_function,
23*da0073e9SAndroid Build Coastguard Worker    has_torch_function,
24*da0073e9SAndroid Build Coastguard Worker    has_torch_function_unary,
25*da0073e9SAndroid Build Coastguard Worker    has_torch_function_variadic,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
30*da0073e9SAndroid Build Coastguard Worker    from torch.types import _dtype as DType
31*da0073e9SAndroid Build Coastguard Workerelse:
32*da0073e9SAndroid Build Coastguard Worker    # The JIT doesn't understand Union, nor torch.dtype here
33*da0073e9SAndroid Build Coastguard Worker    DType = int
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workertry:
36*da0073e9SAndroid Build Coastguard Worker    import numpy as np
37*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
38*da0073e9SAndroid Build Coastguard Worker    np = None
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerconv1d = _add_docstr(
42*da0073e9SAndroid Build Coastguard Worker    torch.conv1d,
43*da0073e9SAndroid Build Coastguard Worker    r"""
44*da0073e9SAndroid Build Coastguard Workerconv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard WorkerApplies a 1D convolution over an input signal composed of several input
47*da0073e9SAndroid Build Coastguard Workerplanes.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker{tf32_note}
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv1d` for details and output shape.
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard WorkerNote:
54*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard WorkerNote:
57*da0073e9SAndroid Build Coastguard Worker    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
58*da0073e9SAndroid Build Coastguard Worker""".format(
59*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
60*da0073e9SAndroid Build Coastguard Worker    )
61*da0073e9SAndroid Build Coastguard Worker    + r"""
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard WorkerArgs:
64*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
65*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)`
66*da0073e9SAndroid Build Coastguard Worker    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None``
67*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or
68*da0073e9SAndroid Build Coastguard Worker      a one-element tuple `(sW,)`. Default: 1
69*da0073e9SAndroid Build Coastguard Worker    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
70*da0073e9SAndroid Build Coastguard Worker      single number or a one-element tuple `(padW,)`. Default: 0
71*da0073e9SAndroid Build Coastguard Worker      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
72*da0073e9SAndroid Build Coastguard Worker      the input so the output has the same shape as the input. However, this mode
73*da0073e9SAndroid Build Coastguard Worker      doesn't support any stride values other than 1.
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker      .. warning::
76*da0073e9SAndroid Build Coastguard Worker          For ``padding='same'``, if the ``weight`` is even-length and
77*da0073e9SAndroid Build Coastguard Worker          ``dilation`` is odd in any dimension, a full :func:`pad` operation
78*da0073e9SAndroid Build Coastguard Worker          may be needed internally. Lowering performance.
79*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
80*da0073e9SAndroid Build Coastguard Worker      a one-element tuple `(dW,)`. Default: 1
81*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
82*da0073e9SAndroid Build Coastguard Worker      the number of groups. Default: 1
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard WorkerExamples::
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(33, 16, 30)
87*da0073e9SAndroid Build Coastguard Worker    >>> filters = torch.randn(20, 16, 5)
88*da0073e9SAndroid Build Coastguard Worker    >>> F.conv1d(inputs, filters)
89*da0073e9SAndroid Build Coastguard Worker""",
90*da0073e9SAndroid Build Coastguard Worker)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerconv2d = _add_docstr(
93*da0073e9SAndroid Build Coastguard Worker    torch.conv2d,
94*da0073e9SAndroid Build Coastguard Worker    r"""
95*da0073e9SAndroid Build Coastguard Workerconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard WorkerApplies a 2D convolution over an input image composed of several input
98*da0073e9SAndroid Build Coastguard Workerplanes.
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker{tf32_note}
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv2d` for details and output shape.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard WorkerNote:
105*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard WorkerNote:
108*da0073e9SAndroid Build Coastguard Worker    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
109*da0073e9SAndroid Build Coastguard Worker""".format(
110*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
111*da0073e9SAndroid Build Coastguard Worker    )
112*da0073e9SAndroid Build Coastguard Worker    + r"""
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard WorkerArgs:
115*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
116*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
117*da0073e9SAndroid Build Coastguard Worker    bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
118*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or a
119*da0073e9SAndroid Build Coastguard Worker      tuple `(sH, sW)`. Default: 1
120*da0073e9SAndroid Build Coastguard Worker    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
121*da0073e9SAndroid Build Coastguard Worker      single number or a tuple `(padH, padW)`. Default: 0
122*da0073e9SAndroid Build Coastguard Worker      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
123*da0073e9SAndroid Build Coastguard Worker      the input so the output has the same shape as the input. However, this mode
124*da0073e9SAndroid Build Coastguard Worker      doesn't support any stride values other than 1.
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker      .. warning::
127*da0073e9SAndroid Build Coastguard Worker          For ``padding='same'``, if the ``weight`` is even-length and
128*da0073e9SAndroid Build Coastguard Worker          ``dilation`` is odd in any dimension, a full :func:`pad` operation
129*da0073e9SAndroid Build Coastguard Worker          may be needed internally. Lowering performance.
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
132*da0073e9SAndroid Build Coastguard Worker      a tuple `(dH, dW)`. Default: 1
133*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}`
134*da0073e9SAndroid Build Coastguard Worker      should be divisible by the number of groups. Default: 1
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard WorkerExamples::
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    >>> # With square kernels and equal stride
139*da0073e9SAndroid Build Coastguard Worker    >>> filters = torch.randn(8, 4, 3, 3)
140*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(1, 4, 5, 5)
141*da0073e9SAndroid Build Coastguard Worker    >>> F.conv2d(inputs, filters, padding=1)
142*da0073e9SAndroid Build Coastguard Worker""",
143*da0073e9SAndroid Build Coastguard Worker)  # noqa: E501
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerconv3d = _add_docstr(
146*da0073e9SAndroid Build Coastguard Worker    torch.conv3d,
147*da0073e9SAndroid Build Coastguard Worker    r"""
148*da0073e9SAndroid Build Coastguard Workerconv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard WorkerApplies a 3D convolution over an input image composed of several input
151*da0073e9SAndroid Build Coastguard Workerplanes.
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker{tf32_note}
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv3d` for details and output shape.
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard WorkerNote:
158*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard WorkerNote:
161*da0073e9SAndroid Build Coastguard Worker    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
162*da0073e9SAndroid Build Coastguard Worker""".format(
163*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
164*da0073e9SAndroid Build Coastguard Worker    )
165*da0073e9SAndroid Build Coastguard Worker    + r"""
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard WorkerArgs:
168*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
169*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)`
170*da0073e9SAndroid Build Coastguard Worker    bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None
171*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or a
172*da0073e9SAndroid Build Coastguard Worker      tuple `(sT, sH, sW)`. Default: 1
173*da0073e9SAndroid Build Coastguard Worker    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
174*da0073e9SAndroid Build Coastguard Worker      single number or a tuple `(padT, padH, padW)`. Default: 0
175*da0073e9SAndroid Build Coastguard Worker      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
176*da0073e9SAndroid Build Coastguard Worker      the input so the output has the same shape as the input. However, this mode
177*da0073e9SAndroid Build Coastguard Worker      doesn't support any stride values other than 1.
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker      .. warning::
180*da0073e9SAndroid Build Coastguard Worker          For ``padding='same'``, if the ``weight`` is even-length and
181*da0073e9SAndroid Build Coastguard Worker          ``dilation`` is odd in any dimension, a full :func:`pad` operation
182*da0073e9SAndroid Build Coastguard Worker          may be needed internally. Lowering performance.
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
185*da0073e9SAndroid Build Coastguard Worker      a tuple `(dT, dH, dW)`. Default: 1
186*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
187*da0073e9SAndroid Build Coastguard Worker      the number of groups. Default: 1
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard WorkerExamples::
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    >>> filters = torch.randn(33, 16, 3, 3, 3)
192*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(20, 16, 50, 10, 20)
193*da0073e9SAndroid Build Coastguard Worker    >>> F.conv3d(inputs, filters)
194*da0073e9SAndroid Build Coastguard Worker""",
195*da0073e9SAndroid Build Coastguard Worker)  # noqa: E501
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Workerconv_transpose1d = _add_docstr(
198*da0073e9SAndroid Build Coastguard Worker    torch.conv_transpose1d,
199*da0073e9SAndroid Build Coastguard Worker    r"""
200*da0073e9SAndroid Build Coastguard Workerconv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard WorkerApplies a 1D transposed convolution operator over an input signal
203*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution".
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker{tf32_note}
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose1d` for details and output shape.
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard WorkerNote:
210*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
211*da0073e9SAndroid Build Coastguard Worker""".format(
212*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
213*da0073e9SAndroid Build Coastguard Worker    )
214*da0073e9SAndroid Build Coastguard Worker    + r"""
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard WorkerArgs:
217*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
218*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)`
219*da0073e9SAndroid Build Coastguard Worker    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
220*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or a
221*da0073e9SAndroid Build Coastguard Worker      tuple ``(sW,)``. Default: 1
222*da0073e9SAndroid Build Coastguard Worker    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
223*da0073e9SAndroid Build Coastguard Worker      sides of each dimension in the input. Can be a single number or a tuple
224*da0073e9SAndroid Build Coastguard Worker      ``(padW,)``. Default: 0
225*da0073e9SAndroid Build Coastguard Worker    output_padding: additional size added to one side of each dimension in the
226*da0073e9SAndroid Build Coastguard Worker      output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0
227*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
228*da0073e9SAndroid Build Coastguard Worker      number of groups. Default: 1
229*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
230*da0073e9SAndroid Build Coastguard Worker      a tuple ``(dW,)``. Default: 1
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard WorkerExamples::
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(20, 16, 50)
235*da0073e9SAndroid Build Coastguard Worker    >>> weights = torch.randn(16, 33, 5)
236*da0073e9SAndroid Build Coastguard Worker    >>> F.conv_transpose1d(inputs, weights)
237*da0073e9SAndroid Build Coastguard Worker""",
238*da0073e9SAndroid Build Coastguard Worker)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Workerconv_transpose2d = _add_docstr(
241*da0073e9SAndroid Build Coastguard Worker    torch.conv_transpose2d,
242*da0073e9SAndroid Build Coastguard Worker    r"""
243*da0073e9SAndroid Build Coastguard Workerconv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard WorkerApplies a 2D transposed convolution operator over an input image
246*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution".
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker{tf32_note}
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose2d` for details and output shape.
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard WorkerNote:
253*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
254*da0073e9SAndroid Build Coastguard Worker""".format(
255*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
256*da0073e9SAndroid Build Coastguard Worker    )
257*da0073e9SAndroid Build Coastguard Worker    + r"""
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard WorkerArgs:
260*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
261*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)`
262*da0073e9SAndroid Build Coastguard Worker    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
263*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or a
264*da0073e9SAndroid Build Coastguard Worker      tuple ``(sH, sW)``. Default: 1
265*da0073e9SAndroid Build Coastguard Worker    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
266*da0073e9SAndroid Build Coastguard Worker      sides of each dimension in the input. Can be a single number or a tuple
267*da0073e9SAndroid Build Coastguard Worker      ``(padH, padW)``. Default: 0
268*da0073e9SAndroid Build Coastguard Worker    output_padding: additional size added to one side of each dimension in the
269*da0073e9SAndroid Build Coastguard Worker      output shape. Can be a single number or a tuple ``(out_padH, out_padW)``.
270*da0073e9SAndroid Build Coastguard Worker      Default: 0
271*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
272*da0073e9SAndroid Build Coastguard Worker      number of groups. Default: 1
273*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
274*da0073e9SAndroid Build Coastguard Worker      a tuple ``(dH, dW)``. Default: 1
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard WorkerExamples::
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    >>> # With square kernels and equal stride
279*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(1, 4, 5, 5)
280*da0073e9SAndroid Build Coastguard Worker    >>> weights = torch.randn(4, 8, 3, 3)
281*da0073e9SAndroid Build Coastguard Worker    >>> F.conv_transpose2d(inputs, weights, padding=1)
282*da0073e9SAndroid Build Coastguard Worker""",
283*da0073e9SAndroid Build Coastguard Worker)  # noqa: E501
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Workerconv_transpose3d = _add_docstr(
286*da0073e9SAndroid Build Coastguard Worker    torch.conv_transpose3d,
287*da0073e9SAndroid Build Coastguard Worker    r"""
288*da0073e9SAndroid Build Coastguard Workerconv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard WorkerApplies a 3D transposed convolution operator over an input image
291*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution"
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker{tf32_note}
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose3d` for details and output shape.
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard WorkerNote:
298*da0073e9SAndroid Build Coastguard Worker    {cudnn_reproducibility_note}
299*da0073e9SAndroid Build Coastguard Worker""".format(
300*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes, **tf32_notes
301*da0073e9SAndroid Build Coastguard Worker    )
302*da0073e9SAndroid Build Coastguard Worker    + r"""
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard WorkerArgs:
305*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
306*da0073e9SAndroid Build Coastguard Worker    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)`
307*da0073e9SAndroid Build Coastguard Worker    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
308*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the convolving kernel. Can be a single number or a
309*da0073e9SAndroid Build Coastguard Worker      tuple ``(sT, sH, sW)``. Default: 1
310*da0073e9SAndroid Build Coastguard Worker    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
311*da0073e9SAndroid Build Coastguard Worker      sides of each dimension in the input. Can be a single number or a tuple
312*da0073e9SAndroid Build Coastguard Worker      ``(padT, padH, padW)``. Default: 0
313*da0073e9SAndroid Build Coastguard Worker    output_padding: additional size added to one side of each dimension in the
314*da0073e9SAndroid Build Coastguard Worker      output shape. Can be a single number or a tuple
315*da0073e9SAndroid Build Coastguard Worker      ``(out_padT, out_padH, out_padW)``. Default: 0
316*da0073e9SAndroid Build Coastguard Worker    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
317*da0073e9SAndroid Build Coastguard Worker      number of groups. Default: 1
318*da0073e9SAndroid Build Coastguard Worker    dilation: the spacing between kernel elements. Can be a single number or
319*da0073e9SAndroid Build Coastguard Worker      a tuple `(dT, dH, dW)`. Default: 1
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard WorkerExamples::
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    >>> inputs = torch.randn(20, 16, 50, 10, 20)
324*da0073e9SAndroid Build Coastguard Worker    >>> weights = torch.randn(16, 33, 3, 3, 3)
325*da0073e9SAndroid Build Coastguard Worker    >>> F.conv_transpose3d(inputs, weights)
326*da0073e9SAndroid Build Coastguard Worker""",
327*da0073e9SAndroid Build Coastguard Worker)  # noqa: E501
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Workerconv_tbc = _add_docstr(
330*da0073e9SAndroid Build Coastguard Worker    torch.conv_tbc,
331*da0073e9SAndroid Build Coastguard Worker    r"""
332*da0073e9SAndroid Build Coastguard WorkerApplies a 1-dimensional sequence convolution over an input sequence.
333*da0073e9SAndroid Build Coastguard WorkerInput and output dimensions are (Time, Batch, Channels) - hence TBC.
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard WorkerArgs:
336*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})`
337*da0073e9SAndroid Build Coastguard Worker    weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`)
338*da0073e9SAndroid Build Coastguard Worker    bias: bias of shape (:math:`\text{out\_channels}`)
339*da0073e9SAndroid Build Coastguard Worker    pad: number of timesteps to pad. Default: 0
340*da0073e9SAndroid Build Coastguard Worker""",
341*da0073e9SAndroid Build Coastguard Worker)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker# Pooling
345*da0073e9SAndroid Build Coastguard Workeravg_pool1d = _add_docstr(
346*da0073e9SAndroid Build Coastguard Worker    torch.avg_pool1d,
347*da0073e9SAndroid Build Coastguard Worker    r"""
348*da0073e9SAndroid Build Coastguard Workeravg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard WorkerApplies a 1D average pooling over an input signal composed of several
351*da0073e9SAndroid Build Coastguard Workerinput planes.
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool1d` for details and output shape.
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard WorkerArgs:
356*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
357*da0073e9SAndroid Build Coastguard Worker    kernel_size: the size of the window. Can be a single number or a
358*da0073e9SAndroid Build Coastguard Worker      tuple `(kW,)`
359*da0073e9SAndroid Build Coastguard Worker    stride: the stride of the window. Can be a single number or a tuple
360*da0073e9SAndroid Build Coastguard Worker      `(sW,)`. Default: :attr:`kernel_size`
361*da0073e9SAndroid Build Coastguard Worker    padding: implicit zero paddings on both sides of the input. Can be a
362*da0073e9SAndroid Build Coastguard Worker      single number or a tuple `(padW,)`. Default: 0
363*da0073e9SAndroid Build Coastguard Worker    ceil_mode: when True, will use `ceil` instead of `floor` to compute the
364*da0073e9SAndroid Build Coastguard Worker        output shape. Default: ``False``
365*da0073e9SAndroid Build Coastguard Worker    count_include_pad: when True, will include the zero-padding in the
366*da0073e9SAndroid Build Coastguard Worker        averaging calculation. Default: ``True``
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard WorkerExamples::
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    >>> # pool of square window of size=3, stride=2
371*da0073e9SAndroid Build Coastguard Worker    >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32)
372*da0073e9SAndroid Build Coastguard Worker    >>> F.avg_pool1d(input, kernel_size=3, stride=2)
373*da0073e9SAndroid Build Coastguard Worker    tensor([[[ 2.,  4.,  6.]]])
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker""",
376*da0073e9SAndroid Build Coastguard Worker)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Workeravg_pool2d = _add_docstr(
380*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.avg_pool2d,
381*da0073e9SAndroid Build Coastguard Worker    r"""
382*da0073e9SAndroid Build Coastguard Workeravg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard WorkerApplies 2D average-pooling operation in :math:`kH \times kW` regions by step size
385*da0073e9SAndroid Build Coastguard Worker:math:`sH \times sW` steps. The number of output features is equal to the number of
386*da0073e9SAndroid Build Coastguard Workerinput planes.
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool2d` for details and output shape.
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard WorkerArgs:
391*da0073e9SAndroid Build Coastguard Worker    input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
392*da0073e9SAndroid Build Coastguard Worker    kernel_size: size of the pooling region. Can be a single number or a
393*da0073e9SAndroid Build Coastguard Worker      tuple `(kH, kW)`
394*da0073e9SAndroid Build Coastguard Worker    stride: stride of the pooling operation. Can be a single number or a
395*da0073e9SAndroid Build Coastguard Worker      tuple `(sH, sW)`. Default: :attr:`kernel_size`
396*da0073e9SAndroid Build Coastguard Worker    padding: implicit zero paddings on both sides of the input. Can be a
397*da0073e9SAndroid Build Coastguard Worker      single number or a tuple `(padH, padW)`. Default: 0
398*da0073e9SAndroid Build Coastguard Worker    ceil_mode: when True, will use `ceil` instead of `floor` in the formula
399*da0073e9SAndroid Build Coastguard Worker        to compute the output shape. Default: ``False``
400*da0073e9SAndroid Build Coastguard Worker    count_include_pad: when True, will include the zero-padding in the
401*da0073e9SAndroid Build Coastguard Worker        averaging calculation. Default: ``True``
402*da0073e9SAndroid Build Coastguard Worker    divisor_override: if specified, it will be used as divisor, otherwise
403*da0073e9SAndroid Build Coastguard Worker         size of the pooling region will be used. Default: None
404*da0073e9SAndroid Build Coastguard Worker""",
405*da0073e9SAndroid Build Coastguard Worker)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Workeravg_pool3d = _add_docstr(
408*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.avg_pool3d,
409*da0073e9SAndroid Build Coastguard Worker    r"""
410*da0073e9SAndroid Build Coastguard Workeravg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard WorkerApplies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step
413*da0073e9SAndroid Build Coastguard Workersize :math:`sT \times sH \times sW` steps. The number of output features is equal to
414*da0073e9SAndroid Build Coastguard Worker:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`.
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool3d` for details and output shape.
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard WorkerArgs:
419*da0073e9SAndroid Build Coastguard Worker    input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)`
420*da0073e9SAndroid Build Coastguard Worker    kernel_size: size of the pooling region. Can be a single number or a
421*da0073e9SAndroid Build Coastguard Worker      tuple `(kT, kH, kW)`
422*da0073e9SAndroid Build Coastguard Worker    stride: stride of the pooling operation. Can be a single number or a
423*da0073e9SAndroid Build Coastguard Worker      tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
424*da0073e9SAndroid Build Coastguard Worker    padding: implicit zero paddings on both sides of the input. Can be a
425*da0073e9SAndroid Build Coastguard Worker      single number or a tuple `(padT, padH, padW)`, Default: 0
426*da0073e9SAndroid Build Coastguard Worker    ceil_mode: when True, will use `ceil` instead of `floor` in the formula
427*da0073e9SAndroid Build Coastguard Worker        to compute the output shape
428*da0073e9SAndroid Build Coastguard Worker    count_include_pad: when True, will include the zero-padding in the
429*da0073e9SAndroid Build Coastguard Worker        averaging calculation
430*da0073e9SAndroid Build Coastguard Worker    divisor_override: if specified, it will be used as divisor, otherwise
431*da0073e9SAndroid Build Coastguard Worker        size of the pooling region will be used. Default: None
432*da0073e9SAndroid Build Coastguard Worker""",
433*da0073e9SAndroid Build Coastguard Worker)
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool2d_with_indices(
437*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
438*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
439*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList2[int]] = None,
440*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[BroadcastingList2[float]] = None,
441*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
442*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = None,
443*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
444*da0073e9SAndroid Build Coastguard Worker    r"""
445*da0073e9SAndroid Build Coastguard Worker    fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    Applies 2D fractional max pooling over an input signal composed of several input planes.
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
452*da0073e9SAndroid Build Coastguard Worker    step size determined by the target output size.
453*da0073e9SAndroid Build Coastguard Worker    The number of output features is equal to the number of input planes.
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    Args:
456*da0073e9SAndroid Build Coastguard Worker        kernel_size: the size of the window to take a max over.
457*da0073e9SAndroid Build Coastguard Worker                     Can be a single number :math:`k` (for a square kernel of :math:`k \times k`)
458*da0073e9SAndroid Build Coastguard Worker                     or a tuple `(kH, kW)`
459*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size of the image of the form :math:`oH \times oW`.
460*da0073e9SAndroid Build Coastguard Worker                     Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH`
461*da0073e9SAndroid Build Coastguard Worker        output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
462*da0073e9SAndroid Build Coastguard Worker                      This has to be a number or tuple in the range (0, 1)
463*da0073e9SAndroid Build Coastguard Worker        return_indices: if ``True``, will return the indices along with the outputs.
464*da0073e9SAndroid Build Coastguard Worker                        Useful to pass to :func:`~torch.nn.functional.max_unpool2d`.
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    Examples::
467*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(20, 16, 50, 32)
468*da0073e9SAndroid Build Coastguard Worker        >>> # pool of square window of size=3, and target output size 13x12
469*da0073e9SAndroid Build Coastguard Worker        >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12))
470*da0073e9SAndroid Build Coastguard Worker        >>> # pool of square window and target output size being half of input image size
471*da0073e9SAndroid Build Coastguard Worker        >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5))
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    .. _Fractional MaxPooling:
474*da0073e9SAndroid Build Coastguard Worker        http://arxiv.org/abs/1412.6071
475*da0073e9SAndroid Build Coastguard Worker    """
476*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, _random_samples):
477*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
478*da0073e9SAndroid Build Coastguard Worker            fractional_max_pool2d_with_indices,
479*da0073e9SAndroid Build Coastguard Worker            (input, _random_samples),
480*da0073e9SAndroid Build Coastguard Worker            input,
481*da0073e9SAndroid Build Coastguard Worker            kernel_size,
482*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
483*da0073e9SAndroid Build Coastguard Worker            output_ratio=output_ratio,
484*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
485*da0073e9SAndroid Build Coastguard Worker            _random_samples=_random_samples,
486*da0073e9SAndroid Build Coastguard Worker        )
487*da0073e9SAndroid Build Coastguard Worker    if output_size is None and output_ratio is None:
488*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
489*da0073e9SAndroid Build Coastguard Worker            "fractional_max_pool2d requires specifying either an output_size or an output_ratio"
490*da0073e9SAndroid Build Coastguard Worker        )
491*da0073e9SAndroid Build Coastguard Worker    if output_size is None:
492*da0073e9SAndroid Build Coastguard Worker        assert output_ratio is not None
493*da0073e9SAndroid Build Coastguard Worker        if len(output_ratio) > 2:
494*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
495*da0073e9SAndroid Build Coastguard Worker                "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."
496*da0073e9SAndroid Build Coastguard Worker            )
497*da0073e9SAndroid Build Coastguard Worker        _output_ratio = _pair(output_ratio)
498*da0073e9SAndroid Build Coastguard Worker        output_size = [
499*da0073e9SAndroid Build Coastguard Worker            int(input.size(-2) * _output_ratio[0]),
500*da0073e9SAndroid Build Coastguard Worker            int(input.size(-1) * _output_ratio[1]),
501*da0073e9SAndroid Build Coastguard Worker        ]
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    if _random_samples is None:
504*da0073e9SAndroid Build Coastguard Worker        n_batch = 1 if input.dim() == 3 else input.size(0)
505*da0073e9SAndroid Build Coastguard Worker        _random_samples = torch.rand(
506*da0073e9SAndroid Build Coastguard Worker            n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device
507*da0073e9SAndroid Build Coastguard Worker        )
508*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.fractional_max_pool2d(
509*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, output_size, _random_samples
510*da0073e9SAndroid Build Coastguard Worker    )
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Workerdef _fractional_max_pool2d(
514*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
515*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
516*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList2[int]] = None,
517*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[BroadcastingList2[float]] = None,
518*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
519*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = None,
520*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
521*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, _random_samples):
522*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
523*da0073e9SAndroid Build Coastguard Worker            fractional_max_pool2d,
524*da0073e9SAndroid Build Coastguard Worker            (input, _random_samples),
525*da0073e9SAndroid Build Coastguard Worker            input,
526*da0073e9SAndroid Build Coastguard Worker            kernel_size,
527*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
528*da0073e9SAndroid Build Coastguard Worker            output_ratio=output_ratio,
529*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
530*da0073e9SAndroid Build Coastguard Worker            _random_samples=_random_samples,
531*da0073e9SAndroid Build Coastguard Worker        )
532*da0073e9SAndroid Build Coastguard Worker    return fractional_max_pool2d_with_indices(
533*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, output_size, output_ratio, return_indices, _random_samples
534*da0073e9SAndroid Build Coastguard Worker    )[0]
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Workerfractional_max_pool2d = boolean_dispatch(
538*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
539*da0073e9SAndroid Build Coastguard Worker    arg_index=4,
540*da0073e9SAndroid Build Coastguard Worker    default=False,
541*da0073e9SAndroid Build Coastguard Worker    if_true=fractional_max_pool2d_with_indices,
542*da0073e9SAndroid Build Coastguard Worker    if_false=_fractional_max_pool2d,
543*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
544*da0073e9SAndroid Build Coastguard Worker    func_name="fractional_max_pool2d",
545*da0073e9SAndroid Build Coastguard Worker)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool3d_with_indices(
549*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
550*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
551*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList3[int]] = None,
552*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[BroadcastingList3[float]] = None,
553*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
554*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = None,
555*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
556*da0073e9SAndroid Build Coastguard Worker    r"""
557*da0073e9SAndroid Build Coastguard Worker    fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    Applies 3D fractional max pooling over an input signal composed of several input planes.
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
564*da0073e9SAndroid Build Coastguard Worker    step size determined by the target output size.
565*da0073e9SAndroid Build Coastguard Worker    The number of output features is equal to the number of input planes.
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    Args:
568*da0073e9SAndroid Build Coastguard Worker        kernel_size: the size of the window to take a max over.
569*da0073e9SAndroid Build Coastguard Worker                     Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`)
570*da0073e9SAndroid Build Coastguard Worker                     or a tuple `(kT, kH, kW)`
571*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size of the form :math:`oT \times oH \times oW`.
572*da0073e9SAndroid Build Coastguard Worker                     Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output
573*da0073e9SAndroid Build Coastguard Worker                     :math:`oH \times oH \times oH`
574*da0073e9SAndroid Build Coastguard Worker        output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
575*da0073e9SAndroid Build Coastguard Worker                      This has to be a number or tuple in the range (0, 1)
576*da0073e9SAndroid Build Coastguard Worker        return_indices: if ``True``, will return the indices along with the outputs.
577*da0073e9SAndroid Build Coastguard Worker                        Useful to pass to :func:`~torch.nn.functional.max_unpool3d`.
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    Shape:
580*da0073e9SAndroid Build Coastguard Worker        - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
581*da0073e9SAndroid Build Coastguard Worker        - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
582*da0073e9SAndroid Build Coastguard Worker          :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
583*da0073e9SAndroid Build Coastguard Worker          :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    Examples::
586*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(20, 16, 50, 32, 16)
587*da0073e9SAndroid Build Coastguard Worker        >>> # pool of cubic window of size=3, and target output size 13x12x11
588*da0073e9SAndroid Build Coastguard Worker        >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11))
589*da0073e9SAndroid Build Coastguard Worker        >>> # pool of cubic window and target output size being half of input size
590*da0073e9SAndroid Build Coastguard Worker        >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5))
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    .. _Fractional MaxPooling:
593*da0073e9SAndroid Build Coastguard Worker        http://arxiv.org/abs/1412.6071
594*da0073e9SAndroid Build Coastguard Worker    """
595*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, _random_samples):
596*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
597*da0073e9SAndroid Build Coastguard Worker            fractional_max_pool3d_with_indices,
598*da0073e9SAndroid Build Coastguard Worker            (input, _random_samples),
599*da0073e9SAndroid Build Coastguard Worker            input,
600*da0073e9SAndroid Build Coastguard Worker            kernel_size,
601*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
602*da0073e9SAndroid Build Coastguard Worker            output_ratio=output_ratio,
603*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
604*da0073e9SAndroid Build Coastguard Worker            _random_samples=_random_samples,
605*da0073e9SAndroid Build Coastguard Worker        )
606*da0073e9SAndroid Build Coastguard Worker    if output_size is None and output_ratio is None:
607*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
608*da0073e9SAndroid Build Coastguard Worker            "fractional_max_pool3d requires specifying either an output_size or an output_ratio"
609*da0073e9SAndroid Build Coastguard Worker        )
610*da0073e9SAndroid Build Coastguard Worker    if output_size is None:
611*da0073e9SAndroid Build Coastguard Worker        assert output_ratio is not None
612*da0073e9SAndroid Build Coastguard Worker        _output_ratio = _triple(output_ratio)
613*da0073e9SAndroid Build Coastguard Worker        output_size = [
614*da0073e9SAndroid Build Coastguard Worker            int(input.size(-3) * _output_ratio[0]),
615*da0073e9SAndroid Build Coastguard Worker            int(input.size(-2) * _output_ratio[1]),
616*da0073e9SAndroid Build Coastguard Worker            int(input.size(-1) * _output_ratio[2]),
617*da0073e9SAndroid Build Coastguard Worker        ]
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker    if _random_samples is None:
620*da0073e9SAndroid Build Coastguard Worker        n_batch = 1 if input.dim() == 4 else input.size(0)
621*da0073e9SAndroid Build Coastguard Worker        _random_samples = torch.rand(
622*da0073e9SAndroid Build Coastguard Worker            n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device
623*da0073e9SAndroid Build Coastguard Worker        )
624*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.fractional_max_pool3d(
625*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, output_size, _random_samples
626*da0073e9SAndroid Build Coastguard Worker    )
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Workerdef _fractional_max_pool3d(
630*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
631*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
632*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList3[int]] = None,
633*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[BroadcastingList3[float]] = None,
634*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
635*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = None,
636*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
637*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, _random_samples):
638*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
639*da0073e9SAndroid Build Coastguard Worker            fractional_max_pool3d,
640*da0073e9SAndroid Build Coastguard Worker            (input, _random_samples),
641*da0073e9SAndroid Build Coastguard Worker            input,
642*da0073e9SAndroid Build Coastguard Worker            kernel_size,
643*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
644*da0073e9SAndroid Build Coastguard Worker            output_ratio=output_ratio,
645*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
646*da0073e9SAndroid Build Coastguard Worker            _random_samples=_random_samples,
647*da0073e9SAndroid Build Coastguard Worker        )
648*da0073e9SAndroid Build Coastguard Worker    return fractional_max_pool3d_with_indices(
649*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, output_size, output_ratio, return_indices, _random_samples
650*da0073e9SAndroid Build Coastguard Worker    )[0]
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Workerfractional_max_pool3d = boolean_dispatch(
654*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
655*da0073e9SAndroid Build Coastguard Worker    arg_index=4,
656*da0073e9SAndroid Build Coastguard Worker    default=False,
657*da0073e9SAndroid Build Coastguard Worker    if_true=fractional_max_pool3d_with_indices,
658*da0073e9SAndroid Build Coastguard Worker    if_false=_fractional_max_pool3d,
659*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
660*da0073e9SAndroid Build Coastguard Worker    func_name="fractional_max_pool3d",
661*da0073e9SAndroid Build Coastguard Worker)
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Workerdef max_pool1d_with_indices(
665*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
666*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList1[int],
667*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList1[int]] = None,
668*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList1[int] = 0,
669*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList1[int] = 1,
670*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
671*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
672*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
673*da0073e9SAndroid Build Coastguard Worker    r"""
674*da0073e9SAndroid Build Coastguard Worker    max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    Applies a 1D max pooling over an input signal composed of several input
677*da0073e9SAndroid Build Coastguard Worker    planes.
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    .. note::
680*da0073e9SAndroid Build Coastguard Worker        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
681*da0073e9SAndroid Build Coastguard Worker        what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release.
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxPool1d` for details.
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker    Args:
686*da0073e9SAndroid Build Coastguard Worker        input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional.
687*da0073e9SAndroid Build Coastguard Worker        kernel_size: the size of the window. Can be a single number or a
688*da0073e9SAndroid Build Coastguard Worker            tuple `(kW,)`
689*da0073e9SAndroid Build Coastguard Worker        stride: the stride of the window. Can be a single number or a tuple
690*da0073e9SAndroid Build Coastguard Worker            `(sW,)`. Default: :attr:`kernel_size`
691*da0073e9SAndroid Build Coastguard Worker        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
692*da0073e9SAndroid Build Coastguard Worker        dilation: The stride between elements within a sliding window, must be > 0.
693*da0073e9SAndroid Build Coastguard Worker        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
694*da0073e9SAndroid Build Coastguard Worker                   ensures that every element in the input tensor is covered by a sliding window.
695*da0073e9SAndroid Build Coastguard Worker        return_indices: If ``True``, will return the argmax along with the max values.
696*da0073e9SAndroid Build Coastguard Worker                        Useful for :class:`torch.nn.functional.max_unpool1d` later
697*da0073e9SAndroid Build Coastguard Worker    """
698*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
699*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
700*da0073e9SAndroid Build Coastguard Worker            max_pool1d_with_indices,
701*da0073e9SAndroid Build Coastguard Worker            (input,),
702*da0073e9SAndroid Build Coastguard Worker            input,
703*da0073e9SAndroid Build Coastguard Worker            kernel_size,
704*da0073e9SAndroid Build Coastguard Worker            stride=stride,
705*da0073e9SAndroid Build Coastguard Worker            padding=padding,
706*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
707*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
708*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
709*da0073e9SAndroid Build Coastguard Worker        )
710*da0073e9SAndroid Build Coastguard Worker    if stride is None:
711*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
712*da0073e9SAndroid Build Coastguard Worker    return torch.max_pool1d_with_indices(
713*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, stride, padding, dilation, ceil_mode
714*da0073e9SAndroid Build Coastguard Worker    )
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Workerdef _max_pool1d(
718*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
719*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList1[int],
720*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList1[int]] = None,
721*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList1[int] = 0,
722*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList1[int] = 1,
723*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
724*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
725*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
726*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
727*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
728*da0073e9SAndroid Build Coastguard Worker            max_pool1d,
729*da0073e9SAndroid Build Coastguard Worker            (input,),
730*da0073e9SAndroid Build Coastguard Worker            input,
731*da0073e9SAndroid Build Coastguard Worker            kernel_size,
732*da0073e9SAndroid Build Coastguard Worker            stride=stride,
733*da0073e9SAndroid Build Coastguard Worker            padding=padding,
734*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
735*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
736*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
737*da0073e9SAndroid Build Coastguard Worker        )
738*da0073e9SAndroid Build Coastguard Worker    if stride is None:
739*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
740*da0073e9SAndroid Build Coastguard Worker    return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Workermax_pool1d = boolean_dispatch(
744*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
745*da0073e9SAndroid Build Coastguard Worker    arg_index=6,
746*da0073e9SAndroid Build Coastguard Worker    default=False,
747*da0073e9SAndroid Build Coastguard Worker    if_true=max_pool1d_with_indices,
748*da0073e9SAndroid Build Coastguard Worker    if_false=_max_pool1d,
749*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
750*da0073e9SAndroid Build Coastguard Worker    func_name="max_pool1d",
751*da0073e9SAndroid Build Coastguard Worker)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Workerdef max_pool2d_with_indices(
755*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
756*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
757*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList2[int]] = None,
758*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList2[int] = 0,
759*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList2[int] = 1,
760*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
761*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
762*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
763*da0073e9SAndroid Build Coastguard Worker    r"""
764*da0073e9SAndroid Build Coastguard Worker    max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    Applies a 2D max pooling over an input signal composed of several input
767*da0073e9SAndroid Build Coastguard Worker    planes.
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker    .. note::
770*da0073e9SAndroid Build Coastguard Worker        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
771*da0073e9SAndroid Build Coastguard Worker        what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release.
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxPool2d` for details.
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker    Args:
776*da0073e9SAndroid Build Coastguard Worker        input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional.
777*da0073e9SAndroid Build Coastguard Worker        kernel_size: size of the pooling region. Can be a single number or a
778*da0073e9SAndroid Build Coastguard Worker            tuple `(kH, kW)`
779*da0073e9SAndroid Build Coastguard Worker        stride: stride of the pooling operation. Can be a single number or a
780*da0073e9SAndroid Build Coastguard Worker            tuple `(sH, sW)`. Default: :attr:`kernel_size`
781*da0073e9SAndroid Build Coastguard Worker        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
782*da0073e9SAndroid Build Coastguard Worker        dilation: The stride between elements within a sliding window, must be > 0.
783*da0073e9SAndroid Build Coastguard Worker        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
784*da0073e9SAndroid Build Coastguard Worker                   ensures that every element in the input tensor is covered by a sliding window.
785*da0073e9SAndroid Build Coastguard Worker        return_indices: If ``True``, will return the argmax along with the max values.
786*da0073e9SAndroid Build Coastguard Worker                        Useful for :class:`torch.nn.functional.max_unpool2d` later
787*da0073e9SAndroid Build Coastguard Worker    """
788*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
789*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
790*da0073e9SAndroid Build Coastguard Worker            max_pool2d_with_indices,
791*da0073e9SAndroid Build Coastguard Worker            (input,),
792*da0073e9SAndroid Build Coastguard Worker            input,
793*da0073e9SAndroid Build Coastguard Worker            kernel_size,
794*da0073e9SAndroid Build Coastguard Worker            stride=stride,
795*da0073e9SAndroid Build Coastguard Worker            padding=padding,
796*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
797*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
798*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
799*da0073e9SAndroid Build Coastguard Worker        )
800*da0073e9SAndroid Build Coastguard Worker    if stride is None:
801*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
802*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.max_pool2d_with_indices(
803*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, stride, padding, dilation, ceil_mode
804*da0073e9SAndroid Build Coastguard Worker    )
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Workerdef _max_pool2d(
808*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
809*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
810*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList2[int]] = None,
811*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList2[int] = 0,
812*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList2[int] = 1,
813*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
814*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
815*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
816*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
817*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
818*da0073e9SAndroid Build Coastguard Worker            max_pool2d,
819*da0073e9SAndroid Build Coastguard Worker            (input,),
820*da0073e9SAndroid Build Coastguard Worker            input,
821*da0073e9SAndroid Build Coastguard Worker            kernel_size,
822*da0073e9SAndroid Build Coastguard Worker            stride=stride,
823*da0073e9SAndroid Build Coastguard Worker            padding=padding,
824*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
825*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
826*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
827*da0073e9SAndroid Build Coastguard Worker        )
828*da0073e9SAndroid Build Coastguard Worker    if stride is None:
829*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
830*da0073e9SAndroid Build Coastguard Worker    return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Workermax_pool2d = boolean_dispatch(
834*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
835*da0073e9SAndroid Build Coastguard Worker    arg_index=6,
836*da0073e9SAndroid Build Coastguard Worker    default=False,
837*da0073e9SAndroid Build Coastguard Worker    if_true=max_pool2d_with_indices,
838*da0073e9SAndroid Build Coastguard Worker    if_false=_max_pool2d,
839*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
840*da0073e9SAndroid Build Coastguard Worker    func_name="max_pool2d",
841*da0073e9SAndroid Build Coastguard Worker)
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Workerdef max_pool3d_with_indices(
845*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
846*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
847*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList3[int]] = None,
848*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList3[int] = 0,
849*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList3[int] = 1,
850*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
851*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
852*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
853*da0073e9SAndroid Build Coastguard Worker    r"""
854*da0073e9SAndroid Build Coastguard Worker    max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    Applies a 3D max pooling over an input signal composed of several input
857*da0073e9SAndroid Build Coastguard Worker    planes.
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    .. note::
860*da0073e9SAndroid Build Coastguard Worker        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
861*da0073e9SAndroid Build Coastguard Worker        what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release.
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxPool3d` for details.
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker    Args:
866*da0073e9SAndroid Build Coastguard Worker        input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional.
867*da0073e9SAndroid Build Coastguard Worker        kernel_size: size of the pooling region. Can be a single number or a
868*da0073e9SAndroid Build Coastguard Worker                     tuple `(kT, kH, kW)`
869*da0073e9SAndroid Build Coastguard Worker        stride: stride of the pooling operation. Can be a single number or a
870*da0073e9SAndroid Build Coastguard Worker                tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
871*da0073e9SAndroid Build Coastguard Worker        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
872*da0073e9SAndroid Build Coastguard Worker        dilation: The stride between elements within a sliding window, must be > 0.
873*da0073e9SAndroid Build Coastguard Worker        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
874*da0073e9SAndroid Build Coastguard Worker                   ensures that every element in the input tensor is covered by a sliding window.
875*da0073e9SAndroid Build Coastguard Worker        return_indices: If ``True``, will return the argmax along with the max values.
876*da0073e9SAndroid Build Coastguard Worker                        Useful for :class:`torch.nn.functional.max_unpool3d` later
877*da0073e9SAndroid Build Coastguard Worker    """
878*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
879*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
880*da0073e9SAndroid Build Coastguard Worker            max_pool3d_with_indices,
881*da0073e9SAndroid Build Coastguard Worker            (input,),
882*da0073e9SAndroid Build Coastguard Worker            input,
883*da0073e9SAndroid Build Coastguard Worker            kernel_size,
884*da0073e9SAndroid Build Coastguard Worker            stride=stride,
885*da0073e9SAndroid Build Coastguard Worker            padding=padding,
886*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
887*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
888*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
889*da0073e9SAndroid Build Coastguard Worker        )
890*da0073e9SAndroid Build Coastguard Worker    if stride is None:
891*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
892*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.max_pool3d_with_indices(
893*da0073e9SAndroid Build Coastguard Worker        input, kernel_size, stride, padding, dilation, ceil_mode
894*da0073e9SAndroid Build Coastguard Worker    )
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Workerdef _max_pool3d(
898*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
899*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
900*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList3[int]] = None,
901*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList3[int] = 0,
902*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList3[int] = 1,
903*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
904*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
905*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
906*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
907*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
908*da0073e9SAndroid Build Coastguard Worker            max_pool3d,
909*da0073e9SAndroid Build Coastguard Worker            (input,),
910*da0073e9SAndroid Build Coastguard Worker            input,
911*da0073e9SAndroid Build Coastguard Worker            kernel_size,
912*da0073e9SAndroid Build Coastguard Worker            stride=stride,
913*da0073e9SAndroid Build Coastguard Worker            padding=padding,
914*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
915*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
916*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
917*da0073e9SAndroid Build Coastguard Worker        )
918*da0073e9SAndroid Build Coastguard Worker    if stride is None:
919*da0073e9SAndroid Build Coastguard Worker        stride = torch.jit.annotate(List[int], [])
920*da0073e9SAndroid Build Coastguard Worker    return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Workermax_pool3d = boolean_dispatch(
924*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
925*da0073e9SAndroid Build Coastguard Worker    arg_index=6,
926*da0073e9SAndroid Build Coastguard Worker    default=False,
927*da0073e9SAndroid Build Coastguard Worker    if_true=max_pool3d_with_indices,
928*da0073e9SAndroid Build Coastguard Worker    if_false=_max_pool3d,
929*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
930*da0073e9SAndroid Build Coastguard Worker    func_name="max_pool3d",
931*da0073e9SAndroid Build Coastguard Worker)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Workerdef _unpool_output_size(
935*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
936*da0073e9SAndroid Build Coastguard Worker    kernel_size: List[int],
937*da0073e9SAndroid Build Coastguard Worker    stride: List[int],
938*da0073e9SAndroid Build Coastguard Worker    padding: List[int],
939*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[List[int]],
940*da0073e9SAndroid Build Coastguard Worker) -> List[int]:
941*da0073e9SAndroid Build Coastguard Worker    input_size = input.size()
942*da0073e9SAndroid Build Coastguard Worker    default_size = torch.jit.annotate(List[int], [])
943*da0073e9SAndroid Build Coastguard Worker    for d in range(len(kernel_size)):
944*da0073e9SAndroid Build Coastguard Worker        default_size.append(
945*da0073e9SAndroid Build Coastguard Worker            (input_size[-len(kernel_size) + d] - 1) * stride[d]
946*da0073e9SAndroid Build Coastguard Worker            + kernel_size[d]
947*da0073e9SAndroid Build Coastguard Worker            - 2 * padding[d]
948*da0073e9SAndroid Build Coastguard Worker        )
949*da0073e9SAndroid Build Coastguard Worker    if output_size is None:
950*da0073e9SAndroid Build Coastguard Worker        ret = default_size
951*da0073e9SAndroid Build Coastguard Worker    else:
952*da0073e9SAndroid Build Coastguard Worker        if len(output_size) == len(kernel_size) + 2:
953*da0073e9SAndroid Build Coastguard Worker            output_size = output_size[2:]
954*da0073e9SAndroid Build Coastguard Worker        if len(output_size) != len(kernel_size):
955*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
956*da0073e9SAndroid Build Coastguard Worker                "output_size should be a sequence containing "
957*da0073e9SAndroid Build Coastguard Worker                f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'"
958*da0073e9SAndroid Build Coastguard Worker            )
959*da0073e9SAndroid Build Coastguard Worker        for d in range(len(kernel_size)):
960*da0073e9SAndroid Build Coastguard Worker            min_size = default_size[d] - stride[d]
961*da0073e9SAndroid Build Coastguard Worker            max_size = default_size[d] + stride[d]
962*da0073e9SAndroid Build Coastguard Worker            if not (min_size < output_size[d] < max_size):
963*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
964*da0073e9SAndroid Build Coastguard Worker                    f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})'
965*da0073e9SAndroid Build Coastguard Worker                )
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        ret = output_size
968*da0073e9SAndroid Build Coastguard Worker    return ret
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Workerdef max_unpool1d(
972*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
973*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
974*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList1[int],
975*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList1[int]] = None,
976*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList1[int] = 0,
977*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList1[int]] = None,
978*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
979*da0073e9SAndroid Build Coastguard Worker    r"""Compute a partial inverse of :class:`MaxPool1d`.
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxUnpool1d` for details.
982*da0073e9SAndroid Build Coastguard Worker    """
983*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
984*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
985*da0073e9SAndroid Build Coastguard Worker            max_unpool1d,
986*da0073e9SAndroid Build Coastguard Worker            (input,),
987*da0073e9SAndroid Build Coastguard Worker            input,
988*da0073e9SAndroid Build Coastguard Worker            indices,
989*da0073e9SAndroid Build Coastguard Worker            kernel_size,
990*da0073e9SAndroid Build Coastguard Worker            stride=stride,
991*da0073e9SAndroid Build Coastguard Worker            padding=padding,
992*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
993*da0073e9SAndroid Build Coastguard Worker        )
994*da0073e9SAndroid Build Coastguard Worker    kernel_size = _single(kernel_size)
995*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
996*da0073e9SAndroid Build Coastguard Worker        _stride = _single(stride)
997*da0073e9SAndroid Build Coastguard Worker    else:
998*da0073e9SAndroid Build Coastguard Worker        _stride = kernel_size
999*da0073e9SAndroid Build Coastguard Worker    padding = _single(padding)
1000*da0073e9SAndroid Build Coastguard Worker    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1001*da0073e9SAndroid Build Coastguard Worker    if isinstance(output_size, list):
1002*da0073e9SAndroid Build Coastguard Worker        output_size = output_size + [1]
1003*da0073e9SAndroid Build Coastguard Worker    else:
1004*da0073e9SAndroid Build Coastguard Worker        output_size = output_size + (1,)
1005*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.max_unpool2d(
1006*da0073e9SAndroid Build Coastguard Worker        input.unsqueeze(-1), indices.unsqueeze(-1), output_size
1007*da0073e9SAndroid Build Coastguard Worker    ).squeeze(-1)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Workerdef max_unpool2d(
1011*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1012*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
1013*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
1014*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList2[int]] = None,
1015*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList2[int] = 0,
1016*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList2[int]] = None,
1017*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1018*da0073e9SAndroid Build Coastguard Worker    r"""Compute a partial inverse of :class:`MaxPool2d`.
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxUnpool2d` for details.
1021*da0073e9SAndroid Build Coastguard Worker    """
1022*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1023*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1024*da0073e9SAndroid Build Coastguard Worker            max_unpool2d,
1025*da0073e9SAndroid Build Coastguard Worker            (input,),
1026*da0073e9SAndroid Build Coastguard Worker            input,
1027*da0073e9SAndroid Build Coastguard Worker            indices,
1028*da0073e9SAndroid Build Coastguard Worker            kernel_size,
1029*da0073e9SAndroid Build Coastguard Worker            stride=stride,
1030*da0073e9SAndroid Build Coastguard Worker            padding=padding,
1031*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
1032*da0073e9SAndroid Build Coastguard Worker        )
1033*da0073e9SAndroid Build Coastguard Worker    kernel_size = _pair(kernel_size)
1034*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
1035*da0073e9SAndroid Build Coastguard Worker        _stride = _pair(stride)
1036*da0073e9SAndroid Build Coastguard Worker    else:
1037*da0073e9SAndroid Build Coastguard Worker        _stride = kernel_size
1038*da0073e9SAndroid Build Coastguard Worker    padding = _pair(padding)
1039*da0073e9SAndroid Build Coastguard Worker    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1040*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.max_unpool2d(input, indices, output_size)
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Workerdef max_unpool3d(
1044*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1045*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
1046*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
1047*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList3[int]] = None,
1048*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList3[int] = 0,
1049*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[BroadcastingList3[int]] = None,
1050*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1051*da0073e9SAndroid Build Coastguard Worker    r"""Compute a partial inverse of :class:`MaxPool3d`.
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MaxUnpool3d` for details.
1054*da0073e9SAndroid Build Coastguard Worker    """
1055*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1056*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1057*da0073e9SAndroid Build Coastguard Worker            max_unpool3d,
1058*da0073e9SAndroid Build Coastguard Worker            (input,),
1059*da0073e9SAndroid Build Coastguard Worker            input,
1060*da0073e9SAndroid Build Coastguard Worker            indices,
1061*da0073e9SAndroid Build Coastguard Worker            kernel_size,
1062*da0073e9SAndroid Build Coastguard Worker            stride=stride,
1063*da0073e9SAndroid Build Coastguard Worker            padding=padding,
1064*da0073e9SAndroid Build Coastguard Worker            output_size=output_size,
1065*da0073e9SAndroid Build Coastguard Worker        )
1066*da0073e9SAndroid Build Coastguard Worker    kernel_size = _triple(kernel_size)
1067*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
1068*da0073e9SAndroid Build Coastguard Worker        _stride = _triple(stride)
1069*da0073e9SAndroid Build Coastguard Worker    else:
1070*da0073e9SAndroid Build Coastguard Worker        _stride = kernel_size
1071*da0073e9SAndroid Build Coastguard Worker    padding = _triple(padding)
1072*da0073e9SAndroid Build Coastguard Worker    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1073*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Workerdef lp_pool3d(
1077*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1078*da0073e9SAndroid Build Coastguard Worker    norm_type: Union[int, float],
1079*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList3[int],
1080*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList3[int]] = None,
1081*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
1082*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1083*da0073e9SAndroid Build Coastguard Worker    r"""
1084*da0073e9SAndroid Build Coastguard Worker    Apply a 3D power-average pooling over an input signal composed of several input planes.
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker    If the sum of all inputs to the power of `p` is
1087*da0073e9SAndroid Build Coastguard Worker    zero, the gradient is set to zero as well.
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LPPool3d` for details.
1090*da0073e9SAndroid Build Coastguard Worker    """
1091*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1092*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1093*da0073e9SAndroid Build Coastguard Worker            lp_pool3d,
1094*da0073e9SAndroid Build Coastguard Worker            (input,),
1095*da0073e9SAndroid Build Coastguard Worker            input,
1096*da0073e9SAndroid Build Coastguard Worker            norm_type,
1097*da0073e9SAndroid Build Coastguard Worker            kernel_size,
1098*da0073e9SAndroid Build Coastguard Worker            stride=stride,
1099*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
1100*da0073e9SAndroid Build Coastguard Worker        )
1101*da0073e9SAndroid Build Coastguard Worker    kd, kw, kh = _triple(kernel_size)
1102*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
1103*da0073e9SAndroid Build Coastguard Worker        out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1104*da0073e9SAndroid Build Coastguard Worker    else:
1105*da0073e9SAndroid Build Coastguard Worker        out = avg_pool3d(
1106*da0073e9SAndroid Build Coastguard Worker            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1107*da0073e9SAndroid Build Coastguard Worker        )
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker    return (
1110*da0073e9SAndroid Build Coastguard Worker        (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type)
1111*da0073e9SAndroid Build Coastguard Worker    )
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Workerdef lp_pool2d(
1115*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1116*da0073e9SAndroid Build Coastguard Worker    norm_type: Union[int, float],
1117*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
1118*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList2[int]] = None,
1119*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
1120*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1121*da0073e9SAndroid Build Coastguard Worker    r"""
1122*da0073e9SAndroid Build Coastguard Worker    Apply a 2D power-average pooling over an input signal composed of several input planes.
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    If the sum of all inputs to the power of `p` is
1125*da0073e9SAndroid Build Coastguard Worker    zero, the gradient is set to zero as well.
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LPPool2d` for details.
1128*da0073e9SAndroid Build Coastguard Worker    """
1129*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1130*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1131*da0073e9SAndroid Build Coastguard Worker            lp_pool2d,
1132*da0073e9SAndroid Build Coastguard Worker            (input,),
1133*da0073e9SAndroid Build Coastguard Worker            input,
1134*da0073e9SAndroid Build Coastguard Worker            norm_type,
1135*da0073e9SAndroid Build Coastguard Worker            kernel_size,
1136*da0073e9SAndroid Build Coastguard Worker            stride=stride,
1137*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
1138*da0073e9SAndroid Build Coastguard Worker        )
1139*da0073e9SAndroid Build Coastguard Worker    kw, kh = _pair(kernel_size)
1140*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
1141*da0073e9SAndroid Build Coastguard Worker        out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1142*da0073e9SAndroid Build Coastguard Worker    else:
1143*da0073e9SAndroid Build Coastguard Worker        out = avg_pool2d(
1144*da0073e9SAndroid Build Coastguard Worker            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1145*da0073e9SAndroid Build Coastguard Worker        )
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type)
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Workerdef lp_pool1d(
1151*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1152*da0073e9SAndroid Build Coastguard Worker    norm_type: Union[int, float],
1153*da0073e9SAndroid Build Coastguard Worker    kernel_size: int,
1154*da0073e9SAndroid Build Coastguard Worker    stride: Optional[BroadcastingList1[int]] = None,
1155*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = False,
1156*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1157*da0073e9SAndroid Build Coastguard Worker    r"""Apply a 1D power-average pooling over an input signal composed of several input planes.
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker    If the sum of all inputs to the power of `p` is
1160*da0073e9SAndroid Build Coastguard Worker    zero, the gradient is set to zero as well.
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LPPool1d` for details.
1163*da0073e9SAndroid Build Coastguard Worker    """
1164*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1165*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1166*da0073e9SAndroid Build Coastguard Worker            lp_pool1d,
1167*da0073e9SAndroid Build Coastguard Worker            (input,),
1168*da0073e9SAndroid Build Coastguard Worker            input,
1169*da0073e9SAndroid Build Coastguard Worker            norm_type,
1170*da0073e9SAndroid Build Coastguard Worker            kernel_size,
1171*da0073e9SAndroid Build Coastguard Worker            stride=stride,
1172*da0073e9SAndroid Build Coastguard Worker            ceil_mode=ceil_mode,
1173*da0073e9SAndroid Build Coastguard Worker        )
1174*da0073e9SAndroid Build Coastguard Worker    if stride is not None:
1175*da0073e9SAndroid Build Coastguard Worker        out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1176*da0073e9SAndroid Build Coastguard Worker    else:
1177*da0073e9SAndroid Build Coastguard Worker        out = avg_pool1d(
1178*da0073e9SAndroid Build Coastguard Worker            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1179*da0073e9SAndroid Build Coastguard Worker        )
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker    return (
1182*da0073e9SAndroid Build Coastguard Worker        (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type)
1183*da0073e9SAndroid Build Coastguard Worker    )
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool1d_with_indices(
1187*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1188*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList1[int],
1189*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1190*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
1191*da0073e9SAndroid Build Coastguard Worker    r"""
1192*da0073e9SAndroid Build Coastguard Worker    adaptive_max_pool1d(input, output_size, return_indices=False)
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker    Applies a 1D adaptive max pooling over an input signal composed of
1195*da0073e9SAndroid Build Coastguard Worker    several input planes.
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape.
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    Args:
1200*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size (single integer)
1201*da0073e9SAndroid Build Coastguard Worker        return_indices: whether to return pooling indices. Default: ``False``
1202*da0073e9SAndroid Build Coastguard Worker    """
1203*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1204*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1205*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool1d_with_indices,
1206*da0073e9SAndroid Build Coastguard Worker            (input,),
1207*da0073e9SAndroid Build Coastguard Worker            input,
1208*da0073e9SAndroid Build Coastguard Worker            output_size,
1209*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1210*da0073e9SAndroid Build Coastguard Worker        )
1211*da0073e9SAndroid Build Coastguard Worker    return torch.adaptive_max_pool1d(input, output_size)
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool1d(
1215*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1216*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList1[int],
1217*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1218*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1219*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1220*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1221*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool1d,
1222*da0073e9SAndroid Build Coastguard Worker            (input,),
1223*da0073e9SAndroid Build Coastguard Worker            input,
1224*da0073e9SAndroid Build Coastguard Worker            output_size,
1225*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1226*da0073e9SAndroid Build Coastguard Worker        )
1227*da0073e9SAndroid Build Coastguard Worker    return adaptive_max_pool1d_with_indices(input, output_size)[0]
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool1d = boolean_dispatch(
1231*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
1232*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
1233*da0073e9SAndroid Build Coastguard Worker    default=False,
1234*da0073e9SAndroid Build Coastguard Worker    if_true=adaptive_max_pool1d_with_indices,
1235*da0073e9SAndroid Build Coastguard Worker    if_false=_adaptive_max_pool1d,
1236*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1237*da0073e9SAndroid Build Coastguard Worker    func_name="adaptive_max_pool1d",
1238*da0073e9SAndroid Build Coastguard Worker)
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool2d_with_indices(
1242*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1243*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList2[int],
1244*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1245*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
1246*da0073e9SAndroid Build Coastguard Worker    r"""adaptive_max_pool2d(input, output_size, return_indices=False)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker    Applies a 2D adaptive max pooling over an input signal composed of
1249*da0073e9SAndroid Build Coastguard Worker    several input planes.
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape.
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker    Args:
1254*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size (single integer or
1255*da0073e9SAndroid Build Coastguard Worker            double-integer tuple)
1256*da0073e9SAndroid Build Coastguard Worker        return_indices: whether to return pooling indices. Default: ``False``
1257*da0073e9SAndroid Build Coastguard Worker    """
1258*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1259*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1260*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool2d_with_indices,
1261*da0073e9SAndroid Build Coastguard Worker            (input,),
1262*da0073e9SAndroid Build Coastguard Worker            input,
1263*da0073e9SAndroid Build Coastguard Worker            output_size,
1264*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1265*da0073e9SAndroid Build Coastguard Worker        )
1266*da0073e9SAndroid Build Coastguard Worker    output_size = _list_with_default(output_size, input.size())
1267*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.adaptive_max_pool2d(input, output_size)
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool2d(
1271*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1272*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList2[int],
1273*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1274*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1275*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1276*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1277*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool2d,
1278*da0073e9SAndroid Build Coastguard Worker            (input,),
1279*da0073e9SAndroid Build Coastguard Worker            input,
1280*da0073e9SAndroid Build Coastguard Worker            output_size,
1281*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1282*da0073e9SAndroid Build Coastguard Worker        )
1283*da0073e9SAndroid Build Coastguard Worker    return adaptive_max_pool2d_with_indices(input, output_size)[0]
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool2d = boolean_dispatch(
1287*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
1288*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
1289*da0073e9SAndroid Build Coastguard Worker    default=False,
1290*da0073e9SAndroid Build Coastguard Worker    if_true=adaptive_max_pool2d_with_indices,
1291*da0073e9SAndroid Build Coastguard Worker    if_false=_adaptive_max_pool2d,
1292*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1293*da0073e9SAndroid Build Coastguard Worker    func_name="adaptive_max_pool2d",
1294*da0073e9SAndroid Build Coastguard Worker)
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool3d_with_indices(
1298*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1299*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList3[int],
1300*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1301*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:  # noqa: D400
1302*da0073e9SAndroid Build Coastguard Worker    r"""
1303*da0073e9SAndroid Build Coastguard Worker    adaptive_max_pool3d(input, output_size, return_indices=False)
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker    Applies a 3D adaptive max pooling over an input signal composed of
1306*da0073e9SAndroid Build Coastguard Worker    several input planes.
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape.
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker    Args:
1311*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size (single integer or
1312*da0073e9SAndroid Build Coastguard Worker            triple-integer tuple)
1313*da0073e9SAndroid Build Coastguard Worker        return_indices: whether to return pooling indices. Default: ``False``
1314*da0073e9SAndroid Build Coastguard Worker    """
1315*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1316*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1317*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool3d_with_indices,
1318*da0073e9SAndroid Build Coastguard Worker            (input,),
1319*da0073e9SAndroid Build Coastguard Worker            input,
1320*da0073e9SAndroid Build Coastguard Worker            output_size,
1321*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1322*da0073e9SAndroid Build Coastguard Worker        )
1323*da0073e9SAndroid Build Coastguard Worker    output_size = _list_with_default(output_size, input.size())
1324*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.adaptive_max_pool3d(input, output_size)
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool3d(
1328*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1329*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList3[int],
1330*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = False,
1331*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1332*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1333*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1334*da0073e9SAndroid Build Coastguard Worker            adaptive_max_pool3d,
1335*da0073e9SAndroid Build Coastguard Worker            (input,),
1336*da0073e9SAndroid Build Coastguard Worker            input,
1337*da0073e9SAndroid Build Coastguard Worker            output_size,
1338*da0073e9SAndroid Build Coastguard Worker            return_indices=return_indices,
1339*da0073e9SAndroid Build Coastguard Worker        )
1340*da0073e9SAndroid Build Coastguard Worker    return adaptive_max_pool3d_with_indices(input, output_size)[0]
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool3d = boolean_dispatch(
1344*da0073e9SAndroid Build Coastguard Worker    arg_name="return_indices",
1345*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
1346*da0073e9SAndroid Build Coastguard Worker    default=False,
1347*da0073e9SAndroid Build Coastguard Worker    if_true=adaptive_max_pool3d_with_indices,
1348*da0073e9SAndroid Build Coastguard Worker    if_false=_adaptive_max_pool3d,
1349*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1350*da0073e9SAndroid Build Coastguard Worker    func_name="adaptive_max_pool3d",
1351*da0073e9SAndroid Build Coastguard Worker)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Workeradaptive_avg_pool1d = _add_docstr(
1355*da0073e9SAndroid Build Coastguard Worker    torch.adaptive_avg_pool1d,
1356*da0073e9SAndroid Build Coastguard Worker    r"""
1357*da0073e9SAndroid Build Coastguard Workeradaptive_avg_pool1d(input, output_size) -> Tensor
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard WorkerApplies a 1D adaptive average pooling over an input signal composed of
1360*da0073e9SAndroid Build Coastguard Workerseveral input planes.
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape.
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard WorkerArgs:
1365*da0073e9SAndroid Build Coastguard Worker    output_size: the target output size (single integer)
1366*da0073e9SAndroid Build Coastguard Worker""",
1367*da0073e9SAndroid Build Coastguard Worker)
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
1371*da0073e9SAndroid Build Coastguard Worker    r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes.
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    Args:
1376*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size (single integer or
1377*da0073e9SAndroid Build Coastguard Worker            double-integer tuple)
1378*da0073e9SAndroid Build Coastguard Worker    """
1379*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1380*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
1381*da0073e9SAndroid Build Coastguard Worker    _output_size = _list_with_default(output_size, input.size())
1382*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor:
1386*da0073e9SAndroid Build Coastguard Worker    r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes.
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape.
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker    Args:
1391*da0073e9SAndroid Build Coastguard Worker        output_size: the target output size (single integer or
1392*da0073e9SAndroid Build Coastguard Worker            triple-integer tuple)
1393*da0073e9SAndroid Build Coastguard Worker    """
1394*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1395*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
1396*da0073e9SAndroid Build Coastguard Worker    _output_size = _list_with_default(output_size, input.size())
1397*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker# Activation functions
1401*da0073e9SAndroid Build Coastguard Workerdef dropout(
1402*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1403*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1404*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
1405*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1406*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1407*da0073e9SAndroid Build Coastguard Worker    r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`.
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker    Uses samples from a Bernoulli distribution.
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Dropout` for details.
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    Args:
1414*da0073e9SAndroid Build Coastguard Worker        p: probability of an element to be zeroed. Default: 0.5
1415*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``. Default: ``True``
1416*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1417*da0073e9SAndroid Build Coastguard Worker    """
1418*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1419*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1420*da0073e9SAndroid Build Coastguard Worker            dropout, (input,), input, p=p, training=training, inplace=inplace
1421*da0073e9SAndroid Build Coastguard Worker        )
1422*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1423*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1424*da0073e9SAndroid Build Coastguard Worker    return (
1425*da0073e9SAndroid Build Coastguard Worker        _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
1426*da0073e9SAndroid Build Coastguard Worker    )
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Workerdef alpha_dropout(
1430*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1431*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1432*da0073e9SAndroid Build Coastguard Worker    training: bool = False,
1433*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1434*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1435*da0073e9SAndroid Build Coastguard Worker    r"""Apply alpha dropout to the input.
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.AlphaDropout` for details.
1438*da0073e9SAndroid Build Coastguard Worker    """
1439*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1440*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1441*da0073e9SAndroid Build Coastguard Worker            alpha_dropout, (input,), input, p=p, training=training, inplace=inplace
1442*da0073e9SAndroid Build Coastguard Worker        )
1443*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1444*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1445*da0073e9SAndroid Build Coastguard Worker    return (
1446*da0073e9SAndroid Build Coastguard Worker        _VF.alpha_dropout_(input, p, training)
1447*da0073e9SAndroid Build Coastguard Worker        if inplace
1448*da0073e9SAndroid Build Coastguard Worker        else _VF.alpha_dropout(input, p, training)
1449*da0073e9SAndroid Build Coastguard Worker    )
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Workerdef dropout1d(
1453*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1454*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1455*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
1456*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1457*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1458*da0073e9SAndroid Build Coastguard Worker    r"""Randomly zero out entire channels (a channel is a 1D feature map).
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1461*da0073e9SAndroid Build Coastguard Worker    batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor.
1462*da0073e9SAndroid Build Coastguard Worker    Each channel will be zeroed out independently on every forward call with
1463*da0073e9SAndroid Build Coastguard Worker    probability :attr:`p` using samples from a Bernoulli distribution.
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Dropout1d` for details.
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker    Args:
1468*da0073e9SAndroid Build Coastguard Worker        p: probability of a channel to be zeroed. Default: 0.5
1469*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``. Default: ``True``
1470*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1471*da0073e9SAndroid Build Coastguard Worker    """
1472*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1473*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1474*da0073e9SAndroid Build Coastguard Worker            dropout1d, (input,), input, p=p, training=training, inplace=inplace
1475*da0073e9SAndroid Build Coastguard Worker        )
1476*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1477*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1478*da0073e9SAndroid Build Coastguard Worker    inp_dim = input.dim()
1479*da0073e9SAndroid Build Coastguard Worker    if inp_dim not in (2, 3):
1480*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1481*da0073e9SAndroid Build Coastguard Worker            f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. "
1482*da0073e9SAndroid Build Coastguard Worker            "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
1483*da0073e9SAndroid Build Coastguard Worker            "spatial dimension, a channel dimension, and an optional batch dimension "
1484*da0073e9SAndroid Build Coastguard Worker            "(i.e. 2D or 3D inputs)."
1485*da0073e9SAndroid Build Coastguard Worker        )
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker    is_batched = inp_dim == 3
1488*da0073e9SAndroid Build Coastguard Worker    if not is_batched:
1489*da0073e9SAndroid Build Coastguard Worker        input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker    result = (
1492*da0073e9SAndroid Build Coastguard Worker        _VF.feature_dropout_(input, p, training)
1493*da0073e9SAndroid Build Coastguard Worker        if inplace
1494*da0073e9SAndroid Build Coastguard Worker        else _VF.feature_dropout(input, p, training)
1495*da0073e9SAndroid Build Coastguard Worker    )
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    if not is_batched:
1498*da0073e9SAndroid Build Coastguard Worker        result = result.squeeze_(0) if inplace else result.squeeze(0)
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    return result
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Workerdef dropout2d(
1504*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1505*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1506*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
1507*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1508*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1509*da0073e9SAndroid Build Coastguard Worker    r"""Randomly zero out entire channels (a channel is a 2D feature map).
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1512*da0073e9SAndroid Build Coastguard Worker    batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor.
1513*da0073e9SAndroid Build Coastguard Worker    Each channel will be zeroed out independently on every forward call with
1514*da0073e9SAndroid Build Coastguard Worker    probability :attr:`p` using samples from a Bernoulli distribution.
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Dropout2d` for details.
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker    Args:
1519*da0073e9SAndroid Build Coastguard Worker        p: probability of a channel to be zeroed. Default: 0.5
1520*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``. Default: ``True``
1521*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1522*da0073e9SAndroid Build Coastguard Worker    """
1523*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1524*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1525*da0073e9SAndroid Build Coastguard Worker            dropout2d, (input,), input, p=p, training=training, inplace=inplace
1526*da0073e9SAndroid Build Coastguard Worker        )
1527*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1528*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1529*da0073e9SAndroid Build Coastguard Worker    inp_dim = input.dim()
1530*da0073e9SAndroid Build Coastguard Worker    if inp_dim not in (3, 4):
1531*da0073e9SAndroid Build Coastguard Worker        warn_msg = (
1532*da0073e9SAndroid Build Coastguard Worker            f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated "
1533*da0073e9SAndroid Build Coastguard Worker            "and will result in an error in a future release. To retain the behavior "
1534*da0073e9SAndroid Build Coastguard Worker            "and silence this warning, please use dropout instead. Note that dropout2d "
1535*da0073e9SAndroid Build Coastguard Worker            "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, "
1536*da0073e9SAndroid Build Coastguard Worker            "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."
1537*da0073e9SAndroid Build Coastguard Worker        )
1538*da0073e9SAndroid Build Coastguard Worker        warnings.warn(warn_msg)
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker    # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing
1541*da0073e9SAndroid Build Coastguard Worker    # a 3D input will perform dropout1d behavior instead. This was done historically and the
1542*da0073e9SAndroid Build Coastguard Worker    # behavior is maintained here for now.
1543*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/77081
1544*da0073e9SAndroid Build Coastguard Worker    if inp_dim == 3:
1545*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
1546*da0073e9SAndroid Build Coastguard Worker            "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise "
1547*da0073e9SAndroid Build Coastguard Worker            "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C "
1548*da0073e9SAndroid Build Coastguard Worker            "is the channel dim. This behavior will change in a future release to interpret the "
1549*da0073e9SAndroid Build Coastguard Worker            "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D "
1550*da0073e9SAndroid Build Coastguard Worker            "channel-wise dropout behavior, please switch to using dropout1d instead."
1551*da0073e9SAndroid Build Coastguard Worker        )
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker    result = (
1554*da0073e9SAndroid Build Coastguard Worker        _VF.feature_dropout_(input, p, training)
1555*da0073e9SAndroid Build Coastguard Worker        if inplace
1556*da0073e9SAndroid Build Coastguard Worker        else _VF.feature_dropout(input, p, training)
1557*da0073e9SAndroid Build Coastguard Worker    )
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker    return result
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Workerdef dropout3d(
1563*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1564*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1565*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
1566*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1567*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1568*da0073e9SAndroid Build Coastguard Worker    r"""Randomly zero out entire channels (a channel is a 3D feature map).
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1571*da0073e9SAndroid Build Coastguard Worker    batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor.
1572*da0073e9SAndroid Build Coastguard Worker    Each channel will be zeroed out independently on every forward call with
1573*da0073e9SAndroid Build Coastguard Worker    probability :attr:`p` using samples from a Bernoulli distribution.
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Dropout3d` for details.
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker    Args:
1578*da0073e9SAndroid Build Coastguard Worker        p: probability of a channel to be zeroed. Default: 0.5
1579*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``. Default: ``True``
1580*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1581*da0073e9SAndroid Build Coastguard Worker    """
1582*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1583*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1584*da0073e9SAndroid Build Coastguard Worker            dropout3d, (input,), input, p=p, training=training, inplace=inplace
1585*da0073e9SAndroid Build Coastguard Worker        )
1586*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1587*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1588*da0073e9SAndroid Build Coastguard Worker    inp_dim = input.dim()
1589*da0073e9SAndroid Build Coastguard Worker    if inp_dim not in (4, 5):
1590*da0073e9SAndroid Build Coastguard Worker        warn_msg = (
1591*da0073e9SAndroid Build Coastguard Worker            f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated "
1592*da0073e9SAndroid Build Coastguard Worker            "and will result in an error in a future release. To retain the behavior "
1593*da0073e9SAndroid Build Coastguard Worker            "and silence this warning, please use dropout instead. Note that dropout3d "
1594*da0073e9SAndroid Build Coastguard Worker            "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, "
1595*da0073e9SAndroid Build Coastguard Worker            "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)."
1596*da0073e9SAndroid Build Coastguard Worker        )
1597*da0073e9SAndroid Build Coastguard Worker        warnings.warn(warn_msg)
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker    is_batched = inp_dim == 5
1600*da0073e9SAndroid Build Coastguard Worker    if not is_batched:
1601*da0073e9SAndroid Build Coastguard Worker        input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
1602*da0073e9SAndroid Build Coastguard Worker
1603*da0073e9SAndroid Build Coastguard Worker    result = (
1604*da0073e9SAndroid Build Coastguard Worker        _VF.feature_dropout_(input, p, training)
1605*da0073e9SAndroid Build Coastguard Worker        if inplace
1606*da0073e9SAndroid Build Coastguard Worker        else _VF.feature_dropout(input, p, training)
1607*da0073e9SAndroid Build Coastguard Worker    )
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker    if not is_batched:
1610*da0073e9SAndroid Build Coastguard Worker        result = result.squeeze_(0) if inplace else result.squeeze(0)
1611*da0073e9SAndroid Build Coastguard Worker    return result
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Workerdef feature_alpha_dropout(
1615*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1616*da0073e9SAndroid Build Coastguard Worker    p: float = 0.5,
1617*da0073e9SAndroid Build Coastguard Worker    training: bool = False,
1618*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1619*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1620*da0073e9SAndroid Build Coastguard Worker    r"""Randomly masks out entire channels (a channel is a feature map).
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker    For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input
1623*da0073e9SAndroid Build Coastguard Worker    is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of
1624*da0073e9SAndroid Build Coastguard Worker    setting activations to zero, as in regular Dropout, the activations are set
1625*da0073e9SAndroid Build Coastguard Worker    to the negative saturation value of the SELU activation function.
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker    Each element will be masked independently on every forward call with
1628*da0073e9SAndroid Build Coastguard Worker    probability :attr:`p` using samples from a Bernoulli distribution.
1629*da0073e9SAndroid Build Coastguard Worker    The elements to be masked are randomized on every forward call, and scaled
1630*da0073e9SAndroid Build Coastguard Worker    and shifted to maintain zero mean and unit variance.
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.FeatureAlphaDropout` for details.
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker    Args:
1635*da0073e9SAndroid Build Coastguard Worker        p: dropout probability of a channel to be zeroed. Default: 0.5
1636*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``. Default: ``True``
1637*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1638*da0073e9SAndroid Build Coastguard Worker    """
1639*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1640*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1641*da0073e9SAndroid Build Coastguard Worker            feature_alpha_dropout,
1642*da0073e9SAndroid Build Coastguard Worker            (input,),
1643*da0073e9SAndroid Build Coastguard Worker            input,
1644*da0073e9SAndroid Build Coastguard Worker            p=p,
1645*da0073e9SAndroid Build Coastguard Worker            training=training,
1646*da0073e9SAndroid Build Coastguard Worker            inplace=inplace,
1647*da0073e9SAndroid Build Coastguard Worker        )
1648*da0073e9SAndroid Build Coastguard Worker    if p < 0.0 or p > 1.0:
1649*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1650*da0073e9SAndroid Build Coastguard Worker    return (
1651*da0073e9SAndroid Build Coastguard Worker        _VF.feature_alpha_dropout_(input, p, training)
1652*da0073e9SAndroid Build Coastguard Worker        if inplace
1653*da0073e9SAndroid Build Coastguard Worker        else _VF.feature_alpha_dropout(input, p, training)
1654*da0073e9SAndroid Build Coastguard Worker    )
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Workerdef _threshold(
1658*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1659*da0073e9SAndroid Build Coastguard Worker    threshold: float,
1660*da0073e9SAndroid Build Coastguard Worker    value: float,
1661*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1662*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
1663*da0073e9SAndroid Build Coastguard Worker    r"""Apply a threshold to each element of the input Tensor.
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Threshold` for more details.
1666*da0073e9SAndroid Build Coastguard Worker    """
1667*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1668*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1669*da0073e9SAndroid Build Coastguard Worker            _threshold, (input,), input, threshold, value, inplace=inplace
1670*da0073e9SAndroid Build Coastguard Worker        )
1671*da0073e9SAndroid Build Coastguard Worker    if inplace:
1672*da0073e9SAndroid Build Coastguard Worker        result = _VF.threshold_(input, threshold, value)
1673*da0073e9SAndroid Build Coastguard Worker    else:
1674*da0073e9SAndroid Build Coastguard Worker        result = _VF.threshold(input, threshold, value)
1675*da0073e9SAndroid Build Coastguard Worker    return result
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker# We define this function as _threshold because it takes an argument
1679*da0073e9SAndroid Build Coastguard Worker# named threshold, which clobbers the recursive reference to the
1680*da0073e9SAndroid Build Coastguard Worker# function needed for __torch_function__ support
1681*da0073e9SAndroid Build Coastguard Workerthreshold = _threshold
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Workerthreshold_ = _add_docstr(
1684*da0073e9SAndroid Build Coastguard Worker    _VF.threshold_,
1685*da0073e9SAndroid Build Coastguard Worker    r"""
1686*da0073e9SAndroid Build Coastguard Workerthreshold_(input, threshold, value) -> Tensor
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~threshold`.
1689*da0073e9SAndroid Build Coastguard Worker""",
1690*da0073e9SAndroid Build Coastguard Worker)
1691*da0073e9SAndroid Build Coastguard Worker
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Workerdef relu(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1694*da0073e9SAndroid Build Coastguard Worker    r"""relu(input, inplace=False) -> Tensor
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker    Applies the rectified linear unit function element-wise. See
1697*da0073e9SAndroid Build Coastguard Worker    :class:`~torch.nn.ReLU` for more details.
1698*da0073e9SAndroid Build Coastguard Worker    """
1699*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1700*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(relu, (input,), input, inplace=inplace)
1701*da0073e9SAndroid Build Coastguard Worker    if inplace:
1702*da0073e9SAndroid Build Coastguard Worker        result = torch.relu_(input)
1703*da0073e9SAndroid Build Coastguard Worker    else:
1704*da0073e9SAndroid Build Coastguard Worker        result = torch.relu(input)
1705*da0073e9SAndroid Build Coastguard Worker    return result
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker
1708*da0073e9SAndroid Build Coastguard Workerrelu_ = _add_docstr(
1709*da0073e9SAndroid Build Coastguard Worker    torch.relu_,
1710*da0073e9SAndroid Build Coastguard Worker    r"""
1711*da0073e9SAndroid Build Coastguard Workerrelu_(input) -> Tensor
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~relu`.
1714*da0073e9SAndroid Build Coastguard Worker""",
1715*da0073e9SAndroid Build Coastguard Worker)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Workerdef glu(input: Tensor, dim: int = -1) -> Tensor:  # noqa: D400,D402
1719*da0073e9SAndroid Build Coastguard Worker    r"""
1720*da0073e9SAndroid Build Coastguard Worker    glu(input, dim=-1) -> Tensor
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker    The gated linear unit. Computes:
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker    .. math ::
1725*da0073e9SAndroid Build Coastguard Worker        \text{GLU}(a, b) = a \otimes \sigma(b)
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker    where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma`
1728*da0073e9SAndroid Build Coastguard Worker    is the sigmoid function and :math:`\otimes` is the element-wise product between matrices.
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker    See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_.
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker    Args:
1733*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input tensor
1734*da0073e9SAndroid Build Coastguard Worker        dim (int): dimension on which to split the input. Default: -1
1735*da0073e9SAndroid Build Coastguard Worker    """
1736*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1737*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(glu, (input,), input, dim=dim)
1738*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 0:
1739*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1740*da0073e9SAndroid Build Coastguard Worker            "glu does not support scalars because halving size must be even"
1741*da0073e9SAndroid Build Coastguard Worker        )
1742*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.glu(input, dim)
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Workerdef hardtanh(
1746*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1747*da0073e9SAndroid Build Coastguard Worker    min_val: float = -1.0,
1748*da0073e9SAndroid Build Coastguard Worker    max_val: float = 1.0,
1749*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1750*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
1751*da0073e9SAndroid Build Coastguard Worker    r"""
1752*da0073e9SAndroid Build Coastguard Worker    hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more
1755*da0073e9SAndroid Build Coastguard Worker    details.
1756*da0073e9SAndroid Build Coastguard Worker    """
1757*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1758*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1759*da0073e9SAndroid Build Coastguard Worker            hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace
1760*da0073e9SAndroid Build Coastguard Worker        )
1761*da0073e9SAndroid Build Coastguard Worker    if min_val > max_val:
1762*da0073e9SAndroid Build Coastguard Worker        raise ValueError("min_val cannot be greater than max_val")
1763*da0073e9SAndroid Build Coastguard Worker    if inplace:
1764*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.hardtanh_(input, min_val, max_val)
1765*da0073e9SAndroid Build Coastguard Worker    else:
1766*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.hardtanh(input, min_val, max_val)
1767*da0073e9SAndroid Build Coastguard Worker    return result
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Workerhardtanh_ = _add_docstr(
1771*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.hardtanh_,
1772*da0073e9SAndroid Build Coastguard Worker    r"""
1773*da0073e9SAndroid Build Coastguard Workerhardtanh_(input, min_val=-1., max_val=1.) -> Tensor
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~hardtanh`.
1776*da0073e9SAndroid Build Coastguard Worker""",
1777*da0073e9SAndroid Build Coastguard Worker)
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Workerdef relu6(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1781*da0073e9SAndroid Build Coastguard Worker    r"""relu6(input, inplace=False) -> Tensor
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker    Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.ReLU6` for more details.
1786*da0073e9SAndroid Build Coastguard Worker    """
1787*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1788*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(relu6, (input,), input, inplace=inplace)
1789*da0073e9SAndroid Build Coastguard Worker    if inplace:
1790*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.relu6_(input)
1791*da0073e9SAndroid Build Coastguard Worker    else:
1792*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.relu6(input)
1793*da0073e9SAndroid Build Coastguard Worker    return result
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Workerdef elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor:
1797*da0073e9SAndroid Build Coastguard Worker    r"""Apply the Exponential Linear Unit (ELU) function element-wise.
1798*da0073e9SAndroid Build Coastguard Worker
1799*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.ELU` for more details.
1800*da0073e9SAndroid Build Coastguard Worker    """
1801*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1802*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace)
1803*da0073e9SAndroid Build Coastguard Worker    if inplace:
1804*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.elu_(input, alpha)
1805*da0073e9SAndroid Build Coastguard Worker    else:
1806*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.elu(input, alpha)
1807*da0073e9SAndroid Build Coastguard Worker    return result
1808*da0073e9SAndroid Build Coastguard Worker
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Workerelu_ = _add_docstr(
1811*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.elu_,
1812*da0073e9SAndroid Build Coastguard Worker    r"""
1813*da0073e9SAndroid Build Coastguard Workerelu_(input, alpha=1.) -> Tensor
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~elu`.
1816*da0073e9SAndroid Build Coastguard Worker""",
1817*da0073e9SAndroid Build Coastguard Worker)
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker
1820*da0073e9SAndroid Build Coastguard Workerdef selu(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1821*da0073e9SAndroid Build Coastguard Worker    r"""selu(input, inplace=False) -> Tensor
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker    Applies element-wise,
1824*da0073e9SAndroid Build Coastguard Worker    :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`,
1825*da0073e9SAndroid Build Coastguard Worker    with :math:`\alpha=1.6732632423543772848170429916717` and
1826*da0073e9SAndroid Build Coastguard Worker    :math:`scale=1.0507009873554804934193349852946`.
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.SELU` for more details.
1829*da0073e9SAndroid Build Coastguard Worker    """
1830*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1831*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(selu, (input,), input, inplace=inplace)
1832*da0073e9SAndroid Build Coastguard Worker    if inplace:
1833*da0073e9SAndroid Build Coastguard Worker        result = torch.selu_(input)
1834*da0073e9SAndroid Build Coastguard Worker    else:
1835*da0073e9SAndroid Build Coastguard Worker        result = torch.selu(input)
1836*da0073e9SAndroid Build Coastguard Worker    return result
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Workerselu_ = _add_docstr(
1840*da0073e9SAndroid Build Coastguard Worker    torch.selu_,
1841*da0073e9SAndroid Build Coastguard Worker    r"""
1842*da0073e9SAndroid Build Coastguard Workerselu_(input) -> Tensor
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~selu`.
1845*da0073e9SAndroid Build Coastguard Worker""",
1846*da0073e9SAndroid Build Coastguard Worker)
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Workerdef celu(
1850*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1851*da0073e9SAndroid Build Coastguard Worker    alpha: float = 1.0,
1852*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1853*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
1854*da0073e9SAndroid Build Coastguard Worker    r"""celu(input, alpha=1., inplace=False) -> Tensor
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker    Applies element-wise,
1857*da0073e9SAndroid Build Coastguard Worker    :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`.
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.CELU` for more details.
1860*da0073e9SAndroid Build Coastguard Worker    """
1861*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1862*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1863*da0073e9SAndroid Build Coastguard Worker            celu, (input,), input, alpha=alpha, inplace=inplace
1864*da0073e9SAndroid Build Coastguard Worker        )
1865*da0073e9SAndroid Build Coastguard Worker    if inplace:
1866*da0073e9SAndroid Build Coastguard Worker        result = torch.celu_(input, alpha)
1867*da0073e9SAndroid Build Coastguard Worker    else:
1868*da0073e9SAndroid Build Coastguard Worker        result = torch.celu(input, alpha)
1869*da0073e9SAndroid Build Coastguard Worker    return result
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Workercelu_ = _add_docstr(
1873*da0073e9SAndroid Build Coastguard Worker    torch.celu_,
1874*da0073e9SAndroid Build Coastguard Worker    r"""
1875*da0073e9SAndroid Build Coastguard Workercelu_(input, alpha=1.) -> Tensor
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~celu`.
1878*da0073e9SAndroid Build Coastguard Worker""",
1879*da0073e9SAndroid Build Coastguard Worker)
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Workerdef leaky_relu(
1883*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1884*da0073e9SAndroid Build Coastguard Worker    negative_slope: float = 0.01,
1885*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1886*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
1887*da0073e9SAndroid Build Coastguard Worker    r"""
1888*da0073e9SAndroid Build Coastguard Worker    leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker    Applies element-wise,
1891*da0073e9SAndroid Build Coastguard Worker    :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LeakyReLU` for more details.
1894*da0073e9SAndroid Build Coastguard Worker    """
1895*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1896*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1897*da0073e9SAndroid Build Coastguard Worker            leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace
1898*da0073e9SAndroid Build Coastguard Worker        )
1899*da0073e9SAndroid Build Coastguard Worker    if inplace:
1900*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.leaky_relu_(input, negative_slope)
1901*da0073e9SAndroid Build Coastguard Worker    else:
1902*da0073e9SAndroid Build Coastguard Worker        result = torch._C._nn.leaky_relu(input, negative_slope)
1903*da0073e9SAndroid Build Coastguard Worker    return result
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Workerleaky_relu_ = _add_docstr(
1907*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.leaky_relu_,
1908*da0073e9SAndroid Build Coastguard Worker    r"""
1909*da0073e9SAndroid Build Coastguard Workerleaky_relu_(input, negative_slope=0.01) -> Tensor
1910*da0073e9SAndroid Build Coastguard Worker
1911*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~leaky_relu`.
1912*da0073e9SAndroid Build Coastguard Worker""",
1913*da0073e9SAndroid Build Coastguard Worker)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker
1916*da0073e9SAndroid Build Coastguard Workerprelu = _add_docstr(
1917*da0073e9SAndroid Build Coastguard Worker    torch.prelu,
1918*da0073e9SAndroid Build Coastguard Worker    r"""prelu(input, weight) -> Tensor
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard WorkerApplies element-wise the function
1921*da0073e9SAndroid Build Coastguard Worker:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a
1922*da0073e9SAndroid Build Coastguard Workerlearnable parameter.
1923*da0073e9SAndroid Build Coastguard Worker
1924*da0073e9SAndroid Build Coastguard Worker.. note::
1925*da0073e9SAndroid Build Coastguard Worker    `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D,
1926*da0073e9SAndroid Build Coastguard Worker    its size must match the number of input channels, determined by
1927*da0073e9SAndroid Build Coastguard Worker    `input.size(1)` when `input.dim() >= 2`, otherwise 1.
1928*da0073e9SAndroid Build Coastguard Worker    In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded
1929*da0073e9SAndroid Build Coastguard Worker    to the shape of `input` in a way that is not possible using normal
1930*da0073e9SAndroid Build Coastguard Worker    :ref:`broadcasting semantics<broadcasting-semantics>`.
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PReLU` for more details.
1933*da0073e9SAndroid Build Coastguard Worker""",
1934*da0073e9SAndroid Build Coastguard Worker)
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Workerdef rrelu(
1938*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
1939*da0073e9SAndroid Build Coastguard Worker    lower: float = 1.0 / 8,
1940*da0073e9SAndroid Build Coastguard Worker    upper: float = 1.0 / 3,
1941*da0073e9SAndroid Build Coastguard Worker    training: bool = False,
1942*da0073e9SAndroid Build Coastguard Worker    inplace: bool = False,
1943*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
1944*da0073e9SAndroid Build Coastguard Worker    r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker    Randomized leaky ReLU.
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.RReLU` for more details.
1949*da0073e9SAndroid Build Coastguard Worker    """
1950*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1951*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1952*da0073e9SAndroid Build Coastguard Worker            rrelu,
1953*da0073e9SAndroid Build Coastguard Worker            (input,),
1954*da0073e9SAndroid Build Coastguard Worker            input,
1955*da0073e9SAndroid Build Coastguard Worker            lower=lower,
1956*da0073e9SAndroid Build Coastguard Worker            upper=upper,
1957*da0073e9SAndroid Build Coastguard Worker            training=training,
1958*da0073e9SAndroid Build Coastguard Worker            inplace=inplace,
1959*da0073e9SAndroid Build Coastguard Worker        )
1960*da0073e9SAndroid Build Coastguard Worker    if inplace:
1961*da0073e9SAndroid Build Coastguard Worker        result = torch.rrelu_(input, lower, upper, training)
1962*da0073e9SAndroid Build Coastguard Worker    else:
1963*da0073e9SAndroid Build Coastguard Worker        result = torch.rrelu(input, lower, upper, training)
1964*da0073e9SAndroid Build Coastguard Worker    return result
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Workerrrelu_ = _add_docstr(
1968*da0073e9SAndroid Build Coastguard Worker    torch.rrelu_,
1969*da0073e9SAndroid Build Coastguard Worker    r"""
1970*da0073e9SAndroid Build Coastguard Workerrrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~rrelu`.
1973*da0073e9SAndroid Build Coastguard Worker""",
1974*da0073e9SAndroid Build Coastguard Worker)
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Workerlogsigmoid = _add_docstr(
1977*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.log_sigmoid,
1978*da0073e9SAndroid Build Coastguard Worker    r"""
1979*da0073e9SAndroid Build Coastguard Workerlogsigmoid(input) -> Tensor
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard WorkerApplies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)`
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.LogSigmoid` for more details.
1984*da0073e9SAndroid Build Coastguard Worker""",
1985*da0073e9SAndroid Build Coastguard Worker)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Workergelu = _add_docstr(
1988*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.gelu,
1989*da0073e9SAndroid Build Coastguard Worker    r"""
1990*da0073e9SAndroid Build Coastguard Workergelu(input, approximate = 'none') -> Tensor
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard WorkerWhen the approximate argument is 'none', it applies element-wise the function
1993*da0073e9SAndroid Build Coastguard Worker:math:`\text{GELU}(x) = x * \Phi(x)`
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Workerwhere :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard WorkerWhen the approximate argument is 'tanh', Gelu is estimated with
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker.. math::
2000*da0073e9SAndroid Build Coastguard Worker    \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard WorkerSee `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
2003*da0073e9SAndroid Build Coastguard Worker""",
2004*da0073e9SAndroid Build Coastguard Worker)
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Workerhardshrink = _add_docstr(
2007*da0073e9SAndroid Build Coastguard Worker    torch.hardshrink,
2008*da0073e9SAndroid Build Coastguard Worker    r"""
2009*da0073e9SAndroid Build Coastguard Workerhardshrink(input, lambd=0.5) -> Tensor
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard WorkerApplies the hard shrinkage function element-wise
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Hardshrink` for more details.
2014*da0073e9SAndroid Build Coastguard Worker""",
2015*da0073e9SAndroid Build Coastguard Worker)
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Workerdef tanhshrink(input):  # noqa: D400,D402
2019*da0073e9SAndroid Build Coastguard Worker    r"""tanhshrink(input) -> Tensor
2020*da0073e9SAndroid Build Coastguard Worker
2021*da0073e9SAndroid Build Coastguard Worker    Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Tanhshrink` for more details.
2024*da0073e9SAndroid Build Coastguard Worker    """
2025*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2026*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(tanhshrink, (input,), input)
2027*da0073e9SAndroid Build Coastguard Worker    return input - input.tanh()
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker
2030*da0073e9SAndroid Build Coastguard Workerdef softsign(input):  # noqa: D400,D402
2031*da0073e9SAndroid Build Coastguard Worker    r"""softsign(input) -> Tensor
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Worker    Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}`
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Softsign` for more details.
2036*da0073e9SAndroid Build Coastguard Worker    """
2037*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2038*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(softsign, (input,), input)
2039*da0073e9SAndroid Build Coastguard Worker    return input / (input.abs() + 1)
2040*da0073e9SAndroid Build Coastguard Worker
2041*da0073e9SAndroid Build Coastguard Worker
2042*da0073e9SAndroid Build Coastguard Workersoftplus = _add_docstr(
2043*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.softplus,
2044*da0073e9SAndroid Build Coastguard Worker    r"""
2045*da0073e9SAndroid Build Coastguard Workersoftplus(input, beta=1, threshold=20) -> Tensor
2046*da0073e9SAndroid Build Coastguard Worker
2047*da0073e9SAndroid Build Coastguard WorkerApplies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`.
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard WorkerFor numerical stability the implementation reverts to the linear function
2050*da0073e9SAndroid Build Coastguard Workerwhen :math:`input \times \beta > threshold`.
2051*da0073e9SAndroid Build Coastguard Worker
2052*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Softplus` for more details.
2053*da0073e9SAndroid Build Coastguard Worker""",
2054*da0073e9SAndroid Build Coastguard Worker)
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Workerdef _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int:
2058*da0073e9SAndroid Build Coastguard Worker    warnings.warn(
2059*da0073e9SAndroid Build Coastguard Worker        f"Implicit dimension choice for {name} has been deprecated. "
2060*da0073e9SAndroid Build Coastguard Worker        "Change the call to include dim=X as an argument.",
2061*da0073e9SAndroid Build Coastguard Worker        stacklevel=stacklevel,
2062*da0073e9SAndroid Build Coastguard Worker    )
2063*da0073e9SAndroid Build Coastguard Worker    if ndim == 0 or ndim == 1 or ndim == 3:
2064*da0073e9SAndroid Build Coastguard Worker        ret = 0
2065*da0073e9SAndroid Build Coastguard Worker    else:
2066*da0073e9SAndroid Build Coastguard Worker        ret = 1
2067*da0073e9SAndroid Build Coastguard Worker    return ret
2068*da0073e9SAndroid Build Coastguard Worker
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Workerdef softmin(
2071*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2072*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = None,
2073*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = 3,
2074*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[DType] = None,
2075*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2076*da0073e9SAndroid Build Coastguard Worker    r"""Apply a softmin function.
2077*da0073e9SAndroid Build Coastguard Worker
2078*da0073e9SAndroid Build Coastguard Worker    Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula.
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Softmin` for more details.
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker    Args:
2083*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input
2084*da0073e9SAndroid Build Coastguard Worker        dim (int): A dimension along which softmin will be computed (so every slice
2085*da0073e9SAndroid Build Coastguard Worker            along dim will sum to 1).
2086*da0073e9SAndroid Build Coastguard Worker        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2087*da0073e9SAndroid Build Coastguard Worker          If specified, the input tensor is casted to :attr:`dtype` before the operation
2088*da0073e9SAndroid Build Coastguard Worker          is performed. This is useful for preventing data type overflows. Default: None.
2089*da0073e9SAndroid Build Coastguard Worker    """
2090*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2091*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2092*da0073e9SAndroid Build Coastguard Worker            softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2093*da0073e9SAndroid Build Coastguard Worker        )
2094*da0073e9SAndroid Build Coastguard Worker    if dim is None:
2095*da0073e9SAndroid Build Coastguard Worker        dim = _get_softmax_dim("softmin", input.dim(), _stacklevel)
2096*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
2097*da0073e9SAndroid Build Coastguard Worker        ret = (-input).softmax(dim)
2098*da0073e9SAndroid Build Coastguard Worker    else:
2099*da0073e9SAndroid Build Coastguard Worker        ret = (-input).softmax(dim, dtype=dtype)
2100*da0073e9SAndroid Build Coastguard Worker    return ret
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Workerdef softmax(
2104*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2105*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = None,
2106*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = 3,
2107*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[DType] = None,
2108*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2109*da0073e9SAndroid Build Coastguard Worker    r"""Apply a softmax function.
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker    Softmax is defined as:
2112*da0073e9SAndroid Build Coastguard Worker
2113*da0073e9SAndroid Build Coastguard Worker    :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Worker    It is applied to all slices along dim, and will re-scale them so that the elements
2116*da0073e9SAndroid Build Coastguard Worker    lie in the range `[0, 1]` and sum to 1.
2117*da0073e9SAndroid Build Coastguard Worker
2118*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Softmax` for more details.
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    Args:
2121*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input
2122*da0073e9SAndroid Build Coastguard Worker        dim (int): A dimension along which softmax will be computed.
2123*da0073e9SAndroid Build Coastguard Worker        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2124*da0073e9SAndroid Build Coastguard Worker          If specified, the input tensor is casted to :attr:`dtype` before the operation
2125*da0073e9SAndroid Build Coastguard Worker          is performed. This is useful for preventing data type overflows. Default: None.
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker    .. note::
2128*da0073e9SAndroid Build Coastguard Worker        This function doesn't work directly with NLLLoss,
2129*da0073e9SAndroid Build Coastguard Worker        which expects the Log to be computed between the Softmax and itself.
2130*da0073e9SAndroid Build Coastguard Worker        Use log_softmax instead (it's faster and has better numerical properties).
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker    """
2133*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2134*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2135*da0073e9SAndroid Build Coastguard Worker            softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2136*da0073e9SAndroid Build Coastguard Worker        )
2137*da0073e9SAndroid Build Coastguard Worker    if dim is None:
2138*da0073e9SAndroid Build Coastguard Worker        dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
2139*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
2140*da0073e9SAndroid Build Coastguard Worker        ret = input.softmax(dim)
2141*da0073e9SAndroid Build Coastguard Worker    else:
2142*da0073e9SAndroid Build Coastguard Worker        ret = input.softmax(dim, dtype=dtype)
2143*da0073e9SAndroid Build Coastguard Worker    return ret
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker
2146*da0073e9SAndroid Build Coastguard Workerdef gumbel_softmax(
2147*da0073e9SAndroid Build Coastguard Worker    logits: Tensor,
2148*da0073e9SAndroid Build Coastguard Worker    tau: float = 1,
2149*da0073e9SAndroid Build Coastguard Worker    hard: bool = False,
2150*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-10,
2151*da0073e9SAndroid Build Coastguard Worker    dim: int = -1,
2152*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2153*da0073e9SAndroid Build Coastguard Worker    r"""
2154*da0073e9SAndroid Build Coastguard Worker    Sample from the Gumbel-Softmax distribution (`Link 1`_  `Link 2`_) and optionally discretize.
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker    Args:
2157*da0073e9SAndroid Build Coastguard Worker      logits: `[..., num_features]` unnormalized log probabilities
2158*da0073e9SAndroid Build Coastguard Worker      tau: non-negative scalar temperature
2159*da0073e9SAndroid Build Coastguard Worker      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
2160*da0073e9SAndroid Build Coastguard Worker            but will be differentiated as if it is the soft sample in autograd
2161*da0073e9SAndroid Build Coastguard Worker      dim (int): A dimension along which softmax will be computed. Default: -1.
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker    Returns:
2164*da0073e9SAndroid Build Coastguard Worker      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
2165*da0073e9SAndroid Build Coastguard Worker      If ``hard=True``, the returned samples will be one-hot, otherwise they will
2166*da0073e9SAndroid Build Coastguard Worker      be probability distributions that sum to 1 across `dim`.
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker    .. note::
2169*da0073e9SAndroid Build Coastguard Worker      This function is here for legacy reasons, may be removed from nn.Functional in the future.
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker    .. note::
2172*da0073e9SAndroid Build Coastguard Worker      The main trick for `hard` is to do  `y_hard - y_soft.detach() + y_soft`
2173*da0073e9SAndroid Build Coastguard Worker
2174*da0073e9SAndroid Build Coastguard Worker      It achieves two things:
2175*da0073e9SAndroid Build Coastguard Worker      - makes the output value exactly one-hot
2176*da0073e9SAndroid Build Coastguard Worker      (since we add then subtract y_soft value)
2177*da0073e9SAndroid Build Coastguard Worker      - makes the gradient equal to y_soft gradient
2178*da0073e9SAndroid Build Coastguard Worker      (since we strip all other gradients)
2179*da0073e9SAndroid Build Coastguard Worker
2180*da0073e9SAndroid Build Coastguard Worker    Examples::
2181*da0073e9SAndroid Build Coastguard Worker        >>> logits = torch.randn(20, 32)
2182*da0073e9SAndroid Build Coastguard Worker        >>> # Sample soft categorical using reparametrization trick:
2183*da0073e9SAndroid Build Coastguard Worker        >>> F.gumbel_softmax(logits, tau=1, hard=False)
2184*da0073e9SAndroid Build Coastguard Worker        >>> # Sample hard categorical using "Straight-through" trick:
2185*da0073e9SAndroid Build Coastguard Worker        >>> F.gumbel_softmax(logits, tau=1, hard=True)
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker    .. _Link 1:
2188*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1611.00712
2189*da0073e9SAndroid Build Coastguard Worker    .. _Link 2:
2190*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1611.01144
2191*da0073e9SAndroid Build Coastguard Worker    """
2192*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(logits):
2193*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2194*da0073e9SAndroid Build Coastguard Worker            gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim
2195*da0073e9SAndroid Build Coastguard Worker        )
2196*da0073e9SAndroid Build Coastguard Worker    if eps != 1e-10:
2197*da0073e9SAndroid Build Coastguard Worker        warnings.warn("`eps` parameter is deprecated and has no effect.")
2198*da0073e9SAndroid Build Coastguard Worker
2199*da0073e9SAndroid Build Coastguard Worker    gumbels = (
2200*da0073e9SAndroid Build Coastguard Worker        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
2201*da0073e9SAndroid Build Coastguard Worker        .exponential_()
2202*da0073e9SAndroid Build Coastguard Worker        .log()
2203*da0073e9SAndroid Build Coastguard Worker    )  # ~Gumbel(0,1)
2204*da0073e9SAndroid Build Coastguard Worker    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
2205*da0073e9SAndroid Build Coastguard Worker    y_soft = gumbels.softmax(dim)
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker    if hard:
2208*da0073e9SAndroid Build Coastguard Worker        # Straight through.
2209*da0073e9SAndroid Build Coastguard Worker        index = y_soft.max(dim, keepdim=True)[1]
2210*da0073e9SAndroid Build Coastguard Worker        y_hard = torch.zeros_like(
2211*da0073e9SAndroid Build Coastguard Worker            logits, memory_format=torch.legacy_contiguous_format
2212*da0073e9SAndroid Build Coastguard Worker        ).scatter_(dim, index, 1.0)
2213*da0073e9SAndroid Build Coastguard Worker        ret = y_hard - y_soft.detach() + y_soft
2214*da0073e9SAndroid Build Coastguard Worker    else:
2215*da0073e9SAndroid Build Coastguard Worker        # Reparametrization trick.
2216*da0073e9SAndroid Build Coastguard Worker        ret = y_soft
2217*da0073e9SAndroid Build Coastguard Worker    return ret
2218*da0073e9SAndroid Build Coastguard Worker
2219*da0073e9SAndroid Build Coastguard Worker
2220*da0073e9SAndroid Build Coastguard Workerdef log_softmax(
2221*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2222*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = None,
2223*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = 3,
2224*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[DType] = None,
2225*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2226*da0073e9SAndroid Build Coastguard Worker    r"""Apply a softmax followed by a logarithm.
2227*da0073e9SAndroid Build Coastguard Worker
2228*da0073e9SAndroid Build Coastguard Worker    While mathematically equivalent to log(softmax(x)), doing these two
2229*da0073e9SAndroid Build Coastguard Worker    operations separately is slower and numerically unstable. This function
2230*da0073e9SAndroid Build Coastguard Worker    uses an alternative formulation to compute the output and gradient correctly.
2231*da0073e9SAndroid Build Coastguard Worker
2232*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LogSoftmax` for more details.
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker    Args:
2235*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input
2236*da0073e9SAndroid Build Coastguard Worker        dim (int): A dimension along which log_softmax will be computed.
2237*da0073e9SAndroid Build Coastguard Worker        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2238*da0073e9SAndroid Build Coastguard Worker          If specified, the input tensor is cast to :attr:`dtype` before the operation
2239*da0073e9SAndroid Build Coastguard Worker          is performed. This is useful for preventing data type overflows. Default: None.
2240*da0073e9SAndroid Build Coastguard Worker    """
2241*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2242*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2243*da0073e9SAndroid Build Coastguard Worker            log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2244*da0073e9SAndroid Build Coastguard Worker        )
2245*da0073e9SAndroid Build Coastguard Worker    if dim is None:
2246*da0073e9SAndroid Build Coastguard Worker        dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
2247*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
2248*da0073e9SAndroid Build Coastguard Worker        ret = input.log_softmax(dim)
2249*da0073e9SAndroid Build Coastguard Worker    else:
2250*da0073e9SAndroid Build Coastguard Worker        ret = input.log_softmax(dim, dtype=dtype)
2251*da0073e9SAndroid Build Coastguard Worker    return ret
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Workersoftshrink = _add_docstr(
2255*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.softshrink,
2256*da0073e9SAndroid Build Coastguard Worker    r"""
2257*da0073e9SAndroid Build Coastguard Workersoftshrink(input, lambd=0.5) -> Tensor
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard WorkerApplies the soft shrinkage function elementwise
2260*da0073e9SAndroid Build Coastguard Worker
2261*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Softshrink` for more details.
2262*da0073e9SAndroid Build Coastguard Worker""",
2263*da0073e9SAndroid Build Coastguard Worker)
2264*da0073e9SAndroid Build Coastguard Worker
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Workerdef tanh(input):  # noqa: D400,D402
2267*da0073e9SAndroid Build Coastguard Worker    r"""tanh(input) -> Tensor
2268*da0073e9SAndroid Build Coastguard Worker
2269*da0073e9SAndroid Build Coastguard Worker    Applies element-wise,
2270*da0073e9SAndroid Build Coastguard Worker    :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}`
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Tanh` for more details.
2273*da0073e9SAndroid Build Coastguard Worker    """
2274*da0073e9SAndroid Build Coastguard Worker    return input.tanh()
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Workerdef sigmoid(input):  # noqa: D400,D402
2278*da0073e9SAndroid Build Coastguard Worker    r"""sigmoid(input) -> Tensor
2279*da0073e9SAndroid Build Coastguard Worker
2280*da0073e9SAndroid Build Coastguard Worker    Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}`
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Sigmoid` for more details.
2283*da0073e9SAndroid Build Coastguard Worker    """
2284*da0073e9SAndroid Build Coastguard Worker    return input.sigmoid()
2285*da0073e9SAndroid Build Coastguard Worker
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Workerdef hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
2288*da0073e9SAndroid Build Coastguard Worker    r"""Apply the Hardsigmoid function element-wise.
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker    .. math::
2291*da0073e9SAndroid Build Coastguard Worker        \text{Hardsigmoid}(x) = \begin{cases}
2292*da0073e9SAndroid Build Coastguard Worker            0 & \text{if~} x \le -3, \\
2293*da0073e9SAndroid Build Coastguard Worker            1 & \text{if~} x \ge +3, \\
2294*da0073e9SAndroid Build Coastguard Worker            x / 6 + 1 / 2 & \text{otherwise}
2295*da0073e9SAndroid Build Coastguard Worker        \end{cases}
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker    Args:
2298*da0073e9SAndroid Build Coastguard Worker        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Hardsigmoid` for more details.
2301*da0073e9SAndroid Build Coastguard Worker    """
2302*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2303*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
2304*da0073e9SAndroid Build Coastguard Worker    if inplace:
2305*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.hardsigmoid_(input)
2306*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.hardsigmoid(input)
2307*da0073e9SAndroid Build Coastguard Worker
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Workerlinear = _add_docstr(
2310*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.linear,
2311*da0073e9SAndroid Build Coastguard Worker    r"""
2312*da0073e9SAndroid Build Coastguard Workerlinear(input, weight, bias=None) -> Tensor
2313*da0073e9SAndroid Build Coastguard Worker
2314*da0073e9SAndroid Build Coastguard WorkerApplies a linear transformation to the incoming data: :math:`y = xA^T + b`.
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard WorkerThis operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>`
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker{sparse_beta_warning}
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard WorkerThis operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard WorkerShape:
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker    - Input: :math:`(*, in\_features)` where `*` means any number of
2325*da0073e9SAndroid Build Coastguard Worker      additional dimensions, including none
2326*da0073e9SAndroid Build Coastguard Worker    - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
2327*da0073e9SAndroid Build Coastguard Worker    - Bias: :math:`(out\_features)` or :math:`()`
2328*da0073e9SAndroid Build Coastguard Worker    - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
2329*da0073e9SAndroid Build Coastguard Worker""".format(
2330*da0073e9SAndroid Build Coastguard Worker        **sparse_support_notes
2331*da0073e9SAndroid Build Coastguard Worker    ),
2332*da0073e9SAndroid Build Coastguard Worker)
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker
2335*da0073e9SAndroid Build Coastguard Workerbilinear = _add_docstr(
2336*da0073e9SAndroid Build Coastguard Worker    torch.bilinear,
2337*da0073e9SAndroid Build Coastguard Worker    r"""
2338*da0073e9SAndroid Build Coastguard Workerbilinear(input1, input2, weight, bias=None) -> Tensor
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard WorkerApplies a bilinear transformation to the incoming data:
2341*da0073e9SAndroid Build Coastguard Worker:math:`y = x_1^T A x_2 + b`
2342*da0073e9SAndroid Build Coastguard Worker
2343*da0073e9SAndroid Build Coastguard WorkerShape:
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker    - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}`
2346*da0073e9SAndroid Build Coastguard Worker      and :math:`*` means any number of additional dimensions.
2347*da0073e9SAndroid Build Coastguard Worker      All but the last dimension of the inputs should be the same.
2348*da0073e9SAndroid Build Coastguard Worker    - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`
2349*da0073e9SAndroid Build Coastguard Worker    - weight: :math:`(\text{out\_features}, \text{in1\_features},
2350*da0073e9SAndroid Build Coastguard Worker      \text{in2\_features})`
2351*da0073e9SAndroid Build Coastguard Worker    - bias: :math:`(\text{out\_features})`
2352*da0073e9SAndroid Build Coastguard Worker    - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
2353*da0073e9SAndroid Build Coastguard Worker      and all but the last dimension are the same shape as the input.
2354*da0073e9SAndroid Build Coastguard Worker""",
2355*da0073e9SAndroid Build Coastguard Worker)
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Workerdef silu(input: Tensor, inplace: bool = False) -> Tensor:
2359*da0073e9SAndroid Build Coastguard Worker    r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker    The SiLU function is also known as the swish function.
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker    .. math::
2364*da0073e9SAndroid Build Coastguard Worker        \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker    .. note::
2367*da0073e9SAndroid Build Coastguard Worker        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
2368*da0073e9SAndroid Build Coastguard Worker        where the SiLU (Sigmoid Linear Unit) was originally coined, and see
2369*da0073e9SAndroid Build Coastguard Worker        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
2370*da0073e9SAndroid Build Coastguard Worker        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
2371*da0073e9SAndroid Build Coastguard Worker        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
2372*da0073e9SAndroid Build Coastguard Worker        where the SiLU was experimented with later.
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.SiLU` for more details.
2375*da0073e9SAndroid Build Coastguard Worker    """
2376*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2377*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(silu, (input,), input, inplace=inplace)
2378*da0073e9SAndroid Build Coastguard Worker    if inplace:
2379*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.silu_(input)
2380*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.silu(input)
2381*da0073e9SAndroid Build Coastguard Worker
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Workerdef mish(input: Tensor, inplace: bool = False) -> Tensor:
2384*da0073e9SAndroid Build Coastguard Worker    r"""Apply the Mish function, element-wise.
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker    Mish: A Self Regularized Non-Monotonic Neural Activation Function.
2387*da0073e9SAndroid Build Coastguard Worker
2388*da0073e9SAndroid Build Coastguard Worker    .. math::
2389*da0073e9SAndroid Build Coastguard Worker        \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
2390*da0073e9SAndroid Build Coastguard Worker
2391*da0073e9SAndroid Build Coastguard Worker    .. note::
2392*da0073e9SAndroid Build Coastguard Worker        See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
2393*da0073e9SAndroid Build Coastguard Worker
2394*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Mish` for more details.
2395*da0073e9SAndroid Build Coastguard Worker    """
2396*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2397*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(mish, (input,), input, inplace=inplace)
2398*da0073e9SAndroid Build Coastguard Worker    if inplace:
2399*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.mish_(input)
2400*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.mish(input)
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker
2403*da0073e9SAndroid Build Coastguard Workerdef hardswish(input: Tensor, inplace: bool = False) -> Tensor:
2404*da0073e9SAndroid Build Coastguard Worker    r"""Apply hardswish function, element-wise.
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker    Follows implementation as described in the paper:
2407*da0073e9SAndroid Build Coastguard Worker    `Searching for MobileNetV3`_.
2408*da0073e9SAndroid Build Coastguard Worker
2409*da0073e9SAndroid Build Coastguard Worker    .. math::
2410*da0073e9SAndroid Build Coastguard Worker        \text{Hardswish}(x) = \begin{cases}
2411*da0073e9SAndroid Build Coastguard Worker            0 & \text{if~} x \le -3, \\
2412*da0073e9SAndroid Build Coastguard Worker            x & \text{if~} x \ge +3, \\
2413*da0073e9SAndroid Build Coastguard Worker            x \cdot (x + 3) /6 & \text{otherwise}
2414*da0073e9SAndroid Build Coastguard Worker        \end{cases}
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.Hardswish` for more details.
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker    .. _`Searching for MobileNetV3`:
2419*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1905.02244
2420*da0073e9SAndroid Build Coastguard Worker    """
2421*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2422*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(hardswish, (input,), input, inplace=inplace)
2423*da0073e9SAndroid Build Coastguard Worker    if inplace:
2424*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.hardswish_(input)
2425*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.hardswish(input)
2426*da0073e9SAndroid Build Coastguard Worker
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Workerdef _no_grad_embedding_renorm_(
2429*da0073e9SAndroid Build Coastguard Worker    weight: Tensor,
2430*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2431*da0073e9SAndroid Build Coastguard Worker    max_norm: float,
2432*da0073e9SAndroid Build Coastguard Worker    norm_type: float,
2433*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
2434*da0073e9SAndroid Build Coastguard Worker    torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker
2437*da0073e9SAndroid Build Coastguard Workerdef embedding(
2438*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2439*da0073e9SAndroid Build Coastguard Worker    weight: Tensor,
2440*da0073e9SAndroid Build Coastguard Worker    padding_idx: Optional[int] = None,
2441*da0073e9SAndroid Build Coastguard Worker    max_norm: Optional[float] = None,
2442*da0073e9SAndroid Build Coastguard Worker    norm_type: float = 2.0,
2443*da0073e9SAndroid Build Coastguard Worker    scale_grad_by_freq: bool = False,
2444*da0073e9SAndroid Build Coastguard Worker    sparse: bool = False,
2445*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2446*da0073e9SAndroid Build Coastguard Worker    r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.
2447*da0073e9SAndroid Build Coastguard Worker
2448*da0073e9SAndroid Build Coastguard Worker    This module is often used to retrieve word embeddings using indices.
2449*da0073e9SAndroid Build Coastguard Worker    The input to the module is a list of indices, and the embedding matrix,
2450*da0073e9SAndroid Build Coastguard Worker    and the output is the corresponding word embeddings.
2451*da0073e9SAndroid Build Coastguard Worker
2452*da0073e9SAndroid Build Coastguard Worker    See :class:`torch.nn.Embedding` for more details.
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker    .. note::
2455*da0073e9SAndroid Build Coastguard Worker        Note that the analytical gradients of this function with respect to
2456*da0073e9SAndroid Build Coastguard Worker        entries in :attr:`weight` at the row specified by :attr:`padding_idx`
2457*da0073e9SAndroid Build Coastguard Worker        are expected to differ from the numerical ones.
2458*da0073e9SAndroid Build Coastguard Worker
2459*da0073e9SAndroid Build Coastguard Worker    .. note::
2460*da0073e9SAndroid Build Coastguard Worker        Note that `:class:`torch.nn.Embedding` differs from this function in
2461*da0073e9SAndroid Build Coastguard Worker        that it initializes the row of :attr:`weight` specified by
2462*da0073e9SAndroid Build Coastguard Worker        :attr:`padding_idx` to all zeros on construction.
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker    Args:
2465*da0073e9SAndroid Build Coastguard Worker        input (LongTensor): Tensor containing indices into the embedding matrix
2466*da0073e9SAndroid Build Coastguard Worker        weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
2467*da0073e9SAndroid Build Coastguard Worker            and number of columns equal to the embedding size
2468*da0073e9SAndroid Build Coastguard Worker        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
2469*da0073e9SAndroid Build Coastguard Worker                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
2470*da0073e9SAndroid Build Coastguard Worker                                     i.e. it remains as a fixed "pad".
2471*da0073e9SAndroid Build Coastguard Worker        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
2472*da0073e9SAndroid Build Coastguard Worker                                    is renormalized to have norm :attr:`max_norm`.
2473*da0073e9SAndroid Build Coastguard Worker                                    Note: this will modify :attr:`weight` in-place.
2474*da0073e9SAndroid Build Coastguard Worker        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
2475*da0073e9SAndroid Build Coastguard Worker        scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
2476*da0073e9SAndroid Build Coastguard Worker                                                the words in the mini-batch. Default ``False``.
2477*da0073e9SAndroid Build Coastguard Worker        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
2478*da0073e9SAndroid Build Coastguard Worker                                 :class:`torch.nn.Embedding` for more details regarding sparse gradients.
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker    Shape:
2481*da0073e9SAndroid Build Coastguard Worker        - Input: LongTensor of arbitrary shape containing the indices to extract
2482*da0073e9SAndroid Build Coastguard Worker        - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
2483*da0073e9SAndroid Build Coastguard Worker          where V = maximum index + 1 and embedding_dim = the embedding size
2484*da0073e9SAndroid Build Coastguard Worker        - Output: `(*, embedding_dim)`, where `*` is the input shape
2485*da0073e9SAndroid Build Coastguard Worker
2486*da0073e9SAndroid Build Coastguard Worker    Examples::
2487*da0073e9SAndroid Build Coastguard Worker
2488*da0073e9SAndroid Build Coastguard Worker        >>> # a batch of 2 samples of 4 indices each
2489*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
2490*da0073e9SAndroid Build Coastguard Worker        >>> # an embedding matrix containing 10 tensors of size 3
2491*da0073e9SAndroid Build Coastguard Worker        >>> embedding_matrix = torch.rand(10, 3)
2492*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2493*da0073e9SAndroid Build Coastguard Worker        >>> F.embedding(input, embedding_matrix)
2494*da0073e9SAndroid Build Coastguard Worker        tensor([[[ 0.8490,  0.9625,  0.6753],
2495*da0073e9SAndroid Build Coastguard Worker                 [ 0.9666,  0.7761,  0.6108],
2496*da0073e9SAndroid Build Coastguard Worker                 [ 0.6246,  0.9751,  0.3618],
2497*da0073e9SAndroid Build Coastguard Worker                 [ 0.4161,  0.2419,  0.7383]],
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker                [[ 0.6246,  0.9751,  0.3618],
2500*da0073e9SAndroid Build Coastguard Worker                 [ 0.0237,  0.7794,  0.0528],
2501*da0073e9SAndroid Build Coastguard Worker                 [ 0.9666,  0.7761,  0.6108],
2502*da0073e9SAndroid Build Coastguard Worker                 [ 0.3385,  0.8612,  0.1867]]])
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker        >>> # example with padding_idx
2505*da0073e9SAndroid Build Coastguard Worker        >>> weights = torch.rand(10, 3)
2506*da0073e9SAndroid Build Coastguard Worker        >>> weights[0, :].zero_()
2507*da0073e9SAndroid Build Coastguard Worker        >>> embedding_matrix = weights
2508*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.tensor([[0, 2, 0, 5]])
2509*da0073e9SAndroid Build Coastguard Worker        >>> F.embedding(input, embedding_matrix, padding_idx=0)
2510*da0073e9SAndroid Build Coastguard Worker        tensor([[[ 0.0000,  0.0000,  0.0000],
2511*da0073e9SAndroid Build Coastguard Worker                 [ 0.5609,  0.5384,  0.8720],
2512*da0073e9SAndroid Build Coastguard Worker                 [ 0.0000,  0.0000,  0.0000],
2513*da0073e9SAndroid Build Coastguard Worker                 [ 0.6262,  0.2438,  0.7471]]])
2514*da0073e9SAndroid Build Coastguard Worker    """
2515*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, weight):
2516*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2517*da0073e9SAndroid Build Coastguard Worker            embedding,
2518*da0073e9SAndroid Build Coastguard Worker            (input, weight),
2519*da0073e9SAndroid Build Coastguard Worker            input,
2520*da0073e9SAndroid Build Coastguard Worker            weight,
2521*da0073e9SAndroid Build Coastguard Worker            padding_idx=padding_idx,
2522*da0073e9SAndroid Build Coastguard Worker            max_norm=max_norm,
2523*da0073e9SAndroid Build Coastguard Worker            norm_type=norm_type,
2524*da0073e9SAndroid Build Coastguard Worker            scale_grad_by_freq=scale_grad_by_freq,
2525*da0073e9SAndroid Build Coastguard Worker            sparse=sparse,
2526*da0073e9SAndroid Build Coastguard Worker        )
2527*da0073e9SAndroid Build Coastguard Worker    if padding_idx is not None:
2528*da0073e9SAndroid Build Coastguard Worker        if padding_idx > 0:
2529*da0073e9SAndroid Build Coastguard Worker            assert padding_idx < weight.size(
2530*da0073e9SAndroid Build Coastguard Worker                0
2531*da0073e9SAndroid Build Coastguard Worker            ), "Padding_idx must be within num_embeddings"
2532*da0073e9SAndroid Build Coastguard Worker        elif padding_idx < 0:
2533*da0073e9SAndroid Build Coastguard Worker            assert padding_idx >= -weight.size(
2534*da0073e9SAndroid Build Coastguard Worker                0
2535*da0073e9SAndroid Build Coastguard Worker            ), "Padding_idx must be within num_embeddings"
2536*da0073e9SAndroid Build Coastguard Worker            padding_idx = weight.size(0) + padding_idx
2537*da0073e9SAndroid Build Coastguard Worker    else:
2538*da0073e9SAndroid Build Coastguard Worker        padding_idx = -1
2539*da0073e9SAndroid Build Coastguard Worker    if max_norm is not None:
2540*da0073e9SAndroid Build Coastguard Worker        # Note [embedding_renorm contiguous]
2541*da0073e9SAndroid Build Coastguard Worker        # `embedding_renorm_` will call .contiguous() on input anyways, so we
2542*da0073e9SAndroid Build Coastguard Worker        # call it here and take advantage of the improved locality in the
2543*da0073e9SAndroid Build Coastguard Worker        # `embedding` call below too.
2544*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous()
2545*da0073e9SAndroid Build Coastguard Worker        # Note [embedding_renorm set_grad_enabled]
2546*da0073e9SAndroid Build Coastguard Worker        # XXX: equivalent to
2547*da0073e9SAndroid Build Coastguard Worker        # with torch.no_grad():
2548*da0073e9SAndroid Build Coastguard Worker        #   torch.embedding_renorm_
2549*da0073e9SAndroid Build Coastguard Worker        # remove once script supports set_grad_enabled
2550*da0073e9SAndroid Build Coastguard Worker        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
2551*da0073e9SAndroid Build Coastguard Worker    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2552*da0073e9SAndroid Build Coastguard Worker
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(
2555*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2556*da0073e9SAndroid Build Coastguard Worker    weight: Tensor,
2557*da0073e9SAndroid Build Coastguard Worker    offsets: Optional[Tensor] = None,
2558*da0073e9SAndroid Build Coastguard Worker    max_norm: Optional[float] = None,
2559*da0073e9SAndroid Build Coastguard Worker    norm_type: float = 2,
2560*da0073e9SAndroid Build Coastguard Worker    scale_grad_by_freq: bool = False,
2561*da0073e9SAndroid Build Coastguard Worker    mode: str = "mean",
2562*da0073e9SAndroid Build Coastguard Worker    sparse: bool = False,
2563*da0073e9SAndroid Build Coastguard Worker    per_sample_weights: Optional[Tensor] = None,
2564*da0073e9SAndroid Build Coastguard Worker    include_last_offset: bool = False,
2565*da0073e9SAndroid Build Coastguard Worker    padding_idx: Optional[int] = None,
2566*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2567*da0073e9SAndroid Build Coastguard Worker    r"""Compute sums, means or maxes of `bags` of embeddings.
2568*da0073e9SAndroid Build Coastguard Worker
2569*da0073e9SAndroid Build Coastguard Worker    Calculation is done without instantiating the intermediate embeddings.
2570*da0073e9SAndroid Build Coastguard Worker    See :class:`torch.nn.EmbeddingBag` for more details.
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    Note:
2573*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
2574*da0073e9SAndroid Build Coastguard Worker
2575*da0073e9SAndroid Build Coastguard Worker    Args:
2576*da0073e9SAndroid Build Coastguard Worker        input (LongTensor): Tensor containing bags of indices into the embedding matrix
2577*da0073e9SAndroid Build Coastguard Worker        weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
2578*da0073e9SAndroid Build Coastguard Worker            and number of columns equal to the embedding size
2579*da0073e9SAndroid Build Coastguard Worker        offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
2580*da0073e9SAndroid Build Coastguard Worker                             the starting index position of each bag (sequence) in :attr:`input`.
2581*da0073e9SAndroid Build Coastguard Worker        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
2582*da0073e9SAndroid Build Coastguard Worker                                    is renormalized to have norm :attr:`max_norm`.
2583*da0073e9SAndroid Build Coastguard Worker                                    Note: this will modify :attr:`weight` in-place.
2584*da0073e9SAndroid Build Coastguard Worker        norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option.
2585*da0073e9SAndroid Build Coastguard Worker                                     Default ``2``.
2586*da0073e9SAndroid Build Coastguard Worker        scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
2587*da0073e9SAndroid Build Coastguard Worker                                                the words in the mini-batch. Default ``False``.
2588*da0073e9SAndroid Build Coastguard Worker                                                Note: this option is not supported when ``mode="max"``.
2589*da0073e9SAndroid Build Coastguard Worker        mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
2590*da0073e9SAndroid Build Coastguard Worker                                 Default: ``"mean"``
2591*da0073e9SAndroid Build Coastguard Worker        sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
2592*da0073e9SAndroid Build Coastguard Worker                                 :class:`torch.nn.Embedding` for more details regarding sparse gradients.
2593*da0073e9SAndroid Build Coastguard Worker                                 Note: this option is not supported when ``mode="max"``.
2594*da0073e9SAndroid Build Coastguard Worker        per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
2595*da0073e9SAndroid Build Coastguard Worker            to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
2596*da0073e9SAndroid Build Coastguard Worker            must have exactly the same shape as input and is treated as having the same
2597*da0073e9SAndroid Build Coastguard Worker            :attr:`offsets`, if those are not None.
2598*da0073e9SAndroid Build Coastguard Worker
2599*da0073e9SAndroid Build Coastguard Worker        include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1.
2600*da0073e9SAndroid Build Coastguard Worker            The last element is the size of the input, or the ending index position of the last bag (sequence).
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
2603*da0073e9SAndroid Build Coastguard Worker                                     gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
2604*da0073e9SAndroid Build Coastguard Worker                                     during training, i.e. it remains as a fixed "pad". Note that the embedding
2605*da0073e9SAndroid Build Coastguard Worker                                     vector at :attr:`padding_idx` is excluded from the reduction.
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker    Shape:
2608*da0073e9SAndroid Build Coastguard Worker        - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker          - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
2611*da0073e9SAndroid Build Coastguard Worker            each of fixed length ``N``, and this will return ``B`` values aggregated in a way
2612*da0073e9SAndroid Build Coastguard Worker            depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker          - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
2615*da0073e9SAndroid Build Coastguard Worker            multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing
2616*da0073e9SAndroid Build Coastguard Worker            the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets`
2617*da0073e9SAndroid Build Coastguard Worker            of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags.
2618*da0073e9SAndroid Build Coastguard Worker            Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker        - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
2621*da0073e9SAndroid Build Coastguard Worker
2622*da0073e9SAndroid Build Coastguard Worker        - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`.
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
2625*da0073e9SAndroid Build Coastguard Worker
2626*da0073e9SAndroid Build Coastguard Worker    Examples::
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker        >>> # an Embedding module containing 10 tensors of size 3
2629*da0073e9SAndroid Build Coastguard Worker        >>> embedding_matrix = torch.rand(10, 3)
2630*da0073e9SAndroid Build Coastguard Worker        >>> # a batch of 2 samples of 4 indices each
2631*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
2632*da0073e9SAndroid Build Coastguard Worker        >>> offsets = torch.tensor([0, 4])
2633*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2634*da0073e9SAndroid Build Coastguard Worker        >>> F.embedding_bag(input, embedding_matrix, offsets)
2635*da0073e9SAndroid Build Coastguard Worker        tensor([[ 0.3397,  0.3552,  0.5545],
2636*da0073e9SAndroid Build Coastguard Worker                [ 0.5893,  0.4386,  0.5882]])
2637*da0073e9SAndroid Build Coastguard Worker
2638*da0073e9SAndroid Build Coastguard Worker        >>> # example with padding_idx
2639*da0073e9SAndroid Build Coastguard Worker        >>> embedding_matrix = torch.rand(10, 3)
2640*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
2641*da0073e9SAndroid Build Coastguard Worker        >>> offsets = torch.tensor([0, 4])
2642*da0073e9SAndroid Build Coastguard Worker        >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')
2643*da0073e9SAndroid Build Coastguard Worker        tensor([[ 0.0000,  0.0000,  0.0000],
2644*da0073e9SAndroid Build Coastguard Worker                [-0.7082,  3.2145, -2.6251]])
2645*da0073e9SAndroid Build Coastguard Worker    """
2646*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, weight, offsets, per_sample_weights):
2647*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2648*da0073e9SAndroid Build Coastguard Worker            embedding_bag,
2649*da0073e9SAndroid Build Coastguard Worker            (input, weight, offsets, per_sample_weights),
2650*da0073e9SAndroid Build Coastguard Worker            input,
2651*da0073e9SAndroid Build Coastguard Worker            weight,
2652*da0073e9SAndroid Build Coastguard Worker            offsets=offsets,
2653*da0073e9SAndroid Build Coastguard Worker            max_norm=max_norm,
2654*da0073e9SAndroid Build Coastguard Worker            norm_type=norm_type,
2655*da0073e9SAndroid Build Coastguard Worker            scale_grad_by_freq=scale_grad_by_freq,
2656*da0073e9SAndroid Build Coastguard Worker            mode=mode,
2657*da0073e9SAndroid Build Coastguard Worker            sparse=sparse,
2658*da0073e9SAndroid Build Coastguard Worker            per_sample_weights=per_sample_weights,
2659*da0073e9SAndroid Build Coastguard Worker            include_last_offset=include_last_offset,
2660*da0073e9SAndroid Build Coastguard Worker            padding_idx=padding_idx,
2661*da0073e9SAndroid Build Coastguard Worker        )
2662*da0073e9SAndroid Build Coastguard Worker    # Check for backward compatibility.
2663*da0073e9SAndroid Build Coastguard Worker    # Used to be embedding_bag(weight, input, ...)
2664*da0073e9SAndroid Build Coastguard Worker    # Now is     embedding_bag(input, weight, ...)
2665*da0073e9SAndroid Build Coastguard Worker    if weight.dtype == torch.long and input.is_floating_point():
2666*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
2667*da0073e9SAndroid Build Coastguard Worker            "Argument order of nn.functional.embedding_bag was changed. "
2668*da0073e9SAndroid Build Coastguard Worker            "Usage `embedding_bag(weight, input, ...)` is deprecated, "
2669*da0073e9SAndroid Build Coastguard Worker            "and should now be `embedding_bag(input, weight, ...)`."
2670*da0073e9SAndroid Build Coastguard Worker        )
2671*da0073e9SAndroid Build Coastguard Worker        weight, input = input, weight
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker    if per_sample_weights is not None and input.size() != per_sample_weights.size():
2674*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2675*da0073e9SAndroid Build Coastguard Worker            f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, "
2676*da0073e9SAndroid Build Coastguard Worker            f"then it must have the same shape as the input ({input.shape})"
2677*da0073e9SAndroid Build Coastguard Worker        )
2678*da0073e9SAndroid Build Coastguard Worker
2679*da0073e9SAndroid Build Coastguard Worker    if not weight.dim() == 2:
2680*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2681*da0073e9SAndroid Build Coastguard Worker            f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}"
2682*da0073e9SAndroid Build Coastguard Worker        )
2683*da0073e9SAndroid Build Coastguard Worker
2684*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 2:
2685*da0073e9SAndroid Build Coastguard Worker        if offsets is not None:
2686*da0073e9SAndroid Build Coastguard Worker            type_str = "<unknown>"
2687*da0073e9SAndroid Build Coastguard Worker            # TODO: Remove this once script supports type() calls
2688*da0073e9SAndroid Build Coastguard Worker            if not torch.jit.is_scripting():
2689*da0073e9SAndroid Build Coastguard Worker                type_str = str(type(offsets))
2690*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
2691*da0073e9SAndroid Build Coastguard Worker                "if input is 2D, then offsets has to be None"
2692*da0073e9SAndroid Build Coastguard Worker                ", as input is treated is a mini-batch of"
2693*da0073e9SAndroid Build Coastguard Worker                " fixed length sequences. However, found "
2694*da0073e9SAndroid Build Coastguard Worker                f"offsets of type {type_str}"
2695*da0073e9SAndroid Build Coastguard Worker            )
2696*da0073e9SAndroid Build Coastguard Worker        offsets = torch.arange(
2697*da0073e9SAndroid Build Coastguard Worker            0, input.numel(), input.size(1), dtype=input.dtype, device=input.device
2698*da0073e9SAndroid Build Coastguard Worker        )
2699*da0073e9SAndroid Build Coastguard Worker
2700*da0073e9SAndroid Build Coastguard Worker        input = input.reshape(-1)
2701*da0073e9SAndroid Build Coastguard Worker        if per_sample_weights is not None:
2702*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = per_sample_weights.reshape(-1)
2703*da0073e9SAndroid Build Coastguard Worker    elif input.dim() == 1:
2704*da0073e9SAndroid Build Coastguard Worker        if offsets is None:
2705*da0073e9SAndroid Build Coastguard Worker            raise ValueError("offsets has to be a 1D Tensor but got None")
2706*da0073e9SAndroid Build Coastguard Worker        if offsets.dim() != 1:
2707*da0073e9SAndroid Build Coastguard Worker            raise ValueError("offsets has to be a 1D Tensor")
2708*da0073e9SAndroid Build Coastguard Worker    else:
2709*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2710*da0073e9SAndroid Build Coastguard Worker            f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}"
2711*da0073e9SAndroid Build Coastguard Worker        )
2712*da0073e9SAndroid Build Coastguard Worker    if mode == "sum":
2713*da0073e9SAndroid Build Coastguard Worker        mode_enum = 0
2714*da0073e9SAndroid Build Coastguard Worker    elif mode == "mean":
2715*da0073e9SAndroid Build Coastguard Worker        mode_enum = 1
2716*da0073e9SAndroid Build Coastguard Worker    elif mode == "max":
2717*da0073e9SAndroid Build Coastguard Worker        mode_enum = 2
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        if scale_grad_by_freq:
2720*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
2721*da0073e9SAndroid Build Coastguard Worker                "max mode does not support scaling the gradient by the frequency"
2722*da0073e9SAndroid Build Coastguard Worker            )
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker        if sparse:
2725*da0073e9SAndroid Build Coastguard Worker            raise ValueError("max mode does not support sparse weights")
2726*da0073e9SAndroid Build Coastguard Worker
2727*da0073e9SAndroid Build Coastguard Worker    else:
2728*da0073e9SAndroid Build Coastguard Worker        raise ValueError("mode has to be one of sum, mean or max")
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker    if max_norm is not None:
2731*da0073e9SAndroid Build Coastguard Worker        # XXX: equivalent to
2732*da0073e9SAndroid Build Coastguard Worker        # with torch.no_grad():
2733*da0073e9SAndroid Build Coastguard Worker        #   torch.nembedding_renorm_
2734*da0073e9SAndroid Build Coastguard Worker        # remove once script supports set_grad_enabled
2735*da0073e9SAndroid Build Coastguard Worker        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker    if per_sample_weights is not None and mode != "sum":
2738*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
2739*da0073e9SAndroid Build Coastguard Worker            "embedding_bag: per_sample_weights was not None. "
2740*da0073e9SAndroid Build Coastguard Worker            "per_sample_weights is only supported for mode='sum' "
2741*da0073e9SAndroid Build Coastguard Worker            f"(got mode='{mode}'). Please open a feature request on GitHub."
2742*da0073e9SAndroid Build Coastguard Worker        )
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker    ret, _, _, _ = torch.embedding_bag(
2745*da0073e9SAndroid Build Coastguard Worker        weight,
2746*da0073e9SAndroid Build Coastguard Worker        input,
2747*da0073e9SAndroid Build Coastguard Worker        offsets,
2748*da0073e9SAndroid Build Coastguard Worker        scale_grad_by_freq,
2749*da0073e9SAndroid Build Coastguard Worker        mode_enum,
2750*da0073e9SAndroid Build Coastguard Worker        sparse,
2751*da0073e9SAndroid Build Coastguard Worker        per_sample_weights,
2752*da0073e9SAndroid Build Coastguard Worker        include_last_offset,
2753*da0073e9SAndroid Build Coastguard Worker        padding_idx,
2754*da0073e9SAndroid Build Coastguard Worker    )
2755*da0073e9SAndroid Build Coastguard Worker    return ret
2756*da0073e9SAndroid Build Coastguard Worker
2757*da0073e9SAndroid Build Coastguard Worker
2758*da0073e9SAndroid Build Coastguard Workerif embedding_bag.__doc__:
2759*da0073e9SAndroid Build Coastguard Worker    embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
2760*da0073e9SAndroid Build Coastguard Worker
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Workerdef _verify_batch_size(size: List[int]) -> None:
2763*da0073e9SAndroid Build Coastguard Worker    # XXX: JIT script does not support the reduce from functools, and mul op is a
2764*da0073e9SAndroid Build Coastguard Worker    # builtin, which cannot be used as a value to a func yet, so rewrite this size
2765*da0073e9SAndroid Build Coastguard Worker    # check to a simple equivalent for loop
2766*da0073e9SAndroid Build Coastguard Worker    #
2767*da0073e9SAndroid Build Coastguard Worker    # TODO: make use of reduce like below when JIT is ready with the missing features:
2768*da0073e9SAndroid Build Coastguard Worker    # from operator import mul
2769*da0073e9SAndroid Build Coastguard Worker    # from functools import reduce
2770*da0073e9SAndroid Build Coastguard Worker    #
2771*da0073e9SAndroid Build Coastguard Worker    #   if reduce(mul, size[2:], size[0]) == 1
2772*da0073e9SAndroid Build Coastguard Worker    size_prods = size[0]
2773*da0073e9SAndroid Build Coastguard Worker    for i in range(len(size) - 2):
2774*da0073e9SAndroid Build Coastguard Worker        size_prods *= size[i + 2]
2775*da0073e9SAndroid Build Coastguard Worker    if size_prods == 1:
2776*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2777*da0073e9SAndroid Build Coastguard Worker            f"Expected more than 1 value per channel when training, got input size {size}"
2778*da0073e9SAndroid Build Coastguard Worker        )
2779*da0073e9SAndroid Build Coastguard Worker
2780*da0073e9SAndroid Build Coastguard Worker
2781*da0073e9SAndroid Build Coastguard Workerdef batch_norm(
2782*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2783*da0073e9SAndroid Build Coastguard Worker    running_mean: Optional[Tensor],
2784*da0073e9SAndroid Build Coastguard Worker    running_var: Optional[Tensor],
2785*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
2786*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = None,
2787*da0073e9SAndroid Build Coastguard Worker    training: bool = False,
2788*da0073e9SAndroid Build Coastguard Worker    momentum: float = 0.1,
2789*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-5,
2790*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2791*da0073e9SAndroid Build Coastguard Worker    r"""Apply Batch Normalization for each channel across a batch of data.
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
2794*da0073e9SAndroid Build Coastguard Worker    :class:`~torch.nn.BatchNorm3d` for details.
2795*da0073e9SAndroid Build Coastguard Worker    """
2796*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
2797*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2798*da0073e9SAndroid Build Coastguard Worker            batch_norm,
2799*da0073e9SAndroid Build Coastguard Worker            (input, running_mean, running_var, weight, bias),
2800*da0073e9SAndroid Build Coastguard Worker            input,
2801*da0073e9SAndroid Build Coastguard Worker            running_mean,
2802*da0073e9SAndroid Build Coastguard Worker            running_var,
2803*da0073e9SAndroid Build Coastguard Worker            weight=weight,
2804*da0073e9SAndroid Build Coastguard Worker            bias=bias,
2805*da0073e9SAndroid Build Coastguard Worker            training=training,
2806*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
2807*da0073e9SAndroid Build Coastguard Worker            eps=eps,
2808*da0073e9SAndroid Build Coastguard Worker        )
2809*da0073e9SAndroid Build Coastguard Worker    if training:
2810*da0073e9SAndroid Build Coastguard Worker        _verify_batch_size(input.size())
2811*da0073e9SAndroid Build Coastguard Worker
2812*da0073e9SAndroid Build Coastguard Worker    return torch.batch_norm(
2813*da0073e9SAndroid Build Coastguard Worker        input,
2814*da0073e9SAndroid Build Coastguard Worker        weight,
2815*da0073e9SAndroid Build Coastguard Worker        bias,
2816*da0073e9SAndroid Build Coastguard Worker        running_mean,
2817*da0073e9SAndroid Build Coastguard Worker        running_var,
2818*da0073e9SAndroid Build Coastguard Worker        training,
2819*da0073e9SAndroid Build Coastguard Worker        momentum,
2820*da0073e9SAndroid Build Coastguard Worker        eps,
2821*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.enabled,
2822*da0073e9SAndroid Build Coastguard Worker    )
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Workerdef _verify_spatial_size(size: List[int]) -> None:
2826*da0073e9SAndroid Build Coastguard Worker    # Verify that there is > 1 spatial element for instance norm calculation.
2827*da0073e9SAndroid Build Coastguard Worker    size_prods = 1
2828*da0073e9SAndroid Build Coastguard Worker    for i in range(2, len(size)):
2829*da0073e9SAndroid Build Coastguard Worker        size_prods *= size[i]
2830*da0073e9SAndroid Build Coastguard Worker    if size_prods == 1:
2831*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2832*da0073e9SAndroid Build Coastguard Worker            f"Expected more than 1 spatial element when training, got input size {size}"
2833*da0073e9SAndroid Build Coastguard Worker        )
2834*da0073e9SAndroid Build Coastguard Worker
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Workerdef instance_norm(
2837*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2838*da0073e9SAndroid Build Coastguard Worker    running_mean: Optional[Tensor] = None,
2839*da0073e9SAndroid Build Coastguard Worker    running_var: Optional[Tensor] = None,
2840*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
2841*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = None,
2842*da0073e9SAndroid Build Coastguard Worker    use_input_stats: bool = True,
2843*da0073e9SAndroid Build Coastguard Worker    momentum: float = 0.1,
2844*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-5,
2845*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2846*da0073e9SAndroid Build Coastguard Worker    r"""Apply Instance Normalization independently for each channel in every data sample within a batch.
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
2849*da0073e9SAndroid Build Coastguard Worker    :class:`~torch.nn.InstanceNorm3d` for details.
2850*da0073e9SAndroid Build Coastguard Worker    """
2851*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
2852*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2853*da0073e9SAndroid Build Coastguard Worker            instance_norm,
2854*da0073e9SAndroid Build Coastguard Worker            (input, running_mean, running_var, weight, bias),
2855*da0073e9SAndroid Build Coastguard Worker            input,
2856*da0073e9SAndroid Build Coastguard Worker            running_mean=running_mean,
2857*da0073e9SAndroid Build Coastguard Worker            running_var=running_var,
2858*da0073e9SAndroid Build Coastguard Worker            weight=weight,
2859*da0073e9SAndroid Build Coastguard Worker            bias=bias,
2860*da0073e9SAndroid Build Coastguard Worker            use_input_stats=use_input_stats,
2861*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
2862*da0073e9SAndroid Build Coastguard Worker            eps=eps,
2863*da0073e9SAndroid Build Coastguard Worker        )
2864*da0073e9SAndroid Build Coastguard Worker    if use_input_stats:
2865*da0073e9SAndroid Build Coastguard Worker        _verify_spatial_size(input.size())
2866*da0073e9SAndroid Build Coastguard Worker    return torch.instance_norm(
2867*da0073e9SAndroid Build Coastguard Worker        input,
2868*da0073e9SAndroid Build Coastguard Worker        weight,
2869*da0073e9SAndroid Build Coastguard Worker        bias,
2870*da0073e9SAndroid Build Coastguard Worker        running_mean,
2871*da0073e9SAndroid Build Coastguard Worker        running_var,
2872*da0073e9SAndroid Build Coastguard Worker        use_input_stats,
2873*da0073e9SAndroid Build Coastguard Worker        momentum,
2874*da0073e9SAndroid Build Coastguard Worker        eps,
2875*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.enabled,
2876*da0073e9SAndroid Build Coastguard Worker    )
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker
2879*da0073e9SAndroid Build Coastguard Workerdef layer_norm(
2880*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2881*da0073e9SAndroid Build Coastguard Worker    normalized_shape: List[int],
2882*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
2883*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = None,
2884*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-5,
2885*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2886*da0073e9SAndroid Build Coastguard Worker    r"""Apply Layer Normalization for last certain number of dimensions.
2887*da0073e9SAndroid Build Coastguard Worker
2888*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LayerNorm` for details.
2889*da0073e9SAndroid Build Coastguard Worker    """
2890*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, weight, bias):
2891*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2892*da0073e9SAndroid Build Coastguard Worker            layer_norm,
2893*da0073e9SAndroid Build Coastguard Worker            (input, weight, bias),
2894*da0073e9SAndroid Build Coastguard Worker            input,
2895*da0073e9SAndroid Build Coastguard Worker            normalized_shape,
2896*da0073e9SAndroid Build Coastguard Worker            weight=weight,
2897*da0073e9SAndroid Build Coastguard Worker            bias=bias,
2898*da0073e9SAndroid Build Coastguard Worker            eps=eps,
2899*da0073e9SAndroid Build Coastguard Worker        )
2900*da0073e9SAndroid Build Coastguard Worker    return torch.layer_norm(
2901*da0073e9SAndroid Build Coastguard Worker        input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled
2902*da0073e9SAndroid Build Coastguard Worker    )
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker
2905*da0073e9SAndroid Build Coastguard Workerdef rms_norm(
2906*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2907*da0073e9SAndroid Build Coastguard Worker    normalized_shape: List[int],
2908*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
2909*da0073e9SAndroid Build Coastguard Worker    eps: Optional[float] = None,
2910*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2911*da0073e9SAndroid Build Coastguard Worker    r"""Apply Root Mean Square Layer Normalization.
2912*da0073e9SAndroid Build Coastguard Worker
2913*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.RMSNorm` for details.
2914*da0073e9SAndroid Build Coastguard Worker    """
2915*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, weight):
2916*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2917*da0073e9SAndroid Build Coastguard Worker            rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps
2918*da0073e9SAndroid Build Coastguard Worker        )
2919*da0073e9SAndroid Build Coastguard Worker    return torch.rms_norm(input, normalized_shape, weight, eps)
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker
2922*da0073e9SAndroid Build Coastguard Workerdef group_norm(
2923*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2924*da0073e9SAndroid Build Coastguard Worker    num_groups: int,
2925*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
2926*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = None,
2927*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-5,
2928*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2929*da0073e9SAndroid Build Coastguard Worker    r"""Apply Group Normalization for last certain number of dimensions.
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.GroupNorm` for details.
2932*da0073e9SAndroid Build Coastguard Worker    """
2933*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, weight, bias):
2934*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2935*da0073e9SAndroid Build Coastguard Worker            group_norm,
2936*da0073e9SAndroid Build Coastguard Worker            (
2937*da0073e9SAndroid Build Coastguard Worker                input,
2938*da0073e9SAndroid Build Coastguard Worker                weight,
2939*da0073e9SAndroid Build Coastguard Worker                bias,
2940*da0073e9SAndroid Build Coastguard Worker            ),
2941*da0073e9SAndroid Build Coastguard Worker            input,
2942*da0073e9SAndroid Build Coastguard Worker            num_groups,
2943*da0073e9SAndroid Build Coastguard Worker            weight=weight,
2944*da0073e9SAndroid Build Coastguard Worker            bias=bias,
2945*da0073e9SAndroid Build Coastguard Worker            eps=eps,
2946*da0073e9SAndroid Build Coastguard Worker        )
2947*da0073e9SAndroid Build Coastguard Worker    if input.dim() < 2:
2948*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
2949*da0073e9SAndroid Build Coastguard Worker            f"Expected at least 2 dimensions for input tensor but received {input.dim()}"
2950*da0073e9SAndroid Build Coastguard Worker        )
2951*da0073e9SAndroid Build Coastguard Worker    _verify_batch_size(
2952*da0073e9SAndroid Build Coastguard Worker        [input.size(0) * input.size(1) // num_groups, num_groups]
2953*da0073e9SAndroid Build Coastguard Worker        + list(input.size()[2:])
2954*da0073e9SAndroid Build Coastguard Worker    )
2955*da0073e9SAndroid Build Coastguard Worker    return torch.group_norm(
2956*da0073e9SAndroid Build Coastguard Worker        input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled
2957*da0073e9SAndroid Build Coastguard Worker    )
2958*da0073e9SAndroid Build Coastguard Worker
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Workerdef local_response_norm(
2961*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
2962*da0073e9SAndroid Build Coastguard Worker    size: int,
2963*da0073e9SAndroid Build Coastguard Worker    alpha: float = 1e-4,
2964*da0073e9SAndroid Build Coastguard Worker    beta: float = 0.75,
2965*da0073e9SAndroid Build Coastguard Worker    k: float = 1.0,
2966*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
2967*da0073e9SAndroid Build Coastguard Worker    r"""Apply local response normalization over an input signal.
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker    The input signal is composed of several input planes, where channels occupy the second dimension.
2970*da0073e9SAndroid Build Coastguard Worker    Normalization is applied across channels.
2971*da0073e9SAndroid Build Coastguard Worker
2972*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.LocalResponseNorm` for details.
2973*da0073e9SAndroid Build Coastguard Worker    """
2974*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
2975*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2976*da0073e9SAndroid Build Coastguard Worker            local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k
2977*da0073e9SAndroid Build Coastguard Worker        )
2978*da0073e9SAndroid Build Coastguard Worker    dim = input.dim()
2979*da0073e9SAndroid Build Coastguard Worker    if dim < 3:
2980*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
2981*da0073e9SAndroid Build Coastguard Worker            f"Expected 3D or higher dimensionality                          input (got {dim} dimensions)"
2982*da0073e9SAndroid Build Coastguard Worker        )
2983*da0073e9SAndroid Build Coastguard Worker
2984*da0073e9SAndroid Build Coastguard Worker    if input.numel() == 0:
2985*da0073e9SAndroid Build Coastguard Worker        return input
2986*da0073e9SAndroid Build Coastguard Worker
2987*da0073e9SAndroid Build Coastguard Worker    div = input.mul(input)
2988*da0073e9SAndroid Build Coastguard Worker    if dim == 3:
2989*da0073e9SAndroid Build Coastguard Worker        div = div.unsqueeze(1)
2990*da0073e9SAndroid Build Coastguard Worker        div = pad(div, (0, 0, size // 2, (size - 1) // 2))
2991*da0073e9SAndroid Build Coastguard Worker        div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
2992*da0073e9SAndroid Build Coastguard Worker    else:
2993*da0073e9SAndroid Build Coastguard Worker        sizes = input.size()
2994*da0073e9SAndroid Build Coastguard Worker        div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
2995*da0073e9SAndroid Build Coastguard Worker        div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
2996*da0073e9SAndroid Build Coastguard Worker        div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
2997*da0073e9SAndroid Build Coastguard Worker        div = div.view(sizes)
2998*da0073e9SAndroid Build Coastguard Worker    div = div.mul(alpha).add(k).pow(beta)
2999*da0073e9SAndroid Build Coastguard Worker    return input / div
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker
3002*da0073e9SAndroid Build Coastguard Worker# loss
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker
3005*da0073e9SAndroid Build Coastguard Workerdef ctc_loss(
3006*da0073e9SAndroid Build Coastguard Worker    log_probs: Tensor,
3007*da0073e9SAndroid Build Coastguard Worker    targets: Tensor,
3008*da0073e9SAndroid Build Coastguard Worker    input_lengths: Tensor,
3009*da0073e9SAndroid Build Coastguard Worker    target_lengths: Tensor,
3010*da0073e9SAndroid Build Coastguard Worker    blank: int = 0,
3011*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3012*da0073e9SAndroid Build Coastguard Worker    zero_infinity: bool = False,
3013*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3014*da0073e9SAndroid Build Coastguard Worker    r"""Apply the Connectionist Temporal Classification loss.
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.CTCLoss` for details.
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker    Note:
3019*da0073e9SAndroid Build Coastguard Worker        {cudnn_reproducibility_note}
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker    Note:
3022*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
3023*da0073e9SAndroid Build Coastguard Worker
3024*da0073e9SAndroid Build Coastguard Worker    Args:
3025*da0073e9SAndroid Build Coastguard Worker        log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`,
3026*da0073e9SAndroid Build Coastguard Worker            `T = input length`, and `N = batch size`.
3027*da0073e9SAndroid Build Coastguard Worker            The logarithmized probabilities of the outputs
3028*da0073e9SAndroid Build Coastguard Worker            (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
3029*da0073e9SAndroid Build Coastguard Worker        targets: :math:`(N, S)` or `(sum(target_lengths))`.
3030*da0073e9SAndroid Build Coastguard Worker            Targets cannot be blank. In the second form, the targets are assumed to be concatenated.
3031*da0073e9SAndroid Build Coastguard Worker        input_lengths: :math:`(N)` or :math:`()`.
3032*da0073e9SAndroid Build Coastguard Worker            Lengths of the inputs (must each be :math:`\leq T`)
3033*da0073e9SAndroid Build Coastguard Worker        target_lengths: :math:`(N)` or :math:`()`.
3034*da0073e9SAndroid Build Coastguard Worker            Lengths of the targets
3035*da0073e9SAndroid Build Coastguard Worker        blank (int, optional):
3036*da0073e9SAndroid Build Coastguard Worker            Blank label. Default :math:`0`.
3037*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3038*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3039*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the output losses will be divided by the target lengths and
3040*da0073e9SAndroid Build Coastguard Worker            then the mean over the batch is taken, ``'sum'``: the output will be
3041*da0073e9SAndroid Build Coastguard Worker            summed. Default: ``'mean'``
3042*da0073e9SAndroid Build Coastguard Worker        zero_infinity (bool, optional):
3043*da0073e9SAndroid Build Coastguard Worker            Whether to zero infinite losses and the associated gradients.
3044*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
3045*da0073e9SAndroid Build Coastguard Worker            Infinite losses mainly occur when the inputs are too short
3046*da0073e9SAndroid Build Coastguard Worker            to be aligned to the targets.
3047*da0073e9SAndroid Build Coastguard Worker
3048*da0073e9SAndroid Build Coastguard Worker    Example::
3049*da0073e9SAndroid Build Coastguard Worker
3050*da0073e9SAndroid Build Coastguard Worker        >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
3051*da0073e9SAndroid Build Coastguard Worker        >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
3052*da0073e9SAndroid Build Coastguard Worker        >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
3053*da0073e9SAndroid Build Coastguard Worker        >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
3054*da0073e9SAndroid Build Coastguard Worker        >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
3055*da0073e9SAndroid Build Coastguard Worker        >>> loss.backward()
3056*da0073e9SAndroid Build Coastguard Worker    """
3057*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths):
3058*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3059*da0073e9SAndroid Build Coastguard Worker            ctc_loss,
3060*da0073e9SAndroid Build Coastguard Worker            (log_probs, targets, input_lengths, target_lengths),
3061*da0073e9SAndroid Build Coastguard Worker            log_probs,
3062*da0073e9SAndroid Build Coastguard Worker            targets,
3063*da0073e9SAndroid Build Coastguard Worker            input_lengths,
3064*da0073e9SAndroid Build Coastguard Worker            target_lengths,
3065*da0073e9SAndroid Build Coastguard Worker            blank=blank,
3066*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3067*da0073e9SAndroid Build Coastguard Worker            zero_infinity=zero_infinity,
3068*da0073e9SAndroid Build Coastguard Worker        )
3069*da0073e9SAndroid Build Coastguard Worker    return torch.ctc_loss(
3070*da0073e9SAndroid Build Coastguard Worker        log_probs,
3071*da0073e9SAndroid Build Coastguard Worker        targets,
3072*da0073e9SAndroid Build Coastguard Worker        input_lengths,
3073*da0073e9SAndroid Build Coastguard Worker        target_lengths,
3074*da0073e9SAndroid Build Coastguard Worker        blank,
3075*da0073e9SAndroid Build Coastguard Worker        _Reduction.get_enum(reduction),
3076*da0073e9SAndroid Build Coastguard Worker        zero_infinity,
3077*da0073e9SAndroid Build Coastguard Worker    )
3078*da0073e9SAndroid Build Coastguard Worker
3079*da0073e9SAndroid Build Coastguard Worker
3080*da0073e9SAndroid Build Coastguard Workerif ctc_loss.__doc__:
3081*da0073e9SAndroid Build Coastguard Worker    ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes)
3082*da0073e9SAndroid Build Coastguard Worker
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Workerdef nll_loss(
3085*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3086*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3087*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
3088*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3089*da0073e9SAndroid Build Coastguard Worker    ignore_index: int = -100,
3090*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3091*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3092*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3093*da0073e9SAndroid Build Coastguard Worker    r"""Compute the negative log likelihood loss.
3094*da0073e9SAndroid Build Coastguard Worker
3095*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.NLLLoss` for details.
3096*da0073e9SAndroid Build Coastguard Worker
3097*da0073e9SAndroid Build Coastguard Worker    Args:
3098*da0073e9SAndroid Build Coastguard Worker        input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
3099*da0073e9SAndroid Build Coastguard Worker            in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1`
3100*da0073e9SAndroid Build Coastguard Worker            in the case of K-dimensional loss. `input` is expected to be log-probabilities.
3101*da0073e9SAndroid Build Coastguard Worker        target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`,
3102*da0073e9SAndroid Build Coastguard Worker            or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for
3103*da0073e9SAndroid Build Coastguard Worker            K-dimensional loss.
3104*da0073e9SAndroid Build Coastguard Worker        weight (Tensor, optional): a manual rescaling weight given to each
3105*da0073e9SAndroid Build Coastguard Worker            class. If given, has to be a Tensor of size `C`
3106*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3107*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3108*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3109*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3110*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3111*da0073e9SAndroid Build Coastguard Worker        ignore_index (int, optional): Specifies a target value that is ignored
3112*da0073e9SAndroid Build Coastguard Worker            and does not contribute to the input gradient. When :attr:`size_average` is
3113*da0073e9SAndroid Build Coastguard Worker            ``True``, the loss is averaged over non-ignored targets. Default: -100
3114*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3115*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3116*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3117*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3118*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3119*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3120*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the sum of the output will be divided by the number of
3121*da0073e9SAndroid Build Coastguard Worker            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3122*da0073e9SAndroid Build Coastguard Worker            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3123*da0073e9SAndroid Build Coastguard Worker            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3124*da0073e9SAndroid Build Coastguard Worker
3125*da0073e9SAndroid Build Coastguard Worker    Example::
3126*da0073e9SAndroid Build Coastguard Worker
3127*da0073e9SAndroid Build Coastguard Worker        >>> # input is of size N x C = 3 x 5
3128*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(3, 5, requires_grad=True)
3129*da0073e9SAndroid Build Coastguard Worker        >>> # each element in target has to have 0 <= value < C
3130*da0073e9SAndroid Build Coastguard Worker        >>> target = torch.tensor([1, 0, 4])
3131*da0073e9SAndroid Build Coastguard Worker        >>> output = F.nll_loss(F.log_softmax(input, dim=1), target)
3132*da0073e9SAndroid Build Coastguard Worker        >>> output.backward()
3133*da0073e9SAndroid Build Coastguard Worker    """
3134*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight):
3135*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3136*da0073e9SAndroid Build Coastguard Worker            nll_loss,
3137*da0073e9SAndroid Build Coastguard Worker            (input, target, weight),
3138*da0073e9SAndroid Build Coastguard Worker            input,
3139*da0073e9SAndroid Build Coastguard Worker            target,
3140*da0073e9SAndroid Build Coastguard Worker            weight=weight,
3141*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3142*da0073e9SAndroid Build Coastguard Worker            ignore_index=ignore_index,
3143*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3144*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3145*da0073e9SAndroid Build Coastguard Worker        )
3146*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3147*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3148*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.nll_loss_nd(
3149*da0073e9SAndroid Build Coastguard Worker        input, target, weight, _Reduction.get_enum(reduction), ignore_index
3150*da0073e9SAndroid Build Coastguard Worker    )
3151*da0073e9SAndroid Build Coastguard Worker
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Workerdef poisson_nll_loss(
3154*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3155*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3156*da0073e9SAndroid Build Coastguard Worker    log_input: bool = True,
3157*da0073e9SAndroid Build Coastguard Worker    full: bool = False,
3158*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3159*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-8,
3160*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3161*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3162*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3163*da0073e9SAndroid Build Coastguard Worker    r"""Poisson negative log likelihood loss.
3164*da0073e9SAndroid Build Coastguard Worker
3165*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.PoissonNLLLoss` for details.
3166*da0073e9SAndroid Build Coastguard Worker
3167*da0073e9SAndroid Build Coastguard Worker    Args:
3168*da0073e9SAndroid Build Coastguard Worker        input: expectation of underlying Poisson distribution.
3169*da0073e9SAndroid Build Coastguard Worker        target: random sample :math:`target \sim \text{Poisson}(input)`.
3170*da0073e9SAndroid Build Coastguard Worker        log_input: if ``True`` the loss is computed as
3171*da0073e9SAndroid Build Coastguard Worker            :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is
3172*da0073e9SAndroid Build Coastguard Worker            :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True``
3173*da0073e9SAndroid Build Coastguard Worker        full: whether to compute full loss, i. e. to add the Stirling
3174*da0073e9SAndroid Build Coastguard Worker            approximation term. Default: ``False``
3175*da0073e9SAndroid Build Coastguard Worker            :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`.
3176*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3177*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3178*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3179*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3180*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3181*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
3182*da0073e9SAndroid Build Coastguard Worker            :attr:`log_input`\ =\ ``False``. Default: 1e-8
3183*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3184*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3185*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3186*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3187*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3188*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3189*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the sum of the output will be divided by the number of
3190*da0073e9SAndroid Build Coastguard Worker            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3191*da0073e9SAndroid Build Coastguard Worker            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3192*da0073e9SAndroid Build Coastguard Worker            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3193*da0073e9SAndroid Build Coastguard Worker
3194*da0073e9SAndroid Build Coastguard Worker    """
3195*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3196*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3197*da0073e9SAndroid Build Coastguard Worker            poisson_nll_loss,
3198*da0073e9SAndroid Build Coastguard Worker            (input, target),
3199*da0073e9SAndroid Build Coastguard Worker            input,
3200*da0073e9SAndroid Build Coastguard Worker            target,
3201*da0073e9SAndroid Build Coastguard Worker            log_input=log_input,
3202*da0073e9SAndroid Build Coastguard Worker            full=full,
3203*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3204*da0073e9SAndroid Build Coastguard Worker            eps=eps,
3205*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3206*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3207*da0073e9SAndroid Build Coastguard Worker        )
3208*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3209*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3210*da0073e9SAndroid Build Coastguard Worker    if reduction != "none" and reduction != "mean" and reduction != "sum":
3211*da0073e9SAndroid Build Coastguard Worker        ret = input
3212*da0073e9SAndroid Build Coastguard Worker        raise ValueError(reduction + " is not a valid value for reduction")
3213*da0073e9SAndroid Build Coastguard Worker
3214*da0073e9SAndroid Build Coastguard Worker    ret = torch.poisson_nll_loss(
3215*da0073e9SAndroid Build Coastguard Worker        input, target, log_input, full, eps, _Reduction.get_enum(reduction)
3216*da0073e9SAndroid Build Coastguard Worker    )
3217*da0073e9SAndroid Build Coastguard Worker    return ret
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Workerdef gaussian_nll_loss(
3221*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3222*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3223*da0073e9SAndroid Build Coastguard Worker    var: Tensor,
3224*da0073e9SAndroid Build Coastguard Worker    full: bool = False,
3225*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-6,
3226*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3227*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3228*da0073e9SAndroid Build Coastguard Worker    r"""Gaussian negative log likelihood loss.
3229*da0073e9SAndroid Build Coastguard Worker
3230*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.GaussianNLLLoss` for details.
3231*da0073e9SAndroid Build Coastguard Worker
3232*da0073e9SAndroid Build Coastguard Worker    Args:
3233*da0073e9SAndroid Build Coastguard Worker        input: expectation of the Gaussian distribution.
3234*da0073e9SAndroid Build Coastguard Worker        target: sample from the Gaussian distribution.
3235*da0073e9SAndroid Build Coastguard Worker        var: tensor of positive variance(s), one for each of the expectations
3236*da0073e9SAndroid Build Coastguard Worker            in the input (heteroscedastic), or a single one (homoscedastic).
3237*da0073e9SAndroid Build Coastguard Worker        full (bool, optional): include the constant term in the loss calculation. Default: ``False``.
3238*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): value added to var, for stability. Default: 1e-6.
3239*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): specifies the reduction to apply to the output:
3240*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3241*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the output is the average of all batch member losses,
3242*da0073e9SAndroid Build Coastguard Worker            ``'sum'``: the output is the sum of all batch member losses.
3243*da0073e9SAndroid Build Coastguard Worker            Default: ``'mean'``.
3244*da0073e9SAndroid Build Coastguard Worker    """
3245*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, var):
3246*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3247*da0073e9SAndroid Build Coastguard Worker            gaussian_nll_loss,
3248*da0073e9SAndroid Build Coastguard Worker            (input, target, var),
3249*da0073e9SAndroid Build Coastguard Worker            input,
3250*da0073e9SAndroid Build Coastguard Worker            target,
3251*da0073e9SAndroid Build Coastguard Worker            var,
3252*da0073e9SAndroid Build Coastguard Worker            full=full,
3253*da0073e9SAndroid Build Coastguard Worker            eps=eps,
3254*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3255*da0073e9SAndroid Build Coastguard Worker        )
3256*da0073e9SAndroid Build Coastguard Worker
3257*da0073e9SAndroid Build Coastguard Worker    # Check var size
3258*da0073e9SAndroid Build Coastguard Worker    # If var.size == input.size, the case is heteroscedastic and no further checks are needed.
3259*da0073e9SAndroid Build Coastguard Worker    # Otherwise:
3260*da0073e9SAndroid Build Coastguard Worker    if var.size() != input.size():
3261*da0073e9SAndroid Build Coastguard Worker        # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case.
3262*da0073e9SAndroid Build Coastguard Worker        # e.g. input.size = (10, 2, 3), var.size = (10, 2)
3263*da0073e9SAndroid Build Coastguard Worker        # -> unsqueeze var so that var.shape = (10, 2, 1)
3264*da0073e9SAndroid Build Coastguard Worker        # this is done so that broadcasting can happen in the loss calculation
3265*da0073e9SAndroid Build Coastguard Worker        if input.size()[:-1] == var.size():
3266*da0073e9SAndroid Build Coastguard Worker            var = torch.unsqueeze(var, -1)
3267*da0073e9SAndroid Build Coastguard Worker
3268*da0073e9SAndroid Build Coastguard Worker        # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
3269*da0073e9SAndroid Build Coastguard Worker        # This is also a homoscedastic case.
3270*da0073e9SAndroid Build Coastguard Worker        # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
3271*da0073e9SAndroid Build Coastguard Worker        elif (
3272*da0073e9SAndroid Build Coastguard Worker            input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1
3273*da0073e9SAndroid Build Coastguard Worker        ):  # Heteroscedastic case
3274*da0073e9SAndroid Build Coastguard Worker            pass
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker        # If none of the above pass, then the size of var is incorrect.
3277*da0073e9SAndroid Build Coastguard Worker        else:
3278*da0073e9SAndroid Build Coastguard Worker            raise ValueError("var is of incorrect size")
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker    # Check validity of reduction mode
3281*da0073e9SAndroid Build Coastguard Worker    if reduction != "none" and reduction != "mean" and reduction != "sum":
3282*da0073e9SAndroid Build Coastguard Worker        raise ValueError(reduction + " is not valid")
3283*da0073e9SAndroid Build Coastguard Worker
3284*da0073e9SAndroid Build Coastguard Worker    # Entries of var must be non-negative
3285*da0073e9SAndroid Build Coastguard Worker    if torch.any(var < 0):
3286*da0073e9SAndroid Build Coastguard Worker        raise ValueError("var has negative entry/entries")
3287*da0073e9SAndroid Build Coastguard Worker
3288*da0073e9SAndroid Build Coastguard Worker    # Clamp for stability
3289*da0073e9SAndroid Build Coastguard Worker    var = var.clone()
3290*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
3291*da0073e9SAndroid Build Coastguard Worker        var.clamp_(min=eps)
3292*da0073e9SAndroid Build Coastguard Worker
3293*da0073e9SAndroid Build Coastguard Worker    # Calculate the loss
3294*da0073e9SAndroid Build Coastguard Worker    loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
3295*da0073e9SAndroid Build Coastguard Worker    if full:
3296*da0073e9SAndroid Build Coastguard Worker        loss += 0.5 * math.log(2 * math.pi)
3297*da0073e9SAndroid Build Coastguard Worker
3298*da0073e9SAndroid Build Coastguard Worker    if reduction == "mean":
3299*da0073e9SAndroid Build Coastguard Worker        return loss.mean()
3300*da0073e9SAndroid Build Coastguard Worker    elif reduction == "sum":
3301*da0073e9SAndroid Build Coastguard Worker        return loss.sum()
3302*da0073e9SAndroid Build Coastguard Worker    else:
3303*da0073e9SAndroid Build Coastguard Worker        return loss
3304*da0073e9SAndroid Build Coastguard Worker
3305*da0073e9SAndroid Build Coastguard Worker
3306*da0073e9SAndroid Build Coastguard Workerdef kl_div(
3307*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3308*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3309*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3310*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3311*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3312*da0073e9SAndroid Build Coastguard Worker    log_target: bool = False,
3313*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3314*da0073e9SAndroid Build Coastguard Worker    r"""Compute the KL Divergence loss.
3315*da0073e9SAndroid Build Coastguard Worker
3316*da0073e9SAndroid Build Coastguard Worker    Refer - The `Kullback-Leibler divergence Loss
3317*da0073e9SAndroid Build Coastguard Worker    <https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`__
3318*da0073e9SAndroid Build Coastguard Worker
3319*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.KLDivLoss` for details.
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker    Args:
3322*da0073e9SAndroid Build Coastguard Worker        input: Tensor of arbitrary shape in log-probabilities.
3323*da0073e9SAndroid Build Coastguard Worker        target: Tensor of the same shape as input. See :attr:`log_target` for
3324*da0073e9SAndroid Build Coastguard Worker            the target's interpretation.
3325*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3326*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3327*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3328*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3329*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3330*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3331*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3332*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3333*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3334*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3335*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
3336*da0073e9SAndroid Build Coastguard Worker            ``'none'``: no reduction will be applied
3337*da0073e9SAndroid Build Coastguard Worker            ``'batchmean'``: the sum of the output will be divided by the batchsize
3338*da0073e9SAndroid Build Coastguard Worker            ``'sum'``: the output will be summed
3339*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the output will be divided by the number of elements in the output
3340*da0073e9SAndroid Build Coastguard Worker            Default: ``'mean'``
3341*da0073e9SAndroid Build Coastguard Worker        log_target (bool): A flag indicating whether ``target`` is passed in the log space.
3342*da0073e9SAndroid Build Coastguard Worker            It is recommended to pass certain distributions (like ``softmax``)
3343*da0073e9SAndroid Build Coastguard Worker            in the log space to avoid numerical issues caused by explicit ``log``.
3344*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker    .. note::
3347*da0073e9SAndroid Build Coastguard Worker        :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
3348*da0073e9SAndroid Build Coastguard Worker        and in the meantime, specifying either of those two args will override :attr:`reduction`.
3349*da0073e9SAndroid Build Coastguard Worker
3350*da0073e9SAndroid Build Coastguard Worker    .. warning::
3351*da0073e9SAndroid Build Coastguard Worker        :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
3352*da0073e9SAndroid Build Coastguard Worker        :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
3353*da0073e9SAndroid Build Coastguard Worker    """
3354*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3355*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3356*da0073e9SAndroid Build Coastguard Worker            kl_div,
3357*da0073e9SAndroid Build Coastguard Worker            (input, target),
3358*da0073e9SAndroid Build Coastguard Worker            input,
3359*da0073e9SAndroid Build Coastguard Worker            target,
3360*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3361*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3362*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3363*da0073e9SAndroid Build Coastguard Worker            log_target=log_target,
3364*da0073e9SAndroid Build Coastguard Worker        )
3365*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3366*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3367*da0073e9SAndroid Build Coastguard Worker    else:
3368*da0073e9SAndroid Build Coastguard Worker        if reduction == "mean":
3369*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
3370*da0073e9SAndroid Build Coastguard Worker                "reduction: 'mean' divides the total loss by both the batch size and the support size."
3371*da0073e9SAndroid Build Coastguard Worker                "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
3372*da0073e9SAndroid Build Coastguard Worker                "'mean' will be changed to behave the same as 'batchmean' in the next major release."
3373*da0073e9SAndroid Build Coastguard Worker            )
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker        # special case for batchmean
3376*da0073e9SAndroid Build Coastguard Worker        if reduction == "batchmean":
3377*da0073e9SAndroid Build Coastguard Worker            reduction_enum = _Reduction.get_enum("sum")
3378*da0073e9SAndroid Build Coastguard Worker        else:
3379*da0073e9SAndroid Build Coastguard Worker            reduction_enum = _Reduction.get_enum(reduction)
3380*da0073e9SAndroid Build Coastguard Worker
3381*da0073e9SAndroid Build Coastguard Worker    reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target)
3382*da0073e9SAndroid Build Coastguard Worker
3383*da0073e9SAndroid Build Coastguard Worker    if reduction == "batchmean" and input.dim() != 0:
3384*da0073e9SAndroid Build Coastguard Worker        reduced = reduced / input.size()[0]
3385*da0073e9SAndroid Build Coastguard Worker
3386*da0073e9SAndroid Build Coastguard Worker    return reduced
3387*da0073e9SAndroid Build Coastguard Worker
3388*da0073e9SAndroid Build Coastguard Worker
3389*da0073e9SAndroid Build Coastguard Workerdef cross_entropy(
3390*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3391*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3392*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
3393*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3394*da0073e9SAndroid Build Coastguard Worker    ignore_index: int = -100,
3395*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3396*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3397*da0073e9SAndroid Build Coastguard Worker    label_smoothing: float = 0.0,
3398*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3399*da0073e9SAndroid Build Coastguard Worker    r"""Compute the cross entropy loss between input logits and target.
3400*da0073e9SAndroid Build Coastguard Worker
3401*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.CrossEntropyLoss` for details.
3402*da0073e9SAndroid Build Coastguard Worker
3403*da0073e9SAndroid Build Coastguard Worker    Args:
3404*da0073e9SAndroid Build Coastguard Worker        input (Tensor) : Predicted unnormalized logits;
3405*da0073e9SAndroid Build Coastguard Worker            see Shape section below for supported shapes.
3406*da0073e9SAndroid Build Coastguard Worker        target (Tensor) : Ground truth class indices or class probabilities;
3407*da0073e9SAndroid Build Coastguard Worker            see Shape section below for supported shapes.
3408*da0073e9SAndroid Build Coastguard Worker        weight (Tensor, optional): a manual rescaling weight given to each
3409*da0073e9SAndroid Build Coastguard Worker            class. If given, has to be a Tensor of size `C`
3410*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3411*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3412*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3413*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3414*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3415*da0073e9SAndroid Build Coastguard Worker        ignore_index (int, optional): Specifies a target value that is ignored
3416*da0073e9SAndroid Build Coastguard Worker            and does not contribute to the input gradient. When :attr:`size_average` is
3417*da0073e9SAndroid Build Coastguard Worker            ``True``, the loss is averaged over non-ignored targets. Note that
3418*da0073e9SAndroid Build Coastguard Worker            :attr:`ignore_index` is only applicable when the target contains class indices.
3419*da0073e9SAndroid Build Coastguard Worker            Default: -100
3420*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3421*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3422*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3423*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3424*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3425*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3426*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the sum of the output will be divided by the number of
3427*da0073e9SAndroid Build Coastguard Worker            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3428*da0073e9SAndroid Build Coastguard Worker            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3429*da0073e9SAndroid Build Coastguard Worker            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3430*da0073e9SAndroid Build Coastguard Worker        label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
3431*da0073e9SAndroid Build Coastguard Worker            of smoothing when computing the loss, where 0.0 means no smoothing. The targets
3432*da0073e9SAndroid Build Coastguard Worker            become a mixture of the original ground truth and a uniform distribution as described in
3433*da0073e9SAndroid Build Coastguard Worker            `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
3434*da0073e9SAndroid Build Coastguard Worker
3435*da0073e9SAndroid Build Coastguard Worker    Shape:
3436*da0073e9SAndroid Build Coastguard Worker        - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
3437*da0073e9SAndroid Build Coastguard Worker          in the case of `K`-dimensional loss.
3438*da0073e9SAndroid Build Coastguard Worker        - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
3439*da0073e9SAndroid Build Coastguard Worker          :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.
3440*da0073e9SAndroid Build Coastguard Worker          If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.
3441*da0073e9SAndroid Build Coastguard Worker
3442*da0073e9SAndroid Build Coastguard Worker        where:
3443*da0073e9SAndroid Build Coastguard Worker
3444*da0073e9SAndroid Build Coastguard Worker        .. math::
3445*da0073e9SAndroid Build Coastguard Worker            \begin{aligned}
3446*da0073e9SAndroid Build Coastguard Worker                C ={} & \text{number of classes} \\
3447*da0073e9SAndroid Build Coastguard Worker                N ={} & \text{batch size} \\
3448*da0073e9SAndroid Build Coastguard Worker            \end{aligned}
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker    Examples::
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker        >>> # Example of target with class indices
3453*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(3, 5, requires_grad=True)
3454*da0073e9SAndroid Build Coastguard Worker        >>> target = torch.randint(5, (3,), dtype=torch.int64)
3455*da0073e9SAndroid Build Coastguard Worker        >>> loss = F.cross_entropy(input, target)
3456*da0073e9SAndroid Build Coastguard Worker        >>> loss.backward()
3457*da0073e9SAndroid Build Coastguard Worker        >>>
3458*da0073e9SAndroid Build Coastguard Worker        >>> # Example of target with class probabilities
3459*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(3, 5, requires_grad=True)
3460*da0073e9SAndroid Build Coastguard Worker        >>> target = torch.randn(3, 5).softmax(dim=1)
3461*da0073e9SAndroid Build Coastguard Worker        >>> loss = F.cross_entropy(input, target)
3462*da0073e9SAndroid Build Coastguard Worker        >>> loss.backward()
3463*da0073e9SAndroid Build Coastguard Worker    """
3464*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight):
3465*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3466*da0073e9SAndroid Build Coastguard Worker            cross_entropy,
3467*da0073e9SAndroid Build Coastguard Worker            (input, target, weight),
3468*da0073e9SAndroid Build Coastguard Worker            input,
3469*da0073e9SAndroid Build Coastguard Worker            target,
3470*da0073e9SAndroid Build Coastguard Worker            weight=weight,
3471*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3472*da0073e9SAndroid Build Coastguard Worker            ignore_index=ignore_index,
3473*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3474*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3475*da0073e9SAndroid Build Coastguard Worker            label_smoothing=label_smoothing,
3476*da0073e9SAndroid Build Coastguard Worker        )
3477*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3478*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3479*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.cross_entropy_loss(
3480*da0073e9SAndroid Build Coastguard Worker        input,
3481*da0073e9SAndroid Build Coastguard Worker        target,
3482*da0073e9SAndroid Build Coastguard Worker        weight,
3483*da0073e9SAndroid Build Coastguard Worker        _Reduction.get_enum(reduction),
3484*da0073e9SAndroid Build Coastguard Worker        ignore_index,
3485*da0073e9SAndroid Build Coastguard Worker        label_smoothing,
3486*da0073e9SAndroid Build Coastguard Worker    )
3487*da0073e9SAndroid Build Coastguard Worker
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy(
3490*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3491*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3492*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
3493*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3494*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3495*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3496*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3497*da0073e9SAndroid Build Coastguard Worker    r"""Measure Binary Cross Entropy between the target and input probabilities.
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.BCELoss` for details.
3500*da0073e9SAndroid Build Coastguard Worker
3501*da0073e9SAndroid Build Coastguard Worker    Args:
3502*da0073e9SAndroid Build Coastguard Worker        input: Tensor of arbitrary shape as probabilities.
3503*da0073e9SAndroid Build Coastguard Worker        target: Tensor of the same shape as input with values between 0 and 1.
3504*da0073e9SAndroid Build Coastguard Worker        weight (Tensor, optional): a manual rescaling weight
3505*da0073e9SAndroid Build Coastguard Worker                if provided it's repeated to match input tensor shape
3506*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3507*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3508*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3509*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3510*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3511*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3512*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3513*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3514*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3515*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3516*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3517*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the sum of the output will be divided by the number of
3518*da0073e9SAndroid Build Coastguard Worker            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3519*da0073e9SAndroid Build Coastguard Worker            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3520*da0073e9SAndroid Build Coastguard Worker            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker    Examples::
3523*da0073e9SAndroid Build Coastguard Worker
3524*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(3, 2, requires_grad=True)
3525*da0073e9SAndroid Build Coastguard Worker        >>> target = torch.rand(3, 2, requires_grad=False)
3526*da0073e9SAndroid Build Coastguard Worker        >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target)
3527*da0073e9SAndroid Build Coastguard Worker        >>> loss.backward()
3528*da0073e9SAndroid Build Coastguard Worker    """
3529*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight):
3530*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3531*da0073e9SAndroid Build Coastguard Worker            binary_cross_entropy,
3532*da0073e9SAndroid Build Coastguard Worker            (input, target, weight),
3533*da0073e9SAndroid Build Coastguard Worker            input,
3534*da0073e9SAndroid Build Coastguard Worker            target,
3535*da0073e9SAndroid Build Coastguard Worker            weight=weight,
3536*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3537*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3538*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3539*da0073e9SAndroid Build Coastguard Worker        )
3540*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3541*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3542*da0073e9SAndroid Build Coastguard Worker    else:
3543*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3544*da0073e9SAndroid Build Coastguard Worker    if target.size() != input.size():
3545*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
3546*da0073e9SAndroid Build Coastguard Worker            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. "
3547*da0073e9SAndroid Build Coastguard Worker            "Please ensure they have the same size."
3548*da0073e9SAndroid Build Coastguard Worker        )
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker    if weight is not None:
3551*da0073e9SAndroid Build Coastguard Worker        new_size = _infer_size(target.size(), weight.size())
3552*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(new_size)
3553*da0073e9SAndroid Build Coastguard Worker
3554*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
3555*da0073e9SAndroid Build Coastguard Worker
3556*da0073e9SAndroid Build Coastguard Worker
3557*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy_with_logits(
3558*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3559*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3560*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
3561*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3562*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3563*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3564*da0073e9SAndroid Build Coastguard Worker    pos_weight: Optional[Tensor] = None,
3565*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3566*da0073e9SAndroid Build Coastguard Worker    r"""Calculate Binary Cross Entropy between target and input logits.
3567*da0073e9SAndroid Build Coastguard Worker
3568*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.BCEWithLogitsLoss` for details.
3569*da0073e9SAndroid Build Coastguard Worker
3570*da0073e9SAndroid Build Coastguard Worker    Args:
3571*da0073e9SAndroid Build Coastguard Worker        input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
3572*da0073e9SAndroid Build Coastguard Worker        target: Tensor of the same shape as input with values between 0 and 1
3573*da0073e9SAndroid Build Coastguard Worker        weight (Tensor, optional): a manual rescaling weight
3574*da0073e9SAndroid Build Coastguard Worker            if provided it's repeated to match input tensor shape
3575*da0073e9SAndroid Build Coastguard Worker        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3576*da0073e9SAndroid Build Coastguard Worker            the losses are averaged over each loss element in the batch. Note that for
3577*da0073e9SAndroid Build Coastguard Worker            some losses, there multiple elements per sample. If the field :attr:`size_average`
3578*da0073e9SAndroid Build Coastguard Worker            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3579*da0073e9SAndroid Build Coastguard Worker            when reduce is ``False``. Default: ``True``
3580*da0073e9SAndroid Build Coastguard Worker        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3581*da0073e9SAndroid Build Coastguard Worker            losses are averaged or summed over observations for each minibatch depending
3582*da0073e9SAndroid Build Coastguard Worker            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3583*da0073e9SAndroid Build Coastguard Worker            batch element instead and ignores :attr:`size_average`. Default: ``True``
3584*da0073e9SAndroid Build Coastguard Worker        reduction (str, optional): Specifies the reduction to apply to the output:
3585*da0073e9SAndroid Build Coastguard Worker            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3586*da0073e9SAndroid Build Coastguard Worker            ``'mean'``: the sum of the output will be divided by the number of
3587*da0073e9SAndroid Build Coastguard Worker            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3588*da0073e9SAndroid Build Coastguard Worker            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3589*da0073e9SAndroid Build Coastguard Worker            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3590*da0073e9SAndroid Build Coastguard Worker        pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target.
3591*da0073e9SAndroid Build Coastguard Worker            Must be a tensor with equal size along the class dimension to the number of classes.
3592*da0073e9SAndroid Build Coastguard Worker            Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired
3593*da0073e9SAndroid Build Coastguard Worker            operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of
3594*da0073e9SAndroid Build Coastguard Worker            size [B, C, H, W] will apply different pos_weights to each element of the batch or
3595*da0073e9SAndroid Build Coastguard Worker            [C, H, W] the same pos_weights across the batch. To apply the same positive weight
3596*da0073e9SAndroid Build Coastguard Worker            along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
3597*da0073e9SAndroid Build Coastguard Worker            Default: ``None``
3598*da0073e9SAndroid Build Coastguard Worker
3599*da0073e9SAndroid Build Coastguard Worker    Examples::
3600*da0073e9SAndroid Build Coastguard Worker
3601*da0073e9SAndroid Build Coastguard Worker         >>> input = torch.randn(3, requires_grad=True)
3602*da0073e9SAndroid Build Coastguard Worker         >>> target = torch.empty(3).random_(2)
3603*da0073e9SAndroid Build Coastguard Worker         >>> loss = F.binary_cross_entropy_with_logits(input, target)
3604*da0073e9SAndroid Build Coastguard Worker         >>> loss.backward()
3605*da0073e9SAndroid Build Coastguard Worker    """
3606*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight, pos_weight):
3607*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3608*da0073e9SAndroid Build Coastguard Worker            binary_cross_entropy_with_logits,
3609*da0073e9SAndroid Build Coastguard Worker            (input, target, weight, pos_weight),
3610*da0073e9SAndroid Build Coastguard Worker            input,
3611*da0073e9SAndroid Build Coastguard Worker            target,
3612*da0073e9SAndroid Build Coastguard Worker            weight=weight,
3613*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3614*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3615*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3616*da0073e9SAndroid Build Coastguard Worker            pos_weight=pos_weight,
3617*da0073e9SAndroid Build Coastguard Worker        )
3618*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3619*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3620*da0073e9SAndroid Build Coastguard Worker    else:
3621*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3622*da0073e9SAndroid Build Coastguard Worker
3623*da0073e9SAndroid Build Coastguard Worker    if not (target.size() == input.size()):
3624*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
3625*da0073e9SAndroid Build Coastguard Worker            f"Target size ({target.size()}) must be the same as input size ({input.size()})"
3626*da0073e9SAndroid Build Coastguard Worker        )
3627*da0073e9SAndroid Build Coastguard Worker
3628*da0073e9SAndroid Build Coastguard Worker    return torch.binary_cross_entropy_with_logits(
3629*da0073e9SAndroid Build Coastguard Worker        input, target, weight, pos_weight, reduction_enum
3630*da0073e9SAndroid Build Coastguard Worker    )
3631*da0073e9SAndroid Build Coastguard Worker
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Workerdef smooth_l1_loss(
3634*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3635*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3636*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3637*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3638*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3639*da0073e9SAndroid Build Coastguard Worker    beta: float = 1.0,
3640*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3641*da0073e9SAndroid Build Coastguard Worker    r"""Compute the Smooth L1 loss.
3642*da0073e9SAndroid Build Coastguard Worker
3643*da0073e9SAndroid Build Coastguard Worker    Function uses a squared term if the absolute
3644*da0073e9SAndroid Build Coastguard Worker    element-wise error falls below beta and an L1 term otherwise.
3645*da0073e9SAndroid Build Coastguard Worker
3646*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.SmoothL1Loss` for details.
3647*da0073e9SAndroid Build Coastguard Worker    """
3648*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3649*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3650*da0073e9SAndroid Build Coastguard Worker            smooth_l1_loss,
3651*da0073e9SAndroid Build Coastguard Worker            (input, target),
3652*da0073e9SAndroid Build Coastguard Worker            input,
3653*da0073e9SAndroid Build Coastguard Worker            target,
3654*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3655*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3656*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3657*da0073e9SAndroid Build Coastguard Worker            beta=beta,
3658*da0073e9SAndroid Build Coastguard Worker        )
3659*da0073e9SAndroid Build Coastguard Worker    if not (target.size() == input.size()):
3660*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
3661*da0073e9SAndroid Build Coastguard Worker            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3662*da0073e9SAndroid Build Coastguard Worker            "This will likely lead to incorrect results due to broadcasting. "
3663*da0073e9SAndroid Build Coastguard Worker            "Please ensure they have the same size.",
3664*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
3665*da0073e9SAndroid Build Coastguard Worker        )
3666*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3667*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3668*da0073e9SAndroid Build Coastguard Worker
3669*da0073e9SAndroid Build Coastguard Worker    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3670*da0073e9SAndroid Build Coastguard Worker
3671*da0073e9SAndroid Build Coastguard Worker    if beta == 0.0:
3672*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.l1_loss(
3673*da0073e9SAndroid Build Coastguard Worker            expanded_input, expanded_target, _Reduction.get_enum(reduction)
3674*da0073e9SAndroid Build Coastguard Worker        )
3675*da0073e9SAndroid Build Coastguard Worker    else:
3676*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.smooth_l1_loss(
3677*da0073e9SAndroid Build Coastguard Worker            expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
3678*da0073e9SAndroid Build Coastguard Worker        )
3679*da0073e9SAndroid Build Coastguard Worker
3680*da0073e9SAndroid Build Coastguard Worker
3681*da0073e9SAndroid Build Coastguard Workerdef huber_loss(
3682*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3683*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3684*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3685*da0073e9SAndroid Build Coastguard Worker    delta: float = 1.0,
3686*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
3687*da0073e9SAndroid Build Coastguard Worker    r"""Compute the Huber loss.
3688*da0073e9SAndroid Build Coastguard Worker
3689*da0073e9SAndroid Build Coastguard Worker    Function uses a squared term if the absolute
3690*da0073e9SAndroid Build Coastguard Worker    element-wise error falls below delta and a delta-scaled L1 term otherwise.
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker    When delta equals 1, this loss is equivalent to SmoothL1Loss.
3693*da0073e9SAndroid Build Coastguard Worker    In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1).
3694*da0073e9SAndroid Build Coastguard Worker
3695*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.HuberLoss` for details.
3696*da0073e9SAndroid Build Coastguard Worker    """
3697*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3698*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3699*da0073e9SAndroid Build Coastguard Worker            huber_loss,
3700*da0073e9SAndroid Build Coastguard Worker            (input, target),
3701*da0073e9SAndroid Build Coastguard Worker            input,
3702*da0073e9SAndroid Build Coastguard Worker            target,
3703*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3704*da0073e9SAndroid Build Coastguard Worker            delta=delta,
3705*da0073e9SAndroid Build Coastguard Worker        )
3706*da0073e9SAndroid Build Coastguard Worker    if not (target.size() == input.size()):
3707*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
3708*da0073e9SAndroid Build Coastguard Worker            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3709*da0073e9SAndroid Build Coastguard Worker            "This will likely lead to incorrect results due to broadcasting. "
3710*da0073e9SAndroid Build Coastguard Worker            "Please ensure they have the same size.",
3711*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
3712*da0073e9SAndroid Build Coastguard Worker        )
3713*da0073e9SAndroid Build Coastguard Worker
3714*da0073e9SAndroid Build Coastguard Worker    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3715*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.huber_loss(
3716*da0073e9SAndroid Build Coastguard Worker        expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
3717*da0073e9SAndroid Build Coastguard Worker    )
3718*da0073e9SAndroid Build Coastguard Worker
3719*da0073e9SAndroid Build Coastguard Worker
3720*da0073e9SAndroid Build Coastguard Workerdef l1_loss(
3721*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3722*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3723*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3724*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3725*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3726*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3727*da0073e9SAndroid Build Coastguard Worker    r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3728*da0073e9SAndroid Build Coastguard Worker
3729*da0073e9SAndroid Build Coastguard Worker    Function that takes the mean element-wise absolute value difference.
3730*da0073e9SAndroid Build Coastguard Worker
3731*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.L1Loss` for details.
3732*da0073e9SAndroid Build Coastguard Worker    """
3733*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3734*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3735*da0073e9SAndroid Build Coastguard Worker            l1_loss,
3736*da0073e9SAndroid Build Coastguard Worker            (input, target),
3737*da0073e9SAndroid Build Coastguard Worker            input,
3738*da0073e9SAndroid Build Coastguard Worker            target,
3739*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3740*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3741*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3742*da0073e9SAndroid Build Coastguard Worker        )
3743*da0073e9SAndroid Build Coastguard Worker    if not (target.size() == input.size()):
3744*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
3745*da0073e9SAndroid Build Coastguard Worker            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3746*da0073e9SAndroid Build Coastguard Worker            "This will likely lead to incorrect results due to broadcasting. "
3747*da0073e9SAndroid Build Coastguard Worker            "Please ensure they have the same size.",
3748*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
3749*da0073e9SAndroid Build Coastguard Worker        )
3750*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3751*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3752*da0073e9SAndroid Build Coastguard Worker
3753*da0073e9SAndroid Build Coastguard Worker    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3754*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.l1_loss(
3755*da0073e9SAndroid Build Coastguard Worker        expanded_input, expanded_target, _Reduction.get_enum(reduction)
3756*da0073e9SAndroid Build Coastguard Worker    )
3757*da0073e9SAndroid Build Coastguard Worker
3758*da0073e9SAndroid Build Coastguard Worker
3759*da0073e9SAndroid Build Coastguard Workerdef mse_loss(
3760*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3761*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3762*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3763*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3764*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3765*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3766*da0073e9SAndroid Build Coastguard Worker    r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3767*da0073e9SAndroid Build Coastguard Worker
3768*da0073e9SAndroid Build Coastguard Worker    Measures the element-wise mean squared error.
3769*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MSELoss` for details.
3770*da0073e9SAndroid Build Coastguard Worker    """
3771*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3772*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3773*da0073e9SAndroid Build Coastguard Worker            mse_loss,
3774*da0073e9SAndroid Build Coastguard Worker            (input, target),
3775*da0073e9SAndroid Build Coastguard Worker            input,
3776*da0073e9SAndroid Build Coastguard Worker            target,
3777*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3778*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3779*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3780*da0073e9SAndroid Build Coastguard Worker        )
3781*da0073e9SAndroid Build Coastguard Worker    if not (target.size() == input.size()):
3782*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
3783*da0073e9SAndroid Build Coastguard Worker            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3784*da0073e9SAndroid Build Coastguard Worker            "This will likely lead to incorrect results due to broadcasting. "
3785*da0073e9SAndroid Build Coastguard Worker            "Please ensure they have the same size.",
3786*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
3787*da0073e9SAndroid Build Coastguard Worker        )
3788*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3789*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3792*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.mse_loss(
3793*da0073e9SAndroid Build Coastguard Worker        expanded_input, expanded_target, _Reduction.get_enum(reduction)
3794*da0073e9SAndroid Build Coastguard Worker    )
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker
3797*da0073e9SAndroid Build Coastguard Workerdef margin_ranking_loss(
3798*da0073e9SAndroid Build Coastguard Worker    input1: Tensor,
3799*da0073e9SAndroid Build Coastguard Worker    input2: Tensor,
3800*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3801*da0073e9SAndroid Build Coastguard Worker    margin: float = 0,
3802*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3803*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3804*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3805*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3806*da0073e9SAndroid Build Coastguard Worker    r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
3807*da0073e9SAndroid Build Coastguard Worker
3808*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MarginRankingLoss` for details.
3809*da0073e9SAndroid Build Coastguard Worker    """
3810*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input1, input2, target):
3811*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3812*da0073e9SAndroid Build Coastguard Worker            margin_ranking_loss,
3813*da0073e9SAndroid Build Coastguard Worker            (input1, input2, target),
3814*da0073e9SAndroid Build Coastguard Worker            input1,
3815*da0073e9SAndroid Build Coastguard Worker            input2,
3816*da0073e9SAndroid Build Coastguard Worker            target,
3817*da0073e9SAndroid Build Coastguard Worker            margin=margin,
3818*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3819*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3820*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3821*da0073e9SAndroid Build Coastguard Worker        )
3822*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3823*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3824*da0073e9SAndroid Build Coastguard Worker    else:
3825*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3826*da0073e9SAndroid Build Coastguard Worker    if input1.dim() != input2.dim() or input1.dim() != target.dim():
3827*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
3828*da0073e9SAndroid Build Coastguard Worker            f"margin_ranking_loss : All input tensors should have same dimension but got sizes: "
3829*da0073e9SAndroid Build Coastguard Worker            f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} "
3830*da0073e9SAndroid Build Coastguard Worker        )
3831*da0073e9SAndroid Build Coastguard Worker    return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker
3834*da0073e9SAndroid Build Coastguard Workerdef hinge_embedding_loss(
3835*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3836*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3837*da0073e9SAndroid Build Coastguard Worker    margin: float = 1.0,
3838*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3839*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3840*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3841*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3842*da0073e9SAndroid Build Coastguard Worker    r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor
3843*da0073e9SAndroid Build Coastguard Worker
3844*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.HingeEmbeddingLoss` for details.
3845*da0073e9SAndroid Build Coastguard Worker    """
3846*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3847*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3848*da0073e9SAndroid Build Coastguard Worker            hinge_embedding_loss,
3849*da0073e9SAndroid Build Coastguard Worker            (input, target),
3850*da0073e9SAndroid Build Coastguard Worker            input,
3851*da0073e9SAndroid Build Coastguard Worker            target,
3852*da0073e9SAndroid Build Coastguard Worker            margin=margin,
3853*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3854*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3855*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3856*da0073e9SAndroid Build Coastguard Worker        )
3857*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3858*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3859*da0073e9SAndroid Build Coastguard Worker    else:
3860*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3861*da0073e9SAndroid Build Coastguard Worker    return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
3862*da0073e9SAndroid Build Coastguard Worker
3863*da0073e9SAndroid Build Coastguard Worker
3864*da0073e9SAndroid Build Coastguard Workerdef multilabel_margin_loss(
3865*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3866*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3867*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3868*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3869*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3870*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3871*da0073e9SAndroid Build Coastguard Worker    r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3872*da0073e9SAndroid Build Coastguard Worker
3873*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MultiLabelMarginLoss` for details.
3874*da0073e9SAndroid Build Coastguard Worker    """
3875*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3876*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3877*da0073e9SAndroid Build Coastguard Worker            multilabel_margin_loss,
3878*da0073e9SAndroid Build Coastguard Worker            (input, target),
3879*da0073e9SAndroid Build Coastguard Worker            input,
3880*da0073e9SAndroid Build Coastguard Worker            target,
3881*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3882*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3883*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3884*da0073e9SAndroid Build Coastguard Worker        )
3885*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3886*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3887*da0073e9SAndroid Build Coastguard Worker    else:
3888*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3889*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
3890*da0073e9SAndroid Build Coastguard Worker
3891*da0073e9SAndroid Build Coastguard Worker
3892*da0073e9SAndroid Build Coastguard Workerdef soft_margin_loss(
3893*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3894*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3895*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3896*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3897*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3898*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3899*da0073e9SAndroid Build Coastguard Worker    r"""
3900*da0073e9SAndroid Build Coastguard Worker    soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3901*da0073e9SAndroid Build Coastguard Worker
3902*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.SoftMarginLoss` for details.
3903*da0073e9SAndroid Build Coastguard Worker    """
3904*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target):
3905*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3906*da0073e9SAndroid Build Coastguard Worker            soft_margin_loss,
3907*da0073e9SAndroid Build Coastguard Worker            (input, target),
3908*da0073e9SAndroid Build Coastguard Worker            input,
3909*da0073e9SAndroid Build Coastguard Worker            target,
3910*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3911*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3912*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3913*da0073e9SAndroid Build Coastguard Worker        )
3914*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3915*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3916*da0073e9SAndroid Build Coastguard Worker    else:
3917*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3918*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker
3921*da0073e9SAndroid Build Coastguard Workerdef multilabel_soft_margin_loss(
3922*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
3923*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3924*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
3925*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3926*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3927*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3928*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3929*da0073e9SAndroid Build Coastguard Worker    r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
3932*da0073e9SAndroid Build Coastguard Worker    """
3933*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight):
3934*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3935*da0073e9SAndroid Build Coastguard Worker            multilabel_soft_margin_loss,
3936*da0073e9SAndroid Build Coastguard Worker            (input, target, weight),
3937*da0073e9SAndroid Build Coastguard Worker            input,
3938*da0073e9SAndroid Build Coastguard Worker            target,
3939*da0073e9SAndroid Build Coastguard Worker            weight=weight,
3940*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3941*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3942*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3943*da0073e9SAndroid Build Coastguard Worker        )
3944*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3945*da0073e9SAndroid Build Coastguard Worker        reduction = _Reduction.legacy_get_string(size_average, reduce)
3946*da0073e9SAndroid Build Coastguard Worker
3947*da0073e9SAndroid Build Coastguard Worker    loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))
3948*da0073e9SAndroid Build Coastguard Worker
3949*da0073e9SAndroid Build Coastguard Worker    if weight is not None:
3950*da0073e9SAndroid Build Coastguard Worker        loss = loss * weight
3951*da0073e9SAndroid Build Coastguard Worker
3952*da0073e9SAndroid Build Coastguard Worker    class_dim = input.dim() - 1
3953*da0073e9SAndroid Build Coastguard Worker    C = input.size(class_dim)
3954*da0073e9SAndroid Build Coastguard Worker    loss = loss.sum(dim=class_dim) / C  # only return N loss values
3955*da0073e9SAndroid Build Coastguard Worker
3956*da0073e9SAndroid Build Coastguard Worker    if reduction == "none":
3957*da0073e9SAndroid Build Coastguard Worker        ret = loss
3958*da0073e9SAndroid Build Coastguard Worker    elif reduction == "mean":
3959*da0073e9SAndroid Build Coastguard Worker        ret = loss.mean()
3960*da0073e9SAndroid Build Coastguard Worker    elif reduction == "sum":
3961*da0073e9SAndroid Build Coastguard Worker        ret = loss.sum()
3962*da0073e9SAndroid Build Coastguard Worker    else:
3963*da0073e9SAndroid Build Coastguard Worker        ret = input
3964*da0073e9SAndroid Build Coastguard Worker        raise ValueError(reduction + " is not valid")
3965*da0073e9SAndroid Build Coastguard Worker    return ret
3966*da0073e9SAndroid Build Coastguard Worker
3967*da0073e9SAndroid Build Coastguard Worker
3968*da0073e9SAndroid Build Coastguard Workerdef cosine_embedding_loss(
3969*da0073e9SAndroid Build Coastguard Worker    input1: Tensor,
3970*da0073e9SAndroid Build Coastguard Worker    input2: Tensor,
3971*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
3972*da0073e9SAndroid Build Coastguard Worker    margin: float = 0,
3973*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
3974*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
3975*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
3976*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
3977*da0073e9SAndroid Build Coastguard Worker    r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
3978*da0073e9SAndroid Build Coastguard Worker
3979*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.CosineEmbeddingLoss` for details.
3980*da0073e9SAndroid Build Coastguard Worker    """
3981*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input1, input2, target):
3982*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
3983*da0073e9SAndroid Build Coastguard Worker            cosine_embedding_loss,
3984*da0073e9SAndroid Build Coastguard Worker            (input1, input2, target),
3985*da0073e9SAndroid Build Coastguard Worker            input1,
3986*da0073e9SAndroid Build Coastguard Worker            input2,
3987*da0073e9SAndroid Build Coastguard Worker            target,
3988*da0073e9SAndroid Build Coastguard Worker            margin=margin,
3989*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
3990*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
3991*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
3992*da0073e9SAndroid Build Coastguard Worker        )
3993*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
3994*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3995*da0073e9SAndroid Build Coastguard Worker    else:
3996*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
3997*da0073e9SAndroid Build Coastguard Worker    return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
3998*da0073e9SAndroid Build Coastguard Worker
3999*da0073e9SAndroid Build Coastguard Worker
4000*da0073e9SAndroid Build Coastguard Workerdef multi_margin_loss(
4001*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4002*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
4003*da0073e9SAndroid Build Coastguard Worker    p: int = 1,
4004*da0073e9SAndroid Build Coastguard Worker    margin: float = 1.0,
4005*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = None,
4006*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
4007*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
4008*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
4009*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: D400,D402
4010*da0073e9SAndroid Build Coastguard Worker    r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor
4011*da0073e9SAndroid Build Coastguard Worker
4012*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.MultiMarginLoss` for details.
4013*da0073e9SAndroid Build Coastguard Worker    """
4014*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, target, weight):
4015*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
4016*da0073e9SAndroid Build Coastguard Worker            multi_margin_loss,
4017*da0073e9SAndroid Build Coastguard Worker            (input, target, weight),
4018*da0073e9SAndroid Build Coastguard Worker            input,
4019*da0073e9SAndroid Build Coastguard Worker            target,
4020*da0073e9SAndroid Build Coastguard Worker            p=p,
4021*da0073e9SAndroid Build Coastguard Worker            margin=margin,
4022*da0073e9SAndroid Build Coastguard Worker            weight=weight,
4023*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
4024*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
4025*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
4026*da0073e9SAndroid Build Coastguard Worker        )
4027*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
4028*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
4029*da0073e9SAndroid Build Coastguard Worker    else:
4030*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
4031*da0073e9SAndroid Build Coastguard Worker    if p != 1 and p != 2:
4032*da0073e9SAndroid Build Coastguard Worker        raise ValueError("only p == 1 and p == 2 supported")
4033*da0073e9SAndroid Build Coastguard Worker    if weight is not None:
4034*da0073e9SAndroid Build Coastguard Worker        if weight.dim() != 1:
4035*da0073e9SAndroid Build Coastguard Worker            raise ValueError("weight must be one-dimensional")
4036*da0073e9SAndroid Build Coastguard Worker
4037*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.multi_margin_loss(
4038*da0073e9SAndroid Build Coastguard Worker        input, target, p, margin, weight, reduction_enum
4039*da0073e9SAndroid Build Coastguard Worker    )
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker
4042*da0073e9SAndroid Build Coastguard Workerpixel_shuffle = _add_docstr(
4043*da0073e9SAndroid Build Coastguard Worker    torch.pixel_shuffle,
4044*da0073e9SAndroid Build Coastguard Worker    r"""
4045*da0073e9SAndroid Build Coastguard Workerpixel_shuffle(input, upscale_factor) -> Tensor
4046*da0073e9SAndroid Build Coastguard Worker
4047*da0073e9SAndroid Build Coastguard WorkerRearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
4048*da0073e9SAndroid Build Coastguard Workertensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`.
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PixelShuffle` for details.
4051*da0073e9SAndroid Build Coastguard Worker
4052*da0073e9SAndroid Build Coastguard WorkerArgs:
4053*da0073e9SAndroid Build Coastguard Worker    input (Tensor): the input tensor
4054*da0073e9SAndroid Build Coastguard Worker    upscale_factor (int): factor to increase spatial resolution by
4055*da0073e9SAndroid Build Coastguard Worker
4056*da0073e9SAndroid Build Coastguard WorkerExamples::
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker    >>> input = torch.randn(1, 9, 4, 4)
4059*da0073e9SAndroid Build Coastguard Worker    >>> output = torch.nn.functional.pixel_shuffle(input, 3)
4060*da0073e9SAndroid Build Coastguard Worker    >>> print(output.size())
4061*da0073e9SAndroid Build Coastguard Worker    torch.Size([1, 1, 12, 12])
4062*da0073e9SAndroid Build Coastguard Worker""",
4063*da0073e9SAndroid Build Coastguard Worker)
4064*da0073e9SAndroid Build Coastguard Worker
4065*da0073e9SAndroid Build Coastguard Workerpixel_unshuffle = _add_docstr(
4066*da0073e9SAndroid Build Coastguard Worker    torch.pixel_unshuffle,
4067*da0073e9SAndroid Build Coastguard Worker    r"""
4068*da0073e9SAndroid Build Coastguard Workerpixel_unshuffle(input, downscale_factor) -> Tensor
4069*da0073e9SAndroid Build Coastguard Worker
4070*da0073e9SAndroid Build Coastguard WorkerReverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a
4071*da0073e9SAndroid Build Coastguard Workertensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
4072*da0073e9SAndroid Build Coastguard Worker:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`.
4073*da0073e9SAndroid Build Coastguard Worker
4074*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PixelUnshuffle` for details.
4075*da0073e9SAndroid Build Coastguard Worker
4076*da0073e9SAndroid Build Coastguard WorkerArgs:
4077*da0073e9SAndroid Build Coastguard Worker    input (Tensor): the input tensor
4078*da0073e9SAndroid Build Coastguard Worker    downscale_factor (int): factor to increase spatial resolution by
4079*da0073e9SAndroid Build Coastguard Worker
4080*da0073e9SAndroid Build Coastguard WorkerExamples::
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker    >>> input = torch.randn(1, 1, 12, 12)
4083*da0073e9SAndroid Build Coastguard Worker    >>> output = torch.nn.functional.pixel_unshuffle(input, 3)
4084*da0073e9SAndroid Build Coastguard Worker    >>> print(output.size())
4085*da0073e9SAndroid Build Coastguard Worker    torch.Size([1, 9, 4, 4])
4086*da0073e9SAndroid Build Coastguard Worker""",
4087*da0073e9SAndroid Build Coastguard Worker)
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Workerchannel_shuffle = _add_docstr(
4090*da0073e9SAndroid Build Coastguard Worker    torch.channel_shuffle,
4091*da0073e9SAndroid Build Coastguard Worker    r"""
4092*da0073e9SAndroid Build Coastguard Workerchannel_shuffle(input, groups) -> Tensor
4093*da0073e9SAndroid Build Coastguard Worker
4094*da0073e9SAndroid Build Coastguard WorkerDivide the channels in a tensor of shape :math:`(*, C , H, W)`
4095*da0073e9SAndroid Build Coastguard Workerinto g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
4096*da0073e9SAndroid Build Coastguard Workerwhile keeping the original tensor shape.
4097*da0073e9SAndroid Build Coastguard Worker
4098*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ChannelShuffle` for details.
4099*da0073e9SAndroid Build Coastguard Worker
4100*da0073e9SAndroid Build Coastguard WorkerArgs:
4101*da0073e9SAndroid Build Coastguard Worker    input (Tensor): the input tensor
4102*da0073e9SAndroid Build Coastguard Worker    groups (int): number of groups to divide channels in and rearrange.
4103*da0073e9SAndroid Build Coastguard Worker
4104*da0073e9SAndroid Build Coastguard WorkerExamples::
4105*da0073e9SAndroid Build Coastguard Worker
4106*da0073e9SAndroid Build Coastguard Worker    >>> input = torch.randn(1, 4, 2, 2)
4107*da0073e9SAndroid Build Coastguard Worker    >>> print(input)
4108*da0073e9SAndroid Build Coastguard Worker    [[[[1, 2],
4109*da0073e9SAndroid Build Coastguard Worker       [3, 4]],
4110*da0073e9SAndroid Build Coastguard Worker      [[5, 6],
4111*da0073e9SAndroid Build Coastguard Worker       [7, 8]],
4112*da0073e9SAndroid Build Coastguard Worker      [[9, 10],
4113*da0073e9SAndroid Build Coastguard Worker       [11, 12]],
4114*da0073e9SAndroid Build Coastguard Worker      [[13, 14],
4115*da0073e9SAndroid Build Coastguard Worker       [15, 16]],
4116*da0073e9SAndroid Build Coastguard Worker     ]]
4117*da0073e9SAndroid Build Coastguard Worker    >>> output = torch.nn.functional.channel_shuffle(input, 2)
4118*da0073e9SAndroid Build Coastguard Worker    >>> print(output)
4119*da0073e9SAndroid Build Coastguard Worker    [[[[1, 2],
4120*da0073e9SAndroid Build Coastguard Worker       [3, 4]],
4121*da0073e9SAndroid Build Coastguard Worker      [[9, 10],
4122*da0073e9SAndroid Build Coastguard Worker       [11, 12]],
4123*da0073e9SAndroid Build Coastguard Worker      [[5, 6],
4124*da0073e9SAndroid Build Coastguard Worker       [7, 8]],
4125*da0073e9SAndroid Build Coastguard Worker      [[13, 14],
4126*da0073e9SAndroid Build Coastguard Worker       [15, 16]],
4127*da0073e9SAndroid Build Coastguard Worker     ]]
4128*da0073e9SAndroid Build Coastguard Worker""",
4129*da0073e9SAndroid Build Coastguard Worker)
4130*da0073e9SAndroid Build Coastguard Worker
4131*da0073e9SAndroid Build Coastguard Workernative_channel_shuffle = _add_docstr(
4132*da0073e9SAndroid Build Coastguard Worker    torch.native_channel_shuffle,
4133*da0073e9SAndroid Build Coastguard Worker    r"""
4134*da0073e9SAndroid Build Coastguard Workernative_channel_shuffle(input, groups) -> Tensor
4135*da0073e9SAndroid Build Coastguard Worker
4136*da0073e9SAndroid Build Coastguard WorkerNative kernel level implementation of the `channel_shuffle`.
4137*da0073e9SAndroid Build Coastguard WorkerThis function might become private in future releases, use with caution.
4138*da0073e9SAndroid Build Coastguard Worker
4139*da0073e9SAndroid Build Coastguard WorkerDivide the channels in a tensor of shape :math:`(*, C , H, W)`
4140*da0073e9SAndroid Build Coastguard Workerinto g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
4141*da0073e9SAndroid Build Coastguard Workerwhile keeping the original tensor shape.
4142*da0073e9SAndroid Build Coastguard Worker
4143*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ChannelShuffle` for details.
4144*da0073e9SAndroid Build Coastguard Worker
4145*da0073e9SAndroid Build Coastguard WorkerArgs:
4146*da0073e9SAndroid Build Coastguard Worker    input (Tensor): the input tensor
4147*da0073e9SAndroid Build Coastguard Worker    groups (int): number of groups to divide channels in and rearrange.
4148*da0073e9SAndroid Build Coastguard Worker
4149*da0073e9SAndroid Build Coastguard WorkerExamples::
4150*da0073e9SAndroid Build Coastguard Worker
4151*da0073e9SAndroid Build Coastguard Worker    >>> input = torch.randn(1, 4, 2, 2)
4152*da0073e9SAndroid Build Coastguard Worker    >>> print(input)
4153*da0073e9SAndroid Build Coastguard Worker    [[[[1, 2],
4154*da0073e9SAndroid Build Coastguard Worker       [3, 4]],
4155*da0073e9SAndroid Build Coastguard Worker      [[5, 6],
4156*da0073e9SAndroid Build Coastguard Worker       [7, 8]],
4157*da0073e9SAndroid Build Coastguard Worker      [[9, 10],
4158*da0073e9SAndroid Build Coastguard Worker       [11, 12]],
4159*da0073e9SAndroid Build Coastguard Worker      [[13, 14],
4160*da0073e9SAndroid Build Coastguard Worker       [15, 16]],
4161*da0073e9SAndroid Build Coastguard Worker     ]]
4162*da0073e9SAndroid Build Coastguard Worker    >>> output = torch.nn.functional.native_channel_shuffle(input, 2)
4163*da0073e9SAndroid Build Coastguard Worker    >>> print(output)
4164*da0073e9SAndroid Build Coastguard Worker    [[[[1, 2],
4165*da0073e9SAndroid Build Coastguard Worker       [3, 4]],
4166*da0073e9SAndroid Build Coastguard Worker      [[9, 10],
4167*da0073e9SAndroid Build Coastguard Worker       [11, 12]],
4168*da0073e9SAndroid Build Coastguard Worker      [[5, 6],
4169*da0073e9SAndroid Build Coastguard Worker       [7, 8]],
4170*da0073e9SAndroid Build Coastguard Worker      [[13, 14],
4171*da0073e9SAndroid Build Coastguard Worker       [15, 16]],
4172*da0073e9SAndroid Build Coastguard Worker     ]]
4173*da0073e9SAndroid Build Coastguard Worker""",
4174*da0073e9SAndroid Build Coastguard Worker)
4175*da0073e9SAndroid Build Coastguard Worker
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker@_overload
4178*da0073e9SAndroid Build Coastguard Workerdef upsample(  # noqa: F811
4179*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4180*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4181*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4182*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4183*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4184*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4185*da0073e9SAndroid Build Coastguard Worker    pass
4186*da0073e9SAndroid Build Coastguard Worker
4187*da0073e9SAndroid Build Coastguard Worker
4188*da0073e9SAndroid Build Coastguard Worker@_overload
4189*da0073e9SAndroid Build Coastguard Workerdef upsample(  # noqa: F811
4190*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4191*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4192*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4193*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4194*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4195*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4196*da0073e9SAndroid Build Coastguard Worker    pass
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker
4199*da0073e9SAndroid Build Coastguard Workerdef upsample(  # noqa: F811
4200*da0073e9SAndroid Build Coastguard Worker    input,
4201*da0073e9SAndroid Build Coastguard Worker    size=None,
4202*da0073e9SAndroid Build Coastguard Worker    scale_factor=None,
4203*da0073e9SAndroid Build Coastguard Worker    mode="nearest",
4204*da0073e9SAndroid Build Coastguard Worker    align_corners=None,
4205*da0073e9SAndroid Build Coastguard Worker):
4206*da0073e9SAndroid Build Coastguard Worker    r"""Upsample input.
4207*da0073e9SAndroid Build Coastguard Worker
4208*da0073e9SAndroid Build Coastguard Worker    Provided tensor is upsampled to either the given :attr:`size` or the given
4209*da0073e9SAndroid Build Coastguard Worker    :attr:`scale_factor`
4210*da0073e9SAndroid Build Coastguard Worker
4211*da0073e9SAndroid Build Coastguard Worker    .. warning::
4212*da0073e9SAndroid Build Coastguard Worker        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4213*da0073e9SAndroid Build Coastguard Worker        This is equivalent with ``nn.functional.interpolate(...)``.
4214*da0073e9SAndroid Build Coastguard Worker
4215*da0073e9SAndroid Build Coastguard Worker    Note:
4216*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
4217*da0073e9SAndroid Build Coastguard Worker
4218*da0073e9SAndroid Build Coastguard Worker    The algorithm used for upsampling is determined by :attr:`mode`.
4219*da0073e9SAndroid Build Coastguard Worker
4220*da0073e9SAndroid Build Coastguard Worker    Currently temporal, spatial and volumetric upsampling are supported, i.e.
4221*da0073e9SAndroid Build Coastguard Worker    expected inputs are 3-D, 4-D or 5-D in shape.
4222*da0073e9SAndroid Build Coastguard Worker
4223*da0073e9SAndroid Build Coastguard Worker    The input dimensions are interpreted in the form:
4224*da0073e9SAndroid Build Coastguard Worker    `mini-batch x channels x [optional depth] x [optional height] x width`.
4225*da0073e9SAndroid Build Coastguard Worker
4226*da0073e9SAndroid Build Coastguard Worker    The modes available for upsampling are: `nearest`, `linear` (3D-only),
4227*da0073e9SAndroid Build Coastguard Worker    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
4228*da0073e9SAndroid Build Coastguard Worker
4229*da0073e9SAndroid Build Coastguard Worker    Args:
4230*da0073e9SAndroid Build Coastguard Worker        input (Tensor): the input tensor
4231*da0073e9SAndroid Build Coastguard Worker        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
4232*da0073e9SAndroid Build Coastguard Worker            output spatial size.
4233*da0073e9SAndroid Build Coastguard Worker        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
4234*da0073e9SAndroid Build Coastguard Worker        mode (str): algorithm used for upsampling:
4235*da0073e9SAndroid Build Coastguard Worker            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
4236*da0073e9SAndroid Build Coastguard Worker            ``'trilinear'``. Default: ``'nearest'``
4237*da0073e9SAndroid Build Coastguard Worker        align_corners (bool, optional): Geometrically, we consider the pixels of the
4238*da0073e9SAndroid Build Coastguard Worker            input and output as squares rather than points.
4239*da0073e9SAndroid Build Coastguard Worker            If set to ``True``, the input and output tensors are aligned by the
4240*da0073e9SAndroid Build Coastguard Worker            center points of their corner pixels, preserving the values at the corner pixels.
4241*da0073e9SAndroid Build Coastguard Worker            If set to ``False``, the input and output tensors are aligned by the corner
4242*da0073e9SAndroid Build Coastguard Worker            points of their corner pixels, and the interpolation uses edge value padding
4243*da0073e9SAndroid Build Coastguard Worker            for out-of-boundary values, making this operation *independent* of input size
4244*da0073e9SAndroid Build Coastguard Worker            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
4245*da0073e9SAndroid Build Coastguard Worker            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
4246*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
4247*da0073e9SAndroid Build Coastguard Worker
4248*da0073e9SAndroid Build Coastguard Worker    .. note::
4249*da0073e9SAndroid Build Coastguard Worker        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
4250*da0073e9SAndroid Build Coastguard Worker        negative values or values greater than 255 for images.
4251*da0073e9SAndroid Build Coastguard Worker        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
4252*da0073e9SAndroid Build Coastguard Worker        when displaying the image.
4253*da0073e9SAndroid Build Coastguard Worker
4254*da0073e9SAndroid Build Coastguard Worker    .. warning::
4255*da0073e9SAndroid Build Coastguard Worker        With ``align_corners = True``, the linearly interpolating modes
4256*da0073e9SAndroid Build Coastguard Worker        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
4257*da0073e9SAndroid Build Coastguard Worker        output and input pixels, and thus the output values can depend on the
4258*da0073e9SAndroid Build Coastguard Worker        input size. This was the default behavior for these modes up to version
4259*da0073e9SAndroid Build Coastguard Worker        0.3.1. Since then, the default behavior is ``align_corners = False``.
4260*da0073e9SAndroid Build Coastguard Worker        See :class:`~torch.nn.Upsample` for concrete examples on how this
4261*da0073e9SAndroid Build Coastguard Worker        affects the outputs.
4262*da0073e9SAndroid Build Coastguard Worker
4263*da0073e9SAndroid Build Coastguard Worker    """
4264*da0073e9SAndroid Build Coastguard Worker    warnings.warn(
4265*da0073e9SAndroid Build Coastguard Worker        "`nn.functional.upsample` is deprecated. "
4266*da0073e9SAndroid Build Coastguard Worker        "Use `nn.functional.interpolate` instead.",
4267*da0073e9SAndroid Build Coastguard Worker        stacklevel=2,
4268*da0073e9SAndroid Build Coastguard Worker    )
4269*da0073e9SAndroid Build Coastguard Worker    return interpolate(input, size, scale_factor, mode, align_corners)
4270*da0073e9SAndroid Build Coastguard Worker
4271*da0073e9SAndroid Build Coastguard Worker
4272*da0073e9SAndroid Build Coastguard Workerif upsample.__doc__:
4273*da0073e9SAndroid Build Coastguard Worker    upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes)
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker
4276*da0073e9SAndroid Build Coastguard Workerdef _is_integer(x) -> bool:
4277*da0073e9SAndroid Build Coastguard Worker    r"""Type check the input number is an integer.
4278*da0073e9SAndroid Build Coastguard Worker
4279*da0073e9SAndroid Build Coastguard Worker    Will return True for int, SymInt, Numpy integers and Tensors with integer elements.
4280*da0073e9SAndroid Build Coastguard Worker    """
4281*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, (int, torch.SymInt)):
4282*da0073e9SAndroid Build Coastguard Worker        return True
4283*da0073e9SAndroid Build Coastguard Worker    if np is not None and isinstance(x, np.integer):
4284*da0073e9SAndroid Build Coastguard Worker        return True
4285*da0073e9SAndroid Build Coastguard Worker    return isinstance(x, Tensor) and not x.is_floating_point()
4286*da0073e9SAndroid Build Coastguard Worker
4287*da0073e9SAndroid Build Coastguard Worker
4288*da0073e9SAndroid Build Coastguard Worker@_overload
4289*da0073e9SAndroid Build Coastguard Workerdef interpolate(  # noqa: F811
4290*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4291*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4292*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[List[float]] = None,
4293*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4294*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4295*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[bool] = None,
4296*da0073e9SAndroid Build Coastguard Worker    antialias: bool = False,
4297*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4298*da0073e9SAndroid Build Coastguard Worker    pass
4299*da0073e9SAndroid Build Coastguard Worker
4300*da0073e9SAndroid Build Coastguard Worker
4301*da0073e9SAndroid Build Coastguard Worker@_overload
4302*da0073e9SAndroid Build Coastguard Workerdef interpolate(  # noqa: F811
4303*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4304*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4305*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[List[float]] = None,
4306*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4307*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4308*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[bool] = None,
4309*da0073e9SAndroid Build Coastguard Worker    antialias: bool = False,
4310*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4311*da0073e9SAndroid Build Coastguard Worker    pass
4312*da0073e9SAndroid Build Coastguard Worker
4313*da0073e9SAndroid Build Coastguard Worker
4314*da0073e9SAndroid Build Coastguard Worker@_overload
4315*da0073e9SAndroid Build Coastguard Workerdef interpolate(  # noqa: F811
4316*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4317*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4318*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4319*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4320*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4321*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[bool] = None,
4322*da0073e9SAndroid Build Coastguard Worker    antialias: bool = False,
4323*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4324*da0073e9SAndroid Build Coastguard Worker    pass
4325*da0073e9SAndroid Build Coastguard Worker
4326*da0073e9SAndroid Build Coastguard Worker
4327*da0073e9SAndroid Build Coastguard Worker@_overload
4328*da0073e9SAndroid Build Coastguard Workerdef interpolate(  # noqa: F811
4329*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4330*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4331*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4332*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4333*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4334*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[bool] = None,
4335*da0073e9SAndroid Build Coastguard Worker    antialias: bool = False,
4336*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4337*da0073e9SAndroid Build Coastguard Worker    pass
4338*da0073e9SAndroid Build Coastguard Worker
4339*da0073e9SAndroid Build Coastguard Worker
4340*da0073e9SAndroid Build Coastguard Workerdef interpolate(  # noqa: F811
4341*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4342*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4343*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[List[float]] = None,
4344*da0073e9SAndroid Build Coastguard Worker    mode: str = "nearest",
4345*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4346*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[bool] = None,
4347*da0073e9SAndroid Build Coastguard Worker    antialias: bool = False,
4348*da0073e9SAndroid Build Coastguard Worker) -> Tensor:  # noqa: B950
4349*da0073e9SAndroid Build Coastguard Worker    r"""Down/up samples the input.
4350*da0073e9SAndroid Build Coastguard Worker
4351*da0073e9SAndroid Build Coastguard Worker    Tensor interpolated to either the given :attr:`size` or the given
4352*da0073e9SAndroid Build Coastguard Worker    :attr:`scale_factor`
4353*da0073e9SAndroid Build Coastguard Worker
4354*da0073e9SAndroid Build Coastguard Worker    The algorithm used for interpolation is determined by :attr:`mode`.
4355*da0073e9SAndroid Build Coastguard Worker
4356*da0073e9SAndroid Build Coastguard Worker    Currently temporal, spatial and volumetric sampling are supported, i.e.
4357*da0073e9SAndroid Build Coastguard Worker    expected inputs are 3-D, 4-D or 5-D in shape.
4358*da0073e9SAndroid Build Coastguard Worker
4359*da0073e9SAndroid Build Coastguard Worker    The input dimensions are interpreted in the form:
4360*da0073e9SAndroid Build Coastguard Worker    `mini-batch x channels x [optional depth] x [optional height] x width`.
4361*da0073e9SAndroid Build Coastguard Worker
4362*da0073e9SAndroid Build Coastguard Worker    The modes available for resizing are: `nearest`, `linear` (3D-only),
4363*da0073e9SAndroid Build Coastguard Worker    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact`
4364*da0073e9SAndroid Build Coastguard Worker
4365*da0073e9SAndroid Build Coastguard Worker    Args:
4366*da0073e9SAndroid Build Coastguard Worker        input (Tensor): the input tensor
4367*da0073e9SAndroid Build Coastguard Worker        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
4368*da0073e9SAndroid Build Coastguard Worker            output spatial size.
4369*da0073e9SAndroid Build Coastguard Worker        scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple,
4370*da0073e9SAndroid Build Coastguard Worker            its length has to match the number of spatial dimensions; `input.dim() - 2`.
4371*da0073e9SAndroid Build Coastguard Worker        mode (str): algorithm used for upsampling:
4372*da0073e9SAndroid Build Coastguard Worker            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
4373*da0073e9SAndroid Build Coastguard Worker            ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'``
4374*da0073e9SAndroid Build Coastguard Worker        align_corners (bool, optional): Geometrically, we consider the pixels of the
4375*da0073e9SAndroid Build Coastguard Worker            input and output as squares rather than points.
4376*da0073e9SAndroid Build Coastguard Worker            If set to ``True``, the input and output tensors are aligned by the
4377*da0073e9SAndroid Build Coastguard Worker            center points of their corner pixels, preserving the values at the corner pixels.
4378*da0073e9SAndroid Build Coastguard Worker            If set to ``False``, the input and output tensors are aligned by the corner
4379*da0073e9SAndroid Build Coastguard Worker            points of their corner pixels, and the interpolation uses edge value padding
4380*da0073e9SAndroid Build Coastguard Worker            for out-of-boundary values, making this operation *independent* of input size
4381*da0073e9SAndroid Build Coastguard Worker            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
4382*da0073e9SAndroid Build Coastguard Worker            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
4383*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
4384*da0073e9SAndroid Build Coastguard Worker        recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
4385*da0073e9SAndroid Build Coastguard Worker            interpolation calculation. If `recompute_scale_factor` is ``True``, then
4386*da0073e9SAndroid Build Coastguard Worker            `scale_factor` must be passed in and `scale_factor` is used to compute the
4387*da0073e9SAndroid Build Coastguard Worker            output `size`. The computed output `size` will be used to infer new scales for
4388*da0073e9SAndroid Build Coastguard Worker            the interpolation. Note that when `scale_factor` is floating-point, it may differ
4389*da0073e9SAndroid Build Coastguard Worker            from the recomputed `scale_factor` due to rounding and precision issues.
4390*da0073e9SAndroid Build Coastguard Worker            If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
4391*da0073e9SAndroid Build Coastguard Worker            be used directly for interpolation. Default: ``None``.
4392*da0073e9SAndroid Build Coastguard Worker        antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias
4393*da0073e9SAndroid Build Coastguard Worker            option together with ``align_corners=False``, interpolation result would match Pillow
4394*da0073e9SAndroid Build Coastguard Worker            result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``.
4395*da0073e9SAndroid Build Coastguard Worker
4396*da0073e9SAndroid Build Coastguard Worker    .. note::
4397*da0073e9SAndroid Build Coastguard Worker        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
4398*da0073e9SAndroid Build Coastguard Worker        negative values or values greater than 255 for images.
4399*da0073e9SAndroid Build Coastguard Worker        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
4400*da0073e9SAndroid Build Coastguard Worker        when displaying the image.
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker    .. note::
4403*da0073e9SAndroid Build Coastguard Worker        Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation
4404*da0073e9SAndroid Build Coastguard Worker        algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep
4405*da0073e9SAndroid Build Coastguard Worker        backward compatibility.
4406*da0073e9SAndroid Build Coastguard Worker        Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm.
4407*da0073e9SAndroid Build Coastguard Worker
4408*da0073e9SAndroid Build Coastguard Worker    .. note::
4409*da0073e9SAndroid Build Coastguard Worker        The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation
4410*da0073e9SAndroid Build Coastguard Worker        when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``.
4411*da0073e9SAndroid Build Coastguard Worker        For more details, please refer to the discussion in
4412*da0073e9SAndroid Build Coastguard Worker        `issue#104157 <https://github.com/pytorch/pytorch/issues/104157>`_.
4413*da0073e9SAndroid Build Coastguard Worker
4414*da0073e9SAndroid Build Coastguard Worker    Note:
4415*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
4416*da0073e9SAndroid Build Coastguard Worker    """
4417*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
4418*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
4419*da0073e9SAndroid Build Coastguard Worker            interpolate,
4420*da0073e9SAndroid Build Coastguard Worker            (input,),
4421*da0073e9SAndroid Build Coastguard Worker            input,
4422*da0073e9SAndroid Build Coastguard Worker            size=size,
4423*da0073e9SAndroid Build Coastguard Worker            scale_factor=scale_factor,
4424*da0073e9SAndroid Build Coastguard Worker            mode=mode,
4425*da0073e9SAndroid Build Coastguard Worker            align_corners=align_corners,
4426*da0073e9SAndroid Build Coastguard Worker            recompute_scale_factor=recompute_scale_factor,
4427*da0073e9SAndroid Build Coastguard Worker            antialias=antialias,
4428*da0073e9SAndroid Build Coastguard Worker        )
4429*da0073e9SAndroid Build Coastguard Worker
4430*da0073e9SAndroid Build Coastguard Worker    if mode in ("nearest", "area", "nearest-exact"):
4431*da0073e9SAndroid Build Coastguard Worker        if align_corners is not None:
4432*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
4433*da0073e9SAndroid Build Coastguard Worker                "align_corners option can only be set with the "
4434*da0073e9SAndroid Build Coastguard Worker                "interpolating modes: linear | bilinear | bicubic | trilinear"
4435*da0073e9SAndroid Build Coastguard Worker            )
4436*da0073e9SAndroid Build Coastguard Worker    else:
4437*da0073e9SAndroid Build Coastguard Worker        if align_corners is None:
4438*da0073e9SAndroid Build Coastguard Worker            align_corners = False
4439*da0073e9SAndroid Build Coastguard Worker
4440*da0073e9SAndroid Build Coastguard Worker    dim = input.dim() - 2  # Number of spatial dimensions.
4441*da0073e9SAndroid Build Coastguard Worker
4442*da0073e9SAndroid Build Coastguard Worker    # Process size and scale_factor.  Validate that exactly one is set.
4443*da0073e9SAndroid Build Coastguard Worker    # Validate its length if it is a list, or expand it if it is a scalar.
4444*da0073e9SAndroid Build Coastguard Worker    # After this block, exactly one of output_size and scale_factors will
4445*da0073e9SAndroid Build Coastguard Worker    # be non-None, and it will be a list (or tuple).
4446*da0073e9SAndroid Build Coastguard Worker    if size is not None and scale_factor is not None:
4447*da0073e9SAndroid Build Coastguard Worker        raise ValueError("only one of size or scale_factor should be defined")
4448*da0073e9SAndroid Build Coastguard Worker    elif size is not None:
4449*da0073e9SAndroid Build Coastguard Worker        assert scale_factor is None
4450*da0073e9SAndroid Build Coastguard Worker        scale_factors = None
4451*da0073e9SAndroid Build Coastguard Worker        if isinstance(size, (list, tuple)):
4452*da0073e9SAndroid Build Coastguard Worker            if len(size) != dim:
4453*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
4454*da0073e9SAndroid Build Coastguard Worker                    "Input and output must have the same number of spatial dimensions, but got "
4455*da0073e9SAndroid Build Coastguard Worker                    f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
4456*da0073e9SAndroid Build Coastguard Worker                    "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
4457*da0073e9SAndroid Build Coastguard Worker                    "output size in (o1, o2, ...,oK) format."
4458*da0073e9SAndroid Build Coastguard Worker                )
4459*da0073e9SAndroid Build Coastguard Worker            if not torch.jit.is_scripting():
4460*da0073e9SAndroid Build Coastguard Worker                if not all(_is_integer(x) for x in size):
4461*da0073e9SAndroid Build Coastguard Worker                    raise TypeError(
4462*da0073e9SAndroid Build Coastguard Worker                        "expected size to be one of int or Tuple[int] or Tuple[int, int] or "
4463*da0073e9SAndroid Build Coastguard Worker                        f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
4464*da0073e9SAndroid Build Coastguard Worker                    )
4465*da0073e9SAndroid Build Coastguard Worker            output_size = size
4466*da0073e9SAndroid Build Coastguard Worker        else:
4467*da0073e9SAndroid Build Coastguard Worker            output_size = [size for _ in range(dim)]
4468*da0073e9SAndroid Build Coastguard Worker    elif scale_factor is not None:
4469*da0073e9SAndroid Build Coastguard Worker        assert size is None
4470*da0073e9SAndroid Build Coastguard Worker        output_size = None
4471*da0073e9SAndroid Build Coastguard Worker        if isinstance(scale_factor, (list, tuple)):
4472*da0073e9SAndroid Build Coastguard Worker            if len(scale_factor) != dim:
4473*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
4474*da0073e9SAndroid Build Coastguard Worker                    "Input and scale_factor must have the same number of spatial dimensions, but "
4475*da0073e9SAndroid Build Coastguard Worker                    f"got input with spatial dimensions of {list(input.shape[2:])} and "
4476*da0073e9SAndroid Build Coastguard Worker                    f"scale_factor of shape {scale_factor}. "
4477*da0073e9SAndroid Build Coastguard Worker                    "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
4478*da0073e9SAndroid Build Coastguard Worker                    "scale_factor in (s1, s2, ...,sK) format."
4479*da0073e9SAndroid Build Coastguard Worker                )
4480*da0073e9SAndroid Build Coastguard Worker            scale_factors = scale_factor
4481*da0073e9SAndroid Build Coastguard Worker        else:
4482*da0073e9SAndroid Build Coastguard Worker            scale_factors = [scale_factor for _ in range(dim)]
4483*da0073e9SAndroid Build Coastguard Worker    else:
4484*da0073e9SAndroid Build Coastguard Worker        raise ValueError("either size or scale_factor should be defined")
4485*da0073e9SAndroid Build Coastguard Worker
4486*da0073e9SAndroid Build Coastguard Worker    if (
4487*da0073e9SAndroid Build Coastguard Worker        recompute_scale_factor is not None
4488*da0073e9SAndroid Build Coastguard Worker        and recompute_scale_factor
4489*da0073e9SAndroid Build Coastguard Worker        and size is not None
4490*da0073e9SAndroid Build Coastguard Worker    ):
4491*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
4492*da0073e9SAndroid Build Coastguard Worker            "recompute_scale_factor is not meaningful with an explicit size."
4493*da0073e9SAndroid Build Coastguard Worker        )
4494*da0073e9SAndroid Build Coastguard Worker
4495*da0073e9SAndroid Build Coastguard Worker    # "area" mode always requires an explicit size rather than scale factor.
4496*da0073e9SAndroid Build Coastguard Worker    # Re-use the recompute_scale_factor code path.
4497*da0073e9SAndroid Build Coastguard Worker    if mode == "area" and output_size is None:
4498*da0073e9SAndroid Build Coastguard Worker        recompute_scale_factor = True
4499*da0073e9SAndroid Build Coastguard Worker
4500*da0073e9SAndroid Build Coastguard Worker    if recompute_scale_factor is not None and recompute_scale_factor:
4501*da0073e9SAndroid Build Coastguard Worker        # We compute output_size here, then un-set scale_factors.
4502*da0073e9SAndroid Build Coastguard Worker        # The C++ code will recompute it based on the (integer) output size.
4503*da0073e9SAndroid Build Coastguard Worker        assert scale_factors is not None
4504*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting() and torch._C._get_tracing_state():
4505*da0073e9SAndroid Build Coastguard Worker            # make scale_factor a tensor in tracing so constant doesn't get baked in
4506*da0073e9SAndroid Build Coastguard Worker            output_size = [
4507*da0073e9SAndroid Build Coastguard Worker                (
4508*da0073e9SAndroid Build Coastguard Worker                    torch.floor(
4509*da0073e9SAndroid Build Coastguard Worker                        (
4510*da0073e9SAndroid Build Coastguard Worker                            input.size(i + 2).float()
4511*da0073e9SAndroid Build Coastguard Worker                            * torch.tensor(scale_factors[i], dtype=torch.float32)
4512*da0073e9SAndroid Build Coastguard Worker                        ).float()
4513*da0073e9SAndroid Build Coastguard Worker                    )
4514*da0073e9SAndroid Build Coastguard Worker                )
4515*da0073e9SAndroid Build Coastguard Worker                for i in range(dim)
4516*da0073e9SAndroid Build Coastguard Worker            ]
4517*da0073e9SAndroid Build Coastguard Worker        elif torch.jit.is_scripting():
4518*da0073e9SAndroid Build Coastguard Worker            output_size = [
4519*da0073e9SAndroid Build Coastguard Worker                int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
4520*da0073e9SAndroid Build Coastguard Worker                for i in range(dim)
4521*da0073e9SAndroid Build Coastguard Worker            ]
4522*da0073e9SAndroid Build Coastguard Worker        else:
4523*da0073e9SAndroid Build Coastguard Worker            output_size = [
4524*da0073e9SAndroid Build Coastguard Worker                _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim)
4525*da0073e9SAndroid Build Coastguard Worker            ]
4526*da0073e9SAndroid Build Coastguard Worker        scale_factors = None
4527*da0073e9SAndroid Build Coastguard Worker
4528*da0073e9SAndroid Build Coastguard Worker    if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
4529*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
4530*da0073e9SAndroid Build Coastguard Worker            "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input"
4531*da0073e9SAndroid Build Coastguard Worker        )
4532*da0073e9SAndroid Build Coastguard Worker
4533*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "nearest":
4534*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
4535*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "nearest":
4536*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
4537*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "nearest":
4538*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
4539*da0073e9SAndroid Build Coastguard Worker
4540*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "nearest-exact":
4541*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
4542*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "nearest-exact":
4543*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
4544*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "nearest-exact":
4545*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
4546*da0073e9SAndroid Build Coastguard Worker
4547*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "area":
4548*da0073e9SAndroid Build Coastguard Worker        assert output_size is not None
4549*da0073e9SAndroid Build Coastguard Worker        return adaptive_avg_pool1d(input, output_size)
4550*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "area":
4551*da0073e9SAndroid Build Coastguard Worker        assert output_size is not None
4552*da0073e9SAndroid Build Coastguard Worker        return adaptive_avg_pool2d(input, output_size)
4553*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "area":
4554*da0073e9SAndroid Build Coastguard Worker        assert output_size is not None
4555*da0073e9SAndroid Build Coastguard Worker        return adaptive_avg_pool3d(input, output_size)
4556*da0073e9SAndroid Build Coastguard Worker
4557*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "linear":
4558*da0073e9SAndroid Build Coastguard Worker        assert align_corners is not None
4559*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_linear1d(
4560*da0073e9SAndroid Build Coastguard Worker            input, output_size, align_corners, scale_factors
4561*da0073e9SAndroid Build Coastguard Worker        )
4562*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "bilinear":
4563*da0073e9SAndroid Build Coastguard Worker        assert align_corners is not None
4564*da0073e9SAndroid Build Coastguard Worker        if antialias:
4565*da0073e9SAndroid Build Coastguard Worker            return torch._C._nn._upsample_bilinear2d_aa(
4566*da0073e9SAndroid Build Coastguard Worker                input, output_size, align_corners, scale_factors
4567*da0073e9SAndroid Build Coastguard Worker            )
4568*da0073e9SAndroid Build Coastguard Worker        # Two levels are necessary to prevent TorchScript from touching
4569*da0073e9SAndroid Build Coastguard Worker        # are_deterministic_algorithms_enabled.
4570*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting():
4571*da0073e9SAndroid Build Coastguard Worker            if torch.are_deterministic_algorithms_enabled() and (
4572*da0073e9SAndroid Build Coastguard Worker                input.is_cuda or input.is_xpu
4573*da0073e9SAndroid Build Coastguard Worker            ):
4574*da0073e9SAndroid Build Coastguard Worker                # Use slow decomp whose backward will be in terms of index_put
4575*da0073e9SAndroid Build Coastguard Worker                # importlib is required because the import cannot be top level
4576*da0073e9SAndroid Build Coastguard Worker                # (cycle) and cannot be nested (TS doesn't support)
4577*da0073e9SAndroid Build Coastguard Worker                return importlib.import_module(
4578*da0073e9SAndroid Build Coastguard Worker                    "torch._decomp.decompositions"
4579*da0073e9SAndroid Build Coastguard Worker                )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
4580*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_bilinear2d(
4581*da0073e9SAndroid Build Coastguard Worker            input, output_size, align_corners, scale_factors
4582*da0073e9SAndroid Build Coastguard Worker        )
4583*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "trilinear":
4584*da0073e9SAndroid Build Coastguard Worker        assert align_corners is not None
4585*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_trilinear3d(
4586*da0073e9SAndroid Build Coastguard Worker            input, output_size, align_corners, scale_factors
4587*da0073e9SAndroid Build Coastguard Worker        )
4588*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "bicubic":
4589*da0073e9SAndroid Build Coastguard Worker        assert align_corners is not None
4590*da0073e9SAndroid Build Coastguard Worker        if antialias:
4591*da0073e9SAndroid Build Coastguard Worker            return torch._C._nn._upsample_bicubic2d_aa(
4592*da0073e9SAndroid Build Coastguard Worker                input, output_size, align_corners, scale_factors
4593*da0073e9SAndroid Build Coastguard Worker            )
4594*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.upsample_bicubic2d(
4595*da0073e9SAndroid Build Coastguard Worker            input, output_size, align_corners, scale_factors
4596*da0073e9SAndroid Build Coastguard Worker        )
4597*da0073e9SAndroid Build Coastguard Worker
4598*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "bilinear":
4599*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
4600*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 3 and mode == "trilinear":
4601*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
4602*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "linear":
4603*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
4604*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 4 and mode == "trilinear":
4605*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
4606*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "linear":
4607*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
4608*da0073e9SAndroid Build Coastguard Worker    if input.dim() == 5 and mode == "bilinear":
4609*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
4610*da0073e9SAndroid Build Coastguard Worker
4611*da0073e9SAndroid Build Coastguard Worker    raise NotImplementedError(
4612*da0073e9SAndroid Build Coastguard Worker        "Input Error: Only 3D, 4D and 5D input Tensors supported"
4613*da0073e9SAndroid Build Coastguard Worker        f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
4614*da0073e9SAndroid Build Coastguard Worker        f" (got {mode})"
4615*da0073e9SAndroid Build Coastguard Worker    )
4616*da0073e9SAndroid Build Coastguard Worker
4617*da0073e9SAndroid Build Coastguard Worker
4618*da0073e9SAndroid Build Coastguard Workerif interpolate.__doc__:
4619*da0073e9SAndroid Build Coastguard Worker    interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes)
4620*da0073e9SAndroid Build Coastguard Worker
4621*da0073e9SAndroid Build Coastguard Worker
4622*da0073e9SAndroid Build Coastguard Worker@_overload
4623*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest(  # noqa: F811
4624*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4625*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4626*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4627*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4628*da0073e9SAndroid Build Coastguard Worker    pass
4629*da0073e9SAndroid Build Coastguard Worker
4630*da0073e9SAndroid Build Coastguard Worker
4631*da0073e9SAndroid Build Coastguard Worker@_overload
4632*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest(  # noqa: F811
4633*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4634*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4635*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4636*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4637*da0073e9SAndroid Build Coastguard Worker    pass
4638*da0073e9SAndroid Build Coastguard Worker
4639*da0073e9SAndroid Build Coastguard Worker
4640*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest(input, size=None, scale_factor=None):  # noqa: F811
4641*da0073e9SAndroid Build Coastguard Worker    r"""Upsamples the input, using nearest neighbours' pixel values.
4642*da0073e9SAndroid Build Coastguard Worker
4643*da0073e9SAndroid Build Coastguard Worker    .. warning::
4644*da0073e9SAndroid Build Coastguard Worker        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4645*da0073e9SAndroid Build Coastguard Worker        This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``.
4646*da0073e9SAndroid Build Coastguard Worker
4647*da0073e9SAndroid Build Coastguard Worker    Currently spatial and volumetric upsampling are supported (i.e. expected
4648*da0073e9SAndroid Build Coastguard Worker    inputs are 4 or 5 dimensional).
4649*da0073e9SAndroid Build Coastguard Worker
4650*da0073e9SAndroid Build Coastguard Worker    Args:
4651*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input
4652*da0073e9SAndroid Build Coastguard Worker        size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia
4653*da0073e9SAndroid Build Coastguard Worker            size.
4654*da0073e9SAndroid Build Coastguard Worker        scale_factor (int): multiplier for spatial size. Has to be an integer.
4655*da0073e9SAndroid Build Coastguard Worker
4656*da0073e9SAndroid Build Coastguard Worker    Note:
4657*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
4658*da0073e9SAndroid Build Coastguard Worker    """
4659*da0073e9SAndroid Build Coastguard Worker    # DeprecationWarning is ignored by default
4660*da0073e9SAndroid Build Coastguard Worker    warnings.warn(
4661*da0073e9SAndroid Build Coastguard Worker        "`nn.functional.upsample_nearest` is deprecated. "
4662*da0073e9SAndroid Build Coastguard Worker        "Use `nn.functional.interpolate` instead.",
4663*da0073e9SAndroid Build Coastguard Worker        stacklevel=2,
4664*da0073e9SAndroid Build Coastguard Worker    )
4665*da0073e9SAndroid Build Coastguard Worker    return interpolate(input, size, scale_factor, mode="nearest")
4666*da0073e9SAndroid Build Coastguard Worker
4667*da0073e9SAndroid Build Coastguard Worker
4668*da0073e9SAndroid Build Coastguard Workerif upsample_nearest.__doc__:
4669*da0073e9SAndroid Build Coastguard Worker    upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes)
4670*da0073e9SAndroid Build Coastguard Worker
4671*da0073e9SAndroid Build Coastguard Worker
4672*da0073e9SAndroid Build Coastguard Worker@_overload
4673*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(  # noqa: F811
4674*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4675*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4676*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4677*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4678*da0073e9SAndroid Build Coastguard Worker    pass
4679*da0073e9SAndroid Build Coastguard Worker
4680*da0073e9SAndroid Build Coastguard Worker
4681*da0073e9SAndroid Build Coastguard Worker@_overload
4682*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(  # noqa: F811
4683*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4684*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4685*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[float] = None,
4686*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4687*da0073e9SAndroid Build Coastguard Worker    pass
4688*da0073e9SAndroid Build Coastguard Worker
4689*da0073e9SAndroid Build Coastguard Worker
4690*da0073e9SAndroid Build Coastguard Worker@_overload
4691*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(  # noqa: F811
4692*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4693*da0073e9SAndroid Build Coastguard Worker    size: Optional[int] = None,
4694*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[List[float]] = None,
4695*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4696*da0073e9SAndroid Build Coastguard Worker    pass
4697*da0073e9SAndroid Build Coastguard Worker
4698*da0073e9SAndroid Build Coastguard Worker
4699*da0073e9SAndroid Build Coastguard Worker@_overload
4700*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(  # noqa: F811
4701*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4702*da0073e9SAndroid Build Coastguard Worker    size: Optional[List[int]] = None,
4703*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[List[float]] = None,
4704*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4705*da0073e9SAndroid Build Coastguard Worker    pass
4706*da0073e9SAndroid Build Coastguard Worker
4707*da0073e9SAndroid Build Coastguard Worker
4708*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(input, size=None, scale_factor=None):  # noqa: F811
4709*da0073e9SAndroid Build Coastguard Worker    r"""Upsamples the input, using bilinear upsampling.
4710*da0073e9SAndroid Build Coastguard Worker
4711*da0073e9SAndroid Build Coastguard Worker    .. warning::
4712*da0073e9SAndroid Build Coastguard Worker        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4713*da0073e9SAndroid Build Coastguard Worker        This is equivalent with
4714*da0073e9SAndroid Build Coastguard Worker        ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
4715*da0073e9SAndroid Build Coastguard Worker
4716*da0073e9SAndroid Build Coastguard Worker    Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo
4717*da0073e9SAndroid Build Coastguard Worker    volumetric (5 dimensional) inputs.
4718*da0073e9SAndroid Build Coastguard Worker
4719*da0073e9SAndroid Build Coastguard Worker    Args:
4720*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input
4721*da0073e9SAndroid Build Coastguard Worker        size (int or Tuple[int, int]): output spatial size.
4722*da0073e9SAndroid Build Coastguard Worker        scale_factor (int or Tuple[int, int]): multiplier for spatial size
4723*da0073e9SAndroid Build Coastguard Worker
4724*da0073e9SAndroid Build Coastguard Worker    Note:
4725*da0073e9SAndroid Build Coastguard Worker        {backward_reproducibility_note}
4726*da0073e9SAndroid Build Coastguard Worker    """
4727*da0073e9SAndroid Build Coastguard Worker    # DeprecationWarning is ignored by default
4728*da0073e9SAndroid Build Coastguard Worker    warnings.warn(
4729*da0073e9SAndroid Build Coastguard Worker        "`nn.functional.upsample_bilinear` is deprecated. "
4730*da0073e9SAndroid Build Coastguard Worker        "Use `nn.functional.interpolate` instead.",
4731*da0073e9SAndroid Build Coastguard Worker        stacklevel=2,
4732*da0073e9SAndroid Build Coastguard Worker    )
4733*da0073e9SAndroid Build Coastguard Worker    return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
4734*da0073e9SAndroid Build Coastguard Worker
4735*da0073e9SAndroid Build Coastguard Worker
4736*da0073e9SAndroid Build Coastguard Workerif upsample_bilinear.__doc__:
4737*da0073e9SAndroid Build Coastguard Worker    upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(
4738*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes
4739*da0073e9SAndroid Build Coastguard Worker    )
4740*da0073e9SAndroid Build Coastguard Worker
4741*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_INTERPOLATION_MODES = {
4742*da0073e9SAndroid Build Coastguard Worker    "bilinear": 0,
4743*da0073e9SAndroid Build Coastguard Worker    "nearest": 1,
4744*da0073e9SAndroid Build Coastguard Worker    "bicubic": 2,
4745*da0073e9SAndroid Build Coastguard Worker}
4746*da0073e9SAndroid Build Coastguard Worker
4747*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_PADDING_MODES = {
4748*da0073e9SAndroid Build Coastguard Worker    "zeros": 0,
4749*da0073e9SAndroid Build Coastguard Worker    "border": 1,
4750*da0073e9SAndroid Build Coastguard Worker    "reflection": 2,
4751*da0073e9SAndroid Build Coastguard Worker}
4752*da0073e9SAndroid Build Coastguard Worker
4753*da0073e9SAndroid Build Coastguard Worker
4754*da0073e9SAndroid Build Coastguard Workerdef grid_sample(
4755*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
4756*da0073e9SAndroid Build Coastguard Worker    grid: Tensor,
4757*da0073e9SAndroid Build Coastguard Worker    mode: str = "bilinear",
4758*da0073e9SAndroid Build Coastguard Worker    padding_mode: str = "zeros",
4759*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4760*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4761*da0073e9SAndroid Build Coastguard Worker    r"""Compute grid sample.
4762*da0073e9SAndroid Build Coastguard Worker
4763*da0073e9SAndroid Build Coastguard Worker    Given an :attr:`input` and a flow-field :attr:`grid`, computes the
4764*da0073e9SAndroid Build Coastguard Worker    ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.
4765*da0073e9SAndroid Build Coastguard Worker
4766*da0073e9SAndroid Build Coastguard Worker    Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
4767*da0073e9SAndroid Build Coastguard Worker    supported.
4768*da0073e9SAndroid Build Coastguard Worker
4769*da0073e9SAndroid Build Coastguard Worker    In the spatial (4-D) case, for :attr:`input` with shape
4770*da0073e9SAndroid Build Coastguard Worker    :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape
4771*da0073e9SAndroid Build Coastguard Worker    :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape
4772*da0073e9SAndroid Build Coastguard Worker    :math:`(N, C, H_\text{out}, W_\text{out})`.
4773*da0073e9SAndroid Build Coastguard Worker
4774*da0073e9SAndroid Build Coastguard Worker    For each output location ``output[n, :, h, w]``, the size-2 vector
4775*da0073e9SAndroid Build Coastguard Worker    ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
4776*da0073e9SAndroid Build Coastguard Worker    which are used to interpolate the output value ``output[n, :, h, w]``.
4777*da0073e9SAndroid Build Coastguard Worker    In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
4778*da0073e9SAndroid Build Coastguard Worker    ``x``, ``y``, ``z`` pixel locations for interpolating
4779*da0073e9SAndroid Build Coastguard Worker    ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
4780*da0073e9SAndroid Build Coastguard Worker    ``bilinear`` interpolation method to sample the input pixels.
4781*da0073e9SAndroid Build Coastguard Worker
4782*da0073e9SAndroid Build Coastguard Worker    :attr:`grid` specifies the sampling pixel locations normalized by the
4783*da0073e9SAndroid Build Coastguard Worker    :attr:`input` spatial dimensions. Therefore, it should have most values in
4784*da0073e9SAndroid Build Coastguard Worker    the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the
4785*da0073e9SAndroid Build Coastguard Worker    left-top pixel of :attr:`input`, and values  ``x = 1, y = 1`` is the
4786*da0073e9SAndroid Build Coastguard Worker    right-bottom pixel of :attr:`input`.
4787*da0073e9SAndroid Build Coastguard Worker
4788*da0073e9SAndroid Build Coastguard Worker    If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding
4789*da0073e9SAndroid Build Coastguard Worker    outputs are handled as defined by :attr:`padding_mode`. Options are
4790*da0073e9SAndroid Build Coastguard Worker
4791*da0073e9SAndroid Build Coastguard Worker        * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations,
4792*da0073e9SAndroid Build Coastguard Worker        * ``padding_mode="border"``: use border values for out-of-bound grid locations,
4793*da0073e9SAndroid Build Coastguard Worker        * ``padding_mode="reflection"``: use values at locations reflected by
4794*da0073e9SAndroid Build Coastguard Worker          the border for out-of-bound grid locations. For location far away
4795*da0073e9SAndroid Build Coastguard Worker          from the border, it will keep being reflected until becoming in bound,
4796*da0073e9SAndroid Build Coastguard Worker          e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``
4797*da0073e9SAndroid Build Coastguard Worker          and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes
4798*da0073e9SAndroid Build Coastguard Worker          ``x'' = -0.5``.
4799*da0073e9SAndroid Build Coastguard Worker
4800*da0073e9SAndroid Build Coastguard Worker    Note:
4801*da0073e9SAndroid Build Coastguard Worker        This function is often used in conjunction with :func:`affine_grid`
4802*da0073e9SAndroid Build Coastguard Worker        to build `Spatial Transformer Networks`_ .
4803*da0073e9SAndroid Build Coastguard Worker
4804*da0073e9SAndroid Build Coastguard Worker    Note:
4805*da0073e9SAndroid Build Coastguard Worker        When using the CUDA backend, this operation may induce nondeterministic
4806*da0073e9SAndroid Build Coastguard Worker        behaviour in its backward pass that is not easily switched off.
4807*da0073e9SAndroid Build Coastguard Worker        Please see the notes on :doc:`/notes/randomness` for background.
4808*da0073e9SAndroid Build Coastguard Worker
4809*da0073e9SAndroid Build Coastguard Worker    Note:
4810*da0073e9SAndroid Build Coastguard Worker        NaN values in :attr:`grid` would be interpreted as ``-1``.
4811*da0073e9SAndroid Build Coastguard Worker
4812*da0073e9SAndroid Build Coastguard Worker    Args:
4813*da0073e9SAndroid Build Coastguard Worker        input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case)
4814*da0073e9SAndroid Build Coastguard Worker                        or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case)
4815*da0073e9SAndroid Build Coastguard Worker        grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case)
4816*da0073e9SAndroid Build Coastguard Worker                       or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case)
4817*da0073e9SAndroid Build Coastguard Worker        mode (str): interpolation mode to calculate output values
4818*da0073e9SAndroid Build Coastguard Worker            ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'``
4819*da0073e9SAndroid Build Coastguard Worker            Note: ``mode='bicubic'`` supports only 4-D input.
4820*da0073e9SAndroid Build Coastguard Worker            When ``mode='bilinear'`` and the input is 5-D, the interpolation mode
4821*da0073e9SAndroid Build Coastguard Worker            used internally will actually be trilinear. However, when the input is 4-D,
4822*da0073e9SAndroid Build Coastguard Worker            the interpolation mode will legitimately be bilinear.
4823*da0073e9SAndroid Build Coastguard Worker        padding_mode (str): padding mode for outside grid values
4824*da0073e9SAndroid Build Coastguard Worker            ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``
4825*da0073e9SAndroid Build Coastguard Worker        align_corners (bool, optional): Geometrically, we consider the pixels of the
4826*da0073e9SAndroid Build Coastguard Worker            input  as squares rather than points.
4827*da0073e9SAndroid Build Coastguard Worker            If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring
4828*da0073e9SAndroid Build Coastguard Worker            to the center points of the input's corner pixels. If set to ``False``, they
4829*da0073e9SAndroid Build Coastguard Worker            are instead considered as referring to the corner points of the input's corner
4830*da0073e9SAndroid Build Coastguard Worker            pixels, making the sampling more resolution agnostic.
4831*da0073e9SAndroid Build Coastguard Worker            This option parallels the ``align_corners`` option in
4832*da0073e9SAndroid Build Coastguard Worker            :func:`interpolate`, and so whichever option is used here
4833*da0073e9SAndroid Build Coastguard Worker            should also be used there to resize the input image before grid sampling.
4834*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
4835*da0073e9SAndroid Build Coastguard Worker
4836*da0073e9SAndroid Build Coastguard Worker    Returns:
4837*da0073e9SAndroid Build Coastguard Worker        output (Tensor): output Tensor
4838*da0073e9SAndroid Build Coastguard Worker
4839*da0073e9SAndroid Build Coastguard Worker    .. _`Spatial Transformer Networks`:
4840*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1506.02025
4841*da0073e9SAndroid Build Coastguard Worker
4842*da0073e9SAndroid Build Coastguard Worker    .. warning::
4843*da0073e9SAndroid Build Coastguard Worker        When ``align_corners = True``, the grid positions depend on the pixel
4844*da0073e9SAndroid Build Coastguard Worker        size relative to the input image size, and so the locations sampled by
4845*da0073e9SAndroid Build Coastguard Worker        :func:`grid_sample` will differ for the same input given at different
4846*da0073e9SAndroid Build Coastguard Worker        resolutions (that is, after being upsampled or downsampled).
4847*da0073e9SAndroid Build Coastguard Worker        The default behavior up to version 1.2.0 was ``align_corners = True``.
4848*da0073e9SAndroid Build Coastguard Worker        Since then, the default behavior has been changed to ``align_corners = False``,
4849*da0073e9SAndroid Build Coastguard Worker        in order to bring it in line with the default for :func:`interpolate`.
4850*da0073e9SAndroid Build Coastguard Worker
4851*da0073e9SAndroid Build Coastguard Worker    .. note::
4852*da0073e9SAndroid Build Coastguard Worker        ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`.
4853*da0073e9SAndroid Build Coastguard Worker        The constant :math:`\alpha` might be different from packages to packages.
4854*da0073e9SAndroid Build Coastguard Worker        For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively.
4855*da0073e9SAndroid Build Coastguard Worker        This algorithm may "overshoot" the range of values it's interpolating.
4856*da0073e9SAndroid Build Coastguard Worker        For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].
4857*da0073e9SAndroid Build Coastguard Worker        Clamp the results with :func:`torch.clamp` to ensure they are within the valid range.
4858*da0073e9SAndroid Build Coastguard Worker    .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation
4859*da0073e9SAndroid Build Coastguard Worker    .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51
4860*da0073e9SAndroid Build Coastguard Worker    .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908
4861*da0073e9SAndroid Build Coastguard Worker    """
4862*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, grid):
4863*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
4864*da0073e9SAndroid Build Coastguard Worker            grid_sample,
4865*da0073e9SAndroid Build Coastguard Worker            (input, grid),
4866*da0073e9SAndroid Build Coastguard Worker            input,
4867*da0073e9SAndroid Build Coastguard Worker            grid,
4868*da0073e9SAndroid Build Coastguard Worker            mode=mode,
4869*da0073e9SAndroid Build Coastguard Worker            padding_mode=padding_mode,
4870*da0073e9SAndroid Build Coastguard Worker            align_corners=align_corners,
4871*da0073e9SAndroid Build Coastguard Worker        )
4872*da0073e9SAndroid Build Coastguard Worker    if mode != "bilinear" and mode != "nearest" and mode != "bicubic":
4873*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
4874*da0073e9SAndroid Build Coastguard Worker            f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'"
4875*da0073e9SAndroid Build Coastguard Worker        )
4876*da0073e9SAndroid Build Coastguard Worker    if (
4877*da0073e9SAndroid Build Coastguard Worker        padding_mode != "zeros"
4878*da0073e9SAndroid Build Coastguard Worker        and padding_mode != "border"
4879*da0073e9SAndroid Build Coastguard Worker        and padding_mode != "reflection"
4880*da0073e9SAndroid Build Coastguard Worker    ):
4881*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
4882*da0073e9SAndroid Build Coastguard Worker            "nn.functional.grid_sample(): expected padding_mode "
4883*da0073e9SAndroid Build Coastguard Worker            "to be 'zeros', 'border', or 'reflection', "
4884*da0073e9SAndroid Build Coastguard Worker            f"but got: '{padding_mode}'"
4885*da0073e9SAndroid Build Coastguard Worker        )
4886*da0073e9SAndroid Build Coastguard Worker
4887*da0073e9SAndroid Build Coastguard Worker    if mode == "bilinear":
4888*da0073e9SAndroid Build Coastguard Worker        mode_enum = 0
4889*da0073e9SAndroid Build Coastguard Worker    elif mode == "nearest":
4890*da0073e9SAndroid Build Coastguard Worker        mode_enum = 1
4891*da0073e9SAndroid Build Coastguard Worker    else:  # mode == 'bicubic'
4892*da0073e9SAndroid Build Coastguard Worker        mode_enum = 2
4893*da0073e9SAndroid Build Coastguard Worker
4894*da0073e9SAndroid Build Coastguard Worker    if padding_mode == "zeros":
4895*da0073e9SAndroid Build Coastguard Worker        padding_mode_enum = 0
4896*da0073e9SAndroid Build Coastguard Worker    elif padding_mode == "border":
4897*da0073e9SAndroid Build Coastguard Worker        padding_mode_enum = 1
4898*da0073e9SAndroid Build Coastguard Worker    else:  # padding_mode == 'reflection'
4899*da0073e9SAndroid Build Coastguard Worker        padding_mode_enum = 2
4900*da0073e9SAndroid Build Coastguard Worker
4901*da0073e9SAndroid Build Coastguard Worker    if align_corners is None:
4902*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
4903*da0073e9SAndroid Build Coastguard Worker            "Default grid_sample and affine_grid behavior has changed "
4904*da0073e9SAndroid Build Coastguard Worker            "to align_corners=False since 1.3.0. Please specify "
4905*da0073e9SAndroid Build Coastguard Worker            "align_corners=True if the old behavior is desired. "
4906*da0073e9SAndroid Build Coastguard Worker            "See the documentation of grid_sample for details."
4907*da0073e9SAndroid Build Coastguard Worker        )
4908*da0073e9SAndroid Build Coastguard Worker        align_corners = False
4909*da0073e9SAndroid Build Coastguard Worker
4910*da0073e9SAndroid Build Coastguard Worker    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
4911*da0073e9SAndroid Build Coastguard Worker
4912*da0073e9SAndroid Build Coastguard Worker
4913*da0073e9SAndroid Build Coastguard Workerdef affine_grid(
4914*da0073e9SAndroid Build Coastguard Worker    theta: Tensor,
4915*da0073e9SAndroid Build Coastguard Worker    size: List[int],
4916*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[bool] = None,
4917*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
4918*da0073e9SAndroid Build Coastguard Worker    r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`.
4919*da0073e9SAndroid Build Coastguard Worker
4920*da0073e9SAndroid Build Coastguard Worker    .. note::
4921*da0073e9SAndroid Build Coastguard Worker        This function is often used in conjunction with :func:`grid_sample`
4922*da0073e9SAndroid Build Coastguard Worker        to build `Spatial Transformer Networks`_ .
4923*da0073e9SAndroid Build Coastguard Worker
4924*da0073e9SAndroid Build Coastguard Worker    Args:
4925*da0073e9SAndroid Build Coastguard Worker        theta (Tensor): input batch of affine matrices with shape
4926*da0073e9SAndroid Build Coastguard Worker            (:math:`N \times 2 \times 3`) for 2D or
4927*da0073e9SAndroid Build Coastguard Worker            (:math:`N \times 3 \times 4`) for 3D
4928*da0073e9SAndroid Build Coastguard Worker        size (torch.Size): the target output image size.
4929*da0073e9SAndroid Build Coastguard Worker            (:math:`N \times C \times H \times W` for 2D or
4930*da0073e9SAndroid Build Coastguard Worker            :math:`N \times C \times D \times H \times W` for 3D)
4931*da0073e9SAndroid Build Coastguard Worker            Example: torch.Size((32, 3, 24, 24))
4932*da0073e9SAndroid Build Coastguard Worker        align_corners (bool, optional): if ``True``, consider ``-1`` and ``1``
4933*da0073e9SAndroid Build Coastguard Worker            to refer to the centers of the corner pixels rather than the image corners.
4934*da0073e9SAndroid Build Coastguard Worker            Refer to :func:`grid_sample` for a more complete description.
4935*da0073e9SAndroid Build Coastguard Worker            A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample`
4936*da0073e9SAndroid Build Coastguard Worker            with the same setting for this option.
4937*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
4938*da0073e9SAndroid Build Coastguard Worker
4939*da0073e9SAndroid Build Coastguard Worker    Returns:
4940*da0073e9SAndroid Build Coastguard Worker        output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`)
4941*da0073e9SAndroid Build Coastguard Worker
4942*da0073e9SAndroid Build Coastguard Worker    .. _`Spatial Transformer Networks`:
4943*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1506.02025
4944*da0073e9SAndroid Build Coastguard Worker
4945*da0073e9SAndroid Build Coastguard Worker    .. warning::
4946*da0073e9SAndroid Build Coastguard Worker        When ``align_corners = True``, the grid positions depend on the pixel
4947*da0073e9SAndroid Build Coastguard Worker        size relative to the input image size, and so the locations sampled by
4948*da0073e9SAndroid Build Coastguard Worker        :func:`grid_sample` will differ for the same input given at different
4949*da0073e9SAndroid Build Coastguard Worker        resolutions (that is, after being upsampled or downsampled).
4950*da0073e9SAndroid Build Coastguard Worker        The default behavior up to version 1.2.0 was ``align_corners = True``.
4951*da0073e9SAndroid Build Coastguard Worker        Since then, the default behavior has been changed to ``align_corners = False``,
4952*da0073e9SAndroid Build Coastguard Worker        in order to bring it in line with the default for :func:`interpolate`.
4953*da0073e9SAndroid Build Coastguard Worker    .. warning::
4954*da0073e9SAndroid Build Coastguard Worker        When ``align_corners = True``, 2D affine transforms on 1D data and
4955*da0073e9SAndroid Build Coastguard Worker        3D affine transforms on 2D data (that is, when one of the spatial
4956*da0073e9SAndroid Build Coastguard Worker        dimensions has unit size) are ill-defined, and not an intended use case.
4957*da0073e9SAndroid Build Coastguard Worker        This is not a problem when ``align_corners = False``.
4958*da0073e9SAndroid Build Coastguard Worker        Up to version 1.2.0, all grid points along a unit dimension were
4959*da0073e9SAndroid Build Coastguard Worker        considered arbitrarily to be at ``-1``.
4960*da0073e9SAndroid Build Coastguard Worker        From version 1.3.0, under ``align_corners = True`` all grid points
4961*da0073e9SAndroid Build Coastguard Worker        along a unit dimension are considered to be at ``0``
4962*da0073e9SAndroid Build Coastguard Worker        (the center of the input image).
4963*da0073e9SAndroid Build Coastguard Worker    """
4964*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(theta):
4965*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
4966*da0073e9SAndroid Build Coastguard Worker            affine_grid, (theta,), theta, size, align_corners=align_corners
4967*da0073e9SAndroid Build Coastguard Worker        )
4968*da0073e9SAndroid Build Coastguard Worker    if align_corners is None:
4969*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
4970*da0073e9SAndroid Build Coastguard Worker            "Default grid_sample and affine_grid behavior has changed "
4971*da0073e9SAndroid Build Coastguard Worker            "to align_corners=False since 1.3.0. Please specify "
4972*da0073e9SAndroid Build Coastguard Worker            "align_corners=True if the old behavior is desired. "
4973*da0073e9SAndroid Build Coastguard Worker            "See the documentation of grid_sample for details."
4974*da0073e9SAndroid Build Coastguard Worker        )
4975*da0073e9SAndroid Build Coastguard Worker        align_corners = False
4976*da0073e9SAndroid Build Coastguard Worker
4977*da0073e9SAndroid Build Coastguard Worker    # enforce floating point dtype on theta
4978*da0073e9SAndroid Build Coastguard Worker    if not theta.is_floating_point():
4979*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
4980*da0073e9SAndroid Build Coastguard Worker            f"Expected theta to have floating point type, but got {theta.dtype}"
4981*da0073e9SAndroid Build Coastguard Worker        )
4982*da0073e9SAndroid Build Coastguard Worker    # check that shapes and sizes match
4983*da0073e9SAndroid Build Coastguard Worker    if len(size) == 4:
4984*da0073e9SAndroid Build Coastguard Worker        if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3:
4985*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
4986*da0073e9SAndroid Build Coastguard Worker                f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}."
4987*da0073e9SAndroid Build Coastguard Worker            )
4988*da0073e9SAndroid Build Coastguard Worker        spatial_size = size[-2:]  # spatial dimension sizes
4989*da0073e9SAndroid Build Coastguard Worker    elif len(size) == 5:
4990*da0073e9SAndroid Build Coastguard Worker        if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4:
4991*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
4992*da0073e9SAndroid Build Coastguard Worker                f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}."
4993*da0073e9SAndroid Build Coastguard Worker            )
4994*da0073e9SAndroid Build Coastguard Worker        spatial_size = size[-3:]  # spatial dimension sizes
4995*da0073e9SAndroid Build Coastguard Worker    else:
4996*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
4997*da0073e9SAndroid Build Coastguard Worker            "affine_grid only supports 4D and 5D sizes, "
4998*da0073e9SAndroid Build Coastguard Worker            "for 2D and 3D affine transforms, respectively. "
4999*da0073e9SAndroid Build Coastguard Worker            f"Got size {size}."
5000*da0073e9SAndroid Build Coastguard Worker        )
5001*da0073e9SAndroid Build Coastguard Worker    # check for empty span
5002*da0073e9SAndroid Build Coastguard Worker    if align_corners and min(spatial_size) == 1:
5003*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
5004*da0073e9SAndroid Build Coastguard Worker            "Since version 1.3.0, affine_grid behavior has changed "
5005*da0073e9SAndroid Build Coastguard Worker            "for unit-size grids when align_corners=True. "
5006*da0073e9SAndroid Build Coastguard Worker            "This is not an intended use case of affine_grid. "
5007*da0073e9SAndroid Build Coastguard Worker            "See the documentation of affine_grid for details."
5008*da0073e9SAndroid Build Coastguard Worker        )
5009*da0073e9SAndroid Build Coastguard Worker    elif min(size) <= 0:
5010*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Expected non-zero, positive output size. Got {size}")
5011*da0073e9SAndroid Build Coastguard Worker
5012*da0073e9SAndroid Build Coastguard Worker    return torch.affine_grid_generator(theta, size, align_corners)
5013*da0073e9SAndroid Build Coastguard Worker
5014*da0073e9SAndroid Build Coastguard Worker
5015*da0073e9SAndroid Build Coastguard Workerdef pad(
5016*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
5017*da0073e9SAndroid Build Coastguard Worker    pad: List[int],
5018*da0073e9SAndroid Build Coastguard Worker    mode: str = "constant",
5019*da0073e9SAndroid Build Coastguard Worker    value: Optional[float] = None,
5020*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5021*da0073e9SAndroid Build Coastguard Worker    r"""
5022*da0073e9SAndroid Build Coastguard Worker    pad(input, pad, mode="constant", value=None) -> Tensor
5023*da0073e9SAndroid Build Coastguard Worker
5024*da0073e9SAndroid Build Coastguard Worker    Pads tensor.
5025*da0073e9SAndroid Build Coastguard Worker
5026*da0073e9SAndroid Build Coastguard Worker    Padding size:
5027*da0073e9SAndroid Build Coastguard Worker        The padding size by which to pad some dimensions of :attr:`input`
5028*da0073e9SAndroid Build Coastguard Worker        are described starting from the last dimension and moving forward.
5029*da0073e9SAndroid Build Coastguard Worker        :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
5030*da0073e9SAndroid Build Coastguard Worker        of ``input`` will be padded.
5031*da0073e9SAndroid Build Coastguard Worker        For example, to pad only the last dimension of the input tensor, then
5032*da0073e9SAndroid Build Coastguard Worker        :attr:`pad` has the form
5033*da0073e9SAndroid Build Coastguard Worker        :math:`(\text{padding\_left}, \text{padding\_right})`;
5034*da0073e9SAndroid Build Coastguard Worker        to pad the last 2 dimensions of the input tensor, then use
5035*da0073e9SAndroid Build Coastguard Worker        :math:`(\text{padding\_left}, \text{padding\_right},`
5036*da0073e9SAndroid Build Coastguard Worker        :math:`\text{padding\_top}, \text{padding\_bottom})`;
5037*da0073e9SAndroid Build Coastguard Worker        to pad the last 3 dimensions, use
5038*da0073e9SAndroid Build Coastguard Worker        :math:`(\text{padding\_left}, \text{padding\_right},`
5039*da0073e9SAndroid Build Coastguard Worker        :math:`\text{padding\_top}, \text{padding\_bottom}`
5040*da0073e9SAndroid Build Coastguard Worker        :math:`\text{padding\_front}, \text{padding\_back})`.
5041*da0073e9SAndroid Build Coastguard Worker
5042*da0073e9SAndroid Build Coastguard Worker    Padding mode:
5043*da0073e9SAndroid Build Coastguard Worker        See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`,
5044*da0073e9SAndroid Build Coastguard Worker        :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d`
5045*da0073e9SAndroid Build Coastguard Worker        for concrete examples on how each of the padding modes works. Constant
5046*da0073e9SAndroid Build Coastguard Worker        padding is implemented for arbitrary dimensions. Circular, replicate and
5047*da0073e9SAndroid Build Coastguard Worker        reflection padding are implemented for padding the last 3 dimensions of a
5048*da0073e9SAndroid Build Coastguard Worker        4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor,
5049*da0073e9SAndroid Build Coastguard Worker        or the last dimension of a 2D or 3D input tensor.
5050*da0073e9SAndroid Build Coastguard Worker
5051*da0073e9SAndroid Build Coastguard Worker    Note:
5052*da0073e9SAndroid Build Coastguard Worker        When using the CUDA backend, this operation may induce nondeterministic
5053*da0073e9SAndroid Build Coastguard Worker        behaviour in its backward pass that is not easily switched off.
5054*da0073e9SAndroid Build Coastguard Worker        Please see the notes on :doc:`/notes/randomness` for background.
5055*da0073e9SAndroid Build Coastguard Worker
5056*da0073e9SAndroid Build Coastguard Worker    Args:
5057*da0073e9SAndroid Build Coastguard Worker        input (Tensor): N-dimensional tensor
5058*da0073e9SAndroid Build Coastguard Worker        pad (tuple): m-elements tuple, where
5059*da0073e9SAndroid Build Coastguard Worker            :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
5060*da0073e9SAndroid Build Coastguard Worker        mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
5061*da0073e9SAndroid Build Coastguard Worker            Default: ``'constant'``
5062*da0073e9SAndroid Build Coastguard Worker        value: fill value for ``'constant'`` padding. Default: ``0``
5063*da0073e9SAndroid Build Coastguard Worker
5064*da0073e9SAndroid Build Coastguard Worker    Examples::
5065*da0073e9SAndroid Build Coastguard Worker
5066*da0073e9SAndroid Build Coastguard Worker        >>> t4d = torch.empty(3, 3, 4, 2)
5067*da0073e9SAndroid Build Coastguard Worker        >>> p1d = (1, 1) # pad last dim by 1 on each side
5068*da0073e9SAndroid Build Coastguard Worker        >>> out = F.pad(t4d, p1d, "constant", 0)  # effectively zero padding
5069*da0073e9SAndroid Build Coastguard Worker        >>> print(out.size())
5070*da0073e9SAndroid Build Coastguard Worker        torch.Size([3, 3, 4, 4])
5071*da0073e9SAndroid Build Coastguard Worker        >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
5072*da0073e9SAndroid Build Coastguard Worker        >>> out = F.pad(t4d, p2d, "constant", 0)
5073*da0073e9SAndroid Build Coastguard Worker        >>> print(out.size())
5074*da0073e9SAndroid Build Coastguard Worker        torch.Size([3, 3, 8, 4])
5075*da0073e9SAndroid Build Coastguard Worker        >>> t4d = torch.empty(3, 3, 4, 2)
5076*da0073e9SAndroid Build Coastguard Worker        >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
5077*da0073e9SAndroid Build Coastguard Worker        >>> out = F.pad(t4d, p3d, "constant", 0)
5078*da0073e9SAndroid Build Coastguard Worker        >>> print(out.size())
5079*da0073e9SAndroid Build Coastguard Worker        torch.Size([3, 9, 7, 3])
5080*da0073e9SAndroid Build Coastguard Worker    """
5081*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
5082*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5083*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value
5084*da0073e9SAndroid Build Coastguard Worker        )
5085*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
5086*da0073e9SAndroid Build Coastguard Worker        if torch.are_deterministic_algorithms_enabled() and (
5087*da0073e9SAndroid Build Coastguard Worker            input.is_cuda or input.is_xpu
5088*da0073e9SAndroid Build Coastguard Worker        ):
5089*da0073e9SAndroid Build Coastguard Worker            if mode == "replicate":
5090*da0073e9SAndroid Build Coastguard Worker                # Use slow decomp whose backward will be in terms of index_put.
5091*da0073e9SAndroid Build Coastguard Worker                # importlib is required because the import cannot be top level
5092*da0073e9SAndroid Build Coastguard Worker                # (cycle) and cannot be nested (TS doesn't support)
5093*da0073e9SAndroid Build Coastguard Worker                return importlib.import_module(
5094*da0073e9SAndroid Build Coastguard Worker                    "torch._decomp.decompositions"
5095*da0073e9SAndroid Build Coastguard Worker                )._replication_pad(input, pad)
5096*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.pad(input, pad, mode, value)
5097*da0073e9SAndroid Build Coastguard Worker
5098*da0073e9SAndroid Build Coastguard Worker
5099*da0073e9SAndroid Build Coastguard Worker# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
5100*da0073e9SAndroid Build Coastguard Workerpad.__module__ = "torch.nn.functional"
5101*da0073e9SAndroid Build Coastguard Worker
5102*da0073e9SAndroid Build Coastguard Worker# distance
5103*da0073e9SAndroid Build Coastguard Worker
5104*da0073e9SAndroid Build Coastguard Worker
5105*da0073e9SAndroid Build Coastguard Workerpairwise_distance = _add_docstr(
5106*da0073e9SAndroid Build Coastguard Worker    torch.pairwise_distance,
5107*da0073e9SAndroid Build Coastguard Worker    r"""
5108*da0073e9SAndroid Build Coastguard Workerpairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor
5109*da0073e9SAndroid Build Coastguard Worker
5110*da0073e9SAndroid Build Coastguard WorkerSee :class:`torch.nn.PairwiseDistance` for details
5111*da0073e9SAndroid Build Coastguard Worker""",
5112*da0073e9SAndroid Build Coastguard Worker)
5113*da0073e9SAndroid Build Coastguard Worker
5114*da0073e9SAndroid Build Coastguard Worker
5115*da0073e9SAndroid Build Coastguard Workerpdist = _add_docstr(
5116*da0073e9SAndroid Build Coastguard Worker    torch.pdist,
5117*da0073e9SAndroid Build Coastguard Worker    r"""
5118*da0073e9SAndroid Build Coastguard Workerpdist(input, p=2) -> Tensor
5119*da0073e9SAndroid Build Coastguard Worker
5120*da0073e9SAndroid Build Coastguard WorkerComputes the p-norm distance between every pair of row vectors in the input.
5121*da0073e9SAndroid Build Coastguard WorkerThis is identical to the upper triangular portion, excluding the diagonal, of
5122*da0073e9SAndroid Build Coastguard Worker`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
5123*da0073e9SAndroid Build Coastguard Workerif the rows are contiguous.
5124*da0073e9SAndroid Build Coastguard Worker
5125*da0073e9SAndroid Build Coastguard WorkerIf input has shape :math:`N \times M` then the output will have shape
5126*da0073e9SAndroid Build Coastguard Worker:math:`\frac{1}{2} N (N - 1)`.
5127*da0073e9SAndroid Build Coastguard Worker
5128*da0073e9SAndroid Build Coastguard WorkerThis function is equivalent to ``scipy.spatial.distance.pdist(input,
5129*da0073e9SAndroid Build Coastguard Worker'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
5130*da0073e9SAndroid Build Coastguard Workerequivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``.
5131*da0073e9SAndroid Build Coastguard WorkerWhen :math:`p = \infty`, the closest scipy function is
5132*da0073e9SAndroid Build Coastguard Worker``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``.
5133*da0073e9SAndroid Build Coastguard Worker
5134*da0073e9SAndroid Build Coastguard WorkerArgs:
5135*da0073e9SAndroid Build Coastguard Worker    input: input tensor of shape :math:`N \times M`.
5136*da0073e9SAndroid Build Coastguard Worker    p: p value for the p-norm distance to calculate between each vector pair
5137*da0073e9SAndroid Build Coastguard Worker        :math:`\in [0, \infty]`.
5138*da0073e9SAndroid Build Coastguard Worker""",
5139*da0073e9SAndroid Build Coastguard Worker)
5140*da0073e9SAndroid Build Coastguard Worker
5141*da0073e9SAndroid Build Coastguard Worker
5142*da0073e9SAndroid Build Coastguard Workercosine_similarity = _add_docstr(
5143*da0073e9SAndroid Build Coastguard Worker    torch.cosine_similarity,
5144*da0073e9SAndroid Build Coastguard Worker    r"""
5145*da0073e9SAndroid Build Coastguard Workercosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
5146*da0073e9SAndroid Build Coastguard Worker
5147*da0073e9SAndroid Build Coastguard WorkerReturns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
5148*da0073e9SAndroid Build Coastguard Workerto a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
5149*da0073e9SAndroid Build Coastguard Workersqueezed (see :func:`torch.squeeze`), resulting in the
5150*da0073e9SAndroid Build Coastguard Workeroutput tensor having 1 fewer dimension.
5151*da0073e9SAndroid Build Coastguard Worker
5152*da0073e9SAndroid Build Coastguard Worker.. math ::
5153*da0073e9SAndroid Build Coastguard Worker    \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)}
5154*da0073e9SAndroid Build Coastguard Worker
5155*da0073e9SAndroid Build Coastguard WorkerSupports :ref:`type promotion <type-promotion-doc>`.
5156*da0073e9SAndroid Build Coastguard Worker
5157*da0073e9SAndroid Build Coastguard WorkerArgs:
5158*da0073e9SAndroid Build Coastguard Worker    x1 (Tensor): First input.
5159*da0073e9SAndroid Build Coastguard Worker    x2 (Tensor): Second input.
5160*da0073e9SAndroid Build Coastguard Worker    dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
5161*da0073e9SAndroid Build Coastguard Worker    eps (float, optional): Small value to avoid division by zero.
5162*da0073e9SAndroid Build Coastguard Worker        Default: 1e-8
5163*da0073e9SAndroid Build Coastguard Worker
5164*da0073e9SAndroid Build Coastguard WorkerExample::
5165*da0073e9SAndroid Build Coastguard Worker
5166*da0073e9SAndroid Build Coastguard Worker    >>> input1 = torch.randn(100, 128)
5167*da0073e9SAndroid Build Coastguard Worker    >>> input2 = torch.randn(100, 128)
5168*da0073e9SAndroid Build Coastguard Worker    >>> output = F.cosine_similarity(input1, input2)
5169*da0073e9SAndroid Build Coastguard Worker    >>> print(output)
5170*da0073e9SAndroid Build Coastguard Worker""",
5171*da0073e9SAndroid Build Coastguard Worker)
5172*da0073e9SAndroid Build Coastguard Worker
5173*da0073e9SAndroid Build Coastguard Worker
5174*da0073e9SAndroid Build Coastguard Workerone_hot = _add_docstr(
5175*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.one_hot,
5176*da0073e9SAndroid Build Coastguard Worker    r"""
5177*da0073e9SAndroid Build Coastguard Workerone_hot(tensor, num_classes=-1) -> LongTensor
5178*da0073e9SAndroid Build Coastguard Worker
5179*da0073e9SAndroid Build Coastguard WorkerTakes LongTensor with index values of shape ``(*)`` and returns a tensor
5180*da0073e9SAndroid Build Coastguard Workerof shape ``(*, num_classes)`` that have zeros everywhere except where the
5181*da0073e9SAndroid Build Coastguard Workerindex of last dimension matches the corresponding value of the input tensor,
5182*da0073e9SAndroid Build Coastguard Workerin which case it will be 1.
5183*da0073e9SAndroid Build Coastguard Worker
5184*da0073e9SAndroid Build Coastguard WorkerSee also `One-hot on Wikipedia`_ .
5185*da0073e9SAndroid Build Coastguard Worker
5186*da0073e9SAndroid Build Coastguard Worker.. _One-hot on Wikipedia:
5187*da0073e9SAndroid Build Coastguard Worker    https://en.wikipedia.org/wiki/One-hot
5188*da0073e9SAndroid Build Coastguard Worker
5189*da0073e9SAndroid Build Coastguard WorkerArguments:
5190*da0073e9SAndroid Build Coastguard Worker    tensor (LongTensor): class values of any shape.
5191*da0073e9SAndroid Build Coastguard Worker    num_classes (int):  Total number of classes. If set to -1, the number
5192*da0073e9SAndroid Build Coastguard Worker        of classes will be inferred as one greater than the largest class
5193*da0073e9SAndroid Build Coastguard Worker        value in the input tensor.
5194*da0073e9SAndroid Build Coastguard Worker
5195*da0073e9SAndroid Build Coastguard WorkerReturns:
5196*da0073e9SAndroid Build Coastguard Worker    LongTensor that has one more dimension with 1 values at the
5197*da0073e9SAndroid Build Coastguard Worker    index of last dimension indicated by the input, and 0 everywhere
5198*da0073e9SAndroid Build Coastguard Worker    else.
5199*da0073e9SAndroid Build Coastguard Worker
5200*da0073e9SAndroid Build Coastguard WorkerExamples:
5201*da0073e9SAndroid Build Coastguard Worker    >>> F.one_hot(torch.arange(0, 5) % 3)
5202*da0073e9SAndroid Build Coastguard Worker    tensor([[1, 0, 0],
5203*da0073e9SAndroid Build Coastguard Worker            [0, 1, 0],
5204*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1],
5205*da0073e9SAndroid Build Coastguard Worker            [1, 0, 0],
5206*da0073e9SAndroid Build Coastguard Worker            [0, 1, 0]])
5207*da0073e9SAndroid Build Coastguard Worker    >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
5208*da0073e9SAndroid Build Coastguard Worker    tensor([[1, 0, 0, 0, 0],
5209*da0073e9SAndroid Build Coastguard Worker            [0, 1, 0, 0, 0],
5210*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1, 0, 0],
5211*da0073e9SAndroid Build Coastguard Worker            [1, 0, 0, 0, 0],
5212*da0073e9SAndroid Build Coastguard Worker            [0, 1, 0, 0, 0]])
5213*da0073e9SAndroid Build Coastguard Worker    >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
5214*da0073e9SAndroid Build Coastguard Worker    tensor([[[1, 0, 0],
5215*da0073e9SAndroid Build Coastguard Worker             [0, 1, 0]],
5216*da0073e9SAndroid Build Coastguard Worker            [[0, 0, 1],
5217*da0073e9SAndroid Build Coastguard Worker             [1, 0, 0]],
5218*da0073e9SAndroid Build Coastguard Worker            [[0, 1, 0],
5219*da0073e9SAndroid Build Coastguard Worker             [0, 0, 1]]])
5220*da0073e9SAndroid Build Coastguard Worker""",
5221*da0073e9SAndroid Build Coastguard Worker)
5222*da0073e9SAndroid Build Coastguard Worker
5223*da0073e9SAndroid Build Coastguard Worker
5224*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_loss(
5225*da0073e9SAndroid Build Coastguard Worker    anchor: Tensor,
5226*da0073e9SAndroid Build Coastguard Worker    positive: Tensor,
5227*da0073e9SAndroid Build Coastguard Worker    negative: Tensor,
5228*da0073e9SAndroid Build Coastguard Worker    margin: float = 1.0,
5229*da0073e9SAndroid Build Coastguard Worker    p: float = 2,
5230*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-6,
5231*da0073e9SAndroid Build Coastguard Worker    swap: bool = False,
5232*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = None,
5233*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = None,
5234*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
5235*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5236*da0073e9SAndroid Build Coastguard Worker    r"""Compute the triplet loss between given input tensors and a margin greater than 0.
5237*da0073e9SAndroid Build Coastguard Worker
5238*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.TripletMarginLoss` for details.
5239*da0073e9SAndroid Build Coastguard Worker    """
5240*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(anchor, positive, negative):
5241*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5242*da0073e9SAndroid Build Coastguard Worker            triplet_margin_loss,
5243*da0073e9SAndroid Build Coastguard Worker            (anchor, positive, negative),
5244*da0073e9SAndroid Build Coastguard Worker            anchor,
5245*da0073e9SAndroid Build Coastguard Worker            positive,
5246*da0073e9SAndroid Build Coastguard Worker            negative,
5247*da0073e9SAndroid Build Coastguard Worker            margin=margin,
5248*da0073e9SAndroid Build Coastguard Worker            p=p,
5249*da0073e9SAndroid Build Coastguard Worker            eps=eps,
5250*da0073e9SAndroid Build Coastguard Worker            swap=swap,
5251*da0073e9SAndroid Build Coastguard Worker            size_average=size_average,
5252*da0073e9SAndroid Build Coastguard Worker            reduce=reduce,
5253*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
5254*da0073e9SAndroid Build Coastguard Worker        )
5255*da0073e9SAndroid Build Coastguard Worker    if size_average is not None or reduce is not None:
5256*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
5257*da0073e9SAndroid Build Coastguard Worker    else:
5258*da0073e9SAndroid Build Coastguard Worker        reduction_enum = _Reduction.get_enum(reduction)
5259*da0073e9SAndroid Build Coastguard Worker    if margin <= 0:
5260*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"margin must be greater than 0, got {margin}")
5261*da0073e9SAndroid Build Coastguard Worker    return torch.triplet_margin_loss(
5262*da0073e9SAndroid Build Coastguard Worker        anchor, positive, negative, margin, p, eps, swap, reduction_enum
5263*da0073e9SAndroid Build Coastguard Worker    )
5264*da0073e9SAndroid Build Coastguard Worker
5265*da0073e9SAndroid Build Coastguard Worker
5266*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_with_distance_loss(
5267*da0073e9SAndroid Build Coastguard Worker    anchor: Tensor,
5268*da0073e9SAndroid Build Coastguard Worker    positive: Tensor,
5269*da0073e9SAndroid Build Coastguard Worker    negative: Tensor,
5270*da0073e9SAndroid Build Coastguard Worker    *,
5271*da0073e9SAndroid Build Coastguard Worker    distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
5272*da0073e9SAndroid Build Coastguard Worker    margin: float = 1.0,
5273*da0073e9SAndroid Build Coastguard Worker    swap: bool = False,
5274*da0073e9SAndroid Build Coastguard Worker    reduction: str = "mean",
5275*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5276*da0073e9SAndroid Build Coastguard Worker    r"""Compute the triplet margin loss for input tensors using a custom distance function.
5277*da0073e9SAndroid Build Coastguard Worker
5278*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
5279*da0073e9SAndroid Build Coastguard Worker    """
5280*da0073e9SAndroid Build Coastguard Worker    if torch.jit.is_scripting():
5281*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
5282*da0073e9SAndroid Build Coastguard Worker            "F.triplet_margin_with_distance_loss does not support JIT scripting: "
5283*da0073e9SAndroid Build Coastguard Worker            "functions requiring Callables cannot be scripted."
5284*da0073e9SAndroid Build Coastguard Worker        )
5285*da0073e9SAndroid Build Coastguard Worker
5286*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(anchor, positive, negative):
5287*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5288*da0073e9SAndroid Build Coastguard Worker            triplet_margin_with_distance_loss,
5289*da0073e9SAndroid Build Coastguard Worker            (anchor, positive, negative),
5290*da0073e9SAndroid Build Coastguard Worker            anchor,
5291*da0073e9SAndroid Build Coastguard Worker            positive,
5292*da0073e9SAndroid Build Coastguard Worker            negative,
5293*da0073e9SAndroid Build Coastguard Worker            distance_function=distance_function,
5294*da0073e9SAndroid Build Coastguard Worker            margin=margin,
5295*da0073e9SAndroid Build Coastguard Worker            swap=swap,
5296*da0073e9SAndroid Build Coastguard Worker            reduction=reduction,
5297*da0073e9SAndroid Build Coastguard Worker        )
5298*da0073e9SAndroid Build Coastguard Worker
5299*da0073e9SAndroid Build Coastguard Worker    # Check validity of reduction mode
5300*da0073e9SAndroid Build Coastguard Worker    if reduction not in ("mean", "sum", "none"):
5301*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"{reduction} is not a valid value for reduction")
5302*da0073e9SAndroid Build Coastguard Worker
5303*da0073e9SAndroid Build Coastguard Worker    # Check validity of margin
5304*da0073e9SAndroid Build Coastguard Worker    if margin <= 0:
5305*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"margin must be greater than 0, got {margin}")
5306*da0073e9SAndroid Build Coastguard Worker
5307*da0073e9SAndroid Build Coastguard Worker    # Check dimensions
5308*da0073e9SAndroid Build Coastguard Worker    a_dim = anchor.ndim
5309*da0073e9SAndroid Build Coastguard Worker    p_dim = positive.ndim
5310*da0073e9SAndroid Build Coastguard Worker    n_dim = negative.ndim
5311*da0073e9SAndroid Build Coastguard Worker    if not (a_dim == p_dim and p_dim == n_dim):
5312*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
5313*da0073e9SAndroid Build Coastguard Worker            f"The anchor, positive, and negative tensors are expected to have "
5314*da0073e9SAndroid Build Coastguard Worker            f"the same number of dimensions, but got: anchor {a_dim}D, "
5315*da0073e9SAndroid Build Coastguard Worker            f"positive {p_dim}D, and negative {n_dim}D inputs"
5316*da0073e9SAndroid Build Coastguard Worker        )
5317*da0073e9SAndroid Build Coastguard Worker
5318*da0073e9SAndroid Build Coastguard Worker    # Calculate loss
5319*da0073e9SAndroid Build Coastguard Worker    if distance_function is None:
5320*da0073e9SAndroid Build Coastguard Worker        distance_function = torch.pairwise_distance
5321*da0073e9SAndroid Build Coastguard Worker
5322*da0073e9SAndroid Build Coastguard Worker    dist_pos = distance_function(anchor, positive)
5323*da0073e9SAndroid Build Coastguard Worker    dist_neg = distance_function(anchor, negative)
5324*da0073e9SAndroid Build Coastguard Worker    # The distance swap is described in the paper "Learning shallow
5325*da0073e9SAndroid Build Coastguard Worker    # convolutional feature descriptors with triplet losses" by V. Balntas, E.
5326*da0073e9SAndroid Build Coastguard Worker    # Riba et al.  If True, and if the positive example is closer to the
5327*da0073e9SAndroid Build Coastguard Worker    # negative example than the anchor is, swaps the positive example and the
5328*da0073e9SAndroid Build Coastguard Worker    # anchor in the loss computation.
5329*da0073e9SAndroid Build Coastguard Worker    if swap:
5330*da0073e9SAndroid Build Coastguard Worker        dist_swap = distance_function(positive, negative)
5331*da0073e9SAndroid Build Coastguard Worker        dist_neg = torch.minimum(dist_neg, dist_swap)
5332*da0073e9SAndroid Build Coastguard Worker    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
5333*da0073e9SAndroid Build Coastguard Worker
5334*da0073e9SAndroid Build Coastguard Worker    # Apply reduction
5335*da0073e9SAndroid Build Coastguard Worker    if reduction == "sum":
5336*da0073e9SAndroid Build Coastguard Worker        return torch.sum(loss)
5337*da0073e9SAndroid Build Coastguard Worker    elif reduction == "mean":
5338*da0073e9SAndroid Build Coastguard Worker        return torch.mean(loss)
5339*da0073e9SAndroid Build Coastguard Worker    else:  # reduction == "none"
5340*da0073e9SAndroid Build Coastguard Worker        return loss
5341*da0073e9SAndroid Build Coastguard Worker
5342*da0073e9SAndroid Build Coastguard Worker
5343*da0073e9SAndroid Build Coastguard Workerdef normalize(
5344*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
5345*da0073e9SAndroid Build Coastguard Worker    p: float = 2.0,
5346*da0073e9SAndroid Build Coastguard Worker    dim: int = 1,
5347*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-12,
5348*da0073e9SAndroid Build Coastguard Worker    out: Optional[Tensor] = None,
5349*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5350*da0073e9SAndroid Build Coastguard Worker    r"""Perform :math:`L_p` normalization of inputs over specified dimension.
5351*da0073e9SAndroid Build Coastguard Worker
5352*da0073e9SAndroid Build Coastguard Worker    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
5353*da0073e9SAndroid Build Coastguard Worker    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as
5354*da0073e9SAndroid Build Coastguard Worker
5355*da0073e9SAndroid Build Coastguard Worker    .. math::
5356*da0073e9SAndroid Build Coastguard Worker        v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
5357*da0073e9SAndroid Build Coastguard Worker
5358*da0073e9SAndroid Build Coastguard Worker    With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.
5359*da0073e9SAndroid Build Coastguard Worker
5360*da0073e9SAndroid Build Coastguard Worker    Args:
5361*da0073e9SAndroid Build Coastguard Worker        input: input tensor of any shape
5362*da0073e9SAndroid Build Coastguard Worker        p (float): the exponent value in the norm formulation. Default: 2
5363*da0073e9SAndroid Build Coastguard Worker        dim (int or tuple of ints): the dimension to reduce. Default: 1
5364*da0073e9SAndroid Build Coastguard Worker        eps (float): small value to avoid division by zero. Default: 1e-12
5365*da0073e9SAndroid Build Coastguard Worker        out (Tensor, optional): the output tensor. If :attr:`out` is used, this
5366*da0073e9SAndroid Build Coastguard Worker                                operation won't be differentiable.
5367*da0073e9SAndroid Build Coastguard Worker    """
5368*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(input, out):
5369*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5370*da0073e9SAndroid Build Coastguard Worker            normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out
5371*da0073e9SAndroid Build Coastguard Worker        )
5372*da0073e9SAndroid Build Coastguard Worker    if out is None:
5373*da0073e9SAndroid Build Coastguard Worker        denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
5374*da0073e9SAndroid Build Coastguard Worker        return input / denom
5375*da0073e9SAndroid Build Coastguard Worker    else:
5376*da0073e9SAndroid Build Coastguard Worker        denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input)
5377*da0073e9SAndroid Build Coastguard Worker        return torch.div(input, denom, out=out)
5378*da0073e9SAndroid Build Coastguard Worker
5379*da0073e9SAndroid Build Coastguard Worker
5380*da0073e9SAndroid Build Coastguard Workerdef assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None:
5381*da0073e9SAndroid Build Coastguard Worker    assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
5382*da0073e9SAndroid Build Coastguard Worker
5383*da0073e9SAndroid Build Coastguard Worker
5384*da0073e9SAndroid Build Coastguard Workerdef unfold(
5385*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
5386*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
5387*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList2[int] = 1,
5388*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList2[int] = 0,
5389*da0073e9SAndroid Build Coastguard Worker    stride: BroadcastingList2[int] = 1,
5390*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5391*da0073e9SAndroid Build Coastguard Worker    r"""Extract sliding local blocks from a batched input tensor.
5392*da0073e9SAndroid Build Coastguard Worker
5393*da0073e9SAndroid Build Coastguard Worker    .. warning::
5394*da0073e9SAndroid Build Coastguard Worker        Currently, only 4-D input tensors (batched image-like tensors) are
5395*da0073e9SAndroid Build Coastguard Worker        supported.
5396*da0073e9SAndroid Build Coastguard Worker
5397*da0073e9SAndroid Build Coastguard Worker    .. warning::
5398*da0073e9SAndroid Build Coastguard Worker
5399*da0073e9SAndroid Build Coastguard Worker        More than one element of the unfolded tensor may refer to a single
5400*da0073e9SAndroid Build Coastguard Worker        memory location. As a result, in-place operations (especially ones that
5401*da0073e9SAndroid Build Coastguard Worker        are vectorized) may result in incorrect behavior. If you need to write
5402*da0073e9SAndroid Build Coastguard Worker        to the tensor, please clone it first.
5403*da0073e9SAndroid Build Coastguard Worker
5404*da0073e9SAndroid Build Coastguard Worker
5405*da0073e9SAndroid Build Coastguard Worker    See :class:`torch.nn.Unfold` for details
5406*da0073e9SAndroid Build Coastguard Worker    """
5407*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
5408*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5409*da0073e9SAndroid Build Coastguard Worker            unfold,
5410*da0073e9SAndroid Build Coastguard Worker            (input,),
5411*da0073e9SAndroid Build Coastguard Worker            input,
5412*da0073e9SAndroid Build Coastguard Worker            kernel_size,
5413*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
5414*da0073e9SAndroid Build Coastguard Worker            padding=padding,
5415*da0073e9SAndroid Build Coastguard Worker            stride=stride,
5416*da0073e9SAndroid Build Coastguard Worker        )
5417*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.im2col(
5418*da0073e9SAndroid Build Coastguard Worker        input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
5419*da0073e9SAndroid Build Coastguard Worker    )
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker
5422*da0073e9SAndroid Build Coastguard Workerdef fold(
5423*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
5424*da0073e9SAndroid Build Coastguard Worker    output_size: BroadcastingList2[int],
5425*da0073e9SAndroid Build Coastguard Worker    kernel_size: BroadcastingList2[int],
5426*da0073e9SAndroid Build Coastguard Worker    dilation: BroadcastingList2[int] = 1,
5427*da0073e9SAndroid Build Coastguard Worker    padding: BroadcastingList2[int] = 0,
5428*da0073e9SAndroid Build Coastguard Worker    stride: BroadcastingList2[int] = 1,
5429*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
5430*da0073e9SAndroid Build Coastguard Worker    r"""Combine an array of sliding local blocks into a large containing tensor.
5431*da0073e9SAndroid Build Coastguard Worker
5432*da0073e9SAndroid Build Coastguard Worker    .. warning::
5433*da0073e9SAndroid Build Coastguard Worker        Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker    See :class:`torch.nn.Fold` for details
5436*da0073e9SAndroid Build Coastguard Worker    """
5437*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
5438*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5439*da0073e9SAndroid Build Coastguard Worker            fold,
5440*da0073e9SAndroid Build Coastguard Worker            (input,),
5441*da0073e9SAndroid Build Coastguard Worker            input,
5442*da0073e9SAndroid Build Coastguard Worker            output_size,
5443*da0073e9SAndroid Build Coastguard Worker            kernel_size,
5444*da0073e9SAndroid Build Coastguard Worker            dilation=dilation,
5445*da0073e9SAndroid Build Coastguard Worker            padding=padding,
5446*da0073e9SAndroid Build Coastguard Worker            stride=stride,
5447*da0073e9SAndroid Build Coastguard Worker        )
5448*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.col2im(
5449*da0073e9SAndroid Build Coastguard Worker        input,
5450*da0073e9SAndroid Build Coastguard Worker        _pair(output_size),
5451*da0073e9SAndroid Build Coastguard Worker        _pair(kernel_size),
5452*da0073e9SAndroid Build Coastguard Worker        _pair(dilation),
5453*da0073e9SAndroid Build Coastguard Worker        _pair(padding),
5454*da0073e9SAndroid Build Coastguard Worker        _pair(stride),
5455*da0073e9SAndroid Build Coastguard Worker    )
5456*da0073e9SAndroid Build Coastguard Worker
5457*da0073e9SAndroid Build Coastguard Worker
5458*da0073e9SAndroid Build Coastguard Worker#
5459*da0073e9SAndroid Build Coastguard Worker# multihead attention
5460*da0073e9SAndroid Build Coastguard Worker#
5461*da0073e9SAndroid Build Coastguard Worker
5462*da0073e9SAndroid Build Coastguard Worker
5463*da0073e9SAndroid Build Coastguard Workerdef _in_projection_packed(
5464*da0073e9SAndroid Build Coastguard Worker    q: Tensor,
5465*da0073e9SAndroid Build Coastguard Worker    k: Tensor,
5466*da0073e9SAndroid Build Coastguard Worker    v: Tensor,
5467*da0073e9SAndroid Build Coastguard Worker    w: Tensor,
5468*da0073e9SAndroid Build Coastguard Worker    b: Optional[Tensor] = None,
5469*da0073e9SAndroid Build Coastguard Worker) -> List[Tensor]:
5470*da0073e9SAndroid Build Coastguard Worker    r"""Perform the in-projection step of the attention operation, using packed weights.
5471*da0073e9SAndroid Build Coastguard Worker
5472*da0073e9SAndroid Build Coastguard Worker    Output is a triple containing projection tensors for query, key and value.
5473*da0073e9SAndroid Build Coastguard Worker
5474*da0073e9SAndroid Build Coastguard Worker    Args:
5475*da0073e9SAndroid Build Coastguard Worker        q, k, v: query, key and value tensors to be projected. For self-attention,
5476*da0073e9SAndroid Build Coastguard Worker            these are typically the same tensor; for encoder-decoder attention,
5477*da0073e9SAndroid Build Coastguard Worker            k and v are typically the same tensor. (We take advantage of these
5478*da0073e9SAndroid Build Coastguard Worker            identities for performance if they are present.) Regardless, q, k and v
5479*da0073e9SAndroid Build Coastguard Worker            must share a common embedding dimension; otherwise their shapes may vary.
5480*da0073e9SAndroid Build Coastguard Worker        w: projection weights for q, k and v, packed into a single tensor. Weights
5481*da0073e9SAndroid Build Coastguard Worker            are packed along dimension 0, in q, k, v order.
5482*da0073e9SAndroid Build Coastguard Worker        b: optional projection biases for q, k and v, packed into a single tensor
5483*da0073e9SAndroid Build Coastguard Worker            in q, k, v order.
5484*da0073e9SAndroid Build Coastguard Worker
5485*da0073e9SAndroid Build Coastguard Worker    Shape:
5486*da0073e9SAndroid Build Coastguard Worker        Inputs:
5487*da0073e9SAndroid Build Coastguard Worker        - q: :math:`(..., E)` where E is the embedding dimension
5488*da0073e9SAndroid Build Coastguard Worker        - k: :math:`(..., E)` where E is the embedding dimension
5489*da0073e9SAndroid Build Coastguard Worker        - v: :math:`(..., E)` where E is the embedding dimension
5490*da0073e9SAndroid Build Coastguard Worker        - w: :math:`(E * 3, E)` where E is the embedding dimension
5491*da0073e9SAndroid Build Coastguard Worker        - b: :math:`E * 3` where E is the embedding dimension
5492*da0073e9SAndroid Build Coastguard Worker
5493*da0073e9SAndroid Build Coastguard Worker        Output:
5494*da0073e9SAndroid Build Coastguard Worker        - in output list :math:`[q', k', v']`, each output tensor will have the
5495*da0073e9SAndroid Build Coastguard Worker            same shape as the corresponding input tensor.
5496*da0073e9SAndroid Build Coastguard Worker    """
5497*da0073e9SAndroid Build Coastguard Worker    E = q.size(-1)
5498*da0073e9SAndroid Build Coastguard Worker    if k is v:
5499*da0073e9SAndroid Build Coastguard Worker        if q is k:
5500*da0073e9SAndroid Build Coastguard Worker            # self-attention
5501*da0073e9SAndroid Build Coastguard Worker            proj = linear(q, w, b)
5502*da0073e9SAndroid Build Coastguard Worker            # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
5503*da0073e9SAndroid Build Coastguard Worker            proj = (
5504*da0073e9SAndroid Build Coastguard Worker                proj.unflatten(-1, (3, E))
5505*da0073e9SAndroid Build Coastguard Worker                .unsqueeze(0)
5506*da0073e9SAndroid Build Coastguard Worker                .transpose(0, -2)
5507*da0073e9SAndroid Build Coastguard Worker                .squeeze(-2)
5508*da0073e9SAndroid Build Coastguard Worker                .contiguous()
5509*da0073e9SAndroid Build Coastguard Worker            )
5510*da0073e9SAndroid Build Coastguard Worker            return proj[0], proj[1], proj[2]
5511*da0073e9SAndroid Build Coastguard Worker        else:
5512*da0073e9SAndroid Build Coastguard Worker            # encoder-decoder attention
5513*da0073e9SAndroid Build Coastguard Worker            w_q, w_kv = w.split([E, E * 2])
5514*da0073e9SAndroid Build Coastguard Worker            if b is None:
5515*da0073e9SAndroid Build Coastguard Worker                b_q = b_kv = None
5516*da0073e9SAndroid Build Coastguard Worker            else:
5517*da0073e9SAndroid Build Coastguard Worker                b_q, b_kv = b.split([E, E * 2])
5518*da0073e9SAndroid Build Coastguard Worker            q_proj = linear(q, w_q, b_q)
5519*da0073e9SAndroid Build Coastguard Worker            kv_proj = linear(k, w_kv, b_kv)
5520*da0073e9SAndroid Build Coastguard Worker            # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
5521*da0073e9SAndroid Build Coastguard Worker            kv_proj = (
5522*da0073e9SAndroid Build Coastguard Worker                kv_proj.unflatten(-1, (2, E))
5523*da0073e9SAndroid Build Coastguard Worker                .unsqueeze(0)
5524*da0073e9SAndroid Build Coastguard Worker                .transpose(0, -2)
5525*da0073e9SAndroid Build Coastguard Worker                .squeeze(-2)
5526*da0073e9SAndroid Build Coastguard Worker                .contiguous()
5527*da0073e9SAndroid Build Coastguard Worker            )
5528*da0073e9SAndroid Build Coastguard Worker            return (q_proj, kv_proj[0], kv_proj[1])
5529*da0073e9SAndroid Build Coastguard Worker    else:
5530*da0073e9SAndroid Build Coastguard Worker        w_q, w_k, w_v = w.chunk(3)
5531*da0073e9SAndroid Build Coastguard Worker        if b is None:
5532*da0073e9SAndroid Build Coastguard Worker            b_q = b_k = b_v = None
5533*da0073e9SAndroid Build Coastguard Worker        else:
5534*da0073e9SAndroid Build Coastguard Worker            b_q, b_k, b_v = b.chunk(3)
5535*da0073e9SAndroid Build Coastguard Worker        return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
5536*da0073e9SAndroid Build Coastguard Worker
5537*da0073e9SAndroid Build Coastguard Worker
5538*da0073e9SAndroid Build Coastguard Workerdef _in_projection(
5539*da0073e9SAndroid Build Coastguard Worker    q: Tensor,
5540*da0073e9SAndroid Build Coastguard Worker    k: Tensor,
5541*da0073e9SAndroid Build Coastguard Worker    v: Tensor,
5542*da0073e9SAndroid Build Coastguard Worker    w_q: Tensor,
5543*da0073e9SAndroid Build Coastguard Worker    w_k: Tensor,
5544*da0073e9SAndroid Build Coastguard Worker    w_v: Tensor,
5545*da0073e9SAndroid Build Coastguard Worker    b_q: Optional[Tensor] = None,
5546*da0073e9SAndroid Build Coastguard Worker    b_k: Optional[Tensor] = None,
5547*da0073e9SAndroid Build Coastguard Worker    b_v: Optional[Tensor] = None,
5548*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]:
5549*da0073e9SAndroid Build Coastguard Worker    r"""Perform the in-projection step of the attention operation.
5550*da0073e9SAndroid Build Coastguard Worker
5551*da0073e9SAndroid Build Coastguard Worker    This is simply a triple of linear projections,
5552*da0073e9SAndroid Build Coastguard Worker    with shape constraints on the weights which
5553*da0073e9SAndroid Build Coastguard Worker    ensure embedding dimension uniformity in the projected outputs.
5554*da0073e9SAndroid Build Coastguard Worker    Output is a triple containing projection tensors for query, key and value.
5555*da0073e9SAndroid Build Coastguard Worker
5556*da0073e9SAndroid Build Coastguard Worker    Args:
5557*da0073e9SAndroid Build Coastguard Worker        q, k, v: query, key and value tensors to be projected.
5558*da0073e9SAndroid Build Coastguard Worker        w_q, w_k, w_v: weights for q, k and v, respectively.
5559*da0073e9SAndroid Build Coastguard Worker        b_q, b_k, b_v: optional biases for q, k and v, respectively.
5560*da0073e9SAndroid Build Coastguard Worker
5561*da0073e9SAndroid Build Coastguard Worker    Shape:
5562*da0073e9SAndroid Build Coastguard Worker        Inputs:
5563*da0073e9SAndroid Build Coastguard Worker        - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
5564*da0073e9SAndroid Build Coastguard Worker            number of leading dimensions.
5565*da0073e9SAndroid Build Coastguard Worker        - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
5566*da0073e9SAndroid Build Coastguard Worker            number of leading dimensions.
5567*da0073e9SAndroid Build Coastguard Worker        - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
5568*da0073e9SAndroid Build Coastguard Worker            number of leading dimensions.
5569*da0073e9SAndroid Build Coastguard Worker        - w_q: :math:`(Eq, Eq)`
5570*da0073e9SAndroid Build Coastguard Worker        - w_k: :math:`(Eq, Ek)`
5571*da0073e9SAndroid Build Coastguard Worker        - w_v: :math:`(Eq, Ev)`
5572*da0073e9SAndroid Build Coastguard Worker        - b_q: :math:`(Eq)`
5573*da0073e9SAndroid Build Coastguard Worker        - b_k: :math:`(Eq)`
5574*da0073e9SAndroid Build Coastguard Worker        - b_v: :math:`(Eq)`
5575*da0073e9SAndroid Build Coastguard Worker
5576*da0073e9SAndroid Build Coastguard Worker        Output: in output triple :math:`(q', k', v')`,
5577*da0073e9SAndroid Build Coastguard Worker         - q': :math:`[Qdims..., Eq]`
5578*da0073e9SAndroid Build Coastguard Worker         - k': :math:`[Kdims..., Eq]`
5579*da0073e9SAndroid Build Coastguard Worker         - v': :math:`[Vdims..., Eq]`
5580*da0073e9SAndroid Build Coastguard Worker
5581*da0073e9SAndroid Build Coastguard Worker    """
5582*da0073e9SAndroid Build Coastguard Worker    Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
5583*da0073e9SAndroid Build Coastguard Worker    assert w_q.shape == (
5584*da0073e9SAndroid Build Coastguard Worker        Eq,
5585*da0073e9SAndroid Build Coastguard Worker        Eq,
5586*da0073e9SAndroid Build Coastguard Worker    ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
5587*da0073e9SAndroid Build Coastguard Worker    assert w_k.shape == (
5588*da0073e9SAndroid Build Coastguard Worker        Eq,
5589*da0073e9SAndroid Build Coastguard Worker        Ek,
5590*da0073e9SAndroid Build Coastguard Worker    ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
5591*da0073e9SAndroid Build Coastguard Worker    assert w_v.shape == (
5592*da0073e9SAndroid Build Coastguard Worker        Eq,
5593*da0073e9SAndroid Build Coastguard Worker        Ev,
5594*da0073e9SAndroid Build Coastguard Worker    ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
5595*da0073e9SAndroid Build Coastguard Worker    assert b_q is None or b_q.shape == (
5596*da0073e9SAndroid Build Coastguard Worker        Eq,
5597*da0073e9SAndroid Build Coastguard Worker    ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
5598*da0073e9SAndroid Build Coastguard Worker    assert b_k is None or b_k.shape == (
5599*da0073e9SAndroid Build Coastguard Worker        Eq,
5600*da0073e9SAndroid Build Coastguard Worker    ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
5601*da0073e9SAndroid Build Coastguard Worker    assert b_v is None or b_v.shape == (
5602*da0073e9SAndroid Build Coastguard Worker        Eq,
5603*da0073e9SAndroid Build Coastguard Worker    ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
5604*da0073e9SAndroid Build Coastguard Worker    return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
5605*da0073e9SAndroid Build Coastguard Worker
5606*da0073e9SAndroid Build Coastguard Worker
5607*da0073e9SAndroid Build Coastguard Workerscaled_dot_product_attention = _add_docstr(
5608*da0073e9SAndroid Build Coastguard Worker    torch._C._nn.scaled_dot_product_attention,
5609*da0073e9SAndroid Build Coastguard Worker    r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
5610*da0073e9SAndroid Build Coastguard Worker        is_causal=False, scale=None, enable_gqa=False) -> Tensor:
5611*da0073e9SAndroid Build Coastguard Worker
5612*da0073e9SAndroid Build Coastguard Worker    Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
5613*da0073e9SAndroid Build Coastguard Worker    and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
5614*da0073e9SAndroid Build Coastguard Worker    specified as a keyword argument.
5615*da0073e9SAndroid Build Coastguard Worker
5616*da0073e9SAndroid Build Coastguard Worker    .. code-block:: python
5617*da0073e9SAndroid Build Coastguard Worker
5618*da0073e9SAndroid Build Coastguard Worker        # Efficient implementation equivalent to the following:
5619*da0073e9SAndroid Build Coastguard Worker        def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
5620*da0073e9SAndroid Build Coastguard Worker                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
5621*da0073e9SAndroid Build Coastguard Worker            L, S = query.size(-2), key.size(-2)
5622*da0073e9SAndroid Build Coastguard Worker            scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
5623*da0073e9SAndroid Build Coastguard Worker            attn_bias = torch.zeros(L, S, dtype=query.dtype)
5624*da0073e9SAndroid Build Coastguard Worker            if is_causal:
5625*da0073e9SAndroid Build Coastguard Worker                assert attn_mask is None
5626*da0073e9SAndroid Build Coastguard Worker                temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
5627*da0073e9SAndroid Build Coastguard Worker                attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
5628*da0073e9SAndroid Build Coastguard Worker                attn_bias.to(query.dtype)
5629*da0073e9SAndroid Build Coastguard Worker
5630*da0073e9SAndroid Build Coastguard Worker            if attn_mask is not None:
5631*da0073e9SAndroid Build Coastguard Worker                if attn_mask.dtype == torch.bool:
5632*da0073e9SAndroid Build Coastguard Worker                    attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
5633*da0073e9SAndroid Build Coastguard Worker                else:
5634*da0073e9SAndroid Build Coastguard Worker                    attn_bias += attn_mask
5635*da0073e9SAndroid Build Coastguard Worker
5636*da0073e9SAndroid Build Coastguard Worker            if enable_gqa:
5637*da0073e9SAndroid Build Coastguard Worker                key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
5638*da0073e9SAndroid Build Coastguard Worker                value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
5639*da0073e9SAndroid Build Coastguard Worker
5640*da0073e9SAndroid Build Coastguard Worker            attn_weight = query @ key.transpose(-2, -1) * scale_factor
5641*da0073e9SAndroid Build Coastguard Worker            attn_weight += attn_bias
5642*da0073e9SAndroid Build Coastguard Worker            attn_weight = torch.softmax(attn_weight, dim=-1)
5643*da0073e9SAndroid Build Coastguard Worker            attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
5644*da0073e9SAndroid Build Coastguard Worker            return attn_weight @ value
5645*da0073e9SAndroid Build Coastguard Worker
5646*da0073e9SAndroid Build Coastguard Worker    .. warning::
5647*da0073e9SAndroid Build Coastguard Worker        This function is beta and subject to change.
5648*da0073e9SAndroid Build Coastguard Worker
5649*da0073e9SAndroid Build Coastguard Worker    .. warning::
5650*da0073e9SAndroid Build Coastguard Worker        This function always applies dropout according to the specified ``dropout_p`` argument.
5651*da0073e9SAndroid Build Coastguard Worker        To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
5652*da0073e9SAndroid Build Coastguard Worker        that makes the function call is not in training mode.
5653*da0073e9SAndroid Build Coastguard Worker
5654*da0073e9SAndroid Build Coastguard Worker        For example:
5655*da0073e9SAndroid Build Coastguard Worker
5656*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
5657*da0073e9SAndroid Build Coastguard Worker
5658*da0073e9SAndroid Build Coastguard Worker            class MyModel(nn.Module):
5659*da0073e9SAndroid Build Coastguard Worker                def __init__(self, p=0.5):
5660*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
5661*da0073e9SAndroid Build Coastguard Worker                    self.p = p
5662*da0073e9SAndroid Build Coastguard Worker
5663*da0073e9SAndroid Build Coastguard Worker                def forward(self, ...):
5664*da0073e9SAndroid Build Coastguard Worker                    return F.scaled_dot_product_attention(...,
5665*da0073e9SAndroid Build Coastguard Worker                        dropout_p=(self.p if self.training else 0.0))
5666*da0073e9SAndroid Build Coastguard Worker
5667*da0073e9SAndroid Build Coastguard Worker    Note:
5668*da0073e9SAndroid Build Coastguard Worker
5669*da0073e9SAndroid Build Coastguard Worker        There are currently three supported implementations of scaled dot product attention:
5670*da0073e9SAndroid Build Coastguard Worker
5671*da0073e9SAndroid Build Coastguard Worker            - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_
5672*da0073e9SAndroid Build Coastguard Worker            - `Memory-Efficient Attention`_
5673*da0073e9SAndroid Build Coastguard Worker            - A PyTorch implementation defined in C++ matching the above formulation
5674*da0073e9SAndroid Build Coastguard Worker
5675*da0073e9SAndroid Build Coastguard Worker        The function may call optimized kernels for improved performance when using the CUDA backend.
5676*da0073e9SAndroid Build Coastguard Worker        For all other backends, the PyTorch implementation will be used.
5677*da0073e9SAndroid Build Coastguard Worker
5678*da0073e9SAndroid Build Coastguard Worker        All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
5679*da0073e9SAndroid Build Coastguard Worker        most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
5680*da0073e9SAndroid Build Coastguard Worker        is used, the following functions are provided for enabling and disabling implementations.
5681*da0073e9SAndroid Build Coastguard Worker        The context manager is the preferred mechanism:
5682*da0073e9SAndroid Build Coastguard Worker
5683*da0073e9SAndroid Build Coastguard Worker            - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations.
5684*da0073e9SAndroid Build Coastguard Worker            - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention.
5685*da0073e9SAndroid Build Coastguard Worker            - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables  Memory-Efficient Attention.
5686*da0073e9SAndroid Build Coastguard Worker            - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables  the PyTorch C++ implementation.
5687*da0073e9SAndroid Build Coastguard Worker
5688*da0073e9SAndroid Build Coastguard Worker        Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
5689*da0073e9SAndroid Build Coastguard Worker        disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`.
5690*da0073e9SAndroid Build Coastguard Worker        In the event that a fused implementation is not available, a warning will be raised with the
5691*da0073e9SAndroid Build Coastguard Worker        reasons why the fused implementation cannot run.
5692*da0073e9SAndroid Build Coastguard Worker
5693*da0073e9SAndroid Build Coastguard Worker        Due to the nature of fusing floating point operations, the output of this function may be different
5694*da0073e9SAndroid Build Coastguard Worker        depending on what backend kernel is chosen.
5695*da0073e9SAndroid Build Coastguard Worker        The c++ implementation supports torch.float64 and can be used when higher precision is required.
5696*da0073e9SAndroid Build Coastguard Worker        For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
5697*da0073e9SAndroid Build Coastguard Worker    For more information please see :doc:`/notes/numerical_accuracy`
5698*da0073e9SAndroid Build Coastguard Worker
5699*da0073e9SAndroid Build Coastguard Worker        Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
5700*da0073e9SAndroid Build Coastguard Worker        and math kernel on CUDA tensor, and does not support Nested tensor.
5701*da0073e9SAndroid Build Coastguard Worker        Constraints for GQA:
5702*da0073e9SAndroid Build Coastguard Worker
5703*da0073e9SAndroid Build Coastguard Worker            - number_of_heads_query % number_of_heads_key_value == 0 and,
5704*da0073e9SAndroid Build Coastguard Worker            - number_of_heads_key == number_of_heads_value
5705*da0073e9SAndroid Build Coastguard Worker
5706*da0073e9SAndroid Build Coastguard Worker    Note:
5707*da0073e9SAndroid Build Coastguard Worker
5708*da0073e9SAndroid Build Coastguard Worker        {cudnn_reproducibility_note}
5709*da0073e9SAndroid Build Coastguard Worker    """.format(
5710*da0073e9SAndroid Build Coastguard Worker        **reproducibility_notes
5711*da0073e9SAndroid Build Coastguard Worker    )
5712*da0073e9SAndroid Build Coastguard Worker    + r"""
5713*da0073e9SAndroid Build Coastguard Worker    Args:
5714*da0073e9SAndroid Build Coastguard Worker        query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
5715*da0073e9SAndroid Build Coastguard Worker        key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
5716*da0073e9SAndroid Build Coastguard Worker        value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
5717*da0073e9SAndroid Build Coastguard Worker        attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
5718*da0073e9SAndroid Build Coastguard Worker            which is :math:`(N,..., L, S)`. Two types of masks are supported.
5719*da0073e9SAndroid Build Coastguard Worker            A boolean mask where a value of True indicates that the element *should* take part in attention.
5720*da0073e9SAndroid Build Coastguard Worker            A float mask of the same type as query, key, value that is added to the attention score.
5721*da0073e9SAndroid Build Coastguard Worker        dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
5722*da0073e9SAndroid Build Coastguard Worker        is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
5723*da0073e9SAndroid Build Coastguard Worker            square matrix. The attention masking has the form of the upper left causal bias due to the alignment
5724*da0073e9SAndroid Build Coastguard Worker            (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
5725*da0073e9SAndroid Build Coastguard Worker            An error is thrown if both attn_mask and is_causal are set.
5726*da0073e9SAndroid Build Coastguard Worker        scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
5727*da0073e9SAndroid Build Coastguard Worker            to :math:`\frac{1}{\sqrt{E}}`.
5728*da0073e9SAndroid Build Coastguard Worker        enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
5729*da0073e9SAndroid Build Coastguard Worker
5730*da0073e9SAndroid Build Coastguard Worker    Returns:
5731*da0073e9SAndroid Build Coastguard Worker        output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
5732*da0073e9SAndroid Build Coastguard Worker
5733*da0073e9SAndroid Build Coastguard Worker    Shape legend:
5734*da0073e9SAndroid Build Coastguard Worker        - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
5735*da0073e9SAndroid Build Coastguard Worker        - :math:`S: \text{Source sequence length}`
5736*da0073e9SAndroid Build Coastguard Worker        - :math:`L: \text{Target sequence length}`
5737*da0073e9SAndroid Build Coastguard Worker        - :math:`E: \text{Embedding dimension of the query and key}`
5738*da0073e9SAndroid Build Coastguard Worker        - :math:`Ev: \text{Embedding dimension of the value}`
5739*da0073e9SAndroid Build Coastguard Worker        - :math:`Hq: \text{Number of heads of query}`
5740*da0073e9SAndroid Build Coastguard Worker        - :math:`H: \text{Number of heads of key and value}`
5741*da0073e9SAndroid Build Coastguard Worker
5742*da0073e9SAndroid Build Coastguard Worker    Examples:
5743*da0073e9SAndroid Build Coastguard Worker
5744*da0073e9SAndroid Build Coastguard Worker        >>> # Optionally use the context manager to ensure one of the fused kernels is run
5745*da0073e9SAndroid Build Coastguard Worker        >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5746*da0073e9SAndroid Build Coastguard Worker        >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5747*da0073e9SAndroid Build Coastguard Worker        >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5748*da0073e9SAndroid Build Coastguard Worker        >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
5749*da0073e9SAndroid Build Coastguard Worker        >>>     F.scaled_dot_product_attention(query,key,value)
5750*da0073e9SAndroid Build Coastguard Worker
5751*da0073e9SAndroid Build Coastguard Worker
5752*da0073e9SAndroid Build Coastguard Worker        >>> # Sample for GQA for llama3
5753*da0073e9SAndroid Build Coastguard Worker        >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
5754*da0073e9SAndroid Build Coastguard Worker        >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5755*da0073e9SAndroid Build Coastguard Worker        >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5756*da0073e9SAndroid Build Coastguard Worker        >>> with sdpa_kernel(backends=[SDPBackend.MATH]):
5757*da0073e9SAndroid Build Coastguard Worker        >>>     F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
5758*da0073e9SAndroid Build Coastguard Worker
5759*da0073e9SAndroid Build Coastguard Worker
5760*da0073e9SAndroid Build Coastguard Worker    .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
5761*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/2307.08691
5762*da0073e9SAndroid Build Coastguard Worker    .. _Memory-Efficient Attention:
5763*da0073e9SAndroid Build Coastguard Worker        https://github.com/facebookresearch/xformers
5764*da0073e9SAndroid Build Coastguard Worker    .. _Grouped-Query Attention:
5765*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/pdf/2305.13245
5766*da0073e9SAndroid Build Coastguard Worker    """,
5767*da0073e9SAndroid Build Coastguard Worker)
5768*da0073e9SAndroid Build Coastguard Worker
5769*da0073e9SAndroid Build Coastguard Worker
5770*da0073e9SAndroid Build Coastguard Workerdef _mha_shape_check(
5771*da0073e9SAndroid Build Coastguard Worker    query: Tensor,
5772*da0073e9SAndroid Build Coastguard Worker    key: Tensor,
5773*da0073e9SAndroid Build Coastguard Worker    value: Tensor,
5774*da0073e9SAndroid Build Coastguard Worker    key_padding_mask: Optional[Tensor],
5775*da0073e9SAndroid Build Coastguard Worker    attn_mask: Optional[Tensor],
5776*da0073e9SAndroid Build Coastguard Worker    num_heads: int,
5777*da0073e9SAndroid Build Coastguard Worker):
5778*da0073e9SAndroid Build Coastguard Worker    # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
5779*da0073e9SAndroid Build Coastguard Worker    # and returns if the input is batched or not.
5780*da0073e9SAndroid Build Coastguard Worker    # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
5781*da0073e9SAndroid Build Coastguard Worker
5782*da0073e9SAndroid Build Coastguard Worker    # Shape check.
5783*da0073e9SAndroid Build Coastguard Worker    if query.dim() == 3:
5784*da0073e9SAndroid Build Coastguard Worker        # Batched Inputs
5785*da0073e9SAndroid Build Coastguard Worker        is_batched = True
5786*da0073e9SAndroid Build Coastguard Worker        assert key.dim() == 3 and value.dim() == 3, (
5787*da0073e9SAndroid Build Coastguard Worker            "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
5788*da0073e9SAndroid Build Coastguard Worker            f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
5789*da0073e9SAndroid Build Coastguard Worker        )
5790*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
5791*da0073e9SAndroid Build Coastguard Worker            assert key_padding_mask.dim() == 2, (
5792*da0073e9SAndroid Build Coastguard Worker                "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
5793*da0073e9SAndroid Build Coastguard Worker                f" but found {key_padding_mask.dim()}-D tensor instead"
5794*da0073e9SAndroid Build Coastguard Worker            )
5795*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
5796*da0073e9SAndroid Build Coastguard Worker            assert attn_mask.dim() in (2, 3), (
5797*da0073e9SAndroid Build Coastguard Worker                "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
5798*da0073e9SAndroid Build Coastguard Worker                f" but found {attn_mask.dim()}-D tensor instead"
5799*da0073e9SAndroid Build Coastguard Worker            )
5800*da0073e9SAndroid Build Coastguard Worker    elif query.dim() == 2:
5801*da0073e9SAndroid Build Coastguard Worker        # Unbatched Inputs
5802*da0073e9SAndroid Build Coastguard Worker        is_batched = False
5803*da0073e9SAndroid Build Coastguard Worker        assert key.dim() == 2 and value.dim() == 2, (
5804*da0073e9SAndroid Build Coastguard Worker            "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
5805*da0073e9SAndroid Build Coastguard Worker            f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
5806*da0073e9SAndroid Build Coastguard Worker        )
5807*da0073e9SAndroid Build Coastguard Worker
5808*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
5809*da0073e9SAndroid Build Coastguard Worker            assert key_padding_mask.dim() == 1, (
5810*da0073e9SAndroid Build Coastguard Worker                "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
5811*da0073e9SAndroid Build Coastguard Worker                f" but found {key_padding_mask.dim()}-D tensor instead"
5812*da0073e9SAndroid Build Coastguard Worker            )
5813*da0073e9SAndroid Build Coastguard Worker
5814*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
5815*da0073e9SAndroid Build Coastguard Worker            assert attn_mask.dim() in (2, 3), (
5816*da0073e9SAndroid Build Coastguard Worker                "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
5817*da0073e9SAndroid Build Coastguard Worker                f" but found {attn_mask.dim()}-D tensor instead"
5818*da0073e9SAndroid Build Coastguard Worker            )
5819*da0073e9SAndroid Build Coastguard Worker            if attn_mask.dim() == 3:
5820*da0073e9SAndroid Build Coastguard Worker                expected_shape = (num_heads, query.shape[0], key.shape[0])
5821*da0073e9SAndroid Build Coastguard Worker                assert (
5822*da0073e9SAndroid Build Coastguard Worker                    attn_mask.shape == expected_shape
5823*da0073e9SAndroid Build Coastguard Worker                ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
5824*da0073e9SAndroid Build Coastguard Worker    else:
5825*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(
5826*da0073e9SAndroid Build Coastguard Worker            f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
5827*da0073e9SAndroid Build Coastguard Worker        )
5828*da0073e9SAndroid Build Coastguard Worker
5829*da0073e9SAndroid Build Coastguard Worker    return is_batched
5830*da0073e9SAndroid Build Coastguard Worker
5831*da0073e9SAndroid Build Coastguard Worker
5832*da0073e9SAndroid Build Coastguard Workerdef _canonical_mask(
5833*da0073e9SAndroid Build Coastguard Worker    mask: Optional[Tensor],
5834*da0073e9SAndroid Build Coastguard Worker    mask_name: str,
5835*da0073e9SAndroid Build Coastguard Worker    other_type: Optional[DType],
5836*da0073e9SAndroid Build Coastguard Worker    other_name: str,
5837*da0073e9SAndroid Build Coastguard Worker    target_type: DType,
5838*da0073e9SAndroid Build Coastguard Worker    check_other: bool = True,
5839*da0073e9SAndroid Build Coastguard Worker) -> Optional[Tensor]:
5840*da0073e9SAndroid Build Coastguard Worker    if mask is not None:
5841*da0073e9SAndroid Build Coastguard Worker        _mask_dtype = mask.dtype
5842*da0073e9SAndroid Build Coastguard Worker        _mask_is_float = torch.is_floating_point(mask)
5843*da0073e9SAndroid Build Coastguard Worker        if _mask_dtype != torch.bool and not _mask_is_float:
5844*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
5845*da0073e9SAndroid Build Coastguard Worker                f"only bool and floating types of {mask_name} are supported"
5846*da0073e9SAndroid Build Coastguard Worker            )
5847*da0073e9SAndroid Build Coastguard Worker        if check_other and other_type is not None:
5848*da0073e9SAndroid Build Coastguard Worker            if _mask_dtype != other_type:
5849*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
5850*da0073e9SAndroid Build Coastguard Worker                    f"Support for mismatched {mask_name} and {other_name} "
5851*da0073e9SAndroid Build Coastguard Worker                    "is deprecated. Use same type for both instead."
5852*da0073e9SAndroid Build Coastguard Worker                )
5853*da0073e9SAndroid Build Coastguard Worker        if not _mask_is_float:
5854*da0073e9SAndroid Build Coastguard Worker            mask = torch.zeros_like(mask, dtype=target_type).masked_fill_(
5855*da0073e9SAndroid Build Coastguard Worker                mask, float("-inf")
5856*da0073e9SAndroid Build Coastguard Worker            )
5857*da0073e9SAndroid Build Coastguard Worker    return mask
5858*da0073e9SAndroid Build Coastguard Worker
5859*da0073e9SAndroid Build Coastguard Worker
5860*da0073e9SAndroid Build Coastguard Workerdef _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
5861*da0073e9SAndroid Build Coastguard Worker    if input is None:
5862*da0073e9SAndroid Build Coastguard Worker        return None
5863*da0073e9SAndroid Build Coastguard Worker    elif isinstance(input, torch.Tensor):
5864*da0073e9SAndroid Build Coastguard Worker        return input.dtype
5865*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
5866*da0073e9SAndroid Build Coastguard Worker
5867*da0073e9SAndroid Build Coastguard Worker
5868*da0073e9SAndroid Build Coastguard Workerdef multi_head_attention_forward(
5869*da0073e9SAndroid Build Coastguard Worker    query: Tensor,
5870*da0073e9SAndroid Build Coastguard Worker    key: Tensor,
5871*da0073e9SAndroid Build Coastguard Worker    value: Tensor,
5872*da0073e9SAndroid Build Coastguard Worker    embed_dim_to_check: int,
5873*da0073e9SAndroid Build Coastguard Worker    num_heads: int,
5874*da0073e9SAndroid Build Coastguard Worker    in_proj_weight: Optional[Tensor],
5875*da0073e9SAndroid Build Coastguard Worker    in_proj_bias: Optional[Tensor],
5876*da0073e9SAndroid Build Coastguard Worker    bias_k: Optional[Tensor],
5877*da0073e9SAndroid Build Coastguard Worker    bias_v: Optional[Tensor],
5878*da0073e9SAndroid Build Coastguard Worker    add_zero_attn: bool,
5879*da0073e9SAndroid Build Coastguard Worker    dropout_p: float,
5880*da0073e9SAndroid Build Coastguard Worker    out_proj_weight: Tensor,
5881*da0073e9SAndroid Build Coastguard Worker    out_proj_bias: Optional[Tensor],
5882*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
5883*da0073e9SAndroid Build Coastguard Worker    key_padding_mask: Optional[Tensor] = None,
5884*da0073e9SAndroid Build Coastguard Worker    need_weights: bool = True,
5885*da0073e9SAndroid Build Coastguard Worker    attn_mask: Optional[Tensor] = None,
5886*da0073e9SAndroid Build Coastguard Worker    use_separate_proj_weight: bool = False,
5887*da0073e9SAndroid Build Coastguard Worker    q_proj_weight: Optional[Tensor] = None,
5888*da0073e9SAndroid Build Coastguard Worker    k_proj_weight: Optional[Tensor] = None,
5889*da0073e9SAndroid Build Coastguard Worker    v_proj_weight: Optional[Tensor] = None,
5890*da0073e9SAndroid Build Coastguard Worker    static_k: Optional[Tensor] = None,
5891*da0073e9SAndroid Build Coastguard Worker    static_v: Optional[Tensor] = None,
5892*da0073e9SAndroid Build Coastguard Worker    average_attn_weights: bool = True,
5893*da0073e9SAndroid Build Coastguard Worker    is_causal: bool = False,
5894*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Optional[Tensor]]:
5895*da0073e9SAndroid Build Coastguard Worker    r"""Forward method for MultiHeadAttention.
5896*da0073e9SAndroid Build Coastguard Worker
5897*da0073e9SAndroid Build Coastguard Worker    See :class:`torch.nn.MultiheadAttention` for details.
5898*da0073e9SAndroid Build Coastguard Worker
5899*da0073e9SAndroid Build Coastguard Worker    Args:
5900*da0073e9SAndroid Build Coastguard Worker        query, key, value: map a query and a set of key-value pairs to an output.
5901*da0073e9SAndroid Build Coastguard Worker            See "Attention Is All You Need" for more details.
5902*da0073e9SAndroid Build Coastguard Worker        embed_dim_to_check: total dimension of the model.
5903*da0073e9SAndroid Build Coastguard Worker        num_heads: parallel attention heads.
5904*da0073e9SAndroid Build Coastguard Worker        in_proj_weight, in_proj_bias: input projection weight and bias.
5905*da0073e9SAndroid Build Coastguard Worker        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
5906*da0073e9SAndroid Build Coastguard Worker        add_zero_attn: add a new batch of zeros to the key and
5907*da0073e9SAndroid Build Coastguard Worker                       value sequences at dim=1.
5908*da0073e9SAndroid Build Coastguard Worker        dropout_p: probability of an element to be zeroed.
5909*da0073e9SAndroid Build Coastguard Worker        out_proj_weight, out_proj_bias: the output projection weight and bias.
5910*da0073e9SAndroid Build Coastguard Worker        training: apply dropout if is ``True``.
5911*da0073e9SAndroid Build Coastguard Worker        key_padding_mask: if provided, specified padding elements in the key will
5912*da0073e9SAndroid Build Coastguard Worker            be ignored by the attention. This is an binary mask. When the value is True,
5913*da0073e9SAndroid Build Coastguard Worker            the corresponding value on the attention layer will be filled with -inf.
5914*da0073e9SAndroid Build Coastguard Worker        need_weights: output attn_output_weights.
5915*da0073e9SAndroid Build Coastguard Worker            Default: `True`
5916*da0073e9SAndroid Build Coastguard Worker            Note: `needs_weight` defaults to `True`, but should be set to `False`
5917*da0073e9SAndroid Build Coastguard Worker            For best performance when attention weights are not needed.
5918*da0073e9SAndroid Build Coastguard Worker            *Setting needs_weights to `True`
5919*da0073e9SAndroid Build Coastguard Worker            leads to a significant performance degradation.*
5920*da0073e9SAndroid Build Coastguard Worker        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
5921*da0073e9SAndroid Build Coastguard Worker            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
5922*da0073e9SAndroid Build Coastguard Worker        is_causal: If specified, applies a causal mask as attention mask, and ignores
5923*da0073e9SAndroid Build Coastguard Worker            attn_mask for computing scaled dot product attention.
5924*da0073e9SAndroid Build Coastguard Worker            Default: ``False``.
5925*da0073e9SAndroid Build Coastguard Worker            .. warning::
5926*da0073e9SAndroid Build Coastguard Worker                is_causal is provides a hint that the attn_mask is the
5927*da0073e9SAndroid Build Coastguard Worker                causal mask.Providing incorrect hints can result in
5928*da0073e9SAndroid Build Coastguard Worker                incorrect execution, including forward and backward
5929*da0073e9SAndroid Build Coastguard Worker                compatibility.
5930*da0073e9SAndroid Build Coastguard Worker        use_separate_proj_weight: the function accept the proj. weights for query, key,
5931*da0073e9SAndroid Build Coastguard Worker            and value in different forms. If false, in_proj_weight will be used, which is
5932*da0073e9SAndroid Build Coastguard Worker            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
5933*da0073e9SAndroid Build Coastguard Worker        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
5934*da0073e9SAndroid Build Coastguard Worker        static_k, static_v: static key and value used for attention operators.
5935*da0073e9SAndroid Build Coastguard Worker        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
5936*da0073e9SAndroid Build Coastguard Worker            Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
5937*da0073e9SAndroid Build Coastguard Worker            when ``need_weights=True.``. Default: True
5938*da0073e9SAndroid Build Coastguard Worker
5939*da0073e9SAndroid Build Coastguard Worker
5940*da0073e9SAndroid Build Coastguard Worker    Shape:
5941*da0073e9SAndroid Build Coastguard Worker        Inputs:
5942*da0073e9SAndroid Build Coastguard Worker        - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
5943*da0073e9SAndroid Build Coastguard Worker          the embedding dimension.
5944*da0073e9SAndroid Build Coastguard Worker        - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
5945*da0073e9SAndroid Build Coastguard Worker          the embedding dimension.
5946*da0073e9SAndroid Build Coastguard Worker        - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
5947*da0073e9SAndroid Build Coastguard Worker          the embedding dimension.
5948*da0073e9SAndroid Build Coastguard Worker        - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
5949*da0073e9SAndroid Build Coastguard Worker          If a FloatTensor is provided, it will be directly added to the value.
5950*da0073e9SAndroid Build Coastguard Worker          If a BoolTensor is provided, the positions with the
5951*da0073e9SAndroid Build Coastguard Worker          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
5952*da0073e9SAndroid Build Coastguard Worker        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
5953*da0073e9SAndroid Build Coastguard Worker          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
5954*da0073e9SAndroid Build Coastguard Worker          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
5955*da0073e9SAndroid Build Coastguard Worker          positions. If a BoolTensor is provided, positions with ``True``
5956*da0073e9SAndroid Build Coastguard Worker          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
5957*da0073e9SAndroid Build Coastguard Worker          is provided, it will be added to the attention weight.
5958*da0073e9SAndroid Build Coastguard Worker        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
5959*da0073e9SAndroid Build Coastguard Worker          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
5960*da0073e9SAndroid Build Coastguard Worker        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
5961*da0073e9SAndroid Build Coastguard Worker          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
5962*da0073e9SAndroid Build Coastguard Worker
5963*da0073e9SAndroid Build Coastguard Worker        Outputs:
5964*da0073e9SAndroid Build Coastguard Worker        - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
5965*da0073e9SAndroid Build Coastguard Worker          E is the embedding dimension.
5966*da0073e9SAndroid Build Coastguard Worker        - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
5967*da0073e9SAndroid Build Coastguard Worker          attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
5968*da0073e9SAndroid Build Coastguard Worker          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
5969*da0073e9SAndroid Build Coastguard Worker          :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
5970*da0073e9SAndroid Build Coastguard Worker          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
5971*da0073e9SAndroid Build Coastguard Worker    """
5972*da0073e9SAndroid Build Coastguard Worker    tens_ops = (
5973*da0073e9SAndroid Build Coastguard Worker        query,
5974*da0073e9SAndroid Build Coastguard Worker        key,
5975*da0073e9SAndroid Build Coastguard Worker        value,
5976*da0073e9SAndroid Build Coastguard Worker        in_proj_weight,
5977*da0073e9SAndroid Build Coastguard Worker        in_proj_bias,
5978*da0073e9SAndroid Build Coastguard Worker        bias_k,
5979*da0073e9SAndroid Build Coastguard Worker        bias_v,
5980*da0073e9SAndroid Build Coastguard Worker        out_proj_weight,
5981*da0073e9SAndroid Build Coastguard Worker        out_proj_bias,
5982*da0073e9SAndroid Build Coastguard Worker    )
5983*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tens_ops):
5984*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
5985*da0073e9SAndroid Build Coastguard Worker            multi_head_attention_forward,
5986*da0073e9SAndroid Build Coastguard Worker            tens_ops,
5987*da0073e9SAndroid Build Coastguard Worker            query,
5988*da0073e9SAndroid Build Coastguard Worker            key,
5989*da0073e9SAndroid Build Coastguard Worker            value,
5990*da0073e9SAndroid Build Coastguard Worker            embed_dim_to_check,
5991*da0073e9SAndroid Build Coastguard Worker            num_heads,
5992*da0073e9SAndroid Build Coastguard Worker            in_proj_weight,
5993*da0073e9SAndroid Build Coastguard Worker            in_proj_bias,
5994*da0073e9SAndroid Build Coastguard Worker            bias_k,
5995*da0073e9SAndroid Build Coastguard Worker            bias_v,
5996*da0073e9SAndroid Build Coastguard Worker            add_zero_attn,
5997*da0073e9SAndroid Build Coastguard Worker            dropout_p,
5998*da0073e9SAndroid Build Coastguard Worker            out_proj_weight,
5999*da0073e9SAndroid Build Coastguard Worker            out_proj_bias,
6000*da0073e9SAndroid Build Coastguard Worker            training=training,
6001*da0073e9SAndroid Build Coastguard Worker            key_padding_mask=key_padding_mask,
6002*da0073e9SAndroid Build Coastguard Worker            need_weights=need_weights,
6003*da0073e9SAndroid Build Coastguard Worker            attn_mask=attn_mask,
6004*da0073e9SAndroid Build Coastguard Worker            is_causal=is_causal,
6005*da0073e9SAndroid Build Coastguard Worker            use_separate_proj_weight=use_separate_proj_weight,
6006*da0073e9SAndroid Build Coastguard Worker            q_proj_weight=q_proj_weight,
6007*da0073e9SAndroid Build Coastguard Worker            k_proj_weight=k_proj_weight,
6008*da0073e9SAndroid Build Coastguard Worker            v_proj_weight=v_proj_weight,
6009*da0073e9SAndroid Build Coastguard Worker            static_k=static_k,
6010*da0073e9SAndroid Build Coastguard Worker            static_v=static_v,
6011*da0073e9SAndroid Build Coastguard Worker            average_attn_weights=average_attn_weights,
6012*da0073e9SAndroid Build Coastguard Worker        )
6013*da0073e9SAndroid Build Coastguard Worker
6014*da0073e9SAndroid Build Coastguard Worker    is_batched = _mha_shape_check(
6015*da0073e9SAndroid Build Coastguard Worker        query, key, value, key_padding_mask, attn_mask, num_heads
6016*da0073e9SAndroid Build Coastguard Worker    )
6017*da0073e9SAndroid Build Coastguard Worker
6018*da0073e9SAndroid Build Coastguard Worker    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
6019*da0073e9SAndroid Build Coastguard Worker    # is batched, run the computation and before returning squeeze the
6020*da0073e9SAndroid Build Coastguard Worker    # batch dimension so that the output doesn't carry this temporary batch dimension.
6021*da0073e9SAndroid Build Coastguard Worker    if not is_batched:
6022*da0073e9SAndroid Build Coastguard Worker        # unsqueeze if the input is unbatched
6023*da0073e9SAndroid Build Coastguard Worker        query = query.unsqueeze(1)
6024*da0073e9SAndroid Build Coastguard Worker        key = key.unsqueeze(1)
6025*da0073e9SAndroid Build Coastguard Worker        value = value.unsqueeze(1)
6026*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
6027*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = key_padding_mask.unsqueeze(0)
6028*da0073e9SAndroid Build Coastguard Worker
6029*da0073e9SAndroid Build Coastguard Worker    # set up shape vars
6030*da0073e9SAndroid Build Coastguard Worker    tgt_len, bsz, embed_dim = query.shape
6031*da0073e9SAndroid Build Coastguard Worker    src_len, _, _ = key.shape
6032*da0073e9SAndroid Build Coastguard Worker
6033*da0073e9SAndroid Build Coastguard Worker    key_padding_mask = _canonical_mask(
6034*da0073e9SAndroid Build Coastguard Worker        mask=key_padding_mask,
6035*da0073e9SAndroid Build Coastguard Worker        mask_name="key_padding_mask",
6036*da0073e9SAndroid Build Coastguard Worker        other_type=_none_or_dtype(attn_mask),
6037*da0073e9SAndroid Build Coastguard Worker        other_name="attn_mask",
6038*da0073e9SAndroid Build Coastguard Worker        target_type=query.dtype,
6039*da0073e9SAndroid Build Coastguard Worker    )
6040*da0073e9SAndroid Build Coastguard Worker
6041*da0073e9SAndroid Build Coastguard Worker    if is_causal and attn_mask is None:
6042*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
6043*da0073e9SAndroid Build Coastguard Worker            "Need attn_mask if specifying the is_causal hint. "
6044*da0073e9SAndroid Build Coastguard Worker            "You may use the Transformer module method "
6045*da0073e9SAndroid Build Coastguard Worker            "`generate_square_subsequent_mask` to create this mask."
6046*da0073e9SAndroid Build Coastguard Worker        )
6047*da0073e9SAndroid Build Coastguard Worker
6048*da0073e9SAndroid Build Coastguard Worker    if is_causal and key_padding_mask is None and not need_weights:
6049*da0073e9SAndroid Build Coastguard Worker        # when we have a kpm or need weights, we need attn_mask
6050*da0073e9SAndroid Build Coastguard Worker        # Otherwise, we use the is_causal hint go as is_causal
6051*da0073e9SAndroid Build Coastguard Worker        # indicator to SDPA.
6052*da0073e9SAndroid Build Coastguard Worker        attn_mask = None
6053*da0073e9SAndroid Build Coastguard Worker    else:
6054*da0073e9SAndroid Build Coastguard Worker        attn_mask = _canonical_mask(
6055*da0073e9SAndroid Build Coastguard Worker            mask=attn_mask,
6056*da0073e9SAndroid Build Coastguard Worker            mask_name="attn_mask",
6057*da0073e9SAndroid Build Coastguard Worker            other_type=None,
6058*da0073e9SAndroid Build Coastguard Worker            other_name="",
6059*da0073e9SAndroid Build Coastguard Worker            target_type=query.dtype,
6060*da0073e9SAndroid Build Coastguard Worker            check_other=False,
6061*da0073e9SAndroid Build Coastguard Worker        )
6062*da0073e9SAndroid Build Coastguard Worker
6063*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
6064*da0073e9SAndroid Build Coastguard Worker            # We have the attn_mask, and use that to merge kpm into it.
6065*da0073e9SAndroid Build Coastguard Worker            # Turn off use of is_causal hint, as the merged mask is no
6066*da0073e9SAndroid Build Coastguard Worker            # longer causal.
6067*da0073e9SAndroid Build Coastguard Worker            is_causal = False
6068*da0073e9SAndroid Build Coastguard Worker
6069*da0073e9SAndroid Build Coastguard Worker    assert (
6070*da0073e9SAndroid Build Coastguard Worker        embed_dim == embed_dim_to_check
6071*da0073e9SAndroid Build Coastguard Worker    ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
6072*da0073e9SAndroid Build Coastguard Worker    if isinstance(embed_dim, torch.Tensor):
6073*da0073e9SAndroid Build Coastguard Worker        # embed_dim can be a tensor when JIT tracing
6074*da0073e9SAndroid Build Coastguard Worker        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
6075*da0073e9SAndroid Build Coastguard Worker    else:
6076*da0073e9SAndroid Build Coastguard Worker        head_dim = embed_dim // num_heads
6077*da0073e9SAndroid Build Coastguard Worker    assert (
6078*da0073e9SAndroid Build Coastguard Worker        head_dim * num_heads == embed_dim
6079*da0073e9SAndroid Build Coastguard Worker    ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
6080*da0073e9SAndroid Build Coastguard Worker    if use_separate_proj_weight:
6081*da0073e9SAndroid Build Coastguard Worker        # allow MHA to have different embedding dimensions when separate projection weights are used
6082*da0073e9SAndroid Build Coastguard Worker        assert (
6083*da0073e9SAndroid Build Coastguard Worker            key.shape[:2] == value.shape[:2]
6084*da0073e9SAndroid Build Coastguard Worker        ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
6085*da0073e9SAndroid Build Coastguard Worker    else:
6086*da0073e9SAndroid Build Coastguard Worker        assert (
6087*da0073e9SAndroid Build Coastguard Worker            key.shape == value.shape
6088*da0073e9SAndroid Build Coastguard Worker        ), f"key shape {key.shape} does not match value shape {value.shape}"
6089*da0073e9SAndroid Build Coastguard Worker
6090*da0073e9SAndroid Build Coastguard Worker    #
6091*da0073e9SAndroid Build Coastguard Worker    # compute in-projection
6092*da0073e9SAndroid Build Coastguard Worker    #
6093*da0073e9SAndroid Build Coastguard Worker    if not use_separate_proj_weight:
6094*da0073e9SAndroid Build Coastguard Worker        assert (
6095*da0073e9SAndroid Build Coastguard Worker            in_proj_weight is not None
6096*da0073e9SAndroid Build Coastguard Worker        ), "use_separate_proj_weight is False but in_proj_weight is None"
6097*da0073e9SAndroid Build Coastguard Worker        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
6098*da0073e9SAndroid Build Coastguard Worker    else:
6099*da0073e9SAndroid Build Coastguard Worker        assert (
6100*da0073e9SAndroid Build Coastguard Worker            q_proj_weight is not None
6101*da0073e9SAndroid Build Coastguard Worker        ), "use_separate_proj_weight is True but q_proj_weight is None"
6102*da0073e9SAndroid Build Coastguard Worker        assert (
6103*da0073e9SAndroid Build Coastguard Worker            k_proj_weight is not None
6104*da0073e9SAndroid Build Coastguard Worker        ), "use_separate_proj_weight is True but k_proj_weight is None"
6105*da0073e9SAndroid Build Coastguard Worker        assert (
6106*da0073e9SAndroid Build Coastguard Worker            v_proj_weight is not None
6107*da0073e9SAndroid Build Coastguard Worker        ), "use_separate_proj_weight is True but v_proj_weight is None"
6108*da0073e9SAndroid Build Coastguard Worker        if in_proj_bias is None:
6109*da0073e9SAndroid Build Coastguard Worker            b_q = b_k = b_v = None
6110*da0073e9SAndroid Build Coastguard Worker        else:
6111*da0073e9SAndroid Build Coastguard Worker            b_q, b_k, b_v = in_proj_bias.chunk(3)
6112*da0073e9SAndroid Build Coastguard Worker        q, k, v = _in_projection(
6113*da0073e9SAndroid Build Coastguard Worker            query,
6114*da0073e9SAndroid Build Coastguard Worker            key,
6115*da0073e9SAndroid Build Coastguard Worker            value,
6116*da0073e9SAndroid Build Coastguard Worker            q_proj_weight,
6117*da0073e9SAndroid Build Coastguard Worker            k_proj_weight,
6118*da0073e9SAndroid Build Coastguard Worker            v_proj_weight,
6119*da0073e9SAndroid Build Coastguard Worker            b_q,
6120*da0073e9SAndroid Build Coastguard Worker            b_k,
6121*da0073e9SAndroid Build Coastguard Worker            b_v,
6122*da0073e9SAndroid Build Coastguard Worker        )
6123*da0073e9SAndroid Build Coastguard Worker
6124*da0073e9SAndroid Build Coastguard Worker    # prep attention mask
6125*da0073e9SAndroid Build Coastguard Worker
6126*da0073e9SAndroid Build Coastguard Worker    if attn_mask is not None:
6127*da0073e9SAndroid Build Coastguard Worker        # ensure attn_mask's dim is 3
6128*da0073e9SAndroid Build Coastguard Worker        if attn_mask.dim() == 2:
6129*da0073e9SAndroid Build Coastguard Worker            correct_2d_size = (tgt_len, src_len)
6130*da0073e9SAndroid Build Coastguard Worker            if attn_mask.shape != correct_2d_size:
6131*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
6132*da0073e9SAndroid Build Coastguard Worker                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
6133*da0073e9SAndroid Build Coastguard Worker                )
6134*da0073e9SAndroid Build Coastguard Worker            attn_mask = attn_mask.unsqueeze(0)
6135*da0073e9SAndroid Build Coastguard Worker        elif attn_mask.dim() == 3:
6136*da0073e9SAndroid Build Coastguard Worker            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
6137*da0073e9SAndroid Build Coastguard Worker            if attn_mask.shape != correct_3d_size:
6138*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
6139*da0073e9SAndroid Build Coastguard Worker                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
6140*da0073e9SAndroid Build Coastguard Worker                )
6141*da0073e9SAndroid Build Coastguard Worker        else:
6142*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
6143*da0073e9SAndroid Build Coastguard Worker                f"attn_mask's dimension {attn_mask.dim()} is not supported"
6144*da0073e9SAndroid Build Coastguard Worker            )
6145*da0073e9SAndroid Build Coastguard Worker
6146*da0073e9SAndroid Build Coastguard Worker    # add bias along batch dimension (currently second)
6147*da0073e9SAndroid Build Coastguard Worker    if bias_k is not None and bias_v is not None:
6148*da0073e9SAndroid Build Coastguard Worker        assert static_k is None, "bias cannot be added to static key."
6149*da0073e9SAndroid Build Coastguard Worker        assert static_v is None, "bias cannot be added to static value."
6150*da0073e9SAndroid Build Coastguard Worker        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
6151*da0073e9SAndroid Build Coastguard Worker        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
6152*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
6153*da0073e9SAndroid Build Coastguard Worker            attn_mask = pad(attn_mask, (0, 1))
6154*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
6155*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = pad(key_padding_mask, (0, 1))
6156*da0073e9SAndroid Build Coastguard Worker    else:
6157*da0073e9SAndroid Build Coastguard Worker        assert bias_k is None
6158*da0073e9SAndroid Build Coastguard Worker        assert bias_v is None
6159*da0073e9SAndroid Build Coastguard Worker
6160*da0073e9SAndroid Build Coastguard Worker    #
6161*da0073e9SAndroid Build Coastguard Worker    # reshape q, k, v for multihead attention and make them batch first
6162*da0073e9SAndroid Build Coastguard Worker    #
6163*da0073e9SAndroid Build Coastguard Worker    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
6164*da0073e9SAndroid Build Coastguard Worker    if static_k is None:
6165*da0073e9SAndroid Build Coastguard Worker        k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
6166*da0073e9SAndroid Build Coastguard Worker    else:
6167*da0073e9SAndroid Build Coastguard Worker        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
6168*da0073e9SAndroid Build Coastguard Worker        assert (
6169*da0073e9SAndroid Build Coastguard Worker            static_k.size(0) == bsz * num_heads
6170*da0073e9SAndroid Build Coastguard Worker        ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
6171*da0073e9SAndroid Build Coastguard Worker        assert (
6172*da0073e9SAndroid Build Coastguard Worker            static_k.size(2) == head_dim
6173*da0073e9SAndroid Build Coastguard Worker        ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
6174*da0073e9SAndroid Build Coastguard Worker        k = static_k
6175*da0073e9SAndroid Build Coastguard Worker    if static_v is None:
6176*da0073e9SAndroid Build Coastguard Worker        v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
6177*da0073e9SAndroid Build Coastguard Worker    else:
6178*da0073e9SAndroid Build Coastguard Worker        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
6179*da0073e9SAndroid Build Coastguard Worker        assert (
6180*da0073e9SAndroid Build Coastguard Worker            static_v.size(0) == bsz * num_heads
6181*da0073e9SAndroid Build Coastguard Worker        ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
6182*da0073e9SAndroid Build Coastguard Worker        assert (
6183*da0073e9SAndroid Build Coastguard Worker            static_v.size(2) == head_dim
6184*da0073e9SAndroid Build Coastguard Worker        ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
6185*da0073e9SAndroid Build Coastguard Worker        v = static_v
6186*da0073e9SAndroid Build Coastguard Worker
6187*da0073e9SAndroid Build Coastguard Worker    # add zero attention along batch dimension (now first)
6188*da0073e9SAndroid Build Coastguard Worker    if add_zero_attn:
6189*da0073e9SAndroid Build Coastguard Worker        zero_attn_shape = (bsz * num_heads, 1, head_dim)
6190*da0073e9SAndroid Build Coastguard Worker        k = torch.cat(
6191*da0073e9SAndroid Build Coastguard Worker            [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
6192*da0073e9SAndroid Build Coastguard Worker        )
6193*da0073e9SAndroid Build Coastguard Worker        v = torch.cat(
6194*da0073e9SAndroid Build Coastguard Worker            [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
6195*da0073e9SAndroid Build Coastguard Worker        )
6196*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
6197*da0073e9SAndroid Build Coastguard Worker            attn_mask = pad(attn_mask, (0, 1))
6198*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
6199*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = pad(key_padding_mask, (0, 1))
6200*da0073e9SAndroid Build Coastguard Worker
6201*da0073e9SAndroid Build Coastguard Worker    # update source sequence length after adjustments
6202*da0073e9SAndroid Build Coastguard Worker    src_len = k.size(1)
6203*da0073e9SAndroid Build Coastguard Worker
6204*da0073e9SAndroid Build Coastguard Worker    # merge key padding and attention masks
6205*da0073e9SAndroid Build Coastguard Worker    if key_padding_mask is not None:
6206*da0073e9SAndroid Build Coastguard Worker        assert key_padding_mask.shape == (
6207*da0073e9SAndroid Build Coastguard Worker            bsz,
6208*da0073e9SAndroid Build Coastguard Worker            src_len,
6209*da0073e9SAndroid Build Coastguard Worker        ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
6210*da0073e9SAndroid Build Coastguard Worker        key_padding_mask = (
6211*da0073e9SAndroid Build Coastguard Worker            key_padding_mask.view(bsz, 1, 1, src_len)
6212*da0073e9SAndroid Build Coastguard Worker            .expand(-1, num_heads, -1, -1)
6213*da0073e9SAndroid Build Coastguard Worker            .reshape(bsz * num_heads, 1, src_len)
6214*da0073e9SAndroid Build Coastguard Worker        )
6215*da0073e9SAndroid Build Coastguard Worker        if attn_mask is None:
6216*da0073e9SAndroid Build Coastguard Worker            attn_mask = key_padding_mask
6217*da0073e9SAndroid Build Coastguard Worker        else:
6218*da0073e9SAndroid Build Coastguard Worker            attn_mask = attn_mask + key_padding_mask
6219*da0073e9SAndroid Build Coastguard Worker
6220*da0073e9SAndroid Build Coastguard Worker    # adjust dropout probability
6221*da0073e9SAndroid Build Coastguard Worker    if not training:
6222*da0073e9SAndroid Build Coastguard Worker        dropout_p = 0.0
6223*da0073e9SAndroid Build Coastguard Worker
6224*da0073e9SAndroid Build Coastguard Worker    #
6225*da0073e9SAndroid Build Coastguard Worker    # (deep breath) calculate attention and out projection
6226*da0073e9SAndroid Build Coastguard Worker    #
6227*da0073e9SAndroid Build Coastguard Worker
6228*da0073e9SAndroid Build Coastguard Worker    if need_weights:
6229*da0073e9SAndroid Build Coastguard Worker        B, Nt, E = q.shape
6230*da0073e9SAndroid Build Coastguard Worker        q_scaled = q * math.sqrt(1.0 / float(E))
6231*da0073e9SAndroid Build Coastguard Worker
6232*da0073e9SAndroid Build Coastguard Worker        assert not (
6233*da0073e9SAndroid Build Coastguard Worker            is_causal and attn_mask is None
6234*da0073e9SAndroid Build Coastguard Worker        ), "FIXME: is_causal not implemented for need_weights"
6235*da0073e9SAndroid Build Coastguard Worker
6236*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
6237*da0073e9SAndroid Build Coastguard Worker            attn_output_weights = torch.baddbmm(
6238*da0073e9SAndroid Build Coastguard Worker                attn_mask, q_scaled, k.transpose(-2, -1)
6239*da0073e9SAndroid Build Coastguard Worker            )
6240*da0073e9SAndroid Build Coastguard Worker        else:
6241*da0073e9SAndroid Build Coastguard Worker            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
6242*da0073e9SAndroid Build Coastguard Worker        attn_output_weights = softmax(attn_output_weights, dim=-1)
6243*da0073e9SAndroid Build Coastguard Worker        if dropout_p > 0.0:
6244*da0073e9SAndroid Build Coastguard Worker            attn_output_weights = dropout(attn_output_weights, p=dropout_p)
6245*da0073e9SAndroid Build Coastguard Worker
6246*da0073e9SAndroid Build Coastguard Worker        attn_output = torch.bmm(attn_output_weights, v)
6247*da0073e9SAndroid Build Coastguard Worker
6248*da0073e9SAndroid Build Coastguard Worker        attn_output = (
6249*da0073e9SAndroid Build Coastguard Worker            attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
6250*da0073e9SAndroid Build Coastguard Worker        )
6251*da0073e9SAndroid Build Coastguard Worker        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
6252*da0073e9SAndroid Build Coastguard Worker        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
6253*da0073e9SAndroid Build Coastguard Worker
6254*da0073e9SAndroid Build Coastguard Worker        # optionally average attention weights over heads
6255*da0073e9SAndroid Build Coastguard Worker        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
6256*da0073e9SAndroid Build Coastguard Worker        if average_attn_weights:
6257*da0073e9SAndroid Build Coastguard Worker            attn_output_weights = attn_output_weights.mean(dim=1)
6258*da0073e9SAndroid Build Coastguard Worker
6259*da0073e9SAndroid Build Coastguard Worker        if not is_batched:
6260*da0073e9SAndroid Build Coastguard Worker            # squeeze the output if input was unbatched
6261*da0073e9SAndroid Build Coastguard Worker            attn_output = attn_output.squeeze(1)
6262*da0073e9SAndroid Build Coastguard Worker            attn_output_weights = attn_output_weights.squeeze(0)
6263*da0073e9SAndroid Build Coastguard Worker        return attn_output, attn_output_weights
6264*da0073e9SAndroid Build Coastguard Worker    else:
6265*da0073e9SAndroid Build Coastguard Worker        # attn_mask can be either (L,S) or (N*num_heads, L, S)
6266*da0073e9SAndroid Build Coastguard Worker        # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
6267*da0073e9SAndroid Build Coastguard Worker        # in order to match the input for SDPA of (N, num_heads, L, S)
6268*da0073e9SAndroid Build Coastguard Worker        if attn_mask is not None:
6269*da0073e9SAndroid Build Coastguard Worker            if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
6270*da0073e9SAndroid Build Coastguard Worker                attn_mask = attn_mask.unsqueeze(0)
6271*da0073e9SAndroid Build Coastguard Worker            else:
6272*da0073e9SAndroid Build Coastguard Worker                attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
6273*da0073e9SAndroid Build Coastguard Worker
6274*da0073e9SAndroid Build Coastguard Worker        q = q.view(bsz, num_heads, tgt_len, head_dim)
6275*da0073e9SAndroid Build Coastguard Worker        k = k.view(bsz, num_heads, src_len, head_dim)
6276*da0073e9SAndroid Build Coastguard Worker        v = v.view(bsz, num_heads, src_len, head_dim)
6277*da0073e9SAndroid Build Coastguard Worker
6278*da0073e9SAndroid Build Coastguard Worker        attn_output = scaled_dot_product_attention(
6279*da0073e9SAndroid Build Coastguard Worker            q, k, v, attn_mask, dropout_p, is_causal
6280*da0073e9SAndroid Build Coastguard Worker        )
6281*da0073e9SAndroid Build Coastguard Worker        attn_output = (
6282*da0073e9SAndroid Build Coastguard Worker            attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
6283*da0073e9SAndroid Build Coastguard Worker        )
6284*da0073e9SAndroid Build Coastguard Worker
6285*da0073e9SAndroid Build Coastguard Worker        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
6286*da0073e9SAndroid Build Coastguard Worker        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
6287*da0073e9SAndroid Build Coastguard Worker        if not is_batched:
6288*da0073e9SAndroid Build Coastguard Worker            # squeeze the output if input was unbatched
6289*da0073e9SAndroid Build Coastguard Worker            attn_output = attn_output.squeeze(1)
6290*da0073e9SAndroid Build Coastguard Worker        return attn_output, None
6291