1*da0073e9SAndroid Build Coastguard Worker"""Functional interface.""" 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport importlib 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport warnings 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom torch import _VF, sym_int as _sym_int, Tensor 10*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _add_docstr, _infer_size 11*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import ( 12*da0073e9SAndroid Build Coastguard Worker _overload, 13*da0073e9SAndroid Build Coastguard Worker boolean_dispatch, 14*da0073e9SAndroid Build Coastguard Worker BroadcastingList1, 15*da0073e9SAndroid Build Coastguard Worker BroadcastingList2, 16*da0073e9SAndroid Build Coastguard Worker BroadcastingList3, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Workerfrom torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes 19*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import _reduction as _Reduction, grad # noqa: F401 20*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.utils import _list_with_default, _pair, _single, _triple 21*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import ( 22*da0073e9SAndroid Build Coastguard Worker handle_torch_function, 23*da0073e9SAndroid Build Coastguard Worker has_torch_function, 24*da0073e9SAndroid Build Coastguard Worker has_torch_function_unary, 25*da0073e9SAndroid Build Coastguard Worker has_torch_function_variadic, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 30*da0073e9SAndroid Build Coastguard Worker from torch.types import _dtype as DType 31*da0073e9SAndroid Build Coastguard Workerelse: 32*da0073e9SAndroid Build Coastguard Worker # The JIT doesn't understand Union, nor torch.dtype here 33*da0073e9SAndroid Build Coastguard Worker DType = int 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workertry: 36*da0073e9SAndroid Build Coastguard Worker import numpy as np 37*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 38*da0073e9SAndroid Build Coastguard Worker np = None 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerconv1d = _add_docstr( 42*da0073e9SAndroid Build Coastguard Worker torch.conv1d, 43*da0073e9SAndroid Build Coastguard Worker r""" 44*da0073e9SAndroid Build Coastguard Workerconv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard WorkerApplies a 1D convolution over an input signal composed of several input 47*da0073e9SAndroid Build Coastguard Workerplanes. 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker{tf32_note} 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv1d` for details and output shape. 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard WorkerNote: 54*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard WorkerNote: 57*da0073e9SAndroid Build Coastguard Worker This operator supports complex data types i.e. ``complex32, complex64, complex128``. 58*da0073e9SAndroid Build Coastguard Worker""".format( 59*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker + r""" 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard WorkerArgs: 64*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 65*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)` 66*da0073e9SAndroid Build Coastguard Worker bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` 67*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or 68*da0073e9SAndroid Build Coastguard Worker a one-element tuple `(sW,)`. Default: 1 69*da0073e9SAndroid Build Coastguard Worker padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 70*da0073e9SAndroid Build Coastguard Worker single number or a one-element tuple `(padW,)`. Default: 0 71*da0073e9SAndroid Build Coastguard Worker ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 72*da0073e9SAndroid Build Coastguard Worker the input so the output has the same shape as the input. However, this mode 73*da0073e9SAndroid Build Coastguard Worker doesn't support any stride values other than 1. 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker .. warning:: 76*da0073e9SAndroid Build Coastguard Worker For ``padding='same'``, if the ``weight`` is even-length and 77*da0073e9SAndroid Build Coastguard Worker ``dilation`` is odd in any dimension, a full :func:`pad` operation 78*da0073e9SAndroid Build Coastguard Worker may be needed internally. Lowering performance. 79*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 80*da0073e9SAndroid Build Coastguard Worker a one-element tuple `(dW,)`. Default: 1 81*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, :math:`\text{in\_channels}` should be divisible by 82*da0073e9SAndroid Build Coastguard Worker the number of groups. Default: 1 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard WorkerExamples:: 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(33, 16, 30) 87*da0073e9SAndroid Build Coastguard Worker >>> filters = torch.randn(20, 16, 5) 88*da0073e9SAndroid Build Coastguard Worker >>> F.conv1d(inputs, filters) 89*da0073e9SAndroid Build Coastguard Worker""", 90*da0073e9SAndroid Build Coastguard Worker) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerconv2d = _add_docstr( 93*da0073e9SAndroid Build Coastguard Worker torch.conv2d, 94*da0073e9SAndroid Build Coastguard Worker r""" 95*da0073e9SAndroid Build Coastguard Workerconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard WorkerApplies a 2D convolution over an input image composed of several input 98*da0073e9SAndroid Build Coastguard Workerplanes. 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker{tf32_note} 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv2d` for details and output shape. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard WorkerNote: 105*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard WorkerNote: 108*da0073e9SAndroid Build Coastguard Worker This operator supports complex data types i.e. ``complex32, complex64, complex128``. 109*da0073e9SAndroid Build Coastguard Worker""".format( 110*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker + r""" 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard WorkerArgs: 115*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 116*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` 117*da0073e9SAndroid Build Coastguard Worker bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` 118*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or a 119*da0073e9SAndroid Build Coastguard Worker tuple `(sH, sW)`. Default: 1 120*da0073e9SAndroid Build Coastguard Worker padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 121*da0073e9SAndroid Build Coastguard Worker single number or a tuple `(padH, padW)`. Default: 0 122*da0073e9SAndroid Build Coastguard Worker ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 123*da0073e9SAndroid Build Coastguard Worker the input so the output has the same shape as the input. However, this mode 124*da0073e9SAndroid Build Coastguard Worker doesn't support any stride values other than 1. 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker .. warning:: 127*da0073e9SAndroid Build Coastguard Worker For ``padding='same'``, if the ``weight`` is even-length and 128*da0073e9SAndroid Build Coastguard Worker ``dilation`` is odd in any dimension, a full :func:`pad` operation 129*da0073e9SAndroid Build Coastguard Worker may be needed internally. Lowering performance. 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 132*da0073e9SAndroid Build Coastguard Worker a tuple `(dH, dW)`. Default: 1 133*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}` 134*da0073e9SAndroid Build Coastguard Worker should be divisible by the number of groups. Default: 1 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard WorkerExamples:: 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker >>> # With square kernels and equal stride 139*da0073e9SAndroid Build Coastguard Worker >>> filters = torch.randn(8, 4, 3, 3) 140*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(1, 4, 5, 5) 141*da0073e9SAndroid Build Coastguard Worker >>> F.conv2d(inputs, filters, padding=1) 142*da0073e9SAndroid Build Coastguard Worker""", 143*da0073e9SAndroid Build Coastguard Worker) # noqa: E501 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerconv3d = _add_docstr( 146*da0073e9SAndroid Build Coastguard Worker torch.conv3d, 147*da0073e9SAndroid Build Coastguard Worker r""" 148*da0073e9SAndroid Build Coastguard Workerconv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard WorkerApplies a 3D convolution over an input image composed of several input 151*da0073e9SAndroid Build Coastguard Workerplanes. 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker{tf32_note} 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Conv3d` for details and output shape. 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard WorkerNote: 158*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard WorkerNote: 161*da0073e9SAndroid Build Coastguard Worker This operator supports complex data types i.e. ``complex32, complex64, complex128``. 162*da0073e9SAndroid Build Coastguard Worker""".format( 163*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 164*da0073e9SAndroid Build Coastguard Worker ) 165*da0073e9SAndroid Build Coastguard Worker + r""" 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard WorkerArgs: 168*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` 169*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)` 170*da0073e9SAndroid Build Coastguard Worker bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None 171*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or a 172*da0073e9SAndroid Build Coastguard Worker tuple `(sT, sH, sW)`. Default: 1 173*da0073e9SAndroid Build Coastguard Worker padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 174*da0073e9SAndroid Build Coastguard Worker single number or a tuple `(padT, padH, padW)`. Default: 0 175*da0073e9SAndroid Build Coastguard Worker ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 176*da0073e9SAndroid Build Coastguard Worker the input so the output has the same shape as the input. However, this mode 177*da0073e9SAndroid Build Coastguard Worker doesn't support any stride values other than 1. 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker .. warning:: 180*da0073e9SAndroid Build Coastguard Worker For ``padding='same'``, if the ``weight`` is even-length and 181*da0073e9SAndroid Build Coastguard Worker ``dilation`` is odd in any dimension, a full :func:`pad` operation 182*da0073e9SAndroid Build Coastguard Worker may be needed internally. Lowering performance. 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 185*da0073e9SAndroid Build Coastguard Worker a tuple `(dT, dH, dW)`. Default: 1 186*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, :math:`\text{in\_channels}` should be divisible by 187*da0073e9SAndroid Build Coastguard Worker the number of groups. Default: 1 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard WorkerExamples:: 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker >>> filters = torch.randn(33, 16, 3, 3, 3) 192*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(20, 16, 50, 10, 20) 193*da0073e9SAndroid Build Coastguard Worker >>> F.conv3d(inputs, filters) 194*da0073e9SAndroid Build Coastguard Worker""", 195*da0073e9SAndroid Build Coastguard Worker) # noqa: E501 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Workerconv_transpose1d = _add_docstr( 198*da0073e9SAndroid Build Coastguard Worker torch.conv_transpose1d, 199*da0073e9SAndroid Build Coastguard Worker r""" 200*da0073e9SAndroid Build Coastguard Workerconv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard WorkerApplies a 1D transposed convolution operator over an input signal 203*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution". 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker{tf32_note} 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose1d` for details and output shape. 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard WorkerNote: 210*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 211*da0073e9SAndroid Build Coastguard Worker""".format( 212*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker + r""" 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard WorkerArgs: 217*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 218*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)` 219*da0073e9SAndroid Build Coastguard Worker bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 220*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or a 221*da0073e9SAndroid Build Coastguard Worker tuple ``(sW,)``. Default: 1 222*da0073e9SAndroid Build Coastguard Worker padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 223*da0073e9SAndroid Build Coastguard Worker sides of each dimension in the input. Can be a single number or a tuple 224*da0073e9SAndroid Build Coastguard Worker ``(padW,)``. Default: 0 225*da0073e9SAndroid Build Coastguard Worker output_padding: additional size added to one side of each dimension in the 226*da0073e9SAndroid Build Coastguard Worker output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0 227*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 228*da0073e9SAndroid Build Coastguard Worker number of groups. Default: 1 229*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 230*da0073e9SAndroid Build Coastguard Worker a tuple ``(dW,)``. Default: 1 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard WorkerExamples:: 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(20, 16, 50) 235*da0073e9SAndroid Build Coastguard Worker >>> weights = torch.randn(16, 33, 5) 236*da0073e9SAndroid Build Coastguard Worker >>> F.conv_transpose1d(inputs, weights) 237*da0073e9SAndroid Build Coastguard Worker""", 238*da0073e9SAndroid Build Coastguard Worker) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Workerconv_transpose2d = _add_docstr( 241*da0073e9SAndroid Build Coastguard Worker torch.conv_transpose2d, 242*da0073e9SAndroid Build Coastguard Worker r""" 243*da0073e9SAndroid Build Coastguard Workerconv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard WorkerApplies a 2D transposed convolution operator over an input image 246*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution". 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker{tf32_note} 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose2d` for details and output shape. 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard WorkerNote: 253*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 254*da0073e9SAndroid Build Coastguard Worker""".format( 255*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker + r""" 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard WorkerArgs: 260*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 261*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)` 262*da0073e9SAndroid Build Coastguard Worker bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 263*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or a 264*da0073e9SAndroid Build Coastguard Worker tuple ``(sH, sW)``. Default: 1 265*da0073e9SAndroid Build Coastguard Worker padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 266*da0073e9SAndroid Build Coastguard Worker sides of each dimension in the input. Can be a single number or a tuple 267*da0073e9SAndroid Build Coastguard Worker ``(padH, padW)``. Default: 0 268*da0073e9SAndroid Build Coastguard Worker output_padding: additional size added to one side of each dimension in the 269*da0073e9SAndroid Build Coastguard Worker output shape. Can be a single number or a tuple ``(out_padH, out_padW)``. 270*da0073e9SAndroid Build Coastguard Worker Default: 0 271*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 272*da0073e9SAndroid Build Coastguard Worker number of groups. Default: 1 273*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 274*da0073e9SAndroid Build Coastguard Worker a tuple ``(dH, dW)``. Default: 1 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard WorkerExamples:: 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker >>> # With square kernels and equal stride 279*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(1, 4, 5, 5) 280*da0073e9SAndroid Build Coastguard Worker >>> weights = torch.randn(4, 8, 3, 3) 281*da0073e9SAndroid Build Coastguard Worker >>> F.conv_transpose2d(inputs, weights, padding=1) 282*da0073e9SAndroid Build Coastguard Worker""", 283*da0073e9SAndroid Build Coastguard Worker) # noqa: E501 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Workerconv_transpose3d = _add_docstr( 286*da0073e9SAndroid Build Coastguard Worker torch.conv_transpose3d, 287*da0073e9SAndroid Build Coastguard Worker r""" 288*da0073e9SAndroid Build Coastguard Workerconv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard WorkerApplies a 3D transposed convolution operator over an input image 291*da0073e9SAndroid Build Coastguard Workercomposed of several input planes, sometimes also called "deconvolution" 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker{tf32_note} 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ConvTranspose3d` for details and output shape. 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard WorkerNote: 298*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 299*da0073e9SAndroid Build Coastguard Worker""".format( 300*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes, **tf32_notes 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker + r""" 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard WorkerArgs: 305*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` 306*da0073e9SAndroid Build Coastguard Worker weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)` 307*da0073e9SAndroid Build Coastguard Worker bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 308*da0073e9SAndroid Build Coastguard Worker stride: the stride of the convolving kernel. Can be a single number or a 309*da0073e9SAndroid Build Coastguard Worker tuple ``(sT, sH, sW)``. Default: 1 310*da0073e9SAndroid Build Coastguard Worker padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 311*da0073e9SAndroid Build Coastguard Worker sides of each dimension in the input. Can be a single number or a tuple 312*da0073e9SAndroid Build Coastguard Worker ``(padT, padH, padW)``. Default: 0 313*da0073e9SAndroid Build Coastguard Worker output_padding: additional size added to one side of each dimension in the 314*da0073e9SAndroid Build Coastguard Worker output shape. Can be a single number or a tuple 315*da0073e9SAndroid Build Coastguard Worker ``(out_padT, out_padH, out_padW)``. Default: 0 316*da0073e9SAndroid Build Coastguard Worker groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 317*da0073e9SAndroid Build Coastguard Worker number of groups. Default: 1 318*da0073e9SAndroid Build Coastguard Worker dilation: the spacing between kernel elements. Can be a single number or 319*da0073e9SAndroid Build Coastguard Worker a tuple `(dT, dH, dW)`. Default: 1 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard WorkerExamples:: 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(20, 16, 50, 10, 20) 324*da0073e9SAndroid Build Coastguard Worker >>> weights = torch.randn(16, 33, 3, 3, 3) 325*da0073e9SAndroid Build Coastguard Worker >>> F.conv_transpose3d(inputs, weights) 326*da0073e9SAndroid Build Coastguard Worker""", 327*da0073e9SAndroid Build Coastguard Worker) # noqa: E501 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Workerconv_tbc = _add_docstr( 330*da0073e9SAndroid Build Coastguard Worker torch.conv_tbc, 331*da0073e9SAndroid Build Coastguard Worker r""" 332*da0073e9SAndroid Build Coastguard WorkerApplies a 1-dimensional sequence convolution over an input sequence. 333*da0073e9SAndroid Build Coastguard WorkerInput and output dimensions are (Time, Batch, Channels) - hence TBC. 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard WorkerArgs: 336*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})` 337*da0073e9SAndroid Build Coastguard Worker weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) 338*da0073e9SAndroid Build Coastguard Worker bias: bias of shape (:math:`\text{out\_channels}`) 339*da0073e9SAndroid Build Coastguard Worker pad: number of timesteps to pad. Default: 0 340*da0073e9SAndroid Build Coastguard Worker""", 341*da0073e9SAndroid Build Coastguard Worker) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker# Pooling 345*da0073e9SAndroid Build Coastguard Workeravg_pool1d = _add_docstr( 346*da0073e9SAndroid Build Coastguard Worker torch.avg_pool1d, 347*da0073e9SAndroid Build Coastguard Worker r""" 348*da0073e9SAndroid Build Coastguard Workeravg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard WorkerApplies a 1D average pooling over an input signal composed of several 351*da0073e9SAndroid Build Coastguard Workerinput planes. 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool1d` for details and output shape. 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard WorkerArgs: 356*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 357*da0073e9SAndroid Build Coastguard Worker kernel_size: the size of the window. Can be a single number or a 358*da0073e9SAndroid Build Coastguard Worker tuple `(kW,)` 359*da0073e9SAndroid Build Coastguard Worker stride: the stride of the window. Can be a single number or a tuple 360*da0073e9SAndroid Build Coastguard Worker `(sW,)`. Default: :attr:`kernel_size` 361*da0073e9SAndroid Build Coastguard Worker padding: implicit zero paddings on both sides of the input. Can be a 362*da0073e9SAndroid Build Coastguard Worker single number or a tuple `(padW,)`. Default: 0 363*da0073e9SAndroid Build Coastguard Worker ceil_mode: when True, will use `ceil` instead of `floor` to compute the 364*da0073e9SAndroid Build Coastguard Worker output shape. Default: ``False`` 365*da0073e9SAndroid Build Coastguard Worker count_include_pad: when True, will include the zero-padding in the 366*da0073e9SAndroid Build Coastguard Worker averaging calculation. Default: ``True`` 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard WorkerExamples:: 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker >>> # pool of square window of size=3, stride=2 371*da0073e9SAndroid Build Coastguard Worker >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32) 372*da0073e9SAndroid Build Coastguard Worker >>> F.avg_pool1d(input, kernel_size=3, stride=2) 373*da0073e9SAndroid Build Coastguard Worker tensor([[[ 2., 4., 6.]]]) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker""", 376*da0073e9SAndroid Build Coastguard Worker) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Workeravg_pool2d = _add_docstr( 380*da0073e9SAndroid Build Coastguard Worker torch._C._nn.avg_pool2d, 381*da0073e9SAndroid Build Coastguard Worker r""" 382*da0073e9SAndroid Build Coastguard Workeravg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard WorkerApplies 2D average-pooling operation in :math:`kH \times kW` regions by step size 385*da0073e9SAndroid Build Coastguard Worker:math:`sH \times sW` steps. The number of output features is equal to the number of 386*da0073e9SAndroid Build Coastguard Workerinput planes. 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool2d` for details and output shape. 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard WorkerArgs: 391*da0073e9SAndroid Build Coastguard Worker input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 392*da0073e9SAndroid Build Coastguard Worker kernel_size: size of the pooling region. Can be a single number or a 393*da0073e9SAndroid Build Coastguard Worker tuple `(kH, kW)` 394*da0073e9SAndroid Build Coastguard Worker stride: stride of the pooling operation. Can be a single number or a 395*da0073e9SAndroid Build Coastguard Worker tuple `(sH, sW)`. Default: :attr:`kernel_size` 396*da0073e9SAndroid Build Coastguard Worker padding: implicit zero paddings on both sides of the input. Can be a 397*da0073e9SAndroid Build Coastguard Worker single number or a tuple `(padH, padW)`. Default: 0 398*da0073e9SAndroid Build Coastguard Worker ceil_mode: when True, will use `ceil` instead of `floor` in the formula 399*da0073e9SAndroid Build Coastguard Worker to compute the output shape. Default: ``False`` 400*da0073e9SAndroid Build Coastguard Worker count_include_pad: when True, will include the zero-padding in the 401*da0073e9SAndroid Build Coastguard Worker averaging calculation. Default: ``True`` 402*da0073e9SAndroid Build Coastguard Worker divisor_override: if specified, it will be used as divisor, otherwise 403*da0073e9SAndroid Build Coastguard Worker size of the pooling region will be used. Default: None 404*da0073e9SAndroid Build Coastguard Worker""", 405*da0073e9SAndroid Build Coastguard Worker) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Workeravg_pool3d = _add_docstr( 408*da0073e9SAndroid Build Coastguard Worker torch._C._nn.avg_pool3d, 409*da0073e9SAndroid Build Coastguard Worker r""" 410*da0073e9SAndroid Build Coastguard Workeravg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard WorkerApplies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step 413*da0073e9SAndroid Build Coastguard Workersize :math:`sT \times sH \times sW` steps. The number of output features is equal to 414*da0073e9SAndroid Build Coastguard Worker:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AvgPool3d` for details and output shape. 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard WorkerArgs: 419*da0073e9SAndroid Build Coastguard Worker input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)` 420*da0073e9SAndroid Build Coastguard Worker kernel_size: size of the pooling region. Can be a single number or a 421*da0073e9SAndroid Build Coastguard Worker tuple `(kT, kH, kW)` 422*da0073e9SAndroid Build Coastguard Worker stride: stride of the pooling operation. Can be a single number or a 423*da0073e9SAndroid Build Coastguard Worker tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` 424*da0073e9SAndroid Build Coastguard Worker padding: implicit zero paddings on both sides of the input. Can be a 425*da0073e9SAndroid Build Coastguard Worker single number or a tuple `(padT, padH, padW)`, Default: 0 426*da0073e9SAndroid Build Coastguard Worker ceil_mode: when True, will use `ceil` instead of `floor` in the formula 427*da0073e9SAndroid Build Coastguard Worker to compute the output shape 428*da0073e9SAndroid Build Coastguard Worker count_include_pad: when True, will include the zero-padding in the 429*da0073e9SAndroid Build Coastguard Worker averaging calculation 430*da0073e9SAndroid Build Coastguard Worker divisor_override: if specified, it will be used as divisor, otherwise 431*da0073e9SAndroid Build Coastguard Worker size of the pooling region will be used. Default: None 432*da0073e9SAndroid Build Coastguard Worker""", 433*da0073e9SAndroid Build Coastguard Worker) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool2d_with_indices( 437*da0073e9SAndroid Build Coastguard Worker input: Tensor, 438*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 439*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList2[int]] = None, 440*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[BroadcastingList2[float]] = None, 441*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 442*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = None, 443*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 444*da0073e9SAndroid Build Coastguard Worker r""" 445*da0073e9SAndroid Build Coastguard Worker fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker Applies 2D fractional max pooling over an input signal composed of several input planes. 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic 452*da0073e9SAndroid Build Coastguard Worker step size determined by the target output size. 453*da0073e9SAndroid Build Coastguard Worker The number of output features is equal to the number of input planes. 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker Args: 456*da0073e9SAndroid Build Coastguard Worker kernel_size: the size of the window to take a max over. 457*da0073e9SAndroid Build Coastguard Worker Can be a single number :math:`k` (for a square kernel of :math:`k \times k`) 458*da0073e9SAndroid Build Coastguard Worker or a tuple `(kH, kW)` 459*da0073e9SAndroid Build Coastguard Worker output_size: the target output size of the image of the form :math:`oH \times oW`. 460*da0073e9SAndroid Build Coastguard Worker Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH` 461*da0073e9SAndroid Build Coastguard Worker output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 462*da0073e9SAndroid Build Coastguard Worker This has to be a number or tuple in the range (0, 1) 463*da0073e9SAndroid Build Coastguard Worker return_indices: if ``True``, will return the indices along with the outputs. 464*da0073e9SAndroid Build Coastguard Worker Useful to pass to :func:`~torch.nn.functional.max_unpool2d`. 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker Examples:: 467*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(20, 16, 50, 32) 468*da0073e9SAndroid Build Coastguard Worker >>> # pool of square window of size=3, and target output size 13x12 469*da0073e9SAndroid Build Coastguard Worker >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12)) 470*da0073e9SAndroid Build Coastguard Worker >>> # pool of square window and target output size being half of input image size 471*da0073e9SAndroid Build Coastguard Worker >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5)) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker .. _Fractional MaxPooling: 474*da0073e9SAndroid Build Coastguard Worker http://arxiv.org/abs/1412.6071 475*da0073e9SAndroid Build Coastguard Worker """ 476*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, _random_samples): 477*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 478*da0073e9SAndroid Build Coastguard Worker fractional_max_pool2d_with_indices, 479*da0073e9SAndroid Build Coastguard Worker (input, _random_samples), 480*da0073e9SAndroid Build Coastguard Worker input, 481*da0073e9SAndroid Build Coastguard Worker kernel_size, 482*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 483*da0073e9SAndroid Build Coastguard Worker output_ratio=output_ratio, 484*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 485*da0073e9SAndroid Build Coastguard Worker _random_samples=_random_samples, 486*da0073e9SAndroid Build Coastguard Worker ) 487*da0073e9SAndroid Build Coastguard Worker if output_size is None and output_ratio is None: 488*da0073e9SAndroid Build Coastguard Worker raise ValueError( 489*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d requires specifying either an output_size or an output_ratio" 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker if output_size is None: 492*da0073e9SAndroid Build Coastguard Worker assert output_ratio is not None 493*da0073e9SAndroid Build Coastguard Worker if len(output_ratio) > 2: 494*da0073e9SAndroid Build Coastguard Worker raise ValueError( 495*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints." 496*da0073e9SAndroid Build Coastguard Worker ) 497*da0073e9SAndroid Build Coastguard Worker _output_ratio = _pair(output_ratio) 498*da0073e9SAndroid Build Coastguard Worker output_size = [ 499*da0073e9SAndroid Build Coastguard Worker int(input.size(-2) * _output_ratio[0]), 500*da0073e9SAndroid Build Coastguard Worker int(input.size(-1) * _output_ratio[1]), 501*da0073e9SAndroid Build Coastguard Worker ] 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker if _random_samples is None: 504*da0073e9SAndroid Build Coastguard Worker n_batch = 1 if input.dim() == 3 else input.size(0) 505*da0073e9SAndroid Build Coastguard Worker _random_samples = torch.rand( 506*da0073e9SAndroid Build Coastguard Worker n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.fractional_max_pool2d( 509*da0073e9SAndroid Build Coastguard Worker input, kernel_size, output_size, _random_samples 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Workerdef _fractional_max_pool2d( 514*da0073e9SAndroid Build Coastguard Worker input: Tensor, 515*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 516*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList2[int]] = None, 517*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[BroadcastingList2[float]] = None, 518*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 519*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = None, 520*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 521*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, _random_samples): 522*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 523*da0073e9SAndroid Build Coastguard Worker fractional_max_pool2d, 524*da0073e9SAndroid Build Coastguard Worker (input, _random_samples), 525*da0073e9SAndroid Build Coastguard Worker input, 526*da0073e9SAndroid Build Coastguard Worker kernel_size, 527*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 528*da0073e9SAndroid Build Coastguard Worker output_ratio=output_ratio, 529*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 530*da0073e9SAndroid Build Coastguard Worker _random_samples=_random_samples, 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker return fractional_max_pool2d_with_indices( 533*da0073e9SAndroid Build Coastguard Worker input, kernel_size, output_size, output_ratio, return_indices, _random_samples 534*da0073e9SAndroid Build Coastguard Worker )[0] 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Workerfractional_max_pool2d = boolean_dispatch( 538*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 539*da0073e9SAndroid Build Coastguard Worker arg_index=4, 540*da0073e9SAndroid Build Coastguard Worker default=False, 541*da0073e9SAndroid Build Coastguard Worker if_true=fractional_max_pool2d_with_indices, 542*da0073e9SAndroid Build Coastguard Worker if_false=_fractional_max_pool2d, 543*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 544*da0073e9SAndroid Build Coastguard Worker func_name="fractional_max_pool2d", 545*da0073e9SAndroid Build Coastguard Worker) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool3d_with_indices( 549*da0073e9SAndroid Build Coastguard Worker input: Tensor, 550*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 551*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList3[int]] = None, 552*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[BroadcastingList3[float]] = None, 553*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 554*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = None, 555*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 556*da0073e9SAndroid Build Coastguard Worker r""" 557*da0073e9SAndroid Build Coastguard Worker fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker Applies 3D fractional max pooling over an input signal composed of several input planes. 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic 564*da0073e9SAndroid Build Coastguard Worker step size determined by the target output size. 565*da0073e9SAndroid Build Coastguard Worker The number of output features is equal to the number of input planes. 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker Args: 568*da0073e9SAndroid Build Coastguard Worker kernel_size: the size of the window to take a max over. 569*da0073e9SAndroid Build Coastguard Worker Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) 570*da0073e9SAndroid Build Coastguard Worker or a tuple `(kT, kH, kW)` 571*da0073e9SAndroid Build Coastguard Worker output_size: the target output size of the form :math:`oT \times oH \times oW`. 572*da0073e9SAndroid Build Coastguard Worker Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output 573*da0073e9SAndroid Build Coastguard Worker :math:`oH \times oH \times oH` 574*da0073e9SAndroid Build Coastguard Worker output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 575*da0073e9SAndroid Build Coastguard Worker This has to be a number or tuple in the range (0, 1) 576*da0073e9SAndroid Build Coastguard Worker return_indices: if ``True``, will return the indices along with the outputs. 577*da0073e9SAndroid Build Coastguard Worker Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker Shape: 580*da0073e9SAndroid Build Coastguard Worker - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. 581*da0073e9SAndroid Build Coastguard Worker - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where 582*da0073e9SAndroid Build Coastguard Worker :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or 583*da0073e9SAndroid Build Coastguard Worker :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker Examples:: 586*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(20, 16, 50, 32, 16) 587*da0073e9SAndroid Build Coastguard Worker >>> # pool of cubic window of size=3, and target output size 13x12x11 588*da0073e9SAndroid Build Coastguard Worker >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) 589*da0073e9SAndroid Build Coastguard Worker >>> # pool of cubic window and target output size being half of input size 590*da0073e9SAndroid Build Coastguard Worker >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker .. _Fractional MaxPooling: 593*da0073e9SAndroid Build Coastguard Worker http://arxiv.org/abs/1412.6071 594*da0073e9SAndroid Build Coastguard Worker """ 595*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, _random_samples): 596*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 597*da0073e9SAndroid Build Coastguard Worker fractional_max_pool3d_with_indices, 598*da0073e9SAndroid Build Coastguard Worker (input, _random_samples), 599*da0073e9SAndroid Build Coastguard Worker input, 600*da0073e9SAndroid Build Coastguard Worker kernel_size, 601*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 602*da0073e9SAndroid Build Coastguard Worker output_ratio=output_ratio, 603*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 604*da0073e9SAndroid Build Coastguard Worker _random_samples=_random_samples, 605*da0073e9SAndroid Build Coastguard Worker ) 606*da0073e9SAndroid Build Coastguard Worker if output_size is None and output_ratio is None: 607*da0073e9SAndroid Build Coastguard Worker raise ValueError( 608*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool3d requires specifying either an output_size or an output_ratio" 609*da0073e9SAndroid Build Coastguard Worker ) 610*da0073e9SAndroid Build Coastguard Worker if output_size is None: 611*da0073e9SAndroid Build Coastguard Worker assert output_ratio is not None 612*da0073e9SAndroid Build Coastguard Worker _output_ratio = _triple(output_ratio) 613*da0073e9SAndroid Build Coastguard Worker output_size = [ 614*da0073e9SAndroid Build Coastguard Worker int(input.size(-3) * _output_ratio[0]), 615*da0073e9SAndroid Build Coastguard Worker int(input.size(-2) * _output_ratio[1]), 616*da0073e9SAndroid Build Coastguard Worker int(input.size(-1) * _output_ratio[2]), 617*da0073e9SAndroid Build Coastguard Worker ] 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker if _random_samples is None: 620*da0073e9SAndroid Build Coastguard Worker n_batch = 1 if input.dim() == 4 else input.size(0) 621*da0073e9SAndroid Build Coastguard Worker _random_samples = torch.rand( 622*da0073e9SAndroid Build Coastguard Worker n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device 623*da0073e9SAndroid Build Coastguard Worker ) 624*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.fractional_max_pool3d( 625*da0073e9SAndroid Build Coastguard Worker input, kernel_size, output_size, _random_samples 626*da0073e9SAndroid Build Coastguard Worker ) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Workerdef _fractional_max_pool3d( 630*da0073e9SAndroid Build Coastguard Worker input: Tensor, 631*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 632*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList3[int]] = None, 633*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[BroadcastingList3[float]] = None, 634*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 635*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = None, 636*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 637*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, _random_samples): 638*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 639*da0073e9SAndroid Build Coastguard Worker fractional_max_pool3d, 640*da0073e9SAndroid Build Coastguard Worker (input, _random_samples), 641*da0073e9SAndroid Build Coastguard Worker input, 642*da0073e9SAndroid Build Coastguard Worker kernel_size, 643*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 644*da0073e9SAndroid Build Coastguard Worker output_ratio=output_ratio, 645*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 646*da0073e9SAndroid Build Coastguard Worker _random_samples=_random_samples, 647*da0073e9SAndroid Build Coastguard Worker ) 648*da0073e9SAndroid Build Coastguard Worker return fractional_max_pool3d_with_indices( 649*da0073e9SAndroid Build Coastguard Worker input, kernel_size, output_size, output_ratio, return_indices, _random_samples 650*da0073e9SAndroid Build Coastguard Worker )[0] 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Workerfractional_max_pool3d = boolean_dispatch( 654*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 655*da0073e9SAndroid Build Coastguard Worker arg_index=4, 656*da0073e9SAndroid Build Coastguard Worker default=False, 657*da0073e9SAndroid Build Coastguard Worker if_true=fractional_max_pool3d_with_indices, 658*da0073e9SAndroid Build Coastguard Worker if_false=_fractional_max_pool3d, 659*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 660*da0073e9SAndroid Build Coastguard Worker func_name="fractional_max_pool3d", 661*da0073e9SAndroid Build Coastguard Worker) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Workerdef max_pool1d_with_indices( 665*da0073e9SAndroid Build Coastguard Worker input: Tensor, 666*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList1[int], 667*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList1[int]] = None, 668*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList1[int] = 0, 669*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList1[int] = 1, 670*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 671*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 672*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 673*da0073e9SAndroid Build Coastguard Worker r""" 674*da0073e9SAndroid Build Coastguard Worker max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker Applies a 1D max pooling over an input signal composed of several input 677*da0073e9SAndroid Build Coastguard Worker planes. 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker .. note:: 680*da0073e9SAndroid Build Coastguard Worker The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 681*da0073e9SAndroid Build Coastguard Worker what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release. 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxPool1d` for details. 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker Args: 686*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional. 687*da0073e9SAndroid Build Coastguard Worker kernel_size: the size of the window. Can be a single number or a 688*da0073e9SAndroid Build Coastguard Worker tuple `(kW,)` 689*da0073e9SAndroid Build Coastguard Worker stride: the stride of the window. Can be a single number or a tuple 690*da0073e9SAndroid Build Coastguard Worker `(sW,)`. Default: :attr:`kernel_size` 691*da0073e9SAndroid Build Coastguard Worker padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 692*da0073e9SAndroid Build Coastguard Worker dilation: The stride between elements within a sliding window, must be > 0. 693*da0073e9SAndroid Build Coastguard Worker ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 694*da0073e9SAndroid Build Coastguard Worker ensures that every element in the input tensor is covered by a sliding window. 695*da0073e9SAndroid Build Coastguard Worker return_indices: If ``True``, will return the argmax along with the max values. 696*da0073e9SAndroid Build Coastguard Worker Useful for :class:`torch.nn.functional.max_unpool1d` later 697*da0073e9SAndroid Build Coastguard Worker """ 698*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 699*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 700*da0073e9SAndroid Build Coastguard Worker max_pool1d_with_indices, 701*da0073e9SAndroid Build Coastguard Worker (input,), 702*da0073e9SAndroid Build Coastguard Worker input, 703*da0073e9SAndroid Build Coastguard Worker kernel_size, 704*da0073e9SAndroid Build Coastguard Worker stride=stride, 705*da0073e9SAndroid Build Coastguard Worker padding=padding, 706*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 707*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 708*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 709*da0073e9SAndroid Build Coastguard Worker ) 710*da0073e9SAndroid Build Coastguard Worker if stride is None: 711*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 712*da0073e9SAndroid Build Coastguard Worker return torch.max_pool1d_with_indices( 713*da0073e9SAndroid Build Coastguard Worker input, kernel_size, stride, padding, dilation, ceil_mode 714*da0073e9SAndroid Build Coastguard Worker ) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Workerdef _max_pool1d( 718*da0073e9SAndroid Build Coastguard Worker input: Tensor, 719*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList1[int], 720*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList1[int]] = None, 721*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList1[int] = 0, 722*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList1[int] = 1, 723*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 724*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 725*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 726*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 727*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 728*da0073e9SAndroid Build Coastguard Worker max_pool1d, 729*da0073e9SAndroid Build Coastguard Worker (input,), 730*da0073e9SAndroid Build Coastguard Worker input, 731*da0073e9SAndroid Build Coastguard Worker kernel_size, 732*da0073e9SAndroid Build Coastguard Worker stride=stride, 733*da0073e9SAndroid Build Coastguard Worker padding=padding, 734*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 735*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 736*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 737*da0073e9SAndroid Build Coastguard Worker ) 738*da0073e9SAndroid Build Coastguard Worker if stride is None: 739*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 740*da0073e9SAndroid Build Coastguard Worker return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Workermax_pool1d = boolean_dispatch( 744*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 745*da0073e9SAndroid Build Coastguard Worker arg_index=6, 746*da0073e9SAndroid Build Coastguard Worker default=False, 747*da0073e9SAndroid Build Coastguard Worker if_true=max_pool1d_with_indices, 748*da0073e9SAndroid Build Coastguard Worker if_false=_max_pool1d, 749*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 750*da0073e9SAndroid Build Coastguard Worker func_name="max_pool1d", 751*da0073e9SAndroid Build Coastguard Worker) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Workerdef max_pool2d_with_indices( 755*da0073e9SAndroid Build Coastguard Worker input: Tensor, 756*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 757*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList2[int]] = None, 758*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList2[int] = 0, 759*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList2[int] = 1, 760*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 761*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 762*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 763*da0073e9SAndroid Build Coastguard Worker r""" 764*da0073e9SAndroid Build Coastguard Worker max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker Applies a 2D max pooling over an input signal composed of several input 767*da0073e9SAndroid Build Coastguard Worker planes. 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker .. note:: 770*da0073e9SAndroid Build Coastguard Worker The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 771*da0073e9SAndroid Build Coastguard Worker what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release. 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxPool2d` for details. 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker Args: 776*da0073e9SAndroid Build Coastguard Worker input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional. 777*da0073e9SAndroid Build Coastguard Worker kernel_size: size of the pooling region. Can be a single number or a 778*da0073e9SAndroid Build Coastguard Worker tuple `(kH, kW)` 779*da0073e9SAndroid Build Coastguard Worker stride: stride of the pooling operation. Can be a single number or a 780*da0073e9SAndroid Build Coastguard Worker tuple `(sH, sW)`. Default: :attr:`kernel_size` 781*da0073e9SAndroid Build Coastguard Worker padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 782*da0073e9SAndroid Build Coastguard Worker dilation: The stride between elements within a sliding window, must be > 0. 783*da0073e9SAndroid Build Coastguard Worker ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 784*da0073e9SAndroid Build Coastguard Worker ensures that every element in the input tensor is covered by a sliding window. 785*da0073e9SAndroid Build Coastguard Worker return_indices: If ``True``, will return the argmax along with the max values. 786*da0073e9SAndroid Build Coastguard Worker Useful for :class:`torch.nn.functional.max_unpool2d` later 787*da0073e9SAndroid Build Coastguard Worker """ 788*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 789*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 790*da0073e9SAndroid Build Coastguard Worker max_pool2d_with_indices, 791*da0073e9SAndroid Build Coastguard Worker (input,), 792*da0073e9SAndroid Build Coastguard Worker input, 793*da0073e9SAndroid Build Coastguard Worker kernel_size, 794*da0073e9SAndroid Build Coastguard Worker stride=stride, 795*da0073e9SAndroid Build Coastguard Worker padding=padding, 796*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 797*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 798*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 799*da0073e9SAndroid Build Coastguard Worker ) 800*da0073e9SAndroid Build Coastguard Worker if stride is None: 801*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 802*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.max_pool2d_with_indices( 803*da0073e9SAndroid Build Coastguard Worker input, kernel_size, stride, padding, dilation, ceil_mode 804*da0073e9SAndroid Build Coastguard Worker ) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Workerdef _max_pool2d( 808*da0073e9SAndroid Build Coastguard Worker input: Tensor, 809*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 810*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList2[int]] = None, 811*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList2[int] = 0, 812*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList2[int] = 1, 813*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 814*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 815*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 816*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 817*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 818*da0073e9SAndroid Build Coastguard Worker max_pool2d, 819*da0073e9SAndroid Build Coastguard Worker (input,), 820*da0073e9SAndroid Build Coastguard Worker input, 821*da0073e9SAndroid Build Coastguard Worker kernel_size, 822*da0073e9SAndroid Build Coastguard Worker stride=stride, 823*da0073e9SAndroid Build Coastguard Worker padding=padding, 824*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 825*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 826*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 827*da0073e9SAndroid Build Coastguard Worker ) 828*da0073e9SAndroid Build Coastguard Worker if stride is None: 829*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 830*da0073e9SAndroid Build Coastguard Worker return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Workermax_pool2d = boolean_dispatch( 834*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 835*da0073e9SAndroid Build Coastguard Worker arg_index=6, 836*da0073e9SAndroid Build Coastguard Worker default=False, 837*da0073e9SAndroid Build Coastguard Worker if_true=max_pool2d_with_indices, 838*da0073e9SAndroid Build Coastguard Worker if_false=_max_pool2d, 839*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 840*da0073e9SAndroid Build Coastguard Worker func_name="max_pool2d", 841*da0073e9SAndroid Build Coastguard Worker) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Workerdef max_pool3d_with_indices( 845*da0073e9SAndroid Build Coastguard Worker input: Tensor, 846*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 847*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList3[int]] = None, 848*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList3[int] = 0, 849*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList3[int] = 1, 850*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 851*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 852*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 853*da0073e9SAndroid Build Coastguard Worker r""" 854*da0073e9SAndroid Build Coastguard Worker max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker Applies a 3D max pooling over an input signal composed of several input 857*da0073e9SAndroid Build Coastguard Worker planes. 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker .. note:: 860*da0073e9SAndroid Build Coastguard Worker The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 861*da0073e9SAndroid Build Coastguard Worker what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release. 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxPool3d` for details. 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker Args: 866*da0073e9SAndroid Build Coastguard Worker input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional. 867*da0073e9SAndroid Build Coastguard Worker kernel_size: size of the pooling region. Can be a single number or a 868*da0073e9SAndroid Build Coastguard Worker tuple `(kT, kH, kW)` 869*da0073e9SAndroid Build Coastguard Worker stride: stride of the pooling operation. Can be a single number or a 870*da0073e9SAndroid Build Coastguard Worker tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` 871*da0073e9SAndroid Build Coastguard Worker padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 872*da0073e9SAndroid Build Coastguard Worker dilation: The stride between elements within a sliding window, must be > 0. 873*da0073e9SAndroid Build Coastguard Worker ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 874*da0073e9SAndroid Build Coastguard Worker ensures that every element in the input tensor is covered by a sliding window. 875*da0073e9SAndroid Build Coastguard Worker return_indices: If ``True``, will return the argmax along with the max values. 876*da0073e9SAndroid Build Coastguard Worker Useful for :class:`torch.nn.functional.max_unpool3d` later 877*da0073e9SAndroid Build Coastguard Worker """ 878*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 879*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 880*da0073e9SAndroid Build Coastguard Worker max_pool3d_with_indices, 881*da0073e9SAndroid Build Coastguard Worker (input,), 882*da0073e9SAndroid Build Coastguard Worker input, 883*da0073e9SAndroid Build Coastguard Worker kernel_size, 884*da0073e9SAndroid Build Coastguard Worker stride=stride, 885*da0073e9SAndroid Build Coastguard Worker padding=padding, 886*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 887*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 888*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 889*da0073e9SAndroid Build Coastguard Worker ) 890*da0073e9SAndroid Build Coastguard Worker if stride is None: 891*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 892*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.max_pool3d_with_indices( 893*da0073e9SAndroid Build Coastguard Worker input, kernel_size, stride, padding, dilation, ceil_mode 894*da0073e9SAndroid Build Coastguard Worker ) 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Workerdef _max_pool3d( 898*da0073e9SAndroid Build Coastguard Worker input: Tensor, 899*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 900*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList3[int]] = None, 901*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList3[int] = 0, 902*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList3[int] = 1, 903*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 904*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 905*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 906*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 907*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 908*da0073e9SAndroid Build Coastguard Worker max_pool3d, 909*da0073e9SAndroid Build Coastguard Worker (input,), 910*da0073e9SAndroid Build Coastguard Worker input, 911*da0073e9SAndroid Build Coastguard Worker kernel_size, 912*da0073e9SAndroid Build Coastguard Worker stride=stride, 913*da0073e9SAndroid Build Coastguard Worker padding=padding, 914*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 915*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 916*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 917*da0073e9SAndroid Build Coastguard Worker ) 918*da0073e9SAndroid Build Coastguard Worker if stride is None: 919*da0073e9SAndroid Build Coastguard Worker stride = torch.jit.annotate(List[int], []) 920*da0073e9SAndroid Build Coastguard Worker return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Workermax_pool3d = boolean_dispatch( 924*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 925*da0073e9SAndroid Build Coastguard Worker arg_index=6, 926*da0073e9SAndroid Build Coastguard Worker default=False, 927*da0073e9SAndroid Build Coastguard Worker if_true=max_pool3d_with_indices, 928*da0073e9SAndroid Build Coastguard Worker if_false=_max_pool3d, 929*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 930*da0073e9SAndroid Build Coastguard Worker func_name="max_pool3d", 931*da0073e9SAndroid Build Coastguard Worker) 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Workerdef _unpool_output_size( 935*da0073e9SAndroid Build Coastguard Worker input: Tensor, 936*da0073e9SAndroid Build Coastguard Worker kernel_size: List[int], 937*da0073e9SAndroid Build Coastguard Worker stride: List[int], 938*da0073e9SAndroid Build Coastguard Worker padding: List[int], 939*da0073e9SAndroid Build Coastguard Worker output_size: Optional[List[int]], 940*da0073e9SAndroid Build Coastguard Worker) -> List[int]: 941*da0073e9SAndroid Build Coastguard Worker input_size = input.size() 942*da0073e9SAndroid Build Coastguard Worker default_size = torch.jit.annotate(List[int], []) 943*da0073e9SAndroid Build Coastguard Worker for d in range(len(kernel_size)): 944*da0073e9SAndroid Build Coastguard Worker default_size.append( 945*da0073e9SAndroid Build Coastguard Worker (input_size[-len(kernel_size) + d] - 1) * stride[d] 946*da0073e9SAndroid Build Coastguard Worker + kernel_size[d] 947*da0073e9SAndroid Build Coastguard Worker - 2 * padding[d] 948*da0073e9SAndroid Build Coastguard Worker ) 949*da0073e9SAndroid Build Coastguard Worker if output_size is None: 950*da0073e9SAndroid Build Coastguard Worker ret = default_size 951*da0073e9SAndroid Build Coastguard Worker else: 952*da0073e9SAndroid Build Coastguard Worker if len(output_size) == len(kernel_size) + 2: 953*da0073e9SAndroid Build Coastguard Worker output_size = output_size[2:] 954*da0073e9SAndroid Build Coastguard Worker if len(output_size) != len(kernel_size): 955*da0073e9SAndroid Build Coastguard Worker raise ValueError( 956*da0073e9SAndroid Build Coastguard Worker "output_size should be a sequence containing " 957*da0073e9SAndroid Build Coastguard Worker f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'" 958*da0073e9SAndroid Build Coastguard Worker ) 959*da0073e9SAndroid Build Coastguard Worker for d in range(len(kernel_size)): 960*da0073e9SAndroid Build Coastguard Worker min_size = default_size[d] - stride[d] 961*da0073e9SAndroid Build Coastguard Worker max_size = default_size[d] + stride[d] 962*da0073e9SAndroid Build Coastguard Worker if not (min_size < output_size[d] < max_size): 963*da0073e9SAndroid Build Coastguard Worker raise ValueError( 964*da0073e9SAndroid Build Coastguard Worker f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})' 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker ret = output_size 968*da0073e9SAndroid Build Coastguard Worker return ret 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Workerdef max_unpool1d( 972*da0073e9SAndroid Build Coastguard Worker input: Tensor, 973*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 974*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList1[int], 975*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList1[int]] = None, 976*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList1[int] = 0, 977*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList1[int]] = None, 978*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 979*da0073e9SAndroid Build Coastguard Worker r"""Compute a partial inverse of :class:`MaxPool1d`. 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxUnpool1d` for details. 982*da0073e9SAndroid Build Coastguard Worker """ 983*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 984*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 985*da0073e9SAndroid Build Coastguard Worker max_unpool1d, 986*da0073e9SAndroid Build Coastguard Worker (input,), 987*da0073e9SAndroid Build Coastguard Worker input, 988*da0073e9SAndroid Build Coastguard Worker indices, 989*da0073e9SAndroid Build Coastguard Worker kernel_size, 990*da0073e9SAndroid Build Coastguard Worker stride=stride, 991*da0073e9SAndroid Build Coastguard Worker padding=padding, 992*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 993*da0073e9SAndroid Build Coastguard Worker ) 994*da0073e9SAndroid Build Coastguard Worker kernel_size = _single(kernel_size) 995*da0073e9SAndroid Build Coastguard Worker if stride is not None: 996*da0073e9SAndroid Build Coastguard Worker _stride = _single(stride) 997*da0073e9SAndroid Build Coastguard Worker else: 998*da0073e9SAndroid Build Coastguard Worker _stride = kernel_size 999*da0073e9SAndroid Build Coastguard Worker padding = _single(padding) 1000*da0073e9SAndroid Build Coastguard Worker output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1001*da0073e9SAndroid Build Coastguard Worker if isinstance(output_size, list): 1002*da0073e9SAndroid Build Coastguard Worker output_size = output_size + [1] 1003*da0073e9SAndroid Build Coastguard Worker else: 1004*da0073e9SAndroid Build Coastguard Worker output_size = output_size + (1,) 1005*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.max_unpool2d( 1006*da0073e9SAndroid Build Coastguard Worker input.unsqueeze(-1), indices.unsqueeze(-1), output_size 1007*da0073e9SAndroid Build Coastguard Worker ).squeeze(-1) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Workerdef max_unpool2d( 1011*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1012*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 1013*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 1014*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList2[int]] = None, 1015*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList2[int] = 0, 1016*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList2[int]] = None, 1017*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1018*da0073e9SAndroid Build Coastguard Worker r"""Compute a partial inverse of :class:`MaxPool2d`. 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxUnpool2d` for details. 1021*da0073e9SAndroid Build Coastguard Worker """ 1022*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1023*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1024*da0073e9SAndroid Build Coastguard Worker max_unpool2d, 1025*da0073e9SAndroid Build Coastguard Worker (input,), 1026*da0073e9SAndroid Build Coastguard Worker input, 1027*da0073e9SAndroid Build Coastguard Worker indices, 1028*da0073e9SAndroid Build Coastguard Worker kernel_size, 1029*da0073e9SAndroid Build Coastguard Worker stride=stride, 1030*da0073e9SAndroid Build Coastguard Worker padding=padding, 1031*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 1032*da0073e9SAndroid Build Coastguard Worker ) 1033*da0073e9SAndroid Build Coastguard Worker kernel_size = _pair(kernel_size) 1034*da0073e9SAndroid Build Coastguard Worker if stride is not None: 1035*da0073e9SAndroid Build Coastguard Worker _stride = _pair(stride) 1036*da0073e9SAndroid Build Coastguard Worker else: 1037*da0073e9SAndroid Build Coastguard Worker _stride = kernel_size 1038*da0073e9SAndroid Build Coastguard Worker padding = _pair(padding) 1039*da0073e9SAndroid Build Coastguard Worker output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1040*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.max_unpool2d(input, indices, output_size) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Workerdef max_unpool3d( 1044*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1045*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 1046*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 1047*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList3[int]] = None, 1048*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList3[int] = 0, 1049*da0073e9SAndroid Build Coastguard Worker output_size: Optional[BroadcastingList3[int]] = None, 1050*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1051*da0073e9SAndroid Build Coastguard Worker r"""Compute a partial inverse of :class:`MaxPool3d`. 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MaxUnpool3d` for details. 1054*da0073e9SAndroid Build Coastguard Worker """ 1055*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1056*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1057*da0073e9SAndroid Build Coastguard Worker max_unpool3d, 1058*da0073e9SAndroid Build Coastguard Worker (input,), 1059*da0073e9SAndroid Build Coastguard Worker input, 1060*da0073e9SAndroid Build Coastguard Worker indices, 1061*da0073e9SAndroid Build Coastguard Worker kernel_size, 1062*da0073e9SAndroid Build Coastguard Worker stride=stride, 1063*da0073e9SAndroid Build Coastguard Worker padding=padding, 1064*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 1065*da0073e9SAndroid Build Coastguard Worker ) 1066*da0073e9SAndroid Build Coastguard Worker kernel_size = _triple(kernel_size) 1067*da0073e9SAndroid Build Coastguard Worker if stride is not None: 1068*da0073e9SAndroid Build Coastguard Worker _stride = _triple(stride) 1069*da0073e9SAndroid Build Coastguard Worker else: 1070*da0073e9SAndroid Build Coastguard Worker _stride = kernel_size 1071*da0073e9SAndroid Build Coastguard Worker padding = _triple(padding) 1072*da0073e9SAndroid Build Coastguard Worker output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1073*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Workerdef lp_pool3d( 1077*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1078*da0073e9SAndroid Build Coastguard Worker norm_type: Union[int, float], 1079*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList3[int], 1080*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList3[int]] = None, 1081*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 1082*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1083*da0073e9SAndroid Build Coastguard Worker r""" 1084*da0073e9SAndroid Build Coastguard Worker Apply a 3D power-average pooling over an input signal composed of several input planes. 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker If the sum of all inputs to the power of `p` is 1087*da0073e9SAndroid Build Coastguard Worker zero, the gradient is set to zero as well. 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LPPool3d` for details. 1090*da0073e9SAndroid Build Coastguard Worker """ 1091*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1092*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1093*da0073e9SAndroid Build Coastguard Worker lp_pool3d, 1094*da0073e9SAndroid Build Coastguard Worker (input,), 1095*da0073e9SAndroid Build Coastguard Worker input, 1096*da0073e9SAndroid Build Coastguard Worker norm_type, 1097*da0073e9SAndroid Build Coastguard Worker kernel_size, 1098*da0073e9SAndroid Build Coastguard Worker stride=stride, 1099*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 1100*da0073e9SAndroid Build Coastguard Worker ) 1101*da0073e9SAndroid Build Coastguard Worker kd, kw, kh = _triple(kernel_size) 1102*da0073e9SAndroid Build Coastguard Worker if stride is not None: 1103*da0073e9SAndroid Build Coastguard Worker out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1104*da0073e9SAndroid Build Coastguard Worker else: 1105*da0073e9SAndroid Build Coastguard Worker out = avg_pool3d( 1106*da0073e9SAndroid Build Coastguard Worker input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1107*da0073e9SAndroid Build Coastguard Worker ) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker return ( 1110*da0073e9SAndroid Build Coastguard Worker (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) 1111*da0073e9SAndroid Build Coastguard Worker ) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Workerdef lp_pool2d( 1115*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1116*da0073e9SAndroid Build Coastguard Worker norm_type: Union[int, float], 1117*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 1118*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList2[int]] = None, 1119*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 1120*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1121*da0073e9SAndroid Build Coastguard Worker r""" 1122*da0073e9SAndroid Build Coastguard Worker Apply a 2D power-average pooling over an input signal composed of several input planes. 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker If the sum of all inputs to the power of `p` is 1125*da0073e9SAndroid Build Coastguard Worker zero, the gradient is set to zero as well. 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LPPool2d` for details. 1128*da0073e9SAndroid Build Coastguard Worker """ 1129*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1130*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1131*da0073e9SAndroid Build Coastguard Worker lp_pool2d, 1132*da0073e9SAndroid Build Coastguard Worker (input,), 1133*da0073e9SAndroid Build Coastguard Worker input, 1134*da0073e9SAndroid Build Coastguard Worker norm_type, 1135*da0073e9SAndroid Build Coastguard Worker kernel_size, 1136*da0073e9SAndroid Build Coastguard Worker stride=stride, 1137*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 1138*da0073e9SAndroid Build Coastguard Worker ) 1139*da0073e9SAndroid Build Coastguard Worker kw, kh = _pair(kernel_size) 1140*da0073e9SAndroid Build Coastguard Worker if stride is not None: 1141*da0073e9SAndroid Build Coastguard Worker out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1142*da0073e9SAndroid Build Coastguard Worker else: 1143*da0073e9SAndroid Build Coastguard Worker out = avg_pool2d( 1144*da0073e9SAndroid Build Coastguard Worker input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Workerdef lp_pool1d( 1151*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1152*da0073e9SAndroid Build Coastguard Worker norm_type: Union[int, float], 1153*da0073e9SAndroid Build Coastguard Worker kernel_size: int, 1154*da0073e9SAndroid Build Coastguard Worker stride: Optional[BroadcastingList1[int]] = None, 1155*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = False, 1156*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1157*da0073e9SAndroid Build Coastguard Worker r"""Apply a 1D power-average pooling over an input signal composed of several input planes. 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker If the sum of all inputs to the power of `p` is 1160*da0073e9SAndroid Build Coastguard Worker zero, the gradient is set to zero as well. 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LPPool1d` for details. 1163*da0073e9SAndroid Build Coastguard Worker """ 1164*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1165*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1166*da0073e9SAndroid Build Coastguard Worker lp_pool1d, 1167*da0073e9SAndroid Build Coastguard Worker (input,), 1168*da0073e9SAndroid Build Coastguard Worker input, 1169*da0073e9SAndroid Build Coastguard Worker norm_type, 1170*da0073e9SAndroid Build Coastguard Worker kernel_size, 1171*da0073e9SAndroid Build Coastguard Worker stride=stride, 1172*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 1173*da0073e9SAndroid Build Coastguard Worker ) 1174*da0073e9SAndroid Build Coastguard Worker if stride is not None: 1175*da0073e9SAndroid Build Coastguard Worker out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1176*da0073e9SAndroid Build Coastguard Worker else: 1177*da0073e9SAndroid Build Coastguard Worker out = avg_pool1d( 1178*da0073e9SAndroid Build Coastguard Worker input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1179*da0073e9SAndroid Build Coastguard Worker ) 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker return ( 1182*da0073e9SAndroid Build Coastguard Worker (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) 1183*da0073e9SAndroid Build Coastguard Worker ) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool1d_with_indices( 1187*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1188*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList1[int], 1189*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1190*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 1191*da0073e9SAndroid Build Coastguard Worker r""" 1192*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool1d(input, output_size, return_indices=False) 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker Applies a 1D adaptive max pooling over an input signal composed of 1195*da0073e9SAndroid Build Coastguard Worker several input planes. 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape. 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker Args: 1200*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer) 1201*da0073e9SAndroid Build Coastguard Worker return_indices: whether to return pooling indices. Default: ``False`` 1202*da0073e9SAndroid Build Coastguard Worker """ 1203*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1204*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1205*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool1d_with_indices, 1206*da0073e9SAndroid Build Coastguard Worker (input,), 1207*da0073e9SAndroid Build Coastguard Worker input, 1208*da0073e9SAndroid Build Coastguard Worker output_size, 1209*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1210*da0073e9SAndroid Build Coastguard Worker ) 1211*da0073e9SAndroid Build Coastguard Worker return torch.adaptive_max_pool1d(input, output_size) 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool1d( 1215*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1216*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList1[int], 1217*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1218*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1219*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1220*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1221*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool1d, 1222*da0073e9SAndroid Build Coastguard Worker (input,), 1223*da0073e9SAndroid Build Coastguard Worker input, 1224*da0073e9SAndroid Build Coastguard Worker output_size, 1225*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1226*da0073e9SAndroid Build Coastguard Worker ) 1227*da0073e9SAndroid Build Coastguard Worker return adaptive_max_pool1d_with_indices(input, output_size)[0] 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool1d = boolean_dispatch( 1231*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 1232*da0073e9SAndroid Build Coastguard Worker arg_index=2, 1233*da0073e9SAndroid Build Coastguard Worker default=False, 1234*da0073e9SAndroid Build Coastguard Worker if_true=adaptive_max_pool1d_with_indices, 1235*da0073e9SAndroid Build Coastguard Worker if_false=_adaptive_max_pool1d, 1236*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1237*da0073e9SAndroid Build Coastguard Worker func_name="adaptive_max_pool1d", 1238*da0073e9SAndroid Build Coastguard Worker) 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool2d_with_indices( 1242*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1243*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList2[int], 1244*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1245*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 1246*da0073e9SAndroid Build Coastguard Worker r"""adaptive_max_pool2d(input, output_size, return_indices=False) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker Applies a 2D adaptive max pooling over an input signal composed of 1249*da0073e9SAndroid Build Coastguard Worker several input planes. 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker Args: 1254*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer or 1255*da0073e9SAndroid Build Coastguard Worker double-integer tuple) 1256*da0073e9SAndroid Build Coastguard Worker return_indices: whether to return pooling indices. Default: ``False`` 1257*da0073e9SAndroid Build Coastguard Worker """ 1258*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1259*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1260*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool2d_with_indices, 1261*da0073e9SAndroid Build Coastguard Worker (input,), 1262*da0073e9SAndroid Build Coastguard Worker input, 1263*da0073e9SAndroid Build Coastguard Worker output_size, 1264*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1265*da0073e9SAndroid Build Coastguard Worker ) 1266*da0073e9SAndroid Build Coastguard Worker output_size = _list_with_default(output_size, input.size()) 1267*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.adaptive_max_pool2d(input, output_size) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool2d( 1271*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1272*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList2[int], 1273*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1274*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1275*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1276*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1277*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool2d, 1278*da0073e9SAndroid Build Coastguard Worker (input,), 1279*da0073e9SAndroid Build Coastguard Worker input, 1280*da0073e9SAndroid Build Coastguard Worker output_size, 1281*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1282*da0073e9SAndroid Build Coastguard Worker ) 1283*da0073e9SAndroid Build Coastguard Worker return adaptive_max_pool2d_with_indices(input, output_size)[0] 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool2d = boolean_dispatch( 1287*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 1288*da0073e9SAndroid Build Coastguard Worker arg_index=2, 1289*da0073e9SAndroid Build Coastguard Worker default=False, 1290*da0073e9SAndroid Build Coastguard Worker if_true=adaptive_max_pool2d_with_indices, 1291*da0073e9SAndroid Build Coastguard Worker if_false=_adaptive_max_pool2d, 1292*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1293*da0073e9SAndroid Build Coastguard Worker func_name="adaptive_max_pool2d", 1294*da0073e9SAndroid Build Coastguard Worker) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool3d_with_indices( 1298*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1299*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList3[int], 1300*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1301*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: # noqa: D400 1302*da0073e9SAndroid Build Coastguard Worker r""" 1303*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool3d(input, output_size, return_indices=False) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker Applies a 3D adaptive max pooling over an input signal composed of 1306*da0073e9SAndroid Build Coastguard Worker several input planes. 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape. 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker Args: 1311*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer or 1312*da0073e9SAndroid Build Coastguard Worker triple-integer tuple) 1313*da0073e9SAndroid Build Coastguard Worker return_indices: whether to return pooling indices. Default: ``False`` 1314*da0073e9SAndroid Build Coastguard Worker """ 1315*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1316*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1317*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool3d_with_indices, 1318*da0073e9SAndroid Build Coastguard Worker (input,), 1319*da0073e9SAndroid Build Coastguard Worker input, 1320*da0073e9SAndroid Build Coastguard Worker output_size, 1321*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1322*da0073e9SAndroid Build Coastguard Worker ) 1323*da0073e9SAndroid Build Coastguard Worker output_size = _list_with_default(output_size, input.size()) 1324*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.adaptive_max_pool3d(input, output_size) 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Workerdef _adaptive_max_pool3d( 1328*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1329*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList3[int], 1330*da0073e9SAndroid Build Coastguard Worker return_indices: bool = False, 1331*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1332*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1333*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1334*da0073e9SAndroid Build Coastguard Worker adaptive_max_pool3d, 1335*da0073e9SAndroid Build Coastguard Worker (input,), 1336*da0073e9SAndroid Build Coastguard Worker input, 1337*da0073e9SAndroid Build Coastguard Worker output_size, 1338*da0073e9SAndroid Build Coastguard Worker return_indices=return_indices, 1339*da0073e9SAndroid Build Coastguard Worker ) 1340*da0073e9SAndroid Build Coastguard Worker return adaptive_max_pool3d_with_indices(input, output_size)[0] 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Workeradaptive_max_pool3d = boolean_dispatch( 1344*da0073e9SAndroid Build Coastguard Worker arg_name="return_indices", 1345*da0073e9SAndroid Build Coastguard Worker arg_index=2, 1346*da0073e9SAndroid Build Coastguard Worker default=False, 1347*da0073e9SAndroid Build Coastguard Worker if_true=adaptive_max_pool3d_with_indices, 1348*da0073e9SAndroid Build Coastguard Worker if_false=_adaptive_max_pool3d, 1349*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1350*da0073e9SAndroid Build Coastguard Worker func_name="adaptive_max_pool3d", 1351*da0073e9SAndroid Build Coastguard Worker) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Workeradaptive_avg_pool1d = _add_docstr( 1355*da0073e9SAndroid Build Coastguard Worker torch.adaptive_avg_pool1d, 1356*da0073e9SAndroid Build Coastguard Worker r""" 1357*da0073e9SAndroid Build Coastguard Workeradaptive_avg_pool1d(input, output_size) -> Tensor 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard WorkerApplies a 1D adaptive average pooling over an input signal composed of 1360*da0073e9SAndroid Build Coastguard Workerseveral input planes. 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape. 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard WorkerArgs: 1365*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer) 1366*da0073e9SAndroid Build Coastguard Worker""", 1367*da0073e9SAndroid Build Coastguard Worker) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: 1371*da0073e9SAndroid Build Coastguard Worker r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes. 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker Args: 1376*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer or 1377*da0073e9SAndroid Build Coastguard Worker double-integer tuple) 1378*da0073e9SAndroid Build Coastguard Worker """ 1379*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1380*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) 1381*da0073e9SAndroid Build Coastguard Worker _output_size = _list_with_default(output_size, input.size()) 1382*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.adaptive_avg_pool2d(input, _output_size) 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: 1386*da0073e9SAndroid Build Coastguard Worker r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes. 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape. 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker Args: 1391*da0073e9SAndroid Build Coastguard Worker output_size: the target output size (single integer or 1392*da0073e9SAndroid Build Coastguard Worker triple-integer tuple) 1393*da0073e9SAndroid Build Coastguard Worker """ 1394*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1395*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) 1396*da0073e9SAndroid Build Coastguard Worker _output_size = _list_with_default(output_size, input.size()) 1397*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.adaptive_avg_pool3d(input, _output_size) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker# Activation functions 1401*da0073e9SAndroid Build Coastguard Workerdef dropout( 1402*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1403*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1404*da0073e9SAndroid Build Coastguard Worker training: bool = True, 1405*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1406*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1407*da0073e9SAndroid Build Coastguard Worker r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker Uses samples from a Bernoulli distribution. 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Dropout` for details. 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker Args: 1414*da0073e9SAndroid Build Coastguard Worker p: probability of an element to be zeroed. Default: 0.5 1415*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. Default: ``True`` 1416*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1417*da0073e9SAndroid Build Coastguard Worker """ 1418*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1419*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1420*da0073e9SAndroid Build Coastguard Worker dropout, (input,), input, p=p, training=training, inplace=inplace 1421*da0073e9SAndroid Build Coastguard Worker ) 1422*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1423*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1424*da0073e9SAndroid Build Coastguard Worker return ( 1425*da0073e9SAndroid Build Coastguard Worker _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) 1426*da0073e9SAndroid Build Coastguard Worker ) 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Workerdef alpha_dropout( 1430*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1431*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1432*da0073e9SAndroid Build Coastguard Worker training: bool = False, 1433*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1434*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1435*da0073e9SAndroid Build Coastguard Worker r"""Apply alpha dropout to the input. 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.AlphaDropout` for details. 1438*da0073e9SAndroid Build Coastguard Worker """ 1439*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1440*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1441*da0073e9SAndroid Build Coastguard Worker alpha_dropout, (input,), input, p=p, training=training, inplace=inplace 1442*da0073e9SAndroid Build Coastguard Worker ) 1443*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1444*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1445*da0073e9SAndroid Build Coastguard Worker return ( 1446*da0073e9SAndroid Build Coastguard Worker _VF.alpha_dropout_(input, p, training) 1447*da0073e9SAndroid Build Coastguard Worker if inplace 1448*da0073e9SAndroid Build Coastguard Worker else _VF.alpha_dropout(input, p, training) 1449*da0073e9SAndroid Build Coastguard Worker ) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Workerdef dropout1d( 1453*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1454*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1455*da0073e9SAndroid Build Coastguard Worker training: bool = True, 1456*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1457*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1458*da0073e9SAndroid Build Coastguard Worker r"""Randomly zero out entire channels (a channel is a 1D feature map). 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1461*da0073e9SAndroid Build Coastguard Worker batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor. 1462*da0073e9SAndroid Build Coastguard Worker Each channel will be zeroed out independently on every forward call with 1463*da0073e9SAndroid Build Coastguard Worker probability :attr:`p` using samples from a Bernoulli distribution. 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Dropout1d` for details. 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker Args: 1468*da0073e9SAndroid Build Coastguard Worker p: probability of a channel to be zeroed. Default: 0.5 1469*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. Default: ``True`` 1470*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1471*da0073e9SAndroid Build Coastguard Worker """ 1472*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1473*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1474*da0073e9SAndroid Build Coastguard Worker dropout1d, (input,), input, p=p, training=training, inplace=inplace 1475*da0073e9SAndroid Build Coastguard Worker ) 1476*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1477*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1478*da0073e9SAndroid Build Coastguard Worker inp_dim = input.dim() 1479*da0073e9SAndroid Build Coastguard Worker if inp_dim not in (2, 3): 1480*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1481*da0073e9SAndroid Build Coastguard Worker f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " 1482*da0073e9SAndroid Build Coastguard Worker "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " 1483*da0073e9SAndroid Build Coastguard Worker "spatial dimension, a channel dimension, and an optional batch dimension " 1484*da0073e9SAndroid Build Coastguard Worker "(i.e. 2D or 3D inputs)." 1485*da0073e9SAndroid Build Coastguard Worker ) 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker is_batched = inp_dim == 3 1488*da0073e9SAndroid Build Coastguard Worker if not is_batched: 1489*da0073e9SAndroid Build Coastguard Worker input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker result = ( 1492*da0073e9SAndroid Build Coastguard Worker _VF.feature_dropout_(input, p, training) 1493*da0073e9SAndroid Build Coastguard Worker if inplace 1494*da0073e9SAndroid Build Coastguard Worker else _VF.feature_dropout(input, p, training) 1495*da0073e9SAndroid Build Coastguard Worker ) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker if not is_batched: 1498*da0073e9SAndroid Build Coastguard Worker result = result.squeeze_(0) if inplace else result.squeeze(0) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker return result 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Workerdef dropout2d( 1504*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1505*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1506*da0073e9SAndroid Build Coastguard Worker training: bool = True, 1507*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1508*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1509*da0073e9SAndroid Build Coastguard Worker r"""Randomly zero out entire channels (a channel is a 2D feature map). 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1512*da0073e9SAndroid Build Coastguard Worker batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor. 1513*da0073e9SAndroid Build Coastguard Worker Each channel will be zeroed out independently on every forward call with 1514*da0073e9SAndroid Build Coastguard Worker probability :attr:`p` using samples from a Bernoulli distribution. 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Dropout2d` for details. 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker Args: 1519*da0073e9SAndroid Build Coastguard Worker p: probability of a channel to be zeroed. Default: 0.5 1520*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. Default: ``True`` 1521*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1522*da0073e9SAndroid Build Coastguard Worker """ 1523*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1524*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1525*da0073e9SAndroid Build Coastguard Worker dropout2d, (input,), input, p=p, training=training, inplace=inplace 1526*da0073e9SAndroid Build Coastguard Worker ) 1527*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1528*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1529*da0073e9SAndroid Build Coastguard Worker inp_dim = input.dim() 1530*da0073e9SAndroid Build Coastguard Worker if inp_dim not in (3, 4): 1531*da0073e9SAndroid Build Coastguard Worker warn_msg = ( 1532*da0073e9SAndroid Build Coastguard Worker f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " 1533*da0073e9SAndroid Build Coastguard Worker "and will result in an error in a future release. To retain the behavior " 1534*da0073e9SAndroid Build Coastguard Worker "and silence this warning, please use dropout instead. Note that dropout2d " 1535*da0073e9SAndroid Build Coastguard Worker "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " 1536*da0073e9SAndroid Build Coastguard Worker "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." 1537*da0073e9SAndroid Build Coastguard Worker ) 1538*da0073e9SAndroid Build Coastguard Worker warnings.warn(warn_msg) 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing 1541*da0073e9SAndroid Build Coastguard Worker # a 3D input will perform dropout1d behavior instead. This was done historically and the 1542*da0073e9SAndroid Build Coastguard Worker # behavior is maintained here for now. 1543*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/77081 1544*da0073e9SAndroid Build Coastguard Worker if inp_dim == 3: 1545*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1546*da0073e9SAndroid Build Coastguard Worker "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " 1547*da0073e9SAndroid Build Coastguard Worker "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " 1548*da0073e9SAndroid Build Coastguard Worker "is the channel dim. This behavior will change in a future release to interpret the " 1549*da0073e9SAndroid Build Coastguard Worker "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " 1550*da0073e9SAndroid Build Coastguard Worker "channel-wise dropout behavior, please switch to using dropout1d instead." 1551*da0073e9SAndroid Build Coastguard Worker ) 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker result = ( 1554*da0073e9SAndroid Build Coastguard Worker _VF.feature_dropout_(input, p, training) 1555*da0073e9SAndroid Build Coastguard Worker if inplace 1556*da0073e9SAndroid Build Coastguard Worker else _VF.feature_dropout(input, p, training) 1557*da0073e9SAndroid Build Coastguard Worker ) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker return result 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Workerdef dropout3d( 1563*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1564*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1565*da0073e9SAndroid Build Coastguard Worker training: bool = True, 1566*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1567*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1568*da0073e9SAndroid Build Coastguard Worker r"""Randomly zero out entire channels (a channel is a 3D feature map). 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1571*da0073e9SAndroid Build Coastguard Worker batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor. 1572*da0073e9SAndroid Build Coastguard Worker Each channel will be zeroed out independently on every forward call with 1573*da0073e9SAndroid Build Coastguard Worker probability :attr:`p` using samples from a Bernoulli distribution. 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Dropout3d` for details. 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker Args: 1578*da0073e9SAndroid Build Coastguard Worker p: probability of a channel to be zeroed. Default: 0.5 1579*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. Default: ``True`` 1580*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1581*da0073e9SAndroid Build Coastguard Worker """ 1582*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1583*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1584*da0073e9SAndroid Build Coastguard Worker dropout3d, (input,), input, p=p, training=training, inplace=inplace 1585*da0073e9SAndroid Build Coastguard Worker ) 1586*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1587*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1588*da0073e9SAndroid Build Coastguard Worker inp_dim = input.dim() 1589*da0073e9SAndroid Build Coastguard Worker if inp_dim not in (4, 5): 1590*da0073e9SAndroid Build Coastguard Worker warn_msg = ( 1591*da0073e9SAndroid Build Coastguard Worker f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " 1592*da0073e9SAndroid Build Coastguard Worker "and will result in an error in a future release. To retain the behavior " 1593*da0073e9SAndroid Build Coastguard Worker "and silence this warning, please use dropout instead. Note that dropout3d " 1594*da0073e9SAndroid Build Coastguard Worker "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " 1595*da0073e9SAndroid Build Coastguard Worker "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." 1596*da0073e9SAndroid Build Coastguard Worker ) 1597*da0073e9SAndroid Build Coastguard Worker warnings.warn(warn_msg) 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker is_batched = inp_dim == 5 1600*da0073e9SAndroid Build Coastguard Worker if not is_batched: 1601*da0073e9SAndroid Build Coastguard Worker input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) 1602*da0073e9SAndroid Build Coastguard Worker 1603*da0073e9SAndroid Build Coastguard Worker result = ( 1604*da0073e9SAndroid Build Coastguard Worker _VF.feature_dropout_(input, p, training) 1605*da0073e9SAndroid Build Coastguard Worker if inplace 1606*da0073e9SAndroid Build Coastguard Worker else _VF.feature_dropout(input, p, training) 1607*da0073e9SAndroid Build Coastguard Worker ) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker if not is_batched: 1610*da0073e9SAndroid Build Coastguard Worker result = result.squeeze_(0) if inplace else result.squeeze(0) 1611*da0073e9SAndroid Build Coastguard Worker return result 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Workerdef feature_alpha_dropout( 1615*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1616*da0073e9SAndroid Build Coastguard Worker p: float = 0.5, 1617*da0073e9SAndroid Build Coastguard Worker training: bool = False, 1618*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1619*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1620*da0073e9SAndroid Build Coastguard Worker r"""Randomly masks out entire channels (a channel is a feature map). 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input 1623*da0073e9SAndroid Build Coastguard Worker is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of 1624*da0073e9SAndroid Build Coastguard Worker setting activations to zero, as in regular Dropout, the activations are set 1625*da0073e9SAndroid Build Coastguard Worker to the negative saturation value of the SELU activation function. 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker Each element will be masked independently on every forward call with 1628*da0073e9SAndroid Build Coastguard Worker probability :attr:`p` using samples from a Bernoulli distribution. 1629*da0073e9SAndroid Build Coastguard Worker The elements to be masked are randomized on every forward call, and scaled 1630*da0073e9SAndroid Build Coastguard Worker and shifted to maintain zero mean and unit variance. 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.FeatureAlphaDropout` for details. 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker Args: 1635*da0073e9SAndroid Build Coastguard Worker p: dropout probability of a channel to be zeroed. Default: 0.5 1636*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. Default: ``True`` 1637*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1638*da0073e9SAndroid Build Coastguard Worker """ 1639*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1640*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1641*da0073e9SAndroid Build Coastguard Worker feature_alpha_dropout, 1642*da0073e9SAndroid Build Coastguard Worker (input,), 1643*da0073e9SAndroid Build Coastguard Worker input, 1644*da0073e9SAndroid Build Coastguard Worker p=p, 1645*da0073e9SAndroid Build Coastguard Worker training=training, 1646*da0073e9SAndroid Build Coastguard Worker inplace=inplace, 1647*da0073e9SAndroid Build Coastguard Worker ) 1648*da0073e9SAndroid Build Coastguard Worker if p < 0.0 or p > 1.0: 1649*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1650*da0073e9SAndroid Build Coastguard Worker return ( 1651*da0073e9SAndroid Build Coastguard Worker _VF.feature_alpha_dropout_(input, p, training) 1652*da0073e9SAndroid Build Coastguard Worker if inplace 1653*da0073e9SAndroid Build Coastguard Worker else _VF.feature_alpha_dropout(input, p, training) 1654*da0073e9SAndroid Build Coastguard Worker ) 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Workerdef _threshold( 1658*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1659*da0073e9SAndroid Build Coastguard Worker threshold: float, 1660*da0073e9SAndroid Build Coastguard Worker value: float, 1661*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1662*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 1663*da0073e9SAndroid Build Coastguard Worker r"""Apply a threshold to each element of the input Tensor. 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Threshold` for more details. 1666*da0073e9SAndroid Build Coastguard Worker """ 1667*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1668*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1669*da0073e9SAndroid Build Coastguard Worker _threshold, (input,), input, threshold, value, inplace=inplace 1670*da0073e9SAndroid Build Coastguard Worker ) 1671*da0073e9SAndroid Build Coastguard Worker if inplace: 1672*da0073e9SAndroid Build Coastguard Worker result = _VF.threshold_(input, threshold, value) 1673*da0073e9SAndroid Build Coastguard Worker else: 1674*da0073e9SAndroid Build Coastguard Worker result = _VF.threshold(input, threshold, value) 1675*da0073e9SAndroid Build Coastguard Worker return result 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker# We define this function as _threshold because it takes an argument 1679*da0073e9SAndroid Build Coastguard Worker# named threshold, which clobbers the recursive reference to the 1680*da0073e9SAndroid Build Coastguard Worker# function needed for __torch_function__ support 1681*da0073e9SAndroid Build Coastguard Workerthreshold = _threshold 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Workerthreshold_ = _add_docstr( 1684*da0073e9SAndroid Build Coastguard Worker _VF.threshold_, 1685*da0073e9SAndroid Build Coastguard Worker r""" 1686*da0073e9SAndroid Build Coastguard Workerthreshold_(input, threshold, value) -> Tensor 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~threshold`. 1689*da0073e9SAndroid Build Coastguard Worker""", 1690*da0073e9SAndroid Build Coastguard Worker) 1691*da0073e9SAndroid Build Coastguard Worker 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Workerdef relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1694*da0073e9SAndroid Build Coastguard Worker r"""relu(input, inplace=False) -> Tensor 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker Applies the rectified linear unit function element-wise. See 1697*da0073e9SAndroid Build Coastguard Worker :class:`~torch.nn.ReLU` for more details. 1698*da0073e9SAndroid Build Coastguard Worker """ 1699*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1700*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(relu, (input,), input, inplace=inplace) 1701*da0073e9SAndroid Build Coastguard Worker if inplace: 1702*da0073e9SAndroid Build Coastguard Worker result = torch.relu_(input) 1703*da0073e9SAndroid Build Coastguard Worker else: 1704*da0073e9SAndroid Build Coastguard Worker result = torch.relu(input) 1705*da0073e9SAndroid Build Coastguard Worker return result 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker 1708*da0073e9SAndroid Build Coastguard Workerrelu_ = _add_docstr( 1709*da0073e9SAndroid Build Coastguard Worker torch.relu_, 1710*da0073e9SAndroid Build Coastguard Worker r""" 1711*da0073e9SAndroid Build Coastguard Workerrelu_(input) -> Tensor 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~relu`. 1714*da0073e9SAndroid Build Coastguard Worker""", 1715*da0073e9SAndroid Build Coastguard Worker) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Workerdef glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 1719*da0073e9SAndroid Build Coastguard Worker r""" 1720*da0073e9SAndroid Build Coastguard Worker glu(input, dim=-1) -> Tensor 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker The gated linear unit. Computes: 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker .. math :: 1725*da0073e9SAndroid Build Coastguard Worker \text{GLU}(a, b) = a \otimes \sigma(b) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma` 1728*da0073e9SAndroid Build Coastguard Worker is the sigmoid function and :math:`\otimes` is the element-wise product between matrices. 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_. 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker Args: 1733*da0073e9SAndroid Build Coastguard Worker input (Tensor): input tensor 1734*da0073e9SAndroid Build Coastguard Worker dim (int): dimension on which to split the input. Default: -1 1735*da0073e9SAndroid Build Coastguard Worker """ 1736*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1737*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(glu, (input,), input, dim=dim) 1738*da0073e9SAndroid Build Coastguard Worker if input.dim() == 0: 1739*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1740*da0073e9SAndroid Build Coastguard Worker "glu does not support scalars because halving size must be even" 1741*da0073e9SAndroid Build Coastguard Worker ) 1742*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.glu(input, dim) 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Workerdef hardtanh( 1746*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1747*da0073e9SAndroid Build Coastguard Worker min_val: float = -1.0, 1748*da0073e9SAndroid Build Coastguard Worker max_val: float = 1.0, 1749*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1750*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 1751*da0073e9SAndroid Build Coastguard Worker r""" 1752*da0073e9SAndroid Build Coastguard Worker hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more 1755*da0073e9SAndroid Build Coastguard Worker details. 1756*da0073e9SAndroid Build Coastguard Worker """ 1757*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1758*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1759*da0073e9SAndroid Build Coastguard Worker hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace 1760*da0073e9SAndroid Build Coastguard Worker ) 1761*da0073e9SAndroid Build Coastguard Worker if min_val > max_val: 1762*da0073e9SAndroid Build Coastguard Worker raise ValueError("min_val cannot be greater than max_val") 1763*da0073e9SAndroid Build Coastguard Worker if inplace: 1764*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.hardtanh_(input, min_val, max_val) 1765*da0073e9SAndroid Build Coastguard Worker else: 1766*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.hardtanh(input, min_val, max_val) 1767*da0073e9SAndroid Build Coastguard Worker return result 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Workerhardtanh_ = _add_docstr( 1771*da0073e9SAndroid Build Coastguard Worker torch._C._nn.hardtanh_, 1772*da0073e9SAndroid Build Coastguard Worker r""" 1773*da0073e9SAndroid Build Coastguard Workerhardtanh_(input, min_val=-1., max_val=1.) -> Tensor 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~hardtanh`. 1776*da0073e9SAndroid Build Coastguard Worker""", 1777*da0073e9SAndroid Build Coastguard Worker) 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Workerdef relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1781*da0073e9SAndroid Build Coastguard Worker r"""relu6(input, inplace=False) -> Tensor 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.ReLU6` for more details. 1786*da0073e9SAndroid Build Coastguard Worker """ 1787*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1788*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(relu6, (input,), input, inplace=inplace) 1789*da0073e9SAndroid Build Coastguard Worker if inplace: 1790*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.relu6_(input) 1791*da0073e9SAndroid Build Coastguard Worker else: 1792*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.relu6(input) 1793*da0073e9SAndroid Build Coastguard Worker return result 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Workerdef elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: 1797*da0073e9SAndroid Build Coastguard Worker r"""Apply the Exponential Linear Unit (ELU) function element-wise. 1798*da0073e9SAndroid Build Coastguard Worker 1799*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.ELU` for more details. 1800*da0073e9SAndroid Build Coastguard Worker """ 1801*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1802*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) 1803*da0073e9SAndroid Build Coastguard Worker if inplace: 1804*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.elu_(input, alpha) 1805*da0073e9SAndroid Build Coastguard Worker else: 1806*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.elu(input, alpha) 1807*da0073e9SAndroid Build Coastguard Worker return result 1808*da0073e9SAndroid Build Coastguard Worker 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Workerelu_ = _add_docstr( 1811*da0073e9SAndroid Build Coastguard Worker torch._C._nn.elu_, 1812*da0073e9SAndroid Build Coastguard Worker r""" 1813*da0073e9SAndroid Build Coastguard Workerelu_(input, alpha=1.) -> Tensor 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~elu`. 1816*da0073e9SAndroid Build Coastguard Worker""", 1817*da0073e9SAndroid Build Coastguard Worker) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker 1820*da0073e9SAndroid Build Coastguard Workerdef selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1821*da0073e9SAndroid Build Coastguard Worker r"""selu(input, inplace=False) -> Tensor 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker Applies element-wise, 1824*da0073e9SAndroid Build Coastguard Worker :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, 1825*da0073e9SAndroid Build Coastguard Worker with :math:`\alpha=1.6732632423543772848170429916717` and 1826*da0073e9SAndroid Build Coastguard Worker :math:`scale=1.0507009873554804934193349852946`. 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.SELU` for more details. 1829*da0073e9SAndroid Build Coastguard Worker """ 1830*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1831*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(selu, (input,), input, inplace=inplace) 1832*da0073e9SAndroid Build Coastguard Worker if inplace: 1833*da0073e9SAndroid Build Coastguard Worker result = torch.selu_(input) 1834*da0073e9SAndroid Build Coastguard Worker else: 1835*da0073e9SAndroid Build Coastguard Worker result = torch.selu(input) 1836*da0073e9SAndroid Build Coastguard Worker return result 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Workerselu_ = _add_docstr( 1840*da0073e9SAndroid Build Coastguard Worker torch.selu_, 1841*da0073e9SAndroid Build Coastguard Worker r""" 1842*da0073e9SAndroid Build Coastguard Workerselu_(input) -> Tensor 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~selu`. 1845*da0073e9SAndroid Build Coastguard Worker""", 1846*da0073e9SAndroid Build Coastguard Worker) 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Workerdef celu( 1850*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1851*da0073e9SAndroid Build Coastguard Worker alpha: float = 1.0, 1852*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1853*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 1854*da0073e9SAndroid Build Coastguard Worker r"""celu(input, alpha=1., inplace=False) -> Tensor 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker Applies element-wise, 1857*da0073e9SAndroid Build Coastguard Worker :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`. 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.CELU` for more details. 1860*da0073e9SAndroid Build Coastguard Worker """ 1861*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1862*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1863*da0073e9SAndroid Build Coastguard Worker celu, (input,), input, alpha=alpha, inplace=inplace 1864*da0073e9SAndroid Build Coastguard Worker ) 1865*da0073e9SAndroid Build Coastguard Worker if inplace: 1866*da0073e9SAndroid Build Coastguard Worker result = torch.celu_(input, alpha) 1867*da0073e9SAndroid Build Coastguard Worker else: 1868*da0073e9SAndroid Build Coastguard Worker result = torch.celu(input, alpha) 1869*da0073e9SAndroid Build Coastguard Worker return result 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Workercelu_ = _add_docstr( 1873*da0073e9SAndroid Build Coastguard Worker torch.celu_, 1874*da0073e9SAndroid Build Coastguard Worker r""" 1875*da0073e9SAndroid Build Coastguard Workercelu_(input, alpha=1.) -> Tensor 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~celu`. 1878*da0073e9SAndroid Build Coastguard Worker""", 1879*da0073e9SAndroid Build Coastguard Worker) 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Workerdef leaky_relu( 1883*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1884*da0073e9SAndroid Build Coastguard Worker negative_slope: float = 0.01, 1885*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1886*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 1887*da0073e9SAndroid Build Coastguard Worker r""" 1888*da0073e9SAndroid Build Coastguard Worker leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker Applies element-wise, 1891*da0073e9SAndroid Build Coastguard Worker :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LeakyReLU` for more details. 1894*da0073e9SAndroid Build Coastguard Worker """ 1895*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1896*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1897*da0073e9SAndroid Build Coastguard Worker leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace 1898*da0073e9SAndroid Build Coastguard Worker ) 1899*da0073e9SAndroid Build Coastguard Worker if inplace: 1900*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.leaky_relu_(input, negative_slope) 1901*da0073e9SAndroid Build Coastguard Worker else: 1902*da0073e9SAndroid Build Coastguard Worker result = torch._C._nn.leaky_relu(input, negative_slope) 1903*da0073e9SAndroid Build Coastguard Worker return result 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Workerleaky_relu_ = _add_docstr( 1907*da0073e9SAndroid Build Coastguard Worker torch._C._nn.leaky_relu_, 1908*da0073e9SAndroid Build Coastguard Worker r""" 1909*da0073e9SAndroid Build Coastguard Workerleaky_relu_(input, negative_slope=0.01) -> Tensor 1910*da0073e9SAndroid Build Coastguard Worker 1911*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~leaky_relu`. 1912*da0073e9SAndroid Build Coastguard Worker""", 1913*da0073e9SAndroid Build Coastguard Worker) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker 1916*da0073e9SAndroid Build Coastguard Workerprelu = _add_docstr( 1917*da0073e9SAndroid Build Coastguard Worker torch.prelu, 1918*da0073e9SAndroid Build Coastguard Worker r"""prelu(input, weight) -> Tensor 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard WorkerApplies element-wise the function 1921*da0073e9SAndroid Build Coastguard Worker:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a 1922*da0073e9SAndroid Build Coastguard Workerlearnable parameter. 1923*da0073e9SAndroid Build Coastguard Worker 1924*da0073e9SAndroid Build Coastguard Worker.. note:: 1925*da0073e9SAndroid Build Coastguard Worker `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, 1926*da0073e9SAndroid Build Coastguard Worker its size must match the number of input channels, determined by 1927*da0073e9SAndroid Build Coastguard Worker `input.size(1)` when `input.dim() >= 2`, otherwise 1. 1928*da0073e9SAndroid Build Coastguard Worker In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded 1929*da0073e9SAndroid Build Coastguard Worker to the shape of `input` in a way that is not possible using normal 1930*da0073e9SAndroid Build Coastguard Worker :ref:`broadcasting semantics<broadcasting-semantics>`. 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PReLU` for more details. 1933*da0073e9SAndroid Build Coastguard Worker""", 1934*da0073e9SAndroid Build Coastguard Worker) 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Workerdef rrelu( 1938*da0073e9SAndroid Build Coastguard Worker input: Tensor, 1939*da0073e9SAndroid Build Coastguard Worker lower: float = 1.0 / 8, 1940*da0073e9SAndroid Build Coastguard Worker upper: float = 1.0 / 3, 1941*da0073e9SAndroid Build Coastguard Worker training: bool = False, 1942*da0073e9SAndroid Build Coastguard Worker inplace: bool = False, 1943*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 1944*da0073e9SAndroid Build Coastguard Worker r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker Randomized leaky ReLU. 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.RReLU` for more details. 1949*da0073e9SAndroid Build Coastguard Worker """ 1950*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1951*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1952*da0073e9SAndroid Build Coastguard Worker rrelu, 1953*da0073e9SAndroid Build Coastguard Worker (input,), 1954*da0073e9SAndroid Build Coastguard Worker input, 1955*da0073e9SAndroid Build Coastguard Worker lower=lower, 1956*da0073e9SAndroid Build Coastguard Worker upper=upper, 1957*da0073e9SAndroid Build Coastguard Worker training=training, 1958*da0073e9SAndroid Build Coastguard Worker inplace=inplace, 1959*da0073e9SAndroid Build Coastguard Worker ) 1960*da0073e9SAndroid Build Coastguard Worker if inplace: 1961*da0073e9SAndroid Build Coastguard Worker result = torch.rrelu_(input, lower, upper, training) 1962*da0073e9SAndroid Build Coastguard Worker else: 1963*da0073e9SAndroid Build Coastguard Worker result = torch.rrelu(input, lower, upper, training) 1964*da0073e9SAndroid Build Coastguard Worker return result 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Workerrrelu_ = _add_docstr( 1968*da0073e9SAndroid Build Coastguard Worker torch.rrelu_, 1969*da0073e9SAndroid Build Coastguard Worker r""" 1970*da0073e9SAndroid Build Coastguard Workerrrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard WorkerIn-place version of :func:`~rrelu`. 1973*da0073e9SAndroid Build Coastguard Worker""", 1974*da0073e9SAndroid Build Coastguard Worker) 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Workerlogsigmoid = _add_docstr( 1977*da0073e9SAndroid Build Coastguard Worker torch._C._nn.log_sigmoid, 1978*da0073e9SAndroid Build Coastguard Worker r""" 1979*da0073e9SAndroid Build Coastguard Workerlogsigmoid(input) -> Tensor 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard WorkerApplies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.LogSigmoid` for more details. 1984*da0073e9SAndroid Build Coastguard Worker""", 1985*da0073e9SAndroid Build Coastguard Worker) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Workergelu = _add_docstr( 1988*da0073e9SAndroid Build Coastguard Worker torch._C._nn.gelu, 1989*da0073e9SAndroid Build Coastguard Worker r""" 1990*da0073e9SAndroid Build Coastguard Workergelu(input, approximate = 'none') -> Tensor 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard WorkerWhen the approximate argument is 'none', it applies element-wise the function 1993*da0073e9SAndroid Build Coastguard Worker:math:`\text{GELU}(x) = x * \Phi(x)` 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Workerwhere :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard WorkerWhen the approximate argument is 'tanh', Gelu is estimated with 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker.. math:: 2000*da0073e9SAndroid Build Coastguard Worker \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard WorkerSee `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_. 2003*da0073e9SAndroid Build Coastguard Worker""", 2004*da0073e9SAndroid Build Coastguard Worker) 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Workerhardshrink = _add_docstr( 2007*da0073e9SAndroid Build Coastguard Worker torch.hardshrink, 2008*da0073e9SAndroid Build Coastguard Worker r""" 2009*da0073e9SAndroid Build Coastguard Workerhardshrink(input, lambd=0.5) -> Tensor 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard WorkerApplies the hard shrinkage function element-wise 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Hardshrink` for more details. 2014*da0073e9SAndroid Build Coastguard Worker""", 2015*da0073e9SAndroid Build Coastguard Worker) 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Workerdef tanhshrink(input): # noqa: D400,D402 2019*da0073e9SAndroid Build Coastguard Worker r"""tanhshrink(input) -> Tensor 2020*da0073e9SAndroid Build Coastguard Worker 2021*da0073e9SAndroid Build Coastguard Worker Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Tanhshrink` for more details. 2024*da0073e9SAndroid Build Coastguard Worker """ 2025*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2026*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(tanhshrink, (input,), input) 2027*da0073e9SAndroid Build Coastguard Worker return input - input.tanh() 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker 2030*da0073e9SAndroid Build Coastguard Workerdef softsign(input): # noqa: D400,D402 2031*da0073e9SAndroid Build Coastguard Worker r"""softsign(input) -> Tensor 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Worker Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Softsign` for more details. 2036*da0073e9SAndroid Build Coastguard Worker """ 2037*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2038*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(softsign, (input,), input) 2039*da0073e9SAndroid Build Coastguard Worker return input / (input.abs() + 1) 2040*da0073e9SAndroid Build Coastguard Worker 2041*da0073e9SAndroid Build Coastguard Worker 2042*da0073e9SAndroid Build Coastguard Workersoftplus = _add_docstr( 2043*da0073e9SAndroid Build Coastguard Worker torch._C._nn.softplus, 2044*da0073e9SAndroid Build Coastguard Worker r""" 2045*da0073e9SAndroid Build Coastguard Workersoftplus(input, beta=1, threshold=20) -> Tensor 2046*da0073e9SAndroid Build Coastguard Worker 2047*da0073e9SAndroid Build Coastguard WorkerApplies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard WorkerFor numerical stability the implementation reverts to the linear function 2050*da0073e9SAndroid Build Coastguard Workerwhen :math:`input \times \beta > threshold`. 2051*da0073e9SAndroid Build Coastguard Worker 2052*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Softplus` for more details. 2053*da0073e9SAndroid Build Coastguard Worker""", 2054*da0073e9SAndroid Build Coastguard Worker) 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Workerdef _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: 2058*da0073e9SAndroid Build Coastguard Worker warnings.warn( 2059*da0073e9SAndroid Build Coastguard Worker f"Implicit dimension choice for {name} has been deprecated. " 2060*da0073e9SAndroid Build Coastguard Worker "Change the call to include dim=X as an argument.", 2061*da0073e9SAndroid Build Coastguard Worker stacklevel=stacklevel, 2062*da0073e9SAndroid Build Coastguard Worker ) 2063*da0073e9SAndroid Build Coastguard Worker if ndim == 0 or ndim == 1 or ndim == 3: 2064*da0073e9SAndroid Build Coastguard Worker ret = 0 2065*da0073e9SAndroid Build Coastguard Worker else: 2066*da0073e9SAndroid Build Coastguard Worker ret = 1 2067*da0073e9SAndroid Build Coastguard Worker return ret 2068*da0073e9SAndroid Build Coastguard Worker 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Workerdef softmin( 2071*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2072*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = None, 2073*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = 3, 2074*da0073e9SAndroid Build Coastguard Worker dtype: Optional[DType] = None, 2075*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2076*da0073e9SAndroid Build Coastguard Worker r"""Apply a softmin function. 2077*da0073e9SAndroid Build Coastguard Worker 2078*da0073e9SAndroid Build Coastguard Worker Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Softmin` for more details. 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker Args: 2083*da0073e9SAndroid Build Coastguard Worker input (Tensor): input 2084*da0073e9SAndroid Build Coastguard Worker dim (int): A dimension along which softmin will be computed (so every slice 2085*da0073e9SAndroid Build Coastguard Worker along dim will sum to 1). 2086*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2087*da0073e9SAndroid Build Coastguard Worker If specified, the input tensor is casted to :attr:`dtype` before the operation 2088*da0073e9SAndroid Build Coastguard Worker is performed. This is useful for preventing data type overflows. Default: None. 2089*da0073e9SAndroid Build Coastguard Worker """ 2090*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2091*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2092*da0073e9SAndroid Build Coastguard Worker softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2093*da0073e9SAndroid Build Coastguard Worker ) 2094*da0073e9SAndroid Build Coastguard Worker if dim is None: 2095*da0073e9SAndroid Build Coastguard Worker dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) 2096*da0073e9SAndroid Build Coastguard Worker if dtype is None: 2097*da0073e9SAndroid Build Coastguard Worker ret = (-input).softmax(dim) 2098*da0073e9SAndroid Build Coastguard Worker else: 2099*da0073e9SAndroid Build Coastguard Worker ret = (-input).softmax(dim, dtype=dtype) 2100*da0073e9SAndroid Build Coastguard Worker return ret 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Workerdef softmax( 2104*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2105*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = None, 2106*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = 3, 2107*da0073e9SAndroid Build Coastguard Worker dtype: Optional[DType] = None, 2108*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2109*da0073e9SAndroid Build Coastguard Worker r"""Apply a softmax function. 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker Softmax is defined as: 2112*da0073e9SAndroid Build Coastguard Worker 2113*da0073e9SAndroid Build Coastguard Worker :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Worker It is applied to all slices along dim, and will re-scale them so that the elements 2116*da0073e9SAndroid Build Coastguard Worker lie in the range `[0, 1]` and sum to 1. 2117*da0073e9SAndroid Build Coastguard Worker 2118*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Softmax` for more details. 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker Args: 2121*da0073e9SAndroid Build Coastguard Worker input (Tensor): input 2122*da0073e9SAndroid Build Coastguard Worker dim (int): A dimension along which softmax will be computed. 2123*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2124*da0073e9SAndroid Build Coastguard Worker If specified, the input tensor is casted to :attr:`dtype` before the operation 2125*da0073e9SAndroid Build Coastguard Worker is performed. This is useful for preventing data type overflows. Default: None. 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker .. note:: 2128*da0073e9SAndroid Build Coastguard Worker This function doesn't work directly with NLLLoss, 2129*da0073e9SAndroid Build Coastguard Worker which expects the Log to be computed between the Softmax and itself. 2130*da0073e9SAndroid Build Coastguard Worker Use log_softmax instead (it's faster and has better numerical properties). 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker """ 2133*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2134*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2135*da0073e9SAndroid Build Coastguard Worker softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2136*da0073e9SAndroid Build Coastguard Worker ) 2137*da0073e9SAndroid Build Coastguard Worker if dim is None: 2138*da0073e9SAndroid Build Coastguard Worker dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) 2139*da0073e9SAndroid Build Coastguard Worker if dtype is None: 2140*da0073e9SAndroid Build Coastguard Worker ret = input.softmax(dim) 2141*da0073e9SAndroid Build Coastguard Worker else: 2142*da0073e9SAndroid Build Coastguard Worker ret = input.softmax(dim, dtype=dtype) 2143*da0073e9SAndroid Build Coastguard Worker return ret 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker 2146*da0073e9SAndroid Build Coastguard Workerdef gumbel_softmax( 2147*da0073e9SAndroid Build Coastguard Worker logits: Tensor, 2148*da0073e9SAndroid Build Coastguard Worker tau: float = 1, 2149*da0073e9SAndroid Build Coastguard Worker hard: bool = False, 2150*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-10, 2151*da0073e9SAndroid Build Coastguard Worker dim: int = -1, 2152*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2153*da0073e9SAndroid Build Coastguard Worker r""" 2154*da0073e9SAndroid Build Coastguard Worker Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker Args: 2157*da0073e9SAndroid Build Coastguard Worker logits: `[..., num_features]` unnormalized log probabilities 2158*da0073e9SAndroid Build Coastguard Worker tau: non-negative scalar temperature 2159*da0073e9SAndroid Build Coastguard Worker hard: if ``True``, the returned samples will be discretized as one-hot vectors, 2160*da0073e9SAndroid Build Coastguard Worker but will be differentiated as if it is the soft sample in autograd 2161*da0073e9SAndroid Build Coastguard Worker dim (int): A dimension along which softmax will be computed. Default: -1. 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker Returns: 2164*da0073e9SAndroid Build Coastguard Worker Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. 2165*da0073e9SAndroid Build Coastguard Worker If ``hard=True``, the returned samples will be one-hot, otherwise they will 2166*da0073e9SAndroid Build Coastguard Worker be probability distributions that sum to 1 across `dim`. 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker .. note:: 2169*da0073e9SAndroid Build Coastguard Worker This function is here for legacy reasons, may be removed from nn.Functional in the future. 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker .. note:: 2172*da0073e9SAndroid Build Coastguard Worker The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` 2173*da0073e9SAndroid Build Coastguard Worker 2174*da0073e9SAndroid Build Coastguard Worker It achieves two things: 2175*da0073e9SAndroid Build Coastguard Worker - makes the output value exactly one-hot 2176*da0073e9SAndroid Build Coastguard Worker (since we add then subtract y_soft value) 2177*da0073e9SAndroid Build Coastguard Worker - makes the gradient equal to y_soft gradient 2178*da0073e9SAndroid Build Coastguard Worker (since we strip all other gradients) 2179*da0073e9SAndroid Build Coastguard Worker 2180*da0073e9SAndroid Build Coastguard Worker Examples:: 2181*da0073e9SAndroid Build Coastguard Worker >>> logits = torch.randn(20, 32) 2182*da0073e9SAndroid Build Coastguard Worker >>> # Sample soft categorical using reparametrization trick: 2183*da0073e9SAndroid Build Coastguard Worker >>> F.gumbel_softmax(logits, tau=1, hard=False) 2184*da0073e9SAndroid Build Coastguard Worker >>> # Sample hard categorical using "Straight-through" trick: 2185*da0073e9SAndroid Build Coastguard Worker >>> F.gumbel_softmax(logits, tau=1, hard=True) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker .. _Link 1: 2188*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1611.00712 2189*da0073e9SAndroid Build Coastguard Worker .. _Link 2: 2190*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1611.01144 2191*da0073e9SAndroid Build Coastguard Worker """ 2192*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(logits): 2193*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2194*da0073e9SAndroid Build Coastguard Worker gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim 2195*da0073e9SAndroid Build Coastguard Worker ) 2196*da0073e9SAndroid Build Coastguard Worker if eps != 1e-10: 2197*da0073e9SAndroid Build Coastguard Worker warnings.warn("`eps` parameter is deprecated and has no effect.") 2198*da0073e9SAndroid Build Coastguard Worker 2199*da0073e9SAndroid Build Coastguard Worker gumbels = ( 2200*da0073e9SAndroid Build Coastguard Worker -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) 2201*da0073e9SAndroid Build Coastguard Worker .exponential_() 2202*da0073e9SAndroid Build Coastguard Worker .log() 2203*da0073e9SAndroid Build Coastguard Worker ) # ~Gumbel(0,1) 2204*da0073e9SAndroid Build Coastguard Worker gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 2205*da0073e9SAndroid Build Coastguard Worker y_soft = gumbels.softmax(dim) 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker if hard: 2208*da0073e9SAndroid Build Coastguard Worker # Straight through. 2209*da0073e9SAndroid Build Coastguard Worker index = y_soft.max(dim, keepdim=True)[1] 2210*da0073e9SAndroid Build Coastguard Worker y_hard = torch.zeros_like( 2211*da0073e9SAndroid Build Coastguard Worker logits, memory_format=torch.legacy_contiguous_format 2212*da0073e9SAndroid Build Coastguard Worker ).scatter_(dim, index, 1.0) 2213*da0073e9SAndroid Build Coastguard Worker ret = y_hard - y_soft.detach() + y_soft 2214*da0073e9SAndroid Build Coastguard Worker else: 2215*da0073e9SAndroid Build Coastguard Worker # Reparametrization trick. 2216*da0073e9SAndroid Build Coastguard Worker ret = y_soft 2217*da0073e9SAndroid Build Coastguard Worker return ret 2218*da0073e9SAndroid Build Coastguard Worker 2219*da0073e9SAndroid Build Coastguard Worker 2220*da0073e9SAndroid Build Coastguard Workerdef log_softmax( 2221*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2222*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = None, 2223*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = 3, 2224*da0073e9SAndroid Build Coastguard Worker dtype: Optional[DType] = None, 2225*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2226*da0073e9SAndroid Build Coastguard Worker r"""Apply a softmax followed by a logarithm. 2227*da0073e9SAndroid Build Coastguard Worker 2228*da0073e9SAndroid Build Coastguard Worker While mathematically equivalent to log(softmax(x)), doing these two 2229*da0073e9SAndroid Build Coastguard Worker operations separately is slower and numerically unstable. This function 2230*da0073e9SAndroid Build Coastguard Worker uses an alternative formulation to compute the output and gradient correctly. 2231*da0073e9SAndroid Build Coastguard Worker 2232*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LogSoftmax` for more details. 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker Args: 2235*da0073e9SAndroid Build Coastguard Worker input (Tensor): input 2236*da0073e9SAndroid Build Coastguard Worker dim (int): A dimension along which log_softmax will be computed. 2237*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2238*da0073e9SAndroid Build Coastguard Worker If specified, the input tensor is cast to :attr:`dtype` before the operation 2239*da0073e9SAndroid Build Coastguard Worker is performed. This is useful for preventing data type overflows. Default: None. 2240*da0073e9SAndroid Build Coastguard Worker """ 2241*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2242*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2243*da0073e9SAndroid Build Coastguard Worker log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2244*da0073e9SAndroid Build Coastguard Worker ) 2245*da0073e9SAndroid Build Coastguard Worker if dim is None: 2246*da0073e9SAndroid Build Coastguard Worker dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) 2247*da0073e9SAndroid Build Coastguard Worker if dtype is None: 2248*da0073e9SAndroid Build Coastguard Worker ret = input.log_softmax(dim) 2249*da0073e9SAndroid Build Coastguard Worker else: 2250*da0073e9SAndroid Build Coastguard Worker ret = input.log_softmax(dim, dtype=dtype) 2251*da0073e9SAndroid Build Coastguard Worker return ret 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Workersoftshrink = _add_docstr( 2255*da0073e9SAndroid Build Coastguard Worker torch._C._nn.softshrink, 2256*da0073e9SAndroid Build Coastguard Worker r""" 2257*da0073e9SAndroid Build Coastguard Workersoftshrink(input, lambd=0.5) -> Tensor 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard WorkerApplies the soft shrinkage function elementwise 2260*da0073e9SAndroid Build Coastguard Worker 2261*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.Softshrink` for more details. 2262*da0073e9SAndroid Build Coastguard Worker""", 2263*da0073e9SAndroid Build Coastguard Worker) 2264*da0073e9SAndroid Build Coastguard Worker 2265*da0073e9SAndroid Build Coastguard Worker 2266*da0073e9SAndroid Build Coastguard Workerdef tanh(input): # noqa: D400,D402 2267*da0073e9SAndroid Build Coastguard Worker r"""tanh(input) -> Tensor 2268*da0073e9SAndroid Build Coastguard Worker 2269*da0073e9SAndroid Build Coastguard Worker Applies element-wise, 2270*da0073e9SAndroid Build Coastguard Worker :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}` 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Tanh` for more details. 2273*da0073e9SAndroid Build Coastguard Worker """ 2274*da0073e9SAndroid Build Coastguard Worker return input.tanh() 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Workerdef sigmoid(input): # noqa: D400,D402 2278*da0073e9SAndroid Build Coastguard Worker r"""sigmoid(input) -> Tensor 2279*da0073e9SAndroid Build Coastguard Worker 2280*da0073e9SAndroid Build Coastguard Worker Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Sigmoid` for more details. 2283*da0073e9SAndroid Build Coastguard Worker """ 2284*da0073e9SAndroid Build Coastguard Worker return input.sigmoid() 2285*da0073e9SAndroid Build Coastguard Worker 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Workerdef hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: 2288*da0073e9SAndroid Build Coastguard Worker r"""Apply the Hardsigmoid function element-wise. 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker .. math:: 2291*da0073e9SAndroid Build Coastguard Worker \text{Hardsigmoid}(x) = \begin{cases} 2292*da0073e9SAndroid Build Coastguard Worker 0 & \text{if~} x \le -3, \\ 2293*da0073e9SAndroid Build Coastguard Worker 1 & \text{if~} x \ge +3, \\ 2294*da0073e9SAndroid Build Coastguard Worker x / 6 + 1 / 2 & \text{otherwise} 2295*da0073e9SAndroid Build Coastguard Worker \end{cases} 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker Args: 2298*da0073e9SAndroid Build Coastguard Worker inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Hardsigmoid` for more details. 2301*da0073e9SAndroid Build Coastguard Worker """ 2302*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2303*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) 2304*da0073e9SAndroid Build Coastguard Worker if inplace: 2305*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.hardsigmoid_(input) 2306*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.hardsigmoid(input) 2307*da0073e9SAndroid Build Coastguard Worker 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Workerlinear = _add_docstr( 2310*da0073e9SAndroid Build Coastguard Worker torch._C._nn.linear, 2311*da0073e9SAndroid Build Coastguard Worker r""" 2312*da0073e9SAndroid Build Coastguard Workerlinear(input, weight, bias=None) -> Tensor 2313*da0073e9SAndroid Build Coastguard Worker 2314*da0073e9SAndroid Build Coastguard WorkerApplies a linear transformation to the incoming data: :math:`y = xA^T + b`. 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard WorkerThis operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>` 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker{sparse_beta_warning} 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard WorkerThis operator supports :ref:`TensorFloat32<tf32_on_ampere>`. 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard WorkerShape: 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker - Input: :math:`(*, in\_features)` where `*` means any number of 2325*da0073e9SAndroid Build Coastguard Worker additional dimensions, including none 2326*da0073e9SAndroid Build Coastguard Worker - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` 2327*da0073e9SAndroid Build Coastguard Worker - Bias: :math:`(out\_features)` or :math:`()` 2328*da0073e9SAndroid Build Coastguard Worker - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight 2329*da0073e9SAndroid Build Coastguard Worker""".format( 2330*da0073e9SAndroid Build Coastguard Worker **sparse_support_notes 2331*da0073e9SAndroid Build Coastguard Worker ), 2332*da0073e9SAndroid Build Coastguard Worker) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker 2335*da0073e9SAndroid Build Coastguard Workerbilinear = _add_docstr( 2336*da0073e9SAndroid Build Coastguard Worker torch.bilinear, 2337*da0073e9SAndroid Build Coastguard Worker r""" 2338*da0073e9SAndroid Build Coastguard Workerbilinear(input1, input2, weight, bias=None) -> Tensor 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard WorkerApplies a bilinear transformation to the incoming data: 2341*da0073e9SAndroid Build Coastguard Worker:math:`y = x_1^T A x_2 + b` 2342*da0073e9SAndroid Build Coastguard Worker 2343*da0073e9SAndroid Build Coastguard WorkerShape: 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` 2346*da0073e9SAndroid Build Coastguard Worker and :math:`*` means any number of additional dimensions. 2347*da0073e9SAndroid Build Coastguard Worker All but the last dimension of the inputs should be the same. 2348*da0073e9SAndroid Build Coastguard Worker - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}` 2349*da0073e9SAndroid Build Coastguard Worker - weight: :math:`(\text{out\_features}, \text{in1\_features}, 2350*da0073e9SAndroid Build Coastguard Worker \text{in2\_features})` 2351*da0073e9SAndroid Build Coastguard Worker - bias: :math:`(\text{out\_features})` 2352*da0073e9SAndroid Build Coastguard Worker - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` 2353*da0073e9SAndroid Build Coastguard Worker and all but the last dimension are the same shape as the input. 2354*da0073e9SAndroid Build Coastguard Worker""", 2355*da0073e9SAndroid Build Coastguard Worker) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Workerdef silu(input: Tensor, inplace: bool = False) -> Tensor: 2359*da0073e9SAndroid Build Coastguard Worker r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. 2360*da0073e9SAndroid Build Coastguard Worker 2361*da0073e9SAndroid Build Coastguard Worker The SiLU function is also known as the swish function. 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker .. math:: 2364*da0073e9SAndroid Build Coastguard Worker \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker .. note:: 2367*da0073e9SAndroid Build Coastguard Worker See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ 2368*da0073e9SAndroid Build Coastguard Worker where the SiLU (Sigmoid Linear Unit) was originally coined, and see 2369*da0073e9SAndroid Build Coastguard Worker `Sigmoid-Weighted Linear Units for Neural Network Function Approximation 2370*da0073e9SAndroid Build Coastguard Worker in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: 2371*da0073e9SAndroid Build Coastguard Worker a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ 2372*da0073e9SAndroid Build Coastguard Worker where the SiLU was experimented with later. 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.SiLU` for more details. 2375*da0073e9SAndroid Build Coastguard Worker """ 2376*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2377*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(silu, (input,), input, inplace=inplace) 2378*da0073e9SAndroid Build Coastguard Worker if inplace: 2379*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.silu_(input) 2380*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.silu(input) 2381*da0073e9SAndroid Build Coastguard Worker 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Workerdef mish(input: Tensor, inplace: bool = False) -> Tensor: 2384*da0073e9SAndroid Build Coastguard Worker r"""Apply the Mish function, element-wise. 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker Mish: A Self Regularized Non-Monotonic Neural Activation Function. 2387*da0073e9SAndroid Build Coastguard Worker 2388*da0073e9SAndroid Build Coastguard Worker .. math:: 2389*da0073e9SAndroid Build Coastguard Worker \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) 2390*da0073e9SAndroid Build Coastguard Worker 2391*da0073e9SAndroid Build Coastguard Worker .. note:: 2392*da0073e9SAndroid Build Coastguard Worker See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_ 2393*da0073e9SAndroid Build Coastguard Worker 2394*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Mish` for more details. 2395*da0073e9SAndroid Build Coastguard Worker """ 2396*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2397*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(mish, (input,), input, inplace=inplace) 2398*da0073e9SAndroid Build Coastguard Worker if inplace: 2399*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.mish_(input) 2400*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.mish(input) 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker 2403*da0073e9SAndroid Build Coastguard Workerdef hardswish(input: Tensor, inplace: bool = False) -> Tensor: 2404*da0073e9SAndroid Build Coastguard Worker r"""Apply hardswish function, element-wise. 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker Follows implementation as described in the paper: 2407*da0073e9SAndroid Build Coastguard Worker `Searching for MobileNetV3`_. 2408*da0073e9SAndroid Build Coastguard Worker 2409*da0073e9SAndroid Build Coastguard Worker .. math:: 2410*da0073e9SAndroid Build Coastguard Worker \text{Hardswish}(x) = \begin{cases} 2411*da0073e9SAndroid Build Coastguard Worker 0 & \text{if~} x \le -3, \\ 2412*da0073e9SAndroid Build Coastguard Worker x & \text{if~} x \ge +3, \\ 2413*da0073e9SAndroid Build Coastguard Worker x \cdot (x + 3) /6 & \text{otherwise} 2414*da0073e9SAndroid Build Coastguard Worker \end{cases} 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Hardswish` for more details. 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker .. _`Searching for MobileNetV3`: 2419*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1905.02244 2420*da0073e9SAndroid Build Coastguard Worker """ 2421*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2422*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(hardswish, (input,), input, inplace=inplace) 2423*da0073e9SAndroid Build Coastguard Worker if inplace: 2424*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.hardswish_(input) 2425*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.hardswish(input) 2426*da0073e9SAndroid Build Coastguard Worker 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Workerdef _no_grad_embedding_renorm_( 2429*da0073e9SAndroid Build Coastguard Worker weight: Tensor, 2430*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2431*da0073e9SAndroid Build Coastguard Worker max_norm: float, 2432*da0073e9SAndroid Build Coastguard Worker norm_type: float, 2433*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 2434*da0073e9SAndroid Build Coastguard Worker torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker 2437*da0073e9SAndroid Build Coastguard Workerdef embedding( 2438*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2439*da0073e9SAndroid Build Coastguard Worker weight: Tensor, 2440*da0073e9SAndroid Build Coastguard Worker padding_idx: Optional[int] = None, 2441*da0073e9SAndroid Build Coastguard Worker max_norm: Optional[float] = None, 2442*da0073e9SAndroid Build Coastguard Worker norm_type: float = 2.0, 2443*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq: bool = False, 2444*da0073e9SAndroid Build Coastguard Worker sparse: bool = False, 2445*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2446*da0073e9SAndroid Build Coastguard Worker r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size. 2447*da0073e9SAndroid Build Coastguard Worker 2448*da0073e9SAndroid Build Coastguard Worker This module is often used to retrieve word embeddings using indices. 2449*da0073e9SAndroid Build Coastguard Worker The input to the module is a list of indices, and the embedding matrix, 2450*da0073e9SAndroid Build Coastguard Worker and the output is the corresponding word embeddings. 2451*da0073e9SAndroid Build Coastguard Worker 2452*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.Embedding` for more details. 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker .. note:: 2455*da0073e9SAndroid Build Coastguard Worker Note that the analytical gradients of this function with respect to 2456*da0073e9SAndroid Build Coastguard Worker entries in :attr:`weight` at the row specified by :attr:`padding_idx` 2457*da0073e9SAndroid Build Coastguard Worker are expected to differ from the numerical ones. 2458*da0073e9SAndroid Build Coastguard Worker 2459*da0073e9SAndroid Build Coastguard Worker .. note:: 2460*da0073e9SAndroid Build Coastguard Worker Note that `:class:`torch.nn.Embedding` differs from this function in 2461*da0073e9SAndroid Build Coastguard Worker that it initializes the row of :attr:`weight` specified by 2462*da0073e9SAndroid Build Coastguard Worker :attr:`padding_idx` to all zeros on construction. 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker Args: 2465*da0073e9SAndroid Build Coastguard Worker input (LongTensor): Tensor containing indices into the embedding matrix 2466*da0073e9SAndroid Build Coastguard Worker weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, 2467*da0073e9SAndroid Build Coastguard Worker and number of columns equal to the embedding size 2468*da0073e9SAndroid Build Coastguard Worker padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; 2469*da0073e9SAndroid Build Coastguard Worker therefore, the embedding vector at :attr:`padding_idx` is not updated during training, 2470*da0073e9SAndroid Build Coastguard Worker i.e. it remains as a fixed "pad". 2471*da0073e9SAndroid Build Coastguard Worker max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 2472*da0073e9SAndroid Build Coastguard Worker is renormalized to have norm :attr:`max_norm`. 2473*da0073e9SAndroid Build Coastguard Worker Note: this will modify :attr:`weight` in-place. 2474*da0073e9SAndroid Build Coastguard Worker norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. 2475*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of 2476*da0073e9SAndroid Build Coastguard Worker the words in the mini-batch. Default ``False``. 2477*da0073e9SAndroid Build Coastguard Worker sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under 2478*da0073e9SAndroid Build Coastguard Worker :class:`torch.nn.Embedding` for more details regarding sparse gradients. 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker Shape: 2481*da0073e9SAndroid Build Coastguard Worker - Input: LongTensor of arbitrary shape containing the indices to extract 2482*da0073e9SAndroid Build Coastguard Worker - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`, 2483*da0073e9SAndroid Build Coastguard Worker where V = maximum index + 1 and embedding_dim = the embedding size 2484*da0073e9SAndroid Build Coastguard Worker - Output: `(*, embedding_dim)`, where `*` is the input shape 2485*da0073e9SAndroid Build Coastguard Worker 2486*da0073e9SAndroid Build Coastguard Worker Examples:: 2487*da0073e9SAndroid Build Coastguard Worker 2488*da0073e9SAndroid Build Coastguard Worker >>> # a batch of 2 samples of 4 indices each 2489*da0073e9SAndroid Build Coastguard Worker >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 2490*da0073e9SAndroid Build Coastguard Worker >>> # an embedding matrix containing 10 tensors of size 3 2491*da0073e9SAndroid Build Coastguard Worker >>> embedding_matrix = torch.rand(10, 3) 2492*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2493*da0073e9SAndroid Build Coastguard Worker >>> F.embedding(input, embedding_matrix) 2494*da0073e9SAndroid Build Coastguard Worker tensor([[[ 0.8490, 0.9625, 0.6753], 2495*da0073e9SAndroid Build Coastguard Worker [ 0.9666, 0.7761, 0.6108], 2496*da0073e9SAndroid Build Coastguard Worker [ 0.6246, 0.9751, 0.3618], 2497*da0073e9SAndroid Build Coastguard Worker [ 0.4161, 0.2419, 0.7383]], 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker [[ 0.6246, 0.9751, 0.3618], 2500*da0073e9SAndroid Build Coastguard Worker [ 0.0237, 0.7794, 0.0528], 2501*da0073e9SAndroid Build Coastguard Worker [ 0.9666, 0.7761, 0.6108], 2502*da0073e9SAndroid Build Coastguard Worker [ 0.3385, 0.8612, 0.1867]]]) 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker >>> # example with padding_idx 2505*da0073e9SAndroid Build Coastguard Worker >>> weights = torch.rand(10, 3) 2506*da0073e9SAndroid Build Coastguard Worker >>> weights[0, :].zero_() 2507*da0073e9SAndroid Build Coastguard Worker >>> embedding_matrix = weights 2508*da0073e9SAndroid Build Coastguard Worker >>> input = torch.tensor([[0, 2, 0, 5]]) 2509*da0073e9SAndroid Build Coastguard Worker >>> F.embedding(input, embedding_matrix, padding_idx=0) 2510*da0073e9SAndroid Build Coastguard Worker tensor([[[ 0.0000, 0.0000, 0.0000], 2511*da0073e9SAndroid Build Coastguard Worker [ 0.5609, 0.5384, 0.8720], 2512*da0073e9SAndroid Build Coastguard Worker [ 0.0000, 0.0000, 0.0000], 2513*da0073e9SAndroid Build Coastguard Worker [ 0.6262, 0.2438, 0.7471]]]) 2514*da0073e9SAndroid Build Coastguard Worker """ 2515*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, weight): 2516*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2517*da0073e9SAndroid Build Coastguard Worker embedding, 2518*da0073e9SAndroid Build Coastguard Worker (input, weight), 2519*da0073e9SAndroid Build Coastguard Worker input, 2520*da0073e9SAndroid Build Coastguard Worker weight, 2521*da0073e9SAndroid Build Coastguard Worker padding_idx=padding_idx, 2522*da0073e9SAndroid Build Coastguard Worker max_norm=max_norm, 2523*da0073e9SAndroid Build Coastguard Worker norm_type=norm_type, 2524*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq=scale_grad_by_freq, 2525*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 2526*da0073e9SAndroid Build Coastguard Worker ) 2527*da0073e9SAndroid Build Coastguard Worker if padding_idx is not None: 2528*da0073e9SAndroid Build Coastguard Worker if padding_idx > 0: 2529*da0073e9SAndroid Build Coastguard Worker assert padding_idx < weight.size( 2530*da0073e9SAndroid Build Coastguard Worker 0 2531*da0073e9SAndroid Build Coastguard Worker ), "Padding_idx must be within num_embeddings" 2532*da0073e9SAndroid Build Coastguard Worker elif padding_idx < 0: 2533*da0073e9SAndroid Build Coastguard Worker assert padding_idx >= -weight.size( 2534*da0073e9SAndroid Build Coastguard Worker 0 2535*da0073e9SAndroid Build Coastguard Worker ), "Padding_idx must be within num_embeddings" 2536*da0073e9SAndroid Build Coastguard Worker padding_idx = weight.size(0) + padding_idx 2537*da0073e9SAndroid Build Coastguard Worker else: 2538*da0073e9SAndroid Build Coastguard Worker padding_idx = -1 2539*da0073e9SAndroid Build Coastguard Worker if max_norm is not None: 2540*da0073e9SAndroid Build Coastguard Worker # Note [embedding_renorm contiguous] 2541*da0073e9SAndroid Build Coastguard Worker # `embedding_renorm_` will call .contiguous() on input anyways, so we 2542*da0073e9SAndroid Build Coastguard Worker # call it here and take advantage of the improved locality in the 2543*da0073e9SAndroid Build Coastguard Worker # `embedding` call below too. 2544*da0073e9SAndroid Build Coastguard Worker input = input.contiguous() 2545*da0073e9SAndroid Build Coastguard Worker # Note [embedding_renorm set_grad_enabled] 2546*da0073e9SAndroid Build Coastguard Worker # XXX: equivalent to 2547*da0073e9SAndroid Build Coastguard Worker # with torch.no_grad(): 2548*da0073e9SAndroid Build Coastguard Worker # torch.embedding_renorm_ 2549*da0073e9SAndroid Build Coastguard Worker # remove once script supports set_grad_enabled 2550*da0073e9SAndroid Build Coastguard Worker _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) 2551*da0073e9SAndroid Build Coastguard Worker return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) 2552*da0073e9SAndroid Build Coastguard Worker 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Workerdef embedding_bag( 2555*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2556*da0073e9SAndroid Build Coastguard Worker weight: Tensor, 2557*da0073e9SAndroid Build Coastguard Worker offsets: Optional[Tensor] = None, 2558*da0073e9SAndroid Build Coastguard Worker max_norm: Optional[float] = None, 2559*da0073e9SAndroid Build Coastguard Worker norm_type: float = 2, 2560*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq: bool = False, 2561*da0073e9SAndroid Build Coastguard Worker mode: str = "mean", 2562*da0073e9SAndroid Build Coastguard Worker sparse: bool = False, 2563*da0073e9SAndroid Build Coastguard Worker per_sample_weights: Optional[Tensor] = None, 2564*da0073e9SAndroid Build Coastguard Worker include_last_offset: bool = False, 2565*da0073e9SAndroid Build Coastguard Worker padding_idx: Optional[int] = None, 2566*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2567*da0073e9SAndroid Build Coastguard Worker r"""Compute sums, means or maxes of `bags` of embeddings. 2568*da0073e9SAndroid Build Coastguard Worker 2569*da0073e9SAndroid Build Coastguard Worker Calculation is done without instantiating the intermediate embeddings. 2570*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.EmbeddingBag` for more details. 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker Note: 2573*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 2574*da0073e9SAndroid Build Coastguard Worker 2575*da0073e9SAndroid Build Coastguard Worker Args: 2576*da0073e9SAndroid Build Coastguard Worker input (LongTensor): Tensor containing bags of indices into the embedding matrix 2577*da0073e9SAndroid Build Coastguard Worker weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, 2578*da0073e9SAndroid Build Coastguard Worker and number of columns equal to the embedding size 2579*da0073e9SAndroid Build Coastguard Worker offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines 2580*da0073e9SAndroid Build Coastguard Worker the starting index position of each bag (sequence) in :attr:`input`. 2581*da0073e9SAndroid Build Coastguard Worker max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 2582*da0073e9SAndroid Build Coastguard Worker is renormalized to have norm :attr:`max_norm`. 2583*da0073e9SAndroid Build Coastguard Worker Note: this will modify :attr:`weight` in-place. 2584*da0073e9SAndroid Build Coastguard Worker norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. 2585*da0073e9SAndroid Build Coastguard Worker Default ``2``. 2586*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of 2587*da0073e9SAndroid Build Coastguard Worker the words in the mini-batch. Default ``False``. 2588*da0073e9SAndroid Build Coastguard Worker Note: this option is not supported when ``mode="max"``. 2589*da0073e9SAndroid Build Coastguard Worker mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. 2590*da0073e9SAndroid Build Coastguard Worker Default: ``"mean"`` 2591*da0073e9SAndroid Build Coastguard Worker sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under 2592*da0073e9SAndroid Build Coastguard Worker :class:`torch.nn.Embedding` for more details regarding sparse gradients. 2593*da0073e9SAndroid Build Coastguard Worker Note: this option is not supported when ``mode="max"``. 2594*da0073e9SAndroid Build Coastguard Worker per_sample_weights (Tensor, optional): a tensor of float / double weights, or None 2595*da0073e9SAndroid Build Coastguard Worker to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` 2596*da0073e9SAndroid Build Coastguard Worker must have exactly the same shape as input and is treated as having the same 2597*da0073e9SAndroid Build Coastguard Worker :attr:`offsets`, if those are not None. 2598*da0073e9SAndroid Build Coastguard Worker 2599*da0073e9SAndroid Build Coastguard Worker include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. 2600*da0073e9SAndroid Build Coastguard Worker The last element is the size of the input, or the ending index position of the last bag (sequence). 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the 2603*da0073e9SAndroid Build Coastguard Worker gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated 2604*da0073e9SAndroid Build Coastguard Worker during training, i.e. it remains as a fixed "pad". Note that the embedding 2605*da0073e9SAndroid Build Coastguard Worker vector at :attr:`padding_idx` is excluded from the reduction. 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker Shape: 2608*da0073e9SAndroid Build Coastguard Worker - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) 2611*da0073e9SAndroid Build Coastguard Worker each of fixed length ``N``, and this will return ``B`` values aggregated in a way 2612*da0073e9SAndroid Build Coastguard Worker depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of 2615*da0073e9SAndroid Build Coastguard Worker multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing 2616*da0073e9SAndroid Build Coastguard Worker the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` 2617*da0073e9SAndroid Build Coastguard Worker of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags. 2618*da0073e9SAndroid Build Coastguard Worker Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` 2621*da0073e9SAndroid Build Coastguard Worker 2622*da0073e9SAndroid Build Coastguard Worker - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`. 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` 2625*da0073e9SAndroid Build Coastguard Worker 2626*da0073e9SAndroid Build Coastguard Worker Examples:: 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker >>> # an Embedding module containing 10 tensors of size 3 2629*da0073e9SAndroid Build Coastguard Worker >>> embedding_matrix = torch.rand(10, 3) 2630*da0073e9SAndroid Build Coastguard Worker >>> # a batch of 2 samples of 4 indices each 2631*da0073e9SAndroid Build Coastguard Worker >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 2632*da0073e9SAndroid Build Coastguard Worker >>> offsets = torch.tensor([0, 4]) 2633*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2634*da0073e9SAndroid Build Coastguard Worker >>> F.embedding_bag(input, embedding_matrix, offsets) 2635*da0073e9SAndroid Build Coastguard Worker tensor([[ 0.3397, 0.3552, 0.5545], 2636*da0073e9SAndroid Build Coastguard Worker [ 0.5893, 0.4386, 0.5882]]) 2637*da0073e9SAndroid Build Coastguard Worker 2638*da0073e9SAndroid Build Coastguard Worker >>> # example with padding_idx 2639*da0073e9SAndroid Build Coastguard Worker >>> embedding_matrix = torch.rand(10, 3) 2640*da0073e9SAndroid Build Coastguard Worker >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) 2641*da0073e9SAndroid Build Coastguard Worker >>> offsets = torch.tensor([0, 4]) 2642*da0073e9SAndroid Build Coastguard Worker >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') 2643*da0073e9SAndroid Build Coastguard Worker tensor([[ 0.0000, 0.0000, 0.0000], 2644*da0073e9SAndroid Build Coastguard Worker [-0.7082, 3.2145, -2.6251]]) 2645*da0073e9SAndroid Build Coastguard Worker """ 2646*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, weight, offsets, per_sample_weights): 2647*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2648*da0073e9SAndroid Build Coastguard Worker embedding_bag, 2649*da0073e9SAndroid Build Coastguard Worker (input, weight, offsets, per_sample_weights), 2650*da0073e9SAndroid Build Coastguard Worker input, 2651*da0073e9SAndroid Build Coastguard Worker weight, 2652*da0073e9SAndroid Build Coastguard Worker offsets=offsets, 2653*da0073e9SAndroid Build Coastguard Worker max_norm=max_norm, 2654*da0073e9SAndroid Build Coastguard Worker norm_type=norm_type, 2655*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq=scale_grad_by_freq, 2656*da0073e9SAndroid Build Coastguard Worker mode=mode, 2657*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 2658*da0073e9SAndroid Build Coastguard Worker per_sample_weights=per_sample_weights, 2659*da0073e9SAndroid Build Coastguard Worker include_last_offset=include_last_offset, 2660*da0073e9SAndroid Build Coastguard Worker padding_idx=padding_idx, 2661*da0073e9SAndroid Build Coastguard Worker ) 2662*da0073e9SAndroid Build Coastguard Worker # Check for backward compatibility. 2663*da0073e9SAndroid Build Coastguard Worker # Used to be embedding_bag(weight, input, ...) 2664*da0073e9SAndroid Build Coastguard Worker # Now is embedding_bag(input, weight, ...) 2665*da0073e9SAndroid Build Coastguard Worker if weight.dtype == torch.long and input.is_floating_point(): 2666*da0073e9SAndroid Build Coastguard Worker warnings.warn( 2667*da0073e9SAndroid Build Coastguard Worker "Argument order of nn.functional.embedding_bag was changed. " 2668*da0073e9SAndroid Build Coastguard Worker "Usage `embedding_bag(weight, input, ...)` is deprecated, " 2669*da0073e9SAndroid Build Coastguard Worker "and should now be `embedding_bag(input, weight, ...)`." 2670*da0073e9SAndroid Build Coastguard Worker ) 2671*da0073e9SAndroid Build Coastguard Worker weight, input = input, weight 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker if per_sample_weights is not None and input.size() != per_sample_weights.size(): 2674*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2675*da0073e9SAndroid Build Coastguard Worker f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " 2676*da0073e9SAndroid Build Coastguard Worker f"then it must have the same shape as the input ({input.shape})" 2677*da0073e9SAndroid Build Coastguard Worker ) 2678*da0073e9SAndroid Build Coastguard Worker 2679*da0073e9SAndroid Build Coastguard Worker if not weight.dim() == 2: 2680*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2681*da0073e9SAndroid Build Coastguard Worker f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" 2682*da0073e9SAndroid Build Coastguard Worker ) 2683*da0073e9SAndroid Build Coastguard Worker 2684*da0073e9SAndroid Build Coastguard Worker if input.dim() == 2: 2685*da0073e9SAndroid Build Coastguard Worker if offsets is not None: 2686*da0073e9SAndroid Build Coastguard Worker type_str = "<unknown>" 2687*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this once script supports type() calls 2688*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 2689*da0073e9SAndroid Build Coastguard Worker type_str = str(type(offsets)) 2690*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2691*da0073e9SAndroid Build Coastguard Worker "if input is 2D, then offsets has to be None" 2692*da0073e9SAndroid Build Coastguard Worker ", as input is treated is a mini-batch of" 2693*da0073e9SAndroid Build Coastguard Worker " fixed length sequences. However, found " 2694*da0073e9SAndroid Build Coastguard Worker f"offsets of type {type_str}" 2695*da0073e9SAndroid Build Coastguard Worker ) 2696*da0073e9SAndroid Build Coastguard Worker offsets = torch.arange( 2697*da0073e9SAndroid Build Coastguard Worker 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device 2698*da0073e9SAndroid Build Coastguard Worker ) 2699*da0073e9SAndroid Build Coastguard Worker 2700*da0073e9SAndroid Build Coastguard Worker input = input.reshape(-1) 2701*da0073e9SAndroid Build Coastguard Worker if per_sample_weights is not None: 2702*da0073e9SAndroid Build Coastguard Worker per_sample_weights = per_sample_weights.reshape(-1) 2703*da0073e9SAndroid Build Coastguard Worker elif input.dim() == 1: 2704*da0073e9SAndroid Build Coastguard Worker if offsets is None: 2705*da0073e9SAndroid Build Coastguard Worker raise ValueError("offsets has to be a 1D Tensor but got None") 2706*da0073e9SAndroid Build Coastguard Worker if offsets.dim() != 1: 2707*da0073e9SAndroid Build Coastguard Worker raise ValueError("offsets has to be a 1D Tensor") 2708*da0073e9SAndroid Build Coastguard Worker else: 2709*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2710*da0073e9SAndroid Build Coastguard Worker f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" 2711*da0073e9SAndroid Build Coastguard Worker ) 2712*da0073e9SAndroid Build Coastguard Worker if mode == "sum": 2713*da0073e9SAndroid Build Coastguard Worker mode_enum = 0 2714*da0073e9SAndroid Build Coastguard Worker elif mode == "mean": 2715*da0073e9SAndroid Build Coastguard Worker mode_enum = 1 2716*da0073e9SAndroid Build Coastguard Worker elif mode == "max": 2717*da0073e9SAndroid Build Coastguard Worker mode_enum = 2 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker if scale_grad_by_freq: 2720*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2721*da0073e9SAndroid Build Coastguard Worker "max mode does not support scaling the gradient by the frequency" 2722*da0073e9SAndroid Build Coastguard Worker ) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker if sparse: 2725*da0073e9SAndroid Build Coastguard Worker raise ValueError("max mode does not support sparse weights") 2726*da0073e9SAndroid Build Coastguard Worker 2727*da0073e9SAndroid Build Coastguard Worker else: 2728*da0073e9SAndroid Build Coastguard Worker raise ValueError("mode has to be one of sum, mean or max") 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker if max_norm is not None: 2731*da0073e9SAndroid Build Coastguard Worker # XXX: equivalent to 2732*da0073e9SAndroid Build Coastguard Worker # with torch.no_grad(): 2733*da0073e9SAndroid Build Coastguard Worker # torch.nembedding_renorm_ 2734*da0073e9SAndroid Build Coastguard Worker # remove once script supports set_grad_enabled 2735*da0073e9SAndroid Build Coastguard Worker _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker if per_sample_weights is not None and mode != "sum": 2738*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 2739*da0073e9SAndroid Build Coastguard Worker "embedding_bag: per_sample_weights was not None. " 2740*da0073e9SAndroid Build Coastguard Worker "per_sample_weights is only supported for mode='sum' " 2741*da0073e9SAndroid Build Coastguard Worker f"(got mode='{mode}'). Please open a feature request on GitHub." 2742*da0073e9SAndroid Build Coastguard Worker ) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker ret, _, _, _ = torch.embedding_bag( 2745*da0073e9SAndroid Build Coastguard Worker weight, 2746*da0073e9SAndroid Build Coastguard Worker input, 2747*da0073e9SAndroid Build Coastguard Worker offsets, 2748*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq, 2749*da0073e9SAndroid Build Coastguard Worker mode_enum, 2750*da0073e9SAndroid Build Coastguard Worker sparse, 2751*da0073e9SAndroid Build Coastguard Worker per_sample_weights, 2752*da0073e9SAndroid Build Coastguard Worker include_last_offset, 2753*da0073e9SAndroid Build Coastguard Worker padding_idx, 2754*da0073e9SAndroid Build Coastguard Worker ) 2755*da0073e9SAndroid Build Coastguard Worker return ret 2756*da0073e9SAndroid Build Coastguard Worker 2757*da0073e9SAndroid Build Coastguard Worker 2758*da0073e9SAndroid Build Coastguard Workerif embedding_bag.__doc__: 2759*da0073e9SAndroid Build Coastguard Worker embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) 2760*da0073e9SAndroid Build Coastguard Worker 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Workerdef _verify_batch_size(size: List[int]) -> None: 2763*da0073e9SAndroid Build Coastguard Worker # XXX: JIT script does not support the reduce from functools, and mul op is a 2764*da0073e9SAndroid Build Coastguard Worker # builtin, which cannot be used as a value to a func yet, so rewrite this size 2765*da0073e9SAndroid Build Coastguard Worker # check to a simple equivalent for loop 2766*da0073e9SAndroid Build Coastguard Worker # 2767*da0073e9SAndroid Build Coastguard Worker # TODO: make use of reduce like below when JIT is ready with the missing features: 2768*da0073e9SAndroid Build Coastguard Worker # from operator import mul 2769*da0073e9SAndroid Build Coastguard Worker # from functools import reduce 2770*da0073e9SAndroid Build Coastguard Worker # 2771*da0073e9SAndroid Build Coastguard Worker # if reduce(mul, size[2:], size[0]) == 1 2772*da0073e9SAndroid Build Coastguard Worker size_prods = size[0] 2773*da0073e9SAndroid Build Coastguard Worker for i in range(len(size) - 2): 2774*da0073e9SAndroid Build Coastguard Worker size_prods *= size[i + 2] 2775*da0073e9SAndroid Build Coastguard Worker if size_prods == 1: 2776*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2777*da0073e9SAndroid Build Coastguard Worker f"Expected more than 1 value per channel when training, got input size {size}" 2778*da0073e9SAndroid Build Coastguard Worker ) 2779*da0073e9SAndroid Build Coastguard Worker 2780*da0073e9SAndroid Build Coastguard Worker 2781*da0073e9SAndroid Build Coastguard Workerdef batch_norm( 2782*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2783*da0073e9SAndroid Build Coastguard Worker running_mean: Optional[Tensor], 2784*da0073e9SAndroid Build Coastguard Worker running_var: Optional[Tensor], 2785*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 2786*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = None, 2787*da0073e9SAndroid Build Coastguard Worker training: bool = False, 2788*da0073e9SAndroid Build Coastguard Worker momentum: float = 0.1, 2789*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-5, 2790*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2791*da0073e9SAndroid Build Coastguard Worker r"""Apply Batch Normalization for each channel across a batch of data. 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, 2794*da0073e9SAndroid Build Coastguard Worker :class:`~torch.nn.BatchNorm3d` for details. 2795*da0073e9SAndroid Build Coastguard Worker """ 2796*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, running_mean, running_var, weight, bias): 2797*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2798*da0073e9SAndroid Build Coastguard Worker batch_norm, 2799*da0073e9SAndroid Build Coastguard Worker (input, running_mean, running_var, weight, bias), 2800*da0073e9SAndroid Build Coastguard Worker input, 2801*da0073e9SAndroid Build Coastguard Worker running_mean, 2802*da0073e9SAndroid Build Coastguard Worker running_var, 2803*da0073e9SAndroid Build Coastguard Worker weight=weight, 2804*da0073e9SAndroid Build Coastguard Worker bias=bias, 2805*da0073e9SAndroid Build Coastguard Worker training=training, 2806*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2807*da0073e9SAndroid Build Coastguard Worker eps=eps, 2808*da0073e9SAndroid Build Coastguard Worker ) 2809*da0073e9SAndroid Build Coastguard Worker if training: 2810*da0073e9SAndroid Build Coastguard Worker _verify_batch_size(input.size()) 2811*da0073e9SAndroid Build Coastguard Worker 2812*da0073e9SAndroid Build Coastguard Worker return torch.batch_norm( 2813*da0073e9SAndroid Build Coastguard Worker input, 2814*da0073e9SAndroid Build Coastguard Worker weight, 2815*da0073e9SAndroid Build Coastguard Worker bias, 2816*da0073e9SAndroid Build Coastguard Worker running_mean, 2817*da0073e9SAndroid Build Coastguard Worker running_var, 2818*da0073e9SAndroid Build Coastguard Worker training, 2819*da0073e9SAndroid Build Coastguard Worker momentum, 2820*da0073e9SAndroid Build Coastguard Worker eps, 2821*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.enabled, 2822*da0073e9SAndroid Build Coastguard Worker ) 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Workerdef _verify_spatial_size(size: List[int]) -> None: 2826*da0073e9SAndroid Build Coastguard Worker # Verify that there is > 1 spatial element for instance norm calculation. 2827*da0073e9SAndroid Build Coastguard Worker size_prods = 1 2828*da0073e9SAndroid Build Coastguard Worker for i in range(2, len(size)): 2829*da0073e9SAndroid Build Coastguard Worker size_prods *= size[i] 2830*da0073e9SAndroid Build Coastguard Worker if size_prods == 1: 2831*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2832*da0073e9SAndroid Build Coastguard Worker f"Expected more than 1 spatial element when training, got input size {size}" 2833*da0073e9SAndroid Build Coastguard Worker ) 2834*da0073e9SAndroid Build Coastguard Worker 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Workerdef instance_norm( 2837*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2838*da0073e9SAndroid Build Coastguard Worker running_mean: Optional[Tensor] = None, 2839*da0073e9SAndroid Build Coastguard Worker running_var: Optional[Tensor] = None, 2840*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 2841*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = None, 2842*da0073e9SAndroid Build Coastguard Worker use_input_stats: bool = True, 2843*da0073e9SAndroid Build Coastguard Worker momentum: float = 0.1, 2844*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-5, 2845*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2846*da0073e9SAndroid Build Coastguard Worker r"""Apply Instance Normalization independently for each channel in every data sample within a batch. 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, 2849*da0073e9SAndroid Build Coastguard Worker :class:`~torch.nn.InstanceNorm3d` for details. 2850*da0073e9SAndroid Build Coastguard Worker """ 2851*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, running_mean, running_var, weight, bias): 2852*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2853*da0073e9SAndroid Build Coastguard Worker instance_norm, 2854*da0073e9SAndroid Build Coastguard Worker (input, running_mean, running_var, weight, bias), 2855*da0073e9SAndroid Build Coastguard Worker input, 2856*da0073e9SAndroid Build Coastguard Worker running_mean=running_mean, 2857*da0073e9SAndroid Build Coastguard Worker running_var=running_var, 2858*da0073e9SAndroid Build Coastguard Worker weight=weight, 2859*da0073e9SAndroid Build Coastguard Worker bias=bias, 2860*da0073e9SAndroid Build Coastguard Worker use_input_stats=use_input_stats, 2861*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2862*da0073e9SAndroid Build Coastguard Worker eps=eps, 2863*da0073e9SAndroid Build Coastguard Worker ) 2864*da0073e9SAndroid Build Coastguard Worker if use_input_stats: 2865*da0073e9SAndroid Build Coastguard Worker _verify_spatial_size(input.size()) 2866*da0073e9SAndroid Build Coastguard Worker return torch.instance_norm( 2867*da0073e9SAndroid Build Coastguard Worker input, 2868*da0073e9SAndroid Build Coastguard Worker weight, 2869*da0073e9SAndroid Build Coastguard Worker bias, 2870*da0073e9SAndroid Build Coastguard Worker running_mean, 2871*da0073e9SAndroid Build Coastguard Worker running_var, 2872*da0073e9SAndroid Build Coastguard Worker use_input_stats, 2873*da0073e9SAndroid Build Coastguard Worker momentum, 2874*da0073e9SAndroid Build Coastguard Worker eps, 2875*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.enabled, 2876*da0073e9SAndroid Build Coastguard Worker ) 2877*da0073e9SAndroid Build Coastguard Worker 2878*da0073e9SAndroid Build Coastguard Worker 2879*da0073e9SAndroid Build Coastguard Workerdef layer_norm( 2880*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2881*da0073e9SAndroid Build Coastguard Worker normalized_shape: List[int], 2882*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 2883*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = None, 2884*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-5, 2885*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2886*da0073e9SAndroid Build Coastguard Worker r"""Apply Layer Normalization for last certain number of dimensions. 2887*da0073e9SAndroid Build Coastguard Worker 2888*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LayerNorm` for details. 2889*da0073e9SAndroid Build Coastguard Worker """ 2890*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, weight, bias): 2891*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2892*da0073e9SAndroid Build Coastguard Worker layer_norm, 2893*da0073e9SAndroid Build Coastguard Worker (input, weight, bias), 2894*da0073e9SAndroid Build Coastguard Worker input, 2895*da0073e9SAndroid Build Coastguard Worker normalized_shape, 2896*da0073e9SAndroid Build Coastguard Worker weight=weight, 2897*da0073e9SAndroid Build Coastguard Worker bias=bias, 2898*da0073e9SAndroid Build Coastguard Worker eps=eps, 2899*da0073e9SAndroid Build Coastguard Worker ) 2900*da0073e9SAndroid Build Coastguard Worker return torch.layer_norm( 2901*da0073e9SAndroid Build Coastguard Worker input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled 2902*da0073e9SAndroid Build Coastguard Worker ) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker 2905*da0073e9SAndroid Build Coastguard Workerdef rms_norm( 2906*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2907*da0073e9SAndroid Build Coastguard Worker normalized_shape: List[int], 2908*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 2909*da0073e9SAndroid Build Coastguard Worker eps: Optional[float] = None, 2910*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2911*da0073e9SAndroid Build Coastguard Worker r"""Apply Root Mean Square Layer Normalization. 2912*da0073e9SAndroid Build Coastguard Worker 2913*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.RMSNorm` for details. 2914*da0073e9SAndroid Build Coastguard Worker """ 2915*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, weight): 2916*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2917*da0073e9SAndroid Build Coastguard Worker rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps 2918*da0073e9SAndroid Build Coastguard Worker ) 2919*da0073e9SAndroid Build Coastguard Worker return torch.rms_norm(input, normalized_shape, weight, eps) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker 2922*da0073e9SAndroid Build Coastguard Workerdef group_norm( 2923*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2924*da0073e9SAndroid Build Coastguard Worker num_groups: int, 2925*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 2926*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = None, 2927*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-5, 2928*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2929*da0073e9SAndroid Build Coastguard Worker r"""Apply Group Normalization for last certain number of dimensions. 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.GroupNorm` for details. 2932*da0073e9SAndroid Build Coastguard Worker """ 2933*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, weight, bias): 2934*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2935*da0073e9SAndroid Build Coastguard Worker group_norm, 2936*da0073e9SAndroid Build Coastguard Worker ( 2937*da0073e9SAndroid Build Coastguard Worker input, 2938*da0073e9SAndroid Build Coastguard Worker weight, 2939*da0073e9SAndroid Build Coastguard Worker bias, 2940*da0073e9SAndroid Build Coastguard Worker ), 2941*da0073e9SAndroid Build Coastguard Worker input, 2942*da0073e9SAndroid Build Coastguard Worker num_groups, 2943*da0073e9SAndroid Build Coastguard Worker weight=weight, 2944*da0073e9SAndroid Build Coastguard Worker bias=bias, 2945*da0073e9SAndroid Build Coastguard Worker eps=eps, 2946*da0073e9SAndroid Build Coastguard Worker ) 2947*da0073e9SAndroid Build Coastguard Worker if input.dim() < 2: 2948*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 2949*da0073e9SAndroid Build Coastguard Worker f"Expected at least 2 dimensions for input tensor but received {input.dim()}" 2950*da0073e9SAndroid Build Coastguard Worker ) 2951*da0073e9SAndroid Build Coastguard Worker _verify_batch_size( 2952*da0073e9SAndroid Build Coastguard Worker [input.size(0) * input.size(1) // num_groups, num_groups] 2953*da0073e9SAndroid Build Coastguard Worker + list(input.size()[2:]) 2954*da0073e9SAndroid Build Coastguard Worker ) 2955*da0073e9SAndroid Build Coastguard Worker return torch.group_norm( 2956*da0073e9SAndroid Build Coastguard Worker input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled 2957*da0073e9SAndroid Build Coastguard Worker ) 2958*da0073e9SAndroid Build Coastguard Worker 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Workerdef local_response_norm( 2961*da0073e9SAndroid Build Coastguard Worker input: Tensor, 2962*da0073e9SAndroid Build Coastguard Worker size: int, 2963*da0073e9SAndroid Build Coastguard Worker alpha: float = 1e-4, 2964*da0073e9SAndroid Build Coastguard Worker beta: float = 0.75, 2965*da0073e9SAndroid Build Coastguard Worker k: float = 1.0, 2966*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 2967*da0073e9SAndroid Build Coastguard Worker r"""Apply local response normalization over an input signal. 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker The input signal is composed of several input planes, where channels occupy the second dimension. 2970*da0073e9SAndroid Build Coastguard Worker Normalization is applied across channels. 2971*da0073e9SAndroid Build Coastguard Worker 2972*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.LocalResponseNorm` for details. 2973*da0073e9SAndroid Build Coastguard Worker """ 2974*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 2975*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2976*da0073e9SAndroid Build Coastguard Worker local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k 2977*da0073e9SAndroid Build Coastguard Worker ) 2978*da0073e9SAndroid Build Coastguard Worker dim = input.dim() 2979*da0073e9SAndroid Build Coastguard Worker if dim < 3: 2980*da0073e9SAndroid Build Coastguard Worker raise ValueError( 2981*da0073e9SAndroid Build Coastguard Worker f"Expected 3D or higher dimensionality input (got {dim} dimensions)" 2982*da0073e9SAndroid Build Coastguard Worker ) 2983*da0073e9SAndroid Build Coastguard Worker 2984*da0073e9SAndroid Build Coastguard Worker if input.numel() == 0: 2985*da0073e9SAndroid Build Coastguard Worker return input 2986*da0073e9SAndroid Build Coastguard Worker 2987*da0073e9SAndroid Build Coastguard Worker div = input.mul(input) 2988*da0073e9SAndroid Build Coastguard Worker if dim == 3: 2989*da0073e9SAndroid Build Coastguard Worker div = div.unsqueeze(1) 2990*da0073e9SAndroid Build Coastguard Worker div = pad(div, (0, 0, size // 2, (size - 1) // 2)) 2991*da0073e9SAndroid Build Coastguard Worker div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) 2992*da0073e9SAndroid Build Coastguard Worker else: 2993*da0073e9SAndroid Build Coastguard Worker sizes = input.size() 2994*da0073e9SAndroid Build Coastguard Worker div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) 2995*da0073e9SAndroid Build Coastguard Worker div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) 2996*da0073e9SAndroid Build Coastguard Worker div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) 2997*da0073e9SAndroid Build Coastguard Worker div = div.view(sizes) 2998*da0073e9SAndroid Build Coastguard Worker div = div.mul(alpha).add(k).pow(beta) 2999*da0073e9SAndroid Build Coastguard Worker return input / div 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker 3002*da0073e9SAndroid Build Coastguard Worker# loss 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker 3005*da0073e9SAndroid Build Coastguard Workerdef ctc_loss( 3006*da0073e9SAndroid Build Coastguard Worker log_probs: Tensor, 3007*da0073e9SAndroid Build Coastguard Worker targets: Tensor, 3008*da0073e9SAndroid Build Coastguard Worker input_lengths: Tensor, 3009*da0073e9SAndroid Build Coastguard Worker target_lengths: Tensor, 3010*da0073e9SAndroid Build Coastguard Worker blank: int = 0, 3011*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3012*da0073e9SAndroid Build Coastguard Worker zero_infinity: bool = False, 3013*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3014*da0073e9SAndroid Build Coastguard Worker r"""Apply the Connectionist Temporal Classification loss. 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.CTCLoss` for details. 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker Note: 3019*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker Note: 3022*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 3023*da0073e9SAndroid Build Coastguard Worker 3024*da0073e9SAndroid Build Coastguard Worker Args: 3025*da0073e9SAndroid Build Coastguard Worker log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`, 3026*da0073e9SAndroid Build Coastguard Worker `T = input length`, and `N = batch size`. 3027*da0073e9SAndroid Build Coastguard Worker The logarithmized probabilities of the outputs 3028*da0073e9SAndroid Build Coastguard Worker (e.g. obtained with :func:`torch.nn.functional.log_softmax`). 3029*da0073e9SAndroid Build Coastguard Worker targets: :math:`(N, S)` or `(sum(target_lengths))`. 3030*da0073e9SAndroid Build Coastguard Worker Targets cannot be blank. In the second form, the targets are assumed to be concatenated. 3031*da0073e9SAndroid Build Coastguard Worker input_lengths: :math:`(N)` or :math:`()`. 3032*da0073e9SAndroid Build Coastguard Worker Lengths of the inputs (must each be :math:`\leq T`) 3033*da0073e9SAndroid Build Coastguard Worker target_lengths: :math:`(N)` or :math:`()`. 3034*da0073e9SAndroid Build Coastguard Worker Lengths of the targets 3035*da0073e9SAndroid Build Coastguard Worker blank (int, optional): 3036*da0073e9SAndroid Build Coastguard Worker Blank label. Default :math:`0`. 3037*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3038*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3039*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the output losses will be divided by the target lengths and 3040*da0073e9SAndroid Build Coastguard Worker then the mean over the batch is taken, ``'sum'``: the output will be 3041*da0073e9SAndroid Build Coastguard Worker summed. Default: ``'mean'`` 3042*da0073e9SAndroid Build Coastguard Worker zero_infinity (bool, optional): 3043*da0073e9SAndroid Build Coastguard Worker Whether to zero infinite losses and the associated gradients. 3044*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 3045*da0073e9SAndroid Build Coastguard Worker Infinite losses mainly occur when the inputs are too short 3046*da0073e9SAndroid Build Coastguard Worker to be aligned to the targets. 3047*da0073e9SAndroid Build Coastguard Worker 3048*da0073e9SAndroid Build Coastguard Worker Example:: 3049*da0073e9SAndroid Build Coastguard Worker 3050*da0073e9SAndroid Build Coastguard Worker >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() 3051*da0073e9SAndroid Build Coastguard Worker >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) 3052*da0073e9SAndroid Build Coastguard Worker >>> input_lengths = torch.full((16,), 50, dtype=torch.long) 3053*da0073e9SAndroid Build Coastguard Worker >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) 3054*da0073e9SAndroid Build Coastguard Worker >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) 3055*da0073e9SAndroid Build Coastguard Worker >>> loss.backward() 3056*da0073e9SAndroid Build Coastguard Worker """ 3057*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths): 3058*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3059*da0073e9SAndroid Build Coastguard Worker ctc_loss, 3060*da0073e9SAndroid Build Coastguard Worker (log_probs, targets, input_lengths, target_lengths), 3061*da0073e9SAndroid Build Coastguard Worker log_probs, 3062*da0073e9SAndroid Build Coastguard Worker targets, 3063*da0073e9SAndroid Build Coastguard Worker input_lengths, 3064*da0073e9SAndroid Build Coastguard Worker target_lengths, 3065*da0073e9SAndroid Build Coastguard Worker blank=blank, 3066*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3067*da0073e9SAndroid Build Coastguard Worker zero_infinity=zero_infinity, 3068*da0073e9SAndroid Build Coastguard Worker ) 3069*da0073e9SAndroid Build Coastguard Worker return torch.ctc_loss( 3070*da0073e9SAndroid Build Coastguard Worker log_probs, 3071*da0073e9SAndroid Build Coastguard Worker targets, 3072*da0073e9SAndroid Build Coastguard Worker input_lengths, 3073*da0073e9SAndroid Build Coastguard Worker target_lengths, 3074*da0073e9SAndroid Build Coastguard Worker blank, 3075*da0073e9SAndroid Build Coastguard Worker _Reduction.get_enum(reduction), 3076*da0073e9SAndroid Build Coastguard Worker zero_infinity, 3077*da0073e9SAndroid Build Coastguard Worker ) 3078*da0073e9SAndroid Build Coastguard Worker 3079*da0073e9SAndroid Build Coastguard Worker 3080*da0073e9SAndroid Build Coastguard Workerif ctc_loss.__doc__: 3081*da0073e9SAndroid Build Coastguard Worker ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) 3082*da0073e9SAndroid Build Coastguard Worker 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Workerdef nll_loss( 3085*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3086*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3087*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 3088*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3089*da0073e9SAndroid Build Coastguard Worker ignore_index: int = -100, 3090*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3091*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3092*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3093*da0073e9SAndroid Build Coastguard Worker r"""Compute the negative log likelihood loss. 3094*da0073e9SAndroid Build Coastguard Worker 3095*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.NLLLoss` for details. 3096*da0073e9SAndroid Build Coastguard Worker 3097*da0073e9SAndroid Build Coastguard Worker Args: 3098*da0073e9SAndroid Build Coastguard Worker input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` 3099*da0073e9SAndroid Build Coastguard Worker in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` 3100*da0073e9SAndroid Build Coastguard Worker in the case of K-dimensional loss. `input` is expected to be log-probabilities. 3101*da0073e9SAndroid Build Coastguard Worker target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, 3102*da0073e9SAndroid Build Coastguard Worker or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for 3103*da0073e9SAndroid Build Coastguard Worker K-dimensional loss. 3104*da0073e9SAndroid Build Coastguard Worker weight (Tensor, optional): a manual rescaling weight given to each 3105*da0073e9SAndroid Build Coastguard Worker class. If given, has to be a Tensor of size `C` 3106*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3107*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3108*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3109*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3110*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3111*da0073e9SAndroid Build Coastguard Worker ignore_index (int, optional): Specifies a target value that is ignored 3112*da0073e9SAndroid Build Coastguard Worker and does not contribute to the input gradient. When :attr:`size_average` is 3113*da0073e9SAndroid Build Coastguard Worker ``True``, the loss is averaged over non-ignored targets. Default: -100 3114*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3115*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3116*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3117*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3118*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3119*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3120*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the sum of the output will be divided by the number of 3121*da0073e9SAndroid Build Coastguard Worker elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3122*da0073e9SAndroid Build Coastguard Worker and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3123*da0073e9SAndroid Build Coastguard Worker specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3124*da0073e9SAndroid Build Coastguard Worker 3125*da0073e9SAndroid Build Coastguard Worker Example:: 3126*da0073e9SAndroid Build Coastguard Worker 3127*da0073e9SAndroid Build Coastguard Worker >>> # input is of size N x C = 3 x 5 3128*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(3, 5, requires_grad=True) 3129*da0073e9SAndroid Build Coastguard Worker >>> # each element in target has to have 0 <= value < C 3130*da0073e9SAndroid Build Coastguard Worker >>> target = torch.tensor([1, 0, 4]) 3131*da0073e9SAndroid Build Coastguard Worker >>> output = F.nll_loss(F.log_softmax(input, dim=1), target) 3132*da0073e9SAndroid Build Coastguard Worker >>> output.backward() 3133*da0073e9SAndroid Build Coastguard Worker """ 3134*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight): 3135*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3136*da0073e9SAndroid Build Coastguard Worker nll_loss, 3137*da0073e9SAndroid Build Coastguard Worker (input, target, weight), 3138*da0073e9SAndroid Build Coastguard Worker input, 3139*da0073e9SAndroid Build Coastguard Worker target, 3140*da0073e9SAndroid Build Coastguard Worker weight=weight, 3141*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3142*da0073e9SAndroid Build Coastguard Worker ignore_index=ignore_index, 3143*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3144*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3145*da0073e9SAndroid Build Coastguard Worker ) 3146*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3147*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3148*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.nll_loss_nd( 3149*da0073e9SAndroid Build Coastguard Worker input, target, weight, _Reduction.get_enum(reduction), ignore_index 3150*da0073e9SAndroid Build Coastguard Worker ) 3151*da0073e9SAndroid Build Coastguard Worker 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Workerdef poisson_nll_loss( 3154*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3155*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3156*da0073e9SAndroid Build Coastguard Worker log_input: bool = True, 3157*da0073e9SAndroid Build Coastguard Worker full: bool = False, 3158*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3159*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-8, 3160*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3161*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3162*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3163*da0073e9SAndroid Build Coastguard Worker r"""Poisson negative log likelihood loss. 3164*da0073e9SAndroid Build Coastguard Worker 3165*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.PoissonNLLLoss` for details. 3166*da0073e9SAndroid Build Coastguard Worker 3167*da0073e9SAndroid Build Coastguard Worker Args: 3168*da0073e9SAndroid Build Coastguard Worker input: expectation of underlying Poisson distribution. 3169*da0073e9SAndroid Build Coastguard Worker target: random sample :math:`target \sim \text{Poisson}(input)`. 3170*da0073e9SAndroid Build Coastguard Worker log_input: if ``True`` the loss is computed as 3171*da0073e9SAndroid Build Coastguard Worker :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is 3172*da0073e9SAndroid Build Coastguard Worker :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True`` 3173*da0073e9SAndroid Build Coastguard Worker full: whether to compute full loss, i. e. to add the Stirling 3174*da0073e9SAndroid Build Coastguard Worker approximation term. Default: ``False`` 3175*da0073e9SAndroid Build Coastguard Worker :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`. 3176*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3177*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3178*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3179*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3180*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3181*da0073e9SAndroid Build Coastguard Worker eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when 3182*da0073e9SAndroid Build Coastguard Worker :attr:`log_input`\ =\ ``False``. Default: 1e-8 3183*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3184*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3185*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3186*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3187*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3188*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3189*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the sum of the output will be divided by the number of 3190*da0073e9SAndroid Build Coastguard Worker elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3191*da0073e9SAndroid Build Coastguard Worker and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3192*da0073e9SAndroid Build Coastguard Worker specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3193*da0073e9SAndroid Build Coastguard Worker 3194*da0073e9SAndroid Build Coastguard Worker """ 3195*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3196*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3197*da0073e9SAndroid Build Coastguard Worker poisson_nll_loss, 3198*da0073e9SAndroid Build Coastguard Worker (input, target), 3199*da0073e9SAndroid Build Coastguard Worker input, 3200*da0073e9SAndroid Build Coastguard Worker target, 3201*da0073e9SAndroid Build Coastguard Worker log_input=log_input, 3202*da0073e9SAndroid Build Coastguard Worker full=full, 3203*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3204*da0073e9SAndroid Build Coastguard Worker eps=eps, 3205*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3206*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3207*da0073e9SAndroid Build Coastguard Worker ) 3208*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3209*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3210*da0073e9SAndroid Build Coastguard Worker if reduction != "none" and reduction != "mean" and reduction != "sum": 3211*da0073e9SAndroid Build Coastguard Worker ret = input 3212*da0073e9SAndroid Build Coastguard Worker raise ValueError(reduction + " is not a valid value for reduction") 3213*da0073e9SAndroid Build Coastguard Worker 3214*da0073e9SAndroid Build Coastguard Worker ret = torch.poisson_nll_loss( 3215*da0073e9SAndroid Build Coastguard Worker input, target, log_input, full, eps, _Reduction.get_enum(reduction) 3216*da0073e9SAndroid Build Coastguard Worker ) 3217*da0073e9SAndroid Build Coastguard Worker return ret 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Workerdef gaussian_nll_loss( 3221*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3222*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3223*da0073e9SAndroid Build Coastguard Worker var: Tensor, 3224*da0073e9SAndroid Build Coastguard Worker full: bool = False, 3225*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-6, 3226*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3227*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3228*da0073e9SAndroid Build Coastguard Worker r"""Gaussian negative log likelihood loss. 3229*da0073e9SAndroid Build Coastguard Worker 3230*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.GaussianNLLLoss` for details. 3231*da0073e9SAndroid Build Coastguard Worker 3232*da0073e9SAndroid Build Coastguard Worker Args: 3233*da0073e9SAndroid Build Coastguard Worker input: expectation of the Gaussian distribution. 3234*da0073e9SAndroid Build Coastguard Worker target: sample from the Gaussian distribution. 3235*da0073e9SAndroid Build Coastguard Worker var: tensor of positive variance(s), one for each of the expectations 3236*da0073e9SAndroid Build Coastguard Worker in the input (heteroscedastic), or a single one (homoscedastic). 3237*da0073e9SAndroid Build Coastguard Worker full (bool, optional): include the constant term in the loss calculation. Default: ``False``. 3238*da0073e9SAndroid Build Coastguard Worker eps (float, optional): value added to var, for stability. Default: 1e-6. 3239*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): specifies the reduction to apply to the output: 3240*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3241*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the output is the average of all batch member losses, 3242*da0073e9SAndroid Build Coastguard Worker ``'sum'``: the output is the sum of all batch member losses. 3243*da0073e9SAndroid Build Coastguard Worker Default: ``'mean'``. 3244*da0073e9SAndroid Build Coastguard Worker """ 3245*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, var): 3246*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3247*da0073e9SAndroid Build Coastguard Worker gaussian_nll_loss, 3248*da0073e9SAndroid Build Coastguard Worker (input, target, var), 3249*da0073e9SAndroid Build Coastguard Worker input, 3250*da0073e9SAndroid Build Coastguard Worker target, 3251*da0073e9SAndroid Build Coastguard Worker var, 3252*da0073e9SAndroid Build Coastguard Worker full=full, 3253*da0073e9SAndroid Build Coastguard Worker eps=eps, 3254*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3255*da0073e9SAndroid Build Coastguard Worker ) 3256*da0073e9SAndroid Build Coastguard Worker 3257*da0073e9SAndroid Build Coastguard Worker # Check var size 3258*da0073e9SAndroid Build Coastguard Worker # If var.size == input.size, the case is heteroscedastic and no further checks are needed. 3259*da0073e9SAndroid Build Coastguard Worker # Otherwise: 3260*da0073e9SAndroid Build Coastguard Worker if var.size() != input.size(): 3261*da0073e9SAndroid Build Coastguard Worker # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. 3262*da0073e9SAndroid Build Coastguard Worker # e.g. input.size = (10, 2, 3), var.size = (10, 2) 3263*da0073e9SAndroid Build Coastguard Worker # -> unsqueeze var so that var.shape = (10, 2, 1) 3264*da0073e9SAndroid Build Coastguard Worker # this is done so that broadcasting can happen in the loss calculation 3265*da0073e9SAndroid Build Coastguard Worker if input.size()[:-1] == var.size(): 3266*da0073e9SAndroid Build Coastguard Worker var = torch.unsqueeze(var, -1) 3267*da0073e9SAndroid Build Coastguard Worker 3268*da0073e9SAndroid Build Coastguard Worker # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. 3269*da0073e9SAndroid Build Coastguard Worker # This is also a homoscedastic case. 3270*da0073e9SAndroid Build Coastguard Worker # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) 3271*da0073e9SAndroid Build Coastguard Worker elif ( 3272*da0073e9SAndroid Build Coastguard Worker input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 3273*da0073e9SAndroid Build Coastguard Worker ): # Heteroscedastic case 3274*da0073e9SAndroid Build Coastguard Worker pass 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker # If none of the above pass, then the size of var is incorrect. 3277*da0073e9SAndroid Build Coastguard Worker else: 3278*da0073e9SAndroid Build Coastguard Worker raise ValueError("var is of incorrect size") 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker # Check validity of reduction mode 3281*da0073e9SAndroid Build Coastguard Worker if reduction != "none" and reduction != "mean" and reduction != "sum": 3282*da0073e9SAndroid Build Coastguard Worker raise ValueError(reduction + " is not valid") 3283*da0073e9SAndroid Build Coastguard Worker 3284*da0073e9SAndroid Build Coastguard Worker # Entries of var must be non-negative 3285*da0073e9SAndroid Build Coastguard Worker if torch.any(var < 0): 3286*da0073e9SAndroid Build Coastguard Worker raise ValueError("var has negative entry/entries") 3287*da0073e9SAndroid Build Coastguard Worker 3288*da0073e9SAndroid Build Coastguard Worker # Clamp for stability 3289*da0073e9SAndroid Build Coastguard Worker var = var.clone() 3290*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3291*da0073e9SAndroid Build Coastguard Worker var.clamp_(min=eps) 3292*da0073e9SAndroid Build Coastguard Worker 3293*da0073e9SAndroid Build Coastguard Worker # Calculate the loss 3294*da0073e9SAndroid Build Coastguard Worker loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) 3295*da0073e9SAndroid Build Coastguard Worker if full: 3296*da0073e9SAndroid Build Coastguard Worker loss += 0.5 * math.log(2 * math.pi) 3297*da0073e9SAndroid Build Coastguard Worker 3298*da0073e9SAndroid Build Coastguard Worker if reduction == "mean": 3299*da0073e9SAndroid Build Coastguard Worker return loss.mean() 3300*da0073e9SAndroid Build Coastguard Worker elif reduction == "sum": 3301*da0073e9SAndroid Build Coastguard Worker return loss.sum() 3302*da0073e9SAndroid Build Coastguard Worker else: 3303*da0073e9SAndroid Build Coastguard Worker return loss 3304*da0073e9SAndroid Build Coastguard Worker 3305*da0073e9SAndroid Build Coastguard Worker 3306*da0073e9SAndroid Build Coastguard Workerdef kl_div( 3307*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3308*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3309*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3310*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3311*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3312*da0073e9SAndroid Build Coastguard Worker log_target: bool = False, 3313*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3314*da0073e9SAndroid Build Coastguard Worker r"""Compute the KL Divergence loss. 3315*da0073e9SAndroid Build Coastguard Worker 3316*da0073e9SAndroid Build Coastguard Worker Refer - The `Kullback-Leibler divergence Loss 3317*da0073e9SAndroid Build Coastguard Worker <https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`__ 3318*da0073e9SAndroid Build Coastguard Worker 3319*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.KLDivLoss` for details. 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker Args: 3322*da0073e9SAndroid Build Coastguard Worker input: Tensor of arbitrary shape in log-probabilities. 3323*da0073e9SAndroid Build Coastguard Worker target: Tensor of the same shape as input. See :attr:`log_target` for 3324*da0073e9SAndroid Build Coastguard Worker the target's interpretation. 3325*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3326*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3327*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3328*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3329*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3330*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3331*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3332*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3333*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3334*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3335*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. 3336*da0073e9SAndroid Build Coastguard Worker ``'none'``: no reduction will be applied 3337*da0073e9SAndroid Build Coastguard Worker ``'batchmean'``: the sum of the output will be divided by the batchsize 3338*da0073e9SAndroid Build Coastguard Worker ``'sum'``: the output will be summed 3339*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the output will be divided by the number of elements in the output 3340*da0073e9SAndroid Build Coastguard Worker Default: ``'mean'`` 3341*da0073e9SAndroid Build Coastguard Worker log_target (bool): A flag indicating whether ``target`` is passed in the log space. 3342*da0073e9SAndroid Build Coastguard Worker It is recommended to pass certain distributions (like ``softmax``) 3343*da0073e9SAndroid Build Coastguard Worker in the log space to avoid numerical issues caused by explicit ``log``. 3344*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker .. note:: 3347*da0073e9SAndroid Build Coastguard Worker :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, 3348*da0073e9SAndroid Build Coastguard Worker and in the meantime, specifying either of those two args will override :attr:`reduction`. 3349*da0073e9SAndroid Build Coastguard Worker 3350*da0073e9SAndroid Build Coastguard Worker .. warning:: 3351*da0073e9SAndroid Build Coastguard Worker :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use 3352*da0073e9SAndroid Build Coastguard Worker :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. 3353*da0073e9SAndroid Build Coastguard Worker """ 3354*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3355*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3356*da0073e9SAndroid Build Coastguard Worker kl_div, 3357*da0073e9SAndroid Build Coastguard Worker (input, target), 3358*da0073e9SAndroid Build Coastguard Worker input, 3359*da0073e9SAndroid Build Coastguard Worker target, 3360*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3361*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3362*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3363*da0073e9SAndroid Build Coastguard Worker log_target=log_target, 3364*da0073e9SAndroid Build Coastguard Worker ) 3365*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3366*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3367*da0073e9SAndroid Build Coastguard Worker else: 3368*da0073e9SAndroid Build Coastguard Worker if reduction == "mean": 3369*da0073e9SAndroid Build Coastguard Worker warnings.warn( 3370*da0073e9SAndroid Build Coastguard Worker "reduction: 'mean' divides the total loss by both the batch size and the support size." 3371*da0073e9SAndroid Build Coastguard Worker "'batchmean' divides only by the batch size, and aligns with the KL div math definition." 3372*da0073e9SAndroid Build Coastguard Worker "'mean' will be changed to behave the same as 'batchmean' in the next major release." 3373*da0073e9SAndroid Build Coastguard Worker ) 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker # special case for batchmean 3376*da0073e9SAndroid Build Coastguard Worker if reduction == "batchmean": 3377*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum("sum") 3378*da0073e9SAndroid Build Coastguard Worker else: 3379*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3380*da0073e9SAndroid Build Coastguard Worker 3381*da0073e9SAndroid Build Coastguard Worker reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) 3382*da0073e9SAndroid Build Coastguard Worker 3383*da0073e9SAndroid Build Coastguard Worker if reduction == "batchmean" and input.dim() != 0: 3384*da0073e9SAndroid Build Coastguard Worker reduced = reduced / input.size()[0] 3385*da0073e9SAndroid Build Coastguard Worker 3386*da0073e9SAndroid Build Coastguard Worker return reduced 3387*da0073e9SAndroid Build Coastguard Worker 3388*da0073e9SAndroid Build Coastguard Worker 3389*da0073e9SAndroid Build Coastguard Workerdef cross_entropy( 3390*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3391*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3392*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 3393*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3394*da0073e9SAndroid Build Coastguard Worker ignore_index: int = -100, 3395*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3396*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3397*da0073e9SAndroid Build Coastguard Worker label_smoothing: float = 0.0, 3398*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3399*da0073e9SAndroid Build Coastguard Worker r"""Compute the cross entropy loss between input logits and target. 3400*da0073e9SAndroid Build Coastguard Worker 3401*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.CrossEntropyLoss` for details. 3402*da0073e9SAndroid Build Coastguard Worker 3403*da0073e9SAndroid Build Coastguard Worker Args: 3404*da0073e9SAndroid Build Coastguard Worker input (Tensor) : Predicted unnormalized logits; 3405*da0073e9SAndroid Build Coastguard Worker see Shape section below for supported shapes. 3406*da0073e9SAndroid Build Coastguard Worker target (Tensor) : Ground truth class indices or class probabilities; 3407*da0073e9SAndroid Build Coastguard Worker see Shape section below for supported shapes. 3408*da0073e9SAndroid Build Coastguard Worker weight (Tensor, optional): a manual rescaling weight given to each 3409*da0073e9SAndroid Build Coastguard Worker class. If given, has to be a Tensor of size `C` 3410*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3411*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3412*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3413*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3414*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3415*da0073e9SAndroid Build Coastguard Worker ignore_index (int, optional): Specifies a target value that is ignored 3416*da0073e9SAndroid Build Coastguard Worker and does not contribute to the input gradient. When :attr:`size_average` is 3417*da0073e9SAndroid Build Coastguard Worker ``True``, the loss is averaged over non-ignored targets. Note that 3418*da0073e9SAndroid Build Coastguard Worker :attr:`ignore_index` is only applicable when the target contains class indices. 3419*da0073e9SAndroid Build Coastguard Worker Default: -100 3420*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3421*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3422*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3423*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3424*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3425*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3426*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the sum of the output will be divided by the number of 3427*da0073e9SAndroid Build Coastguard Worker elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3428*da0073e9SAndroid Build Coastguard Worker and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3429*da0073e9SAndroid Build Coastguard Worker specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3430*da0073e9SAndroid Build Coastguard Worker label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount 3431*da0073e9SAndroid Build Coastguard Worker of smoothing when computing the loss, where 0.0 means no smoothing. The targets 3432*da0073e9SAndroid Build Coastguard Worker become a mixture of the original ground truth and a uniform distribution as described in 3433*da0073e9SAndroid Build Coastguard Worker `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`. 3434*da0073e9SAndroid Build Coastguard Worker 3435*da0073e9SAndroid Build Coastguard Worker Shape: 3436*da0073e9SAndroid Build Coastguard Worker - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` 3437*da0073e9SAndroid Build Coastguard Worker in the case of `K`-dimensional loss. 3438*da0073e9SAndroid Build Coastguard Worker - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with 3439*da0073e9SAndroid Build Coastguard Worker :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. 3440*da0073e9SAndroid Build Coastguard Worker If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. 3441*da0073e9SAndroid Build Coastguard Worker 3442*da0073e9SAndroid Build Coastguard Worker where: 3443*da0073e9SAndroid Build Coastguard Worker 3444*da0073e9SAndroid Build Coastguard Worker .. math:: 3445*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 3446*da0073e9SAndroid Build Coastguard Worker C ={} & \text{number of classes} \\ 3447*da0073e9SAndroid Build Coastguard Worker N ={} & \text{batch size} \\ 3448*da0073e9SAndroid Build Coastguard Worker \end{aligned} 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker Examples:: 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker >>> # Example of target with class indices 3453*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(3, 5, requires_grad=True) 3454*da0073e9SAndroid Build Coastguard Worker >>> target = torch.randint(5, (3,), dtype=torch.int64) 3455*da0073e9SAndroid Build Coastguard Worker >>> loss = F.cross_entropy(input, target) 3456*da0073e9SAndroid Build Coastguard Worker >>> loss.backward() 3457*da0073e9SAndroid Build Coastguard Worker >>> 3458*da0073e9SAndroid Build Coastguard Worker >>> # Example of target with class probabilities 3459*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(3, 5, requires_grad=True) 3460*da0073e9SAndroid Build Coastguard Worker >>> target = torch.randn(3, 5).softmax(dim=1) 3461*da0073e9SAndroid Build Coastguard Worker >>> loss = F.cross_entropy(input, target) 3462*da0073e9SAndroid Build Coastguard Worker >>> loss.backward() 3463*da0073e9SAndroid Build Coastguard Worker """ 3464*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight): 3465*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3466*da0073e9SAndroid Build Coastguard Worker cross_entropy, 3467*da0073e9SAndroid Build Coastguard Worker (input, target, weight), 3468*da0073e9SAndroid Build Coastguard Worker input, 3469*da0073e9SAndroid Build Coastguard Worker target, 3470*da0073e9SAndroid Build Coastguard Worker weight=weight, 3471*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3472*da0073e9SAndroid Build Coastguard Worker ignore_index=ignore_index, 3473*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3474*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3475*da0073e9SAndroid Build Coastguard Worker label_smoothing=label_smoothing, 3476*da0073e9SAndroid Build Coastguard Worker ) 3477*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3478*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3479*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.cross_entropy_loss( 3480*da0073e9SAndroid Build Coastguard Worker input, 3481*da0073e9SAndroid Build Coastguard Worker target, 3482*da0073e9SAndroid Build Coastguard Worker weight, 3483*da0073e9SAndroid Build Coastguard Worker _Reduction.get_enum(reduction), 3484*da0073e9SAndroid Build Coastguard Worker ignore_index, 3485*da0073e9SAndroid Build Coastguard Worker label_smoothing, 3486*da0073e9SAndroid Build Coastguard Worker ) 3487*da0073e9SAndroid Build Coastguard Worker 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy( 3490*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3491*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3492*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 3493*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3494*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3495*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3496*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3497*da0073e9SAndroid Build Coastguard Worker r"""Measure Binary Cross Entropy between the target and input probabilities. 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.BCELoss` for details. 3500*da0073e9SAndroid Build Coastguard Worker 3501*da0073e9SAndroid Build Coastguard Worker Args: 3502*da0073e9SAndroid Build Coastguard Worker input: Tensor of arbitrary shape as probabilities. 3503*da0073e9SAndroid Build Coastguard Worker target: Tensor of the same shape as input with values between 0 and 1. 3504*da0073e9SAndroid Build Coastguard Worker weight (Tensor, optional): a manual rescaling weight 3505*da0073e9SAndroid Build Coastguard Worker if provided it's repeated to match input tensor shape 3506*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3507*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3508*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3509*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3510*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3511*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3512*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3513*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3514*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3515*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3516*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3517*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the sum of the output will be divided by the number of 3518*da0073e9SAndroid Build Coastguard Worker elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3519*da0073e9SAndroid Build Coastguard Worker and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3520*da0073e9SAndroid Build Coastguard Worker specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker Examples:: 3523*da0073e9SAndroid Build Coastguard Worker 3524*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(3, 2, requires_grad=True) 3525*da0073e9SAndroid Build Coastguard Worker >>> target = torch.rand(3, 2, requires_grad=False) 3526*da0073e9SAndroid Build Coastguard Worker >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) 3527*da0073e9SAndroid Build Coastguard Worker >>> loss.backward() 3528*da0073e9SAndroid Build Coastguard Worker """ 3529*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight): 3530*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3531*da0073e9SAndroid Build Coastguard Worker binary_cross_entropy, 3532*da0073e9SAndroid Build Coastguard Worker (input, target, weight), 3533*da0073e9SAndroid Build Coastguard Worker input, 3534*da0073e9SAndroid Build Coastguard Worker target, 3535*da0073e9SAndroid Build Coastguard Worker weight=weight, 3536*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3537*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3538*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3539*da0073e9SAndroid Build Coastguard Worker ) 3540*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3541*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3542*da0073e9SAndroid Build Coastguard Worker else: 3543*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3544*da0073e9SAndroid Build Coastguard Worker if target.size() != input.size(): 3545*da0073e9SAndroid Build Coastguard Worker raise ValueError( 3546*da0073e9SAndroid Build Coastguard Worker f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. " 3547*da0073e9SAndroid Build Coastguard Worker "Please ensure they have the same size." 3548*da0073e9SAndroid Build Coastguard Worker ) 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker if weight is not None: 3551*da0073e9SAndroid Build Coastguard Worker new_size = _infer_size(target.size(), weight.size()) 3552*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(new_size) 3553*da0073e9SAndroid Build Coastguard Worker 3554*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) 3555*da0073e9SAndroid Build Coastguard Worker 3556*da0073e9SAndroid Build Coastguard Worker 3557*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy_with_logits( 3558*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3559*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3560*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 3561*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3562*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3563*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3564*da0073e9SAndroid Build Coastguard Worker pos_weight: Optional[Tensor] = None, 3565*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3566*da0073e9SAndroid Build Coastguard Worker r"""Calculate Binary Cross Entropy between target and input logits. 3567*da0073e9SAndroid Build Coastguard Worker 3568*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.BCEWithLogitsLoss` for details. 3569*da0073e9SAndroid Build Coastguard Worker 3570*da0073e9SAndroid Build Coastguard Worker Args: 3571*da0073e9SAndroid Build Coastguard Worker input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). 3572*da0073e9SAndroid Build Coastguard Worker target: Tensor of the same shape as input with values between 0 and 1 3573*da0073e9SAndroid Build Coastguard Worker weight (Tensor, optional): a manual rescaling weight 3574*da0073e9SAndroid Build Coastguard Worker if provided it's repeated to match input tensor shape 3575*da0073e9SAndroid Build Coastguard Worker size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3576*da0073e9SAndroid Build Coastguard Worker the losses are averaged over each loss element in the batch. Note that for 3577*da0073e9SAndroid Build Coastguard Worker some losses, there multiple elements per sample. If the field :attr:`size_average` 3578*da0073e9SAndroid Build Coastguard Worker is set to ``False``, the losses are instead summed for each minibatch. Ignored 3579*da0073e9SAndroid Build Coastguard Worker when reduce is ``False``. Default: ``True`` 3580*da0073e9SAndroid Build Coastguard Worker reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3581*da0073e9SAndroid Build Coastguard Worker losses are averaged or summed over observations for each minibatch depending 3582*da0073e9SAndroid Build Coastguard Worker on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3583*da0073e9SAndroid Build Coastguard Worker batch element instead and ignores :attr:`size_average`. Default: ``True`` 3584*da0073e9SAndroid Build Coastguard Worker reduction (str, optional): Specifies the reduction to apply to the output: 3585*da0073e9SAndroid Build Coastguard Worker ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3586*da0073e9SAndroid Build Coastguard Worker ``'mean'``: the sum of the output will be divided by the number of 3587*da0073e9SAndroid Build Coastguard Worker elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3588*da0073e9SAndroid Build Coastguard Worker and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3589*da0073e9SAndroid Build Coastguard Worker specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3590*da0073e9SAndroid Build Coastguard Worker pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. 3591*da0073e9SAndroid Build Coastguard Worker Must be a tensor with equal size along the class dimension to the number of classes. 3592*da0073e9SAndroid Build Coastguard Worker Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired 3593*da0073e9SAndroid Build Coastguard Worker operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of 3594*da0073e9SAndroid Build Coastguard Worker size [B, C, H, W] will apply different pos_weights to each element of the batch or 3595*da0073e9SAndroid Build Coastguard Worker [C, H, W] the same pos_weights across the batch. To apply the same positive weight 3596*da0073e9SAndroid Build Coastguard Worker along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. 3597*da0073e9SAndroid Build Coastguard Worker Default: ``None`` 3598*da0073e9SAndroid Build Coastguard Worker 3599*da0073e9SAndroid Build Coastguard Worker Examples:: 3600*da0073e9SAndroid Build Coastguard Worker 3601*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(3, requires_grad=True) 3602*da0073e9SAndroid Build Coastguard Worker >>> target = torch.empty(3).random_(2) 3603*da0073e9SAndroid Build Coastguard Worker >>> loss = F.binary_cross_entropy_with_logits(input, target) 3604*da0073e9SAndroid Build Coastguard Worker >>> loss.backward() 3605*da0073e9SAndroid Build Coastguard Worker """ 3606*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight, pos_weight): 3607*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3608*da0073e9SAndroid Build Coastguard Worker binary_cross_entropy_with_logits, 3609*da0073e9SAndroid Build Coastguard Worker (input, target, weight, pos_weight), 3610*da0073e9SAndroid Build Coastguard Worker input, 3611*da0073e9SAndroid Build Coastguard Worker target, 3612*da0073e9SAndroid Build Coastguard Worker weight=weight, 3613*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3614*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3615*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3616*da0073e9SAndroid Build Coastguard Worker pos_weight=pos_weight, 3617*da0073e9SAndroid Build Coastguard Worker ) 3618*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3619*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3620*da0073e9SAndroid Build Coastguard Worker else: 3621*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3622*da0073e9SAndroid Build Coastguard Worker 3623*da0073e9SAndroid Build Coastguard Worker if not (target.size() == input.size()): 3624*da0073e9SAndroid Build Coastguard Worker raise ValueError( 3625*da0073e9SAndroid Build Coastguard Worker f"Target size ({target.size()}) must be the same as input size ({input.size()})" 3626*da0073e9SAndroid Build Coastguard Worker ) 3627*da0073e9SAndroid Build Coastguard Worker 3628*da0073e9SAndroid Build Coastguard Worker return torch.binary_cross_entropy_with_logits( 3629*da0073e9SAndroid Build Coastguard Worker input, target, weight, pos_weight, reduction_enum 3630*da0073e9SAndroid Build Coastguard Worker ) 3631*da0073e9SAndroid Build Coastguard Worker 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Workerdef smooth_l1_loss( 3634*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3635*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3636*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3637*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3638*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3639*da0073e9SAndroid Build Coastguard Worker beta: float = 1.0, 3640*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3641*da0073e9SAndroid Build Coastguard Worker r"""Compute the Smooth L1 loss. 3642*da0073e9SAndroid Build Coastguard Worker 3643*da0073e9SAndroid Build Coastguard Worker Function uses a squared term if the absolute 3644*da0073e9SAndroid Build Coastguard Worker element-wise error falls below beta and an L1 term otherwise. 3645*da0073e9SAndroid Build Coastguard Worker 3646*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.SmoothL1Loss` for details. 3647*da0073e9SAndroid Build Coastguard Worker """ 3648*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3649*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3650*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss, 3651*da0073e9SAndroid Build Coastguard Worker (input, target), 3652*da0073e9SAndroid Build Coastguard Worker input, 3653*da0073e9SAndroid Build Coastguard Worker target, 3654*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3655*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3656*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3657*da0073e9SAndroid Build Coastguard Worker beta=beta, 3658*da0073e9SAndroid Build Coastguard Worker ) 3659*da0073e9SAndroid Build Coastguard Worker if not (target.size() == input.size()): 3660*da0073e9SAndroid Build Coastguard Worker warnings.warn( 3661*da0073e9SAndroid Build Coastguard Worker f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3662*da0073e9SAndroid Build Coastguard Worker "This will likely lead to incorrect results due to broadcasting. " 3663*da0073e9SAndroid Build Coastguard Worker "Please ensure they have the same size.", 3664*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 3665*da0073e9SAndroid Build Coastguard Worker ) 3666*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3667*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3668*da0073e9SAndroid Build Coastguard Worker 3669*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3670*da0073e9SAndroid Build Coastguard Worker 3671*da0073e9SAndroid Build Coastguard Worker if beta == 0.0: 3672*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.l1_loss( 3673*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target, _Reduction.get_enum(reduction) 3674*da0073e9SAndroid Build Coastguard Worker ) 3675*da0073e9SAndroid Build Coastguard Worker else: 3676*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.smooth_l1_loss( 3677*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target, _Reduction.get_enum(reduction), beta 3678*da0073e9SAndroid Build Coastguard Worker ) 3679*da0073e9SAndroid Build Coastguard Worker 3680*da0073e9SAndroid Build Coastguard Worker 3681*da0073e9SAndroid Build Coastguard Workerdef huber_loss( 3682*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3683*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3684*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3685*da0073e9SAndroid Build Coastguard Worker delta: float = 1.0, 3686*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 3687*da0073e9SAndroid Build Coastguard Worker r"""Compute the Huber loss. 3688*da0073e9SAndroid Build Coastguard Worker 3689*da0073e9SAndroid Build Coastguard Worker Function uses a squared term if the absolute 3690*da0073e9SAndroid Build Coastguard Worker element-wise error falls below delta and a delta-scaled L1 term otherwise. 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker When delta equals 1, this loss is equivalent to SmoothL1Loss. 3693*da0073e9SAndroid Build Coastguard Worker In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1). 3694*da0073e9SAndroid Build Coastguard Worker 3695*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.HuberLoss` for details. 3696*da0073e9SAndroid Build Coastguard Worker """ 3697*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3698*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3699*da0073e9SAndroid Build Coastguard Worker huber_loss, 3700*da0073e9SAndroid Build Coastguard Worker (input, target), 3701*da0073e9SAndroid Build Coastguard Worker input, 3702*da0073e9SAndroid Build Coastguard Worker target, 3703*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3704*da0073e9SAndroid Build Coastguard Worker delta=delta, 3705*da0073e9SAndroid Build Coastguard Worker ) 3706*da0073e9SAndroid Build Coastguard Worker if not (target.size() == input.size()): 3707*da0073e9SAndroid Build Coastguard Worker warnings.warn( 3708*da0073e9SAndroid Build Coastguard Worker f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3709*da0073e9SAndroid Build Coastguard Worker "This will likely lead to incorrect results due to broadcasting. " 3710*da0073e9SAndroid Build Coastguard Worker "Please ensure they have the same size.", 3711*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 3712*da0073e9SAndroid Build Coastguard Worker ) 3713*da0073e9SAndroid Build Coastguard Worker 3714*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3715*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.huber_loss( 3716*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target, _Reduction.get_enum(reduction), delta 3717*da0073e9SAndroid Build Coastguard Worker ) 3718*da0073e9SAndroid Build Coastguard Worker 3719*da0073e9SAndroid Build Coastguard Worker 3720*da0073e9SAndroid Build Coastguard Workerdef l1_loss( 3721*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3722*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3723*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3724*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3725*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3726*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3727*da0073e9SAndroid Build Coastguard Worker r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3728*da0073e9SAndroid Build Coastguard Worker 3729*da0073e9SAndroid Build Coastguard Worker Function that takes the mean element-wise absolute value difference. 3730*da0073e9SAndroid Build Coastguard Worker 3731*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.L1Loss` for details. 3732*da0073e9SAndroid Build Coastguard Worker """ 3733*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3734*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3735*da0073e9SAndroid Build Coastguard Worker l1_loss, 3736*da0073e9SAndroid Build Coastguard Worker (input, target), 3737*da0073e9SAndroid Build Coastguard Worker input, 3738*da0073e9SAndroid Build Coastguard Worker target, 3739*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3740*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3741*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3742*da0073e9SAndroid Build Coastguard Worker ) 3743*da0073e9SAndroid Build Coastguard Worker if not (target.size() == input.size()): 3744*da0073e9SAndroid Build Coastguard Worker warnings.warn( 3745*da0073e9SAndroid Build Coastguard Worker f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3746*da0073e9SAndroid Build Coastguard Worker "This will likely lead to incorrect results due to broadcasting. " 3747*da0073e9SAndroid Build Coastguard Worker "Please ensure they have the same size.", 3748*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 3749*da0073e9SAndroid Build Coastguard Worker ) 3750*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3751*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3752*da0073e9SAndroid Build Coastguard Worker 3753*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3754*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.l1_loss( 3755*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target, _Reduction.get_enum(reduction) 3756*da0073e9SAndroid Build Coastguard Worker ) 3757*da0073e9SAndroid Build Coastguard Worker 3758*da0073e9SAndroid Build Coastguard Worker 3759*da0073e9SAndroid Build Coastguard Workerdef mse_loss( 3760*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3761*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3762*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3763*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3764*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3765*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3766*da0073e9SAndroid Build Coastguard Worker r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3767*da0073e9SAndroid Build Coastguard Worker 3768*da0073e9SAndroid Build Coastguard Worker Measures the element-wise mean squared error. 3769*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MSELoss` for details. 3770*da0073e9SAndroid Build Coastguard Worker """ 3771*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3772*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3773*da0073e9SAndroid Build Coastguard Worker mse_loss, 3774*da0073e9SAndroid Build Coastguard Worker (input, target), 3775*da0073e9SAndroid Build Coastguard Worker input, 3776*da0073e9SAndroid Build Coastguard Worker target, 3777*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3778*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3779*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3780*da0073e9SAndroid Build Coastguard Worker ) 3781*da0073e9SAndroid Build Coastguard Worker if not (target.size() == input.size()): 3782*da0073e9SAndroid Build Coastguard Worker warnings.warn( 3783*da0073e9SAndroid Build Coastguard Worker f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3784*da0073e9SAndroid Build Coastguard Worker "This will likely lead to incorrect results due to broadcasting. " 3785*da0073e9SAndroid Build Coastguard Worker "Please ensure they have the same size.", 3786*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 3787*da0073e9SAndroid Build Coastguard Worker ) 3788*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3789*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3792*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.mse_loss( 3793*da0073e9SAndroid Build Coastguard Worker expanded_input, expanded_target, _Reduction.get_enum(reduction) 3794*da0073e9SAndroid Build Coastguard Worker ) 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker 3797*da0073e9SAndroid Build Coastguard Workerdef margin_ranking_loss( 3798*da0073e9SAndroid Build Coastguard Worker input1: Tensor, 3799*da0073e9SAndroid Build Coastguard Worker input2: Tensor, 3800*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3801*da0073e9SAndroid Build Coastguard Worker margin: float = 0, 3802*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3803*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3804*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3805*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3806*da0073e9SAndroid Build Coastguard Worker r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor 3807*da0073e9SAndroid Build Coastguard Worker 3808*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MarginRankingLoss` for details. 3809*da0073e9SAndroid Build Coastguard Worker """ 3810*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input1, input2, target): 3811*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3812*da0073e9SAndroid Build Coastguard Worker margin_ranking_loss, 3813*da0073e9SAndroid Build Coastguard Worker (input1, input2, target), 3814*da0073e9SAndroid Build Coastguard Worker input1, 3815*da0073e9SAndroid Build Coastguard Worker input2, 3816*da0073e9SAndroid Build Coastguard Worker target, 3817*da0073e9SAndroid Build Coastguard Worker margin=margin, 3818*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3819*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3820*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3821*da0073e9SAndroid Build Coastguard Worker ) 3822*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3823*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3824*da0073e9SAndroid Build Coastguard Worker else: 3825*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3826*da0073e9SAndroid Build Coastguard Worker if input1.dim() != input2.dim() or input1.dim() != target.dim(): 3827*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 3828*da0073e9SAndroid Build Coastguard Worker f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " 3829*da0073e9SAndroid Build Coastguard Worker f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " 3830*da0073e9SAndroid Build Coastguard Worker ) 3831*da0073e9SAndroid Build Coastguard Worker return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker 3834*da0073e9SAndroid Build Coastguard Workerdef hinge_embedding_loss( 3835*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3836*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3837*da0073e9SAndroid Build Coastguard Worker margin: float = 1.0, 3838*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3839*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3840*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3841*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3842*da0073e9SAndroid Build Coastguard Worker r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor 3843*da0073e9SAndroid Build Coastguard Worker 3844*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.HingeEmbeddingLoss` for details. 3845*da0073e9SAndroid Build Coastguard Worker """ 3846*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3847*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3848*da0073e9SAndroid Build Coastguard Worker hinge_embedding_loss, 3849*da0073e9SAndroid Build Coastguard Worker (input, target), 3850*da0073e9SAndroid Build Coastguard Worker input, 3851*da0073e9SAndroid Build Coastguard Worker target, 3852*da0073e9SAndroid Build Coastguard Worker margin=margin, 3853*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3854*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3855*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3856*da0073e9SAndroid Build Coastguard Worker ) 3857*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3858*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3859*da0073e9SAndroid Build Coastguard Worker else: 3860*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3861*da0073e9SAndroid Build Coastguard Worker return torch.hinge_embedding_loss(input, target, margin, reduction_enum) 3862*da0073e9SAndroid Build Coastguard Worker 3863*da0073e9SAndroid Build Coastguard Worker 3864*da0073e9SAndroid Build Coastguard Workerdef multilabel_margin_loss( 3865*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3866*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3867*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3868*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3869*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3870*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3871*da0073e9SAndroid Build Coastguard Worker r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3872*da0073e9SAndroid Build Coastguard Worker 3873*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MultiLabelMarginLoss` for details. 3874*da0073e9SAndroid Build Coastguard Worker """ 3875*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3876*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3877*da0073e9SAndroid Build Coastguard Worker multilabel_margin_loss, 3878*da0073e9SAndroid Build Coastguard Worker (input, target), 3879*da0073e9SAndroid Build Coastguard Worker input, 3880*da0073e9SAndroid Build Coastguard Worker target, 3881*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3882*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3883*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3884*da0073e9SAndroid Build Coastguard Worker ) 3885*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3886*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3887*da0073e9SAndroid Build Coastguard Worker else: 3888*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3889*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) 3890*da0073e9SAndroid Build Coastguard Worker 3891*da0073e9SAndroid Build Coastguard Worker 3892*da0073e9SAndroid Build Coastguard Workerdef soft_margin_loss( 3893*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3894*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3895*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3896*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3897*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3898*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3899*da0073e9SAndroid Build Coastguard Worker r""" 3900*da0073e9SAndroid Build Coastguard Worker soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3901*da0073e9SAndroid Build Coastguard Worker 3902*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.SoftMarginLoss` for details. 3903*da0073e9SAndroid Build Coastguard Worker """ 3904*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target): 3905*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3906*da0073e9SAndroid Build Coastguard Worker soft_margin_loss, 3907*da0073e9SAndroid Build Coastguard Worker (input, target), 3908*da0073e9SAndroid Build Coastguard Worker input, 3909*da0073e9SAndroid Build Coastguard Worker target, 3910*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3911*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3912*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3913*da0073e9SAndroid Build Coastguard Worker ) 3914*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3915*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3916*da0073e9SAndroid Build Coastguard Worker else: 3917*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3918*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.soft_margin_loss(input, target, reduction_enum) 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker 3921*da0073e9SAndroid Build Coastguard Workerdef multilabel_soft_margin_loss( 3922*da0073e9SAndroid Build Coastguard Worker input: Tensor, 3923*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3924*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 3925*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3926*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3927*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3928*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3929*da0073e9SAndroid Build Coastguard Worker r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. 3932*da0073e9SAndroid Build Coastguard Worker """ 3933*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight): 3934*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3935*da0073e9SAndroid Build Coastguard Worker multilabel_soft_margin_loss, 3936*da0073e9SAndroid Build Coastguard Worker (input, target, weight), 3937*da0073e9SAndroid Build Coastguard Worker input, 3938*da0073e9SAndroid Build Coastguard Worker target, 3939*da0073e9SAndroid Build Coastguard Worker weight=weight, 3940*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3941*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3942*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3943*da0073e9SAndroid Build Coastguard Worker ) 3944*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3945*da0073e9SAndroid Build Coastguard Worker reduction = _Reduction.legacy_get_string(size_average, reduce) 3946*da0073e9SAndroid Build Coastguard Worker 3947*da0073e9SAndroid Build Coastguard Worker loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) 3948*da0073e9SAndroid Build Coastguard Worker 3949*da0073e9SAndroid Build Coastguard Worker if weight is not None: 3950*da0073e9SAndroid Build Coastguard Worker loss = loss * weight 3951*da0073e9SAndroid Build Coastguard Worker 3952*da0073e9SAndroid Build Coastguard Worker class_dim = input.dim() - 1 3953*da0073e9SAndroid Build Coastguard Worker C = input.size(class_dim) 3954*da0073e9SAndroid Build Coastguard Worker loss = loss.sum(dim=class_dim) / C # only return N loss values 3955*da0073e9SAndroid Build Coastguard Worker 3956*da0073e9SAndroid Build Coastguard Worker if reduction == "none": 3957*da0073e9SAndroid Build Coastguard Worker ret = loss 3958*da0073e9SAndroid Build Coastguard Worker elif reduction == "mean": 3959*da0073e9SAndroid Build Coastguard Worker ret = loss.mean() 3960*da0073e9SAndroid Build Coastguard Worker elif reduction == "sum": 3961*da0073e9SAndroid Build Coastguard Worker ret = loss.sum() 3962*da0073e9SAndroid Build Coastguard Worker else: 3963*da0073e9SAndroid Build Coastguard Worker ret = input 3964*da0073e9SAndroid Build Coastguard Worker raise ValueError(reduction + " is not valid") 3965*da0073e9SAndroid Build Coastguard Worker return ret 3966*da0073e9SAndroid Build Coastguard Worker 3967*da0073e9SAndroid Build Coastguard Worker 3968*da0073e9SAndroid Build Coastguard Workerdef cosine_embedding_loss( 3969*da0073e9SAndroid Build Coastguard Worker input1: Tensor, 3970*da0073e9SAndroid Build Coastguard Worker input2: Tensor, 3971*da0073e9SAndroid Build Coastguard Worker target: Tensor, 3972*da0073e9SAndroid Build Coastguard Worker margin: float = 0, 3973*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 3974*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 3975*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 3976*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 3977*da0073e9SAndroid Build Coastguard Worker r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor 3978*da0073e9SAndroid Build Coastguard Worker 3979*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.CosineEmbeddingLoss` for details. 3980*da0073e9SAndroid Build Coastguard Worker """ 3981*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input1, input2, target): 3982*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 3983*da0073e9SAndroid Build Coastguard Worker cosine_embedding_loss, 3984*da0073e9SAndroid Build Coastguard Worker (input1, input2, target), 3985*da0073e9SAndroid Build Coastguard Worker input1, 3986*da0073e9SAndroid Build Coastguard Worker input2, 3987*da0073e9SAndroid Build Coastguard Worker target, 3988*da0073e9SAndroid Build Coastguard Worker margin=margin, 3989*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 3990*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 3991*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 3992*da0073e9SAndroid Build Coastguard Worker ) 3993*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 3994*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3995*da0073e9SAndroid Build Coastguard Worker else: 3996*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 3997*da0073e9SAndroid Build Coastguard Worker return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) 3998*da0073e9SAndroid Build Coastguard Worker 3999*da0073e9SAndroid Build Coastguard Worker 4000*da0073e9SAndroid Build Coastguard Workerdef multi_margin_loss( 4001*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4002*da0073e9SAndroid Build Coastguard Worker target: Tensor, 4003*da0073e9SAndroid Build Coastguard Worker p: int = 1, 4004*da0073e9SAndroid Build Coastguard Worker margin: float = 1.0, 4005*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = None, 4006*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 4007*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 4008*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 4009*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: D400,D402 4010*da0073e9SAndroid Build Coastguard Worker r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor 4011*da0073e9SAndroid Build Coastguard Worker 4012*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.MultiMarginLoss` for details. 4013*da0073e9SAndroid Build Coastguard Worker """ 4014*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, target, weight): 4015*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 4016*da0073e9SAndroid Build Coastguard Worker multi_margin_loss, 4017*da0073e9SAndroid Build Coastguard Worker (input, target, weight), 4018*da0073e9SAndroid Build Coastguard Worker input, 4019*da0073e9SAndroid Build Coastguard Worker target, 4020*da0073e9SAndroid Build Coastguard Worker p=p, 4021*da0073e9SAndroid Build Coastguard Worker margin=margin, 4022*da0073e9SAndroid Build Coastguard Worker weight=weight, 4023*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 4024*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 4025*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 4026*da0073e9SAndroid Build Coastguard Worker ) 4027*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 4028*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 4029*da0073e9SAndroid Build Coastguard Worker else: 4030*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 4031*da0073e9SAndroid Build Coastguard Worker if p != 1 and p != 2: 4032*da0073e9SAndroid Build Coastguard Worker raise ValueError("only p == 1 and p == 2 supported") 4033*da0073e9SAndroid Build Coastguard Worker if weight is not None: 4034*da0073e9SAndroid Build Coastguard Worker if weight.dim() != 1: 4035*da0073e9SAndroid Build Coastguard Worker raise ValueError("weight must be one-dimensional") 4036*da0073e9SAndroid Build Coastguard Worker 4037*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.multi_margin_loss( 4038*da0073e9SAndroid Build Coastguard Worker input, target, p, margin, weight, reduction_enum 4039*da0073e9SAndroid Build Coastguard Worker ) 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker 4042*da0073e9SAndroid Build Coastguard Workerpixel_shuffle = _add_docstr( 4043*da0073e9SAndroid Build Coastguard Worker torch.pixel_shuffle, 4044*da0073e9SAndroid Build Coastguard Worker r""" 4045*da0073e9SAndroid Build Coastguard Workerpixel_shuffle(input, upscale_factor) -> Tensor 4046*da0073e9SAndroid Build Coastguard Worker 4047*da0073e9SAndroid Build Coastguard WorkerRearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a 4048*da0073e9SAndroid Build Coastguard Workertensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PixelShuffle` for details. 4051*da0073e9SAndroid Build Coastguard Worker 4052*da0073e9SAndroid Build Coastguard WorkerArgs: 4053*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4054*da0073e9SAndroid Build Coastguard Worker upscale_factor (int): factor to increase spatial resolution by 4055*da0073e9SAndroid Build Coastguard Worker 4056*da0073e9SAndroid Build Coastguard WorkerExamples:: 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 9, 4, 4) 4059*da0073e9SAndroid Build Coastguard Worker >>> output = torch.nn.functional.pixel_shuffle(input, 3) 4060*da0073e9SAndroid Build Coastguard Worker >>> print(output.size()) 4061*da0073e9SAndroid Build Coastguard Worker torch.Size([1, 1, 12, 12]) 4062*da0073e9SAndroid Build Coastguard Worker""", 4063*da0073e9SAndroid Build Coastguard Worker) 4064*da0073e9SAndroid Build Coastguard Worker 4065*da0073e9SAndroid Build Coastguard Workerpixel_unshuffle = _add_docstr( 4066*da0073e9SAndroid Build Coastguard Worker torch.pixel_unshuffle, 4067*da0073e9SAndroid Build Coastguard Worker r""" 4068*da0073e9SAndroid Build Coastguard Workerpixel_unshuffle(input, downscale_factor) -> Tensor 4069*da0073e9SAndroid Build Coastguard Worker 4070*da0073e9SAndroid Build Coastguard WorkerReverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a 4071*da0073e9SAndroid Build Coastguard Workertensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape 4072*da0073e9SAndroid Build Coastguard Worker:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. 4073*da0073e9SAndroid Build Coastguard Worker 4074*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.PixelUnshuffle` for details. 4075*da0073e9SAndroid Build Coastguard Worker 4076*da0073e9SAndroid Build Coastguard WorkerArgs: 4077*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4078*da0073e9SAndroid Build Coastguard Worker downscale_factor (int): factor to increase spatial resolution by 4079*da0073e9SAndroid Build Coastguard Worker 4080*da0073e9SAndroid Build Coastguard WorkerExamples:: 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 1, 12, 12) 4083*da0073e9SAndroid Build Coastguard Worker >>> output = torch.nn.functional.pixel_unshuffle(input, 3) 4084*da0073e9SAndroid Build Coastguard Worker >>> print(output.size()) 4085*da0073e9SAndroid Build Coastguard Worker torch.Size([1, 9, 4, 4]) 4086*da0073e9SAndroid Build Coastguard Worker""", 4087*da0073e9SAndroid Build Coastguard Worker) 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Workerchannel_shuffle = _add_docstr( 4090*da0073e9SAndroid Build Coastguard Worker torch.channel_shuffle, 4091*da0073e9SAndroid Build Coastguard Worker r""" 4092*da0073e9SAndroid Build Coastguard Workerchannel_shuffle(input, groups) -> Tensor 4093*da0073e9SAndroid Build Coastguard Worker 4094*da0073e9SAndroid Build Coastguard WorkerDivide the channels in a tensor of shape :math:`(*, C , H, W)` 4095*da0073e9SAndroid Build Coastguard Workerinto g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, 4096*da0073e9SAndroid Build Coastguard Workerwhile keeping the original tensor shape. 4097*da0073e9SAndroid Build Coastguard Worker 4098*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ChannelShuffle` for details. 4099*da0073e9SAndroid Build Coastguard Worker 4100*da0073e9SAndroid Build Coastguard WorkerArgs: 4101*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4102*da0073e9SAndroid Build Coastguard Worker groups (int): number of groups to divide channels in and rearrange. 4103*da0073e9SAndroid Build Coastguard Worker 4104*da0073e9SAndroid Build Coastguard WorkerExamples:: 4105*da0073e9SAndroid Build Coastguard Worker 4106*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 4, 2, 2) 4107*da0073e9SAndroid Build Coastguard Worker >>> print(input) 4108*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 4109*da0073e9SAndroid Build Coastguard Worker [3, 4]], 4110*da0073e9SAndroid Build Coastguard Worker [[5, 6], 4111*da0073e9SAndroid Build Coastguard Worker [7, 8]], 4112*da0073e9SAndroid Build Coastguard Worker [[9, 10], 4113*da0073e9SAndroid Build Coastguard Worker [11, 12]], 4114*da0073e9SAndroid Build Coastguard Worker [[13, 14], 4115*da0073e9SAndroid Build Coastguard Worker [15, 16]], 4116*da0073e9SAndroid Build Coastguard Worker ]] 4117*da0073e9SAndroid Build Coastguard Worker >>> output = torch.nn.functional.channel_shuffle(input, 2) 4118*da0073e9SAndroid Build Coastguard Worker >>> print(output) 4119*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 4120*da0073e9SAndroid Build Coastguard Worker [3, 4]], 4121*da0073e9SAndroid Build Coastguard Worker [[9, 10], 4122*da0073e9SAndroid Build Coastguard Worker [11, 12]], 4123*da0073e9SAndroid Build Coastguard Worker [[5, 6], 4124*da0073e9SAndroid Build Coastguard Worker [7, 8]], 4125*da0073e9SAndroid Build Coastguard Worker [[13, 14], 4126*da0073e9SAndroid Build Coastguard Worker [15, 16]], 4127*da0073e9SAndroid Build Coastguard Worker ]] 4128*da0073e9SAndroid Build Coastguard Worker""", 4129*da0073e9SAndroid Build Coastguard Worker) 4130*da0073e9SAndroid Build Coastguard Worker 4131*da0073e9SAndroid Build Coastguard Workernative_channel_shuffle = _add_docstr( 4132*da0073e9SAndroid Build Coastguard Worker torch.native_channel_shuffle, 4133*da0073e9SAndroid Build Coastguard Worker r""" 4134*da0073e9SAndroid Build Coastguard Workernative_channel_shuffle(input, groups) -> Tensor 4135*da0073e9SAndroid Build Coastguard Worker 4136*da0073e9SAndroid Build Coastguard WorkerNative kernel level implementation of the `channel_shuffle`. 4137*da0073e9SAndroid Build Coastguard WorkerThis function might become private in future releases, use with caution. 4138*da0073e9SAndroid Build Coastguard Worker 4139*da0073e9SAndroid Build Coastguard WorkerDivide the channels in a tensor of shape :math:`(*, C , H, W)` 4140*da0073e9SAndroid Build Coastguard Workerinto g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, 4141*da0073e9SAndroid Build Coastguard Workerwhile keeping the original tensor shape. 4142*da0073e9SAndroid Build Coastguard Worker 4143*da0073e9SAndroid Build Coastguard WorkerSee :class:`~torch.nn.ChannelShuffle` for details. 4144*da0073e9SAndroid Build Coastguard Worker 4145*da0073e9SAndroid Build Coastguard WorkerArgs: 4146*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4147*da0073e9SAndroid Build Coastguard Worker groups (int): number of groups to divide channels in and rearrange. 4148*da0073e9SAndroid Build Coastguard Worker 4149*da0073e9SAndroid Build Coastguard WorkerExamples:: 4150*da0073e9SAndroid Build Coastguard Worker 4151*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 4, 2, 2) 4152*da0073e9SAndroid Build Coastguard Worker >>> print(input) 4153*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 4154*da0073e9SAndroid Build Coastguard Worker [3, 4]], 4155*da0073e9SAndroid Build Coastguard Worker [[5, 6], 4156*da0073e9SAndroid Build Coastguard Worker [7, 8]], 4157*da0073e9SAndroid Build Coastguard Worker [[9, 10], 4158*da0073e9SAndroid Build Coastguard Worker [11, 12]], 4159*da0073e9SAndroid Build Coastguard Worker [[13, 14], 4160*da0073e9SAndroid Build Coastguard Worker [15, 16]], 4161*da0073e9SAndroid Build Coastguard Worker ]] 4162*da0073e9SAndroid Build Coastguard Worker >>> output = torch.nn.functional.native_channel_shuffle(input, 2) 4163*da0073e9SAndroid Build Coastguard Worker >>> print(output) 4164*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 4165*da0073e9SAndroid Build Coastguard Worker [3, 4]], 4166*da0073e9SAndroid Build Coastguard Worker [[9, 10], 4167*da0073e9SAndroid Build Coastguard Worker [11, 12]], 4168*da0073e9SAndroid Build Coastguard Worker [[5, 6], 4169*da0073e9SAndroid Build Coastguard Worker [7, 8]], 4170*da0073e9SAndroid Build Coastguard Worker [[13, 14], 4171*da0073e9SAndroid Build Coastguard Worker [15, 16]], 4172*da0073e9SAndroid Build Coastguard Worker ]] 4173*da0073e9SAndroid Build Coastguard Worker""", 4174*da0073e9SAndroid Build Coastguard Worker) 4175*da0073e9SAndroid Build Coastguard Worker 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker@_overload 4178*da0073e9SAndroid Build Coastguard Workerdef upsample( # noqa: F811 4179*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4180*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4181*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4182*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4183*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4184*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4185*da0073e9SAndroid Build Coastguard Worker pass 4186*da0073e9SAndroid Build Coastguard Worker 4187*da0073e9SAndroid Build Coastguard Worker 4188*da0073e9SAndroid Build Coastguard Worker@_overload 4189*da0073e9SAndroid Build Coastguard Workerdef upsample( # noqa: F811 4190*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4191*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4192*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4193*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4194*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4195*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4196*da0073e9SAndroid Build Coastguard Worker pass 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker 4199*da0073e9SAndroid Build Coastguard Workerdef upsample( # noqa: F811 4200*da0073e9SAndroid Build Coastguard Worker input, 4201*da0073e9SAndroid Build Coastguard Worker size=None, 4202*da0073e9SAndroid Build Coastguard Worker scale_factor=None, 4203*da0073e9SAndroid Build Coastguard Worker mode="nearest", 4204*da0073e9SAndroid Build Coastguard Worker align_corners=None, 4205*da0073e9SAndroid Build Coastguard Worker): 4206*da0073e9SAndroid Build Coastguard Worker r"""Upsample input. 4207*da0073e9SAndroid Build Coastguard Worker 4208*da0073e9SAndroid Build Coastguard Worker Provided tensor is upsampled to either the given :attr:`size` or the given 4209*da0073e9SAndroid Build Coastguard Worker :attr:`scale_factor` 4210*da0073e9SAndroid Build Coastguard Worker 4211*da0073e9SAndroid Build Coastguard Worker .. warning:: 4212*da0073e9SAndroid Build Coastguard Worker This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4213*da0073e9SAndroid Build Coastguard Worker This is equivalent with ``nn.functional.interpolate(...)``. 4214*da0073e9SAndroid Build Coastguard Worker 4215*da0073e9SAndroid Build Coastguard Worker Note: 4216*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 4217*da0073e9SAndroid Build Coastguard Worker 4218*da0073e9SAndroid Build Coastguard Worker The algorithm used for upsampling is determined by :attr:`mode`. 4219*da0073e9SAndroid Build Coastguard Worker 4220*da0073e9SAndroid Build Coastguard Worker Currently temporal, spatial and volumetric upsampling are supported, i.e. 4221*da0073e9SAndroid Build Coastguard Worker expected inputs are 3-D, 4-D or 5-D in shape. 4222*da0073e9SAndroid Build Coastguard Worker 4223*da0073e9SAndroid Build Coastguard Worker The input dimensions are interpreted in the form: 4224*da0073e9SAndroid Build Coastguard Worker `mini-batch x channels x [optional depth] x [optional height] x width`. 4225*da0073e9SAndroid Build Coastguard Worker 4226*da0073e9SAndroid Build Coastguard Worker The modes available for upsampling are: `nearest`, `linear` (3D-only), 4227*da0073e9SAndroid Build Coastguard Worker `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) 4228*da0073e9SAndroid Build Coastguard Worker 4229*da0073e9SAndroid Build Coastguard Worker Args: 4230*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4231*da0073e9SAndroid Build Coastguard Worker size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): 4232*da0073e9SAndroid Build Coastguard Worker output spatial size. 4233*da0073e9SAndroid Build Coastguard Worker scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. 4234*da0073e9SAndroid Build Coastguard Worker mode (str): algorithm used for upsampling: 4235*da0073e9SAndroid Build Coastguard Worker ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | 4236*da0073e9SAndroid Build Coastguard Worker ``'trilinear'``. Default: ``'nearest'`` 4237*da0073e9SAndroid Build Coastguard Worker align_corners (bool, optional): Geometrically, we consider the pixels of the 4238*da0073e9SAndroid Build Coastguard Worker input and output as squares rather than points. 4239*da0073e9SAndroid Build Coastguard Worker If set to ``True``, the input and output tensors are aligned by the 4240*da0073e9SAndroid Build Coastguard Worker center points of their corner pixels, preserving the values at the corner pixels. 4241*da0073e9SAndroid Build Coastguard Worker If set to ``False``, the input and output tensors are aligned by the corner 4242*da0073e9SAndroid Build Coastguard Worker points of their corner pixels, and the interpolation uses edge value padding 4243*da0073e9SAndroid Build Coastguard Worker for out-of-boundary values, making this operation *independent* of input size 4244*da0073e9SAndroid Build Coastguard Worker when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` 4245*da0073e9SAndroid Build Coastguard Worker is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. 4246*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 4247*da0073e9SAndroid Build Coastguard Worker 4248*da0073e9SAndroid Build Coastguard Worker .. note:: 4249*da0073e9SAndroid Build Coastguard Worker With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce 4250*da0073e9SAndroid Build Coastguard Worker negative values or values greater than 255 for images. 4251*da0073e9SAndroid Build Coastguard Worker Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot 4252*da0073e9SAndroid Build Coastguard Worker when displaying the image. 4253*da0073e9SAndroid Build Coastguard Worker 4254*da0073e9SAndroid Build Coastguard Worker .. warning:: 4255*da0073e9SAndroid Build Coastguard Worker With ``align_corners = True``, the linearly interpolating modes 4256*da0073e9SAndroid Build Coastguard Worker (`linear`, `bilinear`, and `trilinear`) don't proportionally align the 4257*da0073e9SAndroid Build Coastguard Worker output and input pixels, and thus the output values can depend on the 4258*da0073e9SAndroid Build Coastguard Worker input size. This was the default behavior for these modes up to version 4259*da0073e9SAndroid Build Coastguard Worker 0.3.1. Since then, the default behavior is ``align_corners = False``. 4260*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.Upsample` for concrete examples on how this 4261*da0073e9SAndroid Build Coastguard Worker affects the outputs. 4262*da0073e9SAndroid Build Coastguard Worker 4263*da0073e9SAndroid Build Coastguard Worker """ 4264*da0073e9SAndroid Build Coastguard Worker warnings.warn( 4265*da0073e9SAndroid Build Coastguard Worker "`nn.functional.upsample` is deprecated. " 4266*da0073e9SAndroid Build Coastguard Worker "Use `nn.functional.interpolate` instead.", 4267*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 4268*da0073e9SAndroid Build Coastguard Worker ) 4269*da0073e9SAndroid Build Coastguard Worker return interpolate(input, size, scale_factor, mode, align_corners) 4270*da0073e9SAndroid Build Coastguard Worker 4271*da0073e9SAndroid Build Coastguard Worker 4272*da0073e9SAndroid Build Coastguard Workerif upsample.__doc__: 4273*da0073e9SAndroid Build Coastguard Worker upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker 4276*da0073e9SAndroid Build Coastguard Workerdef _is_integer(x) -> bool: 4277*da0073e9SAndroid Build Coastguard Worker r"""Type check the input number is an integer. 4278*da0073e9SAndroid Build Coastguard Worker 4279*da0073e9SAndroid Build Coastguard Worker Will return True for int, SymInt, Numpy integers and Tensors with integer elements. 4280*da0073e9SAndroid Build Coastguard Worker """ 4281*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (int, torch.SymInt)): 4282*da0073e9SAndroid Build Coastguard Worker return True 4283*da0073e9SAndroid Build Coastguard Worker if np is not None and isinstance(x, np.integer): 4284*da0073e9SAndroid Build Coastguard Worker return True 4285*da0073e9SAndroid Build Coastguard Worker return isinstance(x, Tensor) and not x.is_floating_point() 4286*da0073e9SAndroid Build Coastguard Worker 4287*da0073e9SAndroid Build Coastguard Worker 4288*da0073e9SAndroid Build Coastguard Worker@_overload 4289*da0073e9SAndroid Build Coastguard Workerdef interpolate( # noqa: F811 4290*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4291*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4292*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[List[float]] = None, 4293*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4294*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4295*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[bool] = None, 4296*da0073e9SAndroid Build Coastguard Worker antialias: bool = False, 4297*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4298*da0073e9SAndroid Build Coastguard Worker pass 4299*da0073e9SAndroid Build Coastguard Worker 4300*da0073e9SAndroid Build Coastguard Worker 4301*da0073e9SAndroid Build Coastguard Worker@_overload 4302*da0073e9SAndroid Build Coastguard Workerdef interpolate( # noqa: F811 4303*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4304*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4305*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[List[float]] = None, 4306*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4307*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4308*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[bool] = None, 4309*da0073e9SAndroid Build Coastguard Worker antialias: bool = False, 4310*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4311*da0073e9SAndroid Build Coastguard Worker pass 4312*da0073e9SAndroid Build Coastguard Worker 4313*da0073e9SAndroid Build Coastguard Worker 4314*da0073e9SAndroid Build Coastguard Worker@_overload 4315*da0073e9SAndroid Build Coastguard Workerdef interpolate( # noqa: F811 4316*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4317*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4318*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4319*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4320*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4321*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[bool] = None, 4322*da0073e9SAndroid Build Coastguard Worker antialias: bool = False, 4323*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4324*da0073e9SAndroid Build Coastguard Worker pass 4325*da0073e9SAndroid Build Coastguard Worker 4326*da0073e9SAndroid Build Coastguard Worker 4327*da0073e9SAndroid Build Coastguard Worker@_overload 4328*da0073e9SAndroid Build Coastguard Workerdef interpolate( # noqa: F811 4329*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4330*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4331*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4332*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4333*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4334*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[bool] = None, 4335*da0073e9SAndroid Build Coastguard Worker antialias: bool = False, 4336*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4337*da0073e9SAndroid Build Coastguard Worker pass 4338*da0073e9SAndroid Build Coastguard Worker 4339*da0073e9SAndroid Build Coastguard Worker 4340*da0073e9SAndroid Build Coastguard Workerdef interpolate( # noqa: F811 4341*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4342*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4343*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[List[float]] = None, 4344*da0073e9SAndroid Build Coastguard Worker mode: str = "nearest", 4345*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4346*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[bool] = None, 4347*da0073e9SAndroid Build Coastguard Worker antialias: bool = False, 4348*da0073e9SAndroid Build Coastguard Worker) -> Tensor: # noqa: B950 4349*da0073e9SAndroid Build Coastguard Worker r"""Down/up samples the input. 4350*da0073e9SAndroid Build Coastguard Worker 4351*da0073e9SAndroid Build Coastguard Worker Tensor interpolated to either the given :attr:`size` or the given 4352*da0073e9SAndroid Build Coastguard Worker :attr:`scale_factor` 4353*da0073e9SAndroid Build Coastguard Worker 4354*da0073e9SAndroid Build Coastguard Worker The algorithm used for interpolation is determined by :attr:`mode`. 4355*da0073e9SAndroid Build Coastguard Worker 4356*da0073e9SAndroid Build Coastguard Worker Currently temporal, spatial and volumetric sampling are supported, i.e. 4357*da0073e9SAndroid Build Coastguard Worker expected inputs are 3-D, 4-D or 5-D in shape. 4358*da0073e9SAndroid Build Coastguard Worker 4359*da0073e9SAndroid Build Coastguard Worker The input dimensions are interpreted in the form: 4360*da0073e9SAndroid Build Coastguard Worker `mini-batch x channels x [optional depth] x [optional height] x width`. 4361*da0073e9SAndroid Build Coastguard Worker 4362*da0073e9SAndroid Build Coastguard Worker The modes available for resizing are: `nearest`, `linear` (3D-only), 4363*da0073e9SAndroid Build Coastguard Worker `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` 4364*da0073e9SAndroid Build Coastguard Worker 4365*da0073e9SAndroid Build Coastguard Worker Args: 4366*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 4367*da0073e9SAndroid Build Coastguard Worker size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): 4368*da0073e9SAndroid Build Coastguard Worker output spatial size. 4369*da0073e9SAndroid Build Coastguard Worker scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, 4370*da0073e9SAndroid Build Coastguard Worker its length has to match the number of spatial dimensions; `input.dim() - 2`. 4371*da0073e9SAndroid Build Coastguard Worker mode (str): algorithm used for upsampling: 4372*da0073e9SAndroid Build Coastguard Worker ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | 4373*da0073e9SAndroid Build Coastguard Worker ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` 4374*da0073e9SAndroid Build Coastguard Worker align_corners (bool, optional): Geometrically, we consider the pixels of the 4375*da0073e9SAndroid Build Coastguard Worker input and output as squares rather than points. 4376*da0073e9SAndroid Build Coastguard Worker If set to ``True``, the input and output tensors are aligned by the 4377*da0073e9SAndroid Build Coastguard Worker center points of their corner pixels, preserving the values at the corner pixels. 4378*da0073e9SAndroid Build Coastguard Worker If set to ``False``, the input and output tensors are aligned by the corner 4379*da0073e9SAndroid Build Coastguard Worker points of their corner pixels, and the interpolation uses edge value padding 4380*da0073e9SAndroid Build Coastguard Worker for out-of-boundary values, making this operation *independent* of input size 4381*da0073e9SAndroid Build Coastguard Worker when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` 4382*da0073e9SAndroid Build Coastguard Worker is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. 4383*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 4384*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor (bool, optional): recompute the scale_factor for use in the 4385*da0073e9SAndroid Build Coastguard Worker interpolation calculation. If `recompute_scale_factor` is ``True``, then 4386*da0073e9SAndroid Build Coastguard Worker `scale_factor` must be passed in and `scale_factor` is used to compute the 4387*da0073e9SAndroid Build Coastguard Worker output `size`. The computed output `size` will be used to infer new scales for 4388*da0073e9SAndroid Build Coastguard Worker the interpolation. Note that when `scale_factor` is floating-point, it may differ 4389*da0073e9SAndroid Build Coastguard Worker from the recomputed `scale_factor` due to rounding and precision issues. 4390*da0073e9SAndroid Build Coastguard Worker If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will 4391*da0073e9SAndroid Build Coastguard Worker be used directly for interpolation. Default: ``None``. 4392*da0073e9SAndroid Build Coastguard Worker antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias 4393*da0073e9SAndroid Build Coastguard Worker option together with ``align_corners=False``, interpolation result would match Pillow 4394*da0073e9SAndroid Build Coastguard Worker result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. 4395*da0073e9SAndroid Build Coastguard Worker 4396*da0073e9SAndroid Build Coastguard Worker .. note:: 4397*da0073e9SAndroid Build Coastguard Worker With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce 4398*da0073e9SAndroid Build Coastguard Worker negative values or values greater than 255 for images. 4399*da0073e9SAndroid Build Coastguard Worker Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot 4400*da0073e9SAndroid Build Coastguard Worker when displaying the image. 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker .. note:: 4403*da0073e9SAndroid Build Coastguard Worker Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation 4404*da0073e9SAndroid Build Coastguard Worker algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep 4405*da0073e9SAndroid Build Coastguard Worker backward compatibility. 4406*da0073e9SAndroid Build Coastguard Worker Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. 4407*da0073e9SAndroid Build Coastguard Worker 4408*da0073e9SAndroid Build Coastguard Worker .. note:: 4409*da0073e9SAndroid Build Coastguard Worker The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation 4410*da0073e9SAndroid Build Coastguard Worker when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``. 4411*da0073e9SAndroid Build Coastguard Worker For more details, please refer to the discussion in 4412*da0073e9SAndroid Build Coastguard Worker `issue#104157 <https://github.com/pytorch/pytorch/issues/104157>`_. 4413*da0073e9SAndroid Build Coastguard Worker 4414*da0073e9SAndroid Build Coastguard Worker Note: 4415*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 4416*da0073e9SAndroid Build Coastguard Worker """ 4417*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 4418*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 4419*da0073e9SAndroid Build Coastguard Worker interpolate, 4420*da0073e9SAndroid Build Coastguard Worker (input,), 4421*da0073e9SAndroid Build Coastguard Worker input, 4422*da0073e9SAndroid Build Coastguard Worker size=size, 4423*da0073e9SAndroid Build Coastguard Worker scale_factor=scale_factor, 4424*da0073e9SAndroid Build Coastguard Worker mode=mode, 4425*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 4426*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor=recompute_scale_factor, 4427*da0073e9SAndroid Build Coastguard Worker antialias=antialias, 4428*da0073e9SAndroid Build Coastguard Worker ) 4429*da0073e9SAndroid Build Coastguard Worker 4430*da0073e9SAndroid Build Coastguard Worker if mode in ("nearest", "area", "nearest-exact"): 4431*da0073e9SAndroid Build Coastguard Worker if align_corners is not None: 4432*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4433*da0073e9SAndroid Build Coastguard Worker "align_corners option can only be set with the " 4434*da0073e9SAndroid Build Coastguard Worker "interpolating modes: linear | bilinear | bicubic | trilinear" 4435*da0073e9SAndroid Build Coastguard Worker ) 4436*da0073e9SAndroid Build Coastguard Worker else: 4437*da0073e9SAndroid Build Coastguard Worker if align_corners is None: 4438*da0073e9SAndroid Build Coastguard Worker align_corners = False 4439*da0073e9SAndroid Build Coastguard Worker 4440*da0073e9SAndroid Build Coastguard Worker dim = input.dim() - 2 # Number of spatial dimensions. 4441*da0073e9SAndroid Build Coastguard Worker 4442*da0073e9SAndroid Build Coastguard Worker # Process size and scale_factor. Validate that exactly one is set. 4443*da0073e9SAndroid Build Coastguard Worker # Validate its length if it is a list, or expand it if it is a scalar. 4444*da0073e9SAndroid Build Coastguard Worker # After this block, exactly one of output_size and scale_factors will 4445*da0073e9SAndroid Build Coastguard Worker # be non-None, and it will be a list (or tuple). 4446*da0073e9SAndroid Build Coastguard Worker if size is not None and scale_factor is not None: 4447*da0073e9SAndroid Build Coastguard Worker raise ValueError("only one of size or scale_factor should be defined") 4448*da0073e9SAndroid Build Coastguard Worker elif size is not None: 4449*da0073e9SAndroid Build Coastguard Worker assert scale_factor is None 4450*da0073e9SAndroid Build Coastguard Worker scale_factors = None 4451*da0073e9SAndroid Build Coastguard Worker if isinstance(size, (list, tuple)): 4452*da0073e9SAndroid Build Coastguard Worker if len(size) != dim: 4453*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4454*da0073e9SAndroid Build Coastguard Worker "Input and output must have the same number of spatial dimensions, but got " 4455*da0073e9SAndroid Build Coastguard Worker f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " 4456*da0073e9SAndroid Build Coastguard Worker "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " 4457*da0073e9SAndroid Build Coastguard Worker "output size in (o1, o2, ...,oK) format." 4458*da0073e9SAndroid Build Coastguard Worker ) 4459*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 4460*da0073e9SAndroid Build Coastguard Worker if not all(_is_integer(x) for x in size): 4461*da0073e9SAndroid Build Coastguard Worker raise TypeError( 4462*da0073e9SAndroid Build Coastguard Worker "expected size to be one of int or Tuple[int] or Tuple[int, int] or " 4463*da0073e9SAndroid Build Coastguard Worker f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}" 4464*da0073e9SAndroid Build Coastguard Worker ) 4465*da0073e9SAndroid Build Coastguard Worker output_size = size 4466*da0073e9SAndroid Build Coastguard Worker else: 4467*da0073e9SAndroid Build Coastguard Worker output_size = [size for _ in range(dim)] 4468*da0073e9SAndroid Build Coastguard Worker elif scale_factor is not None: 4469*da0073e9SAndroid Build Coastguard Worker assert size is None 4470*da0073e9SAndroid Build Coastguard Worker output_size = None 4471*da0073e9SAndroid Build Coastguard Worker if isinstance(scale_factor, (list, tuple)): 4472*da0073e9SAndroid Build Coastguard Worker if len(scale_factor) != dim: 4473*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4474*da0073e9SAndroid Build Coastguard Worker "Input and scale_factor must have the same number of spatial dimensions, but " 4475*da0073e9SAndroid Build Coastguard Worker f"got input with spatial dimensions of {list(input.shape[2:])} and " 4476*da0073e9SAndroid Build Coastguard Worker f"scale_factor of shape {scale_factor}. " 4477*da0073e9SAndroid Build Coastguard Worker "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " 4478*da0073e9SAndroid Build Coastguard Worker "scale_factor in (s1, s2, ...,sK) format." 4479*da0073e9SAndroid Build Coastguard Worker ) 4480*da0073e9SAndroid Build Coastguard Worker scale_factors = scale_factor 4481*da0073e9SAndroid Build Coastguard Worker else: 4482*da0073e9SAndroid Build Coastguard Worker scale_factors = [scale_factor for _ in range(dim)] 4483*da0073e9SAndroid Build Coastguard Worker else: 4484*da0073e9SAndroid Build Coastguard Worker raise ValueError("either size or scale_factor should be defined") 4485*da0073e9SAndroid Build Coastguard Worker 4486*da0073e9SAndroid Build Coastguard Worker if ( 4487*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor is not None 4488*da0073e9SAndroid Build Coastguard Worker and recompute_scale_factor 4489*da0073e9SAndroid Build Coastguard Worker and size is not None 4490*da0073e9SAndroid Build Coastguard Worker ): 4491*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4492*da0073e9SAndroid Build Coastguard Worker "recompute_scale_factor is not meaningful with an explicit size." 4493*da0073e9SAndroid Build Coastguard Worker ) 4494*da0073e9SAndroid Build Coastguard Worker 4495*da0073e9SAndroid Build Coastguard Worker # "area" mode always requires an explicit size rather than scale factor. 4496*da0073e9SAndroid Build Coastguard Worker # Re-use the recompute_scale_factor code path. 4497*da0073e9SAndroid Build Coastguard Worker if mode == "area" and output_size is None: 4498*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor = True 4499*da0073e9SAndroid Build Coastguard Worker 4500*da0073e9SAndroid Build Coastguard Worker if recompute_scale_factor is not None and recompute_scale_factor: 4501*da0073e9SAndroid Build Coastguard Worker # We compute output_size here, then un-set scale_factors. 4502*da0073e9SAndroid Build Coastguard Worker # The C++ code will recompute it based on the (integer) output size. 4503*da0073e9SAndroid Build Coastguard Worker assert scale_factors is not None 4504*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting() and torch._C._get_tracing_state(): 4505*da0073e9SAndroid Build Coastguard Worker # make scale_factor a tensor in tracing so constant doesn't get baked in 4506*da0073e9SAndroid Build Coastguard Worker output_size = [ 4507*da0073e9SAndroid Build Coastguard Worker ( 4508*da0073e9SAndroid Build Coastguard Worker torch.floor( 4509*da0073e9SAndroid Build Coastguard Worker ( 4510*da0073e9SAndroid Build Coastguard Worker input.size(i + 2).float() 4511*da0073e9SAndroid Build Coastguard Worker * torch.tensor(scale_factors[i], dtype=torch.float32) 4512*da0073e9SAndroid Build Coastguard Worker ).float() 4513*da0073e9SAndroid Build Coastguard Worker ) 4514*da0073e9SAndroid Build Coastguard Worker ) 4515*da0073e9SAndroid Build Coastguard Worker for i in range(dim) 4516*da0073e9SAndroid Build Coastguard Worker ] 4517*da0073e9SAndroid Build Coastguard Worker elif torch.jit.is_scripting(): 4518*da0073e9SAndroid Build Coastguard Worker output_size = [ 4519*da0073e9SAndroid Build Coastguard Worker int(math.floor(float(input.size(i + 2)) * scale_factors[i])) 4520*da0073e9SAndroid Build Coastguard Worker for i in range(dim) 4521*da0073e9SAndroid Build Coastguard Worker ] 4522*da0073e9SAndroid Build Coastguard Worker else: 4523*da0073e9SAndroid Build Coastguard Worker output_size = [ 4524*da0073e9SAndroid Build Coastguard Worker _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) 4525*da0073e9SAndroid Build Coastguard Worker ] 4526*da0073e9SAndroid Build Coastguard Worker scale_factors = None 4527*da0073e9SAndroid Build Coastguard Worker 4528*da0073e9SAndroid Build Coastguard Worker if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): 4529*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4530*da0073e9SAndroid Build Coastguard Worker "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" 4531*da0073e9SAndroid Build Coastguard Worker ) 4532*da0073e9SAndroid Build Coastguard Worker 4533*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "nearest": 4534*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) 4535*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "nearest": 4536*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) 4537*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "nearest": 4538*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) 4539*da0073e9SAndroid Build Coastguard Worker 4540*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "nearest-exact": 4541*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) 4542*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "nearest-exact": 4543*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) 4544*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "nearest-exact": 4545*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) 4546*da0073e9SAndroid Build Coastguard Worker 4547*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "area": 4548*da0073e9SAndroid Build Coastguard Worker assert output_size is not None 4549*da0073e9SAndroid Build Coastguard Worker return adaptive_avg_pool1d(input, output_size) 4550*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "area": 4551*da0073e9SAndroid Build Coastguard Worker assert output_size is not None 4552*da0073e9SAndroid Build Coastguard Worker return adaptive_avg_pool2d(input, output_size) 4553*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "area": 4554*da0073e9SAndroid Build Coastguard Worker assert output_size is not None 4555*da0073e9SAndroid Build Coastguard Worker return adaptive_avg_pool3d(input, output_size) 4556*da0073e9SAndroid Build Coastguard Worker 4557*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "linear": 4558*da0073e9SAndroid Build Coastguard Worker assert align_corners is not None 4559*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_linear1d( 4560*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4561*da0073e9SAndroid Build Coastguard Worker ) 4562*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "bilinear": 4563*da0073e9SAndroid Build Coastguard Worker assert align_corners is not None 4564*da0073e9SAndroid Build Coastguard Worker if antialias: 4565*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._upsample_bilinear2d_aa( 4566*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4567*da0073e9SAndroid Build Coastguard Worker ) 4568*da0073e9SAndroid Build Coastguard Worker # Two levels are necessary to prevent TorchScript from touching 4569*da0073e9SAndroid Build Coastguard Worker # are_deterministic_algorithms_enabled. 4570*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 4571*da0073e9SAndroid Build Coastguard Worker if torch.are_deterministic_algorithms_enabled() and ( 4572*da0073e9SAndroid Build Coastguard Worker input.is_cuda or input.is_xpu 4573*da0073e9SAndroid Build Coastguard Worker ): 4574*da0073e9SAndroid Build Coastguard Worker # Use slow decomp whose backward will be in terms of index_put 4575*da0073e9SAndroid Build Coastguard Worker # importlib is required because the import cannot be top level 4576*da0073e9SAndroid Build Coastguard Worker # (cycle) and cannot be nested (TS doesn't support) 4577*da0073e9SAndroid Build Coastguard Worker return importlib.import_module( 4578*da0073e9SAndroid Build Coastguard Worker "torch._decomp.decompositions" 4579*da0073e9SAndroid Build Coastguard Worker )._upsample_linear_vec(input, output_size, align_corners, scale_factors) 4580*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_bilinear2d( 4581*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4582*da0073e9SAndroid Build Coastguard Worker ) 4583*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "trilinear": 4584*da0073e9SAndroid Build Coastguard Worker assert align_corners is not None 4585*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_trilinear3d( 4586*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4587*da0073e9SAndroid Build Coastguard Worker ) 4588*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "bicubic": 4589*da0073e9SAndroid Build Coastguard Worker assert align_corners is not None 4590*da0073e9SAndroid Build Coastguard Worker if antialias: 4591*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._upsample_bicubic2d_aa( 4592*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4593*da0073e9SAndroid Build Coastguard Worker ) 4594*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.upsample_bicubic2d( 4595*da0073e9SAndroid Build Coastguard Worker input, output_size, align_corners, scale_factors 4596*da0073e9SAndroid Build Coastguard Worker ) 4597*da0073e9SAndroid Build Coastguard Worker 4598*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "bilinear": 4599*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") 4600*da0073e9SAndroid Build Coastguard Worker if input.dim() == 3 and mode == "trilinear": 4601*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") 4602*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "linear": 4603*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 4D input, but linear mode needs 3D input") 4604*da0073e9SAndroid Build Coastguard Worker if input.dim() == 4 and mode == "trilinear": 4605*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") 4606*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "linear": 4607*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 5D input, but linear mode needs 3D input") 4608*da0073e9SAndroid Build Coastguard Worker if input.dim() == 5 and mode == "bilinear": 4609*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") 4610*da0073e9SAndroid Build Coastguard Worker 4611*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 4612*da0073e9SAndroid Build Coastguard Worker "Input Error: Only 3D, 4D and 5D input Tensors supported" 4613*da0073e9SAndroid Build Coastguard Worker f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" 4614*da0073e9SAndroid Build Coastguard Worker f" (got {mode})" 4615*da0073e9SAndroid Build Coastguard Worker ) 4616*da0073e9SAndroid Build Coastguard Worker 4617*da0073e9SAndroid Build Coastguard Worker 4618*da0073e9SAndroid Build Coastguard Workerif interpolate.__doc__: 4619*da0073e9SAndroid Build Coastguard Worker interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) 4620*da0073e9SAndroid Build Coastguard Worker 4621*da0073e9SAndroid Build Coastguard Worker 4622*da0073e9SAndroid Build Coastguard Worker@_overload 4623*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest( # noqa: F811 4624*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4625*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4626*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4627*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4628*da0073e9SAndroid Build Coastguard Worker pass 4629*da0073e9SAndroid Build Coastguard Worker 4630*da0073e9SAndroid Build Coastguard Worker 4631*da0073e9SAndroid Build Coastguard Worker@_overload 4632*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest( # noqa: F811 4633*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4634*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4635*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4636*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4637*da0073e9SAndroid Build Coastguard Worker pass 4638*da0073e9SAndroid Build Coastguard Worker 4639*da0073e9SAndroid Build Coastguard Worker 4640*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 4641*da0073e9SAndroid Build Coastguard Worker r"""Upsamples the input, using nearest neighbours' pixel values. 4642*da0073e9SAndroid Build Coastguard Worker 4643*da0073e9SAndroid Build Coastguard Worker .. warning:: 4644*da0073e9SAndroid Build Coastguard Worker This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4645*da0073e9SAndroid Build Coastguard Worker This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``. 4646*da0073e9SAndroid Build Coastguard Worker 4647*da0073e9SAndroid Build Coastguard Worker Currently spatial and volumetric upsampling are supported (i.e. expected 4648*da0073e9SAndroid Build Coastguard Worker inputs are 4 or 5 dimensional). 4649*da0073e9SAndroid Build Coastguard Worker 4650*da0073e9SAndroid Build Coastguard Worker Args: 4651*da0073e9SAndroid Build Coastguard Worker input (Tensor): input 4652*da0073e9SAndroid Build Coastguard Worker size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia 4653*da0073e9SAndroid Build Coastguard Worker size. 4654*da0073e9SAndroid Build Coastguard Worker scale_factor (int): multiplier for spatial size. Has to be an integer. 4655*da0073e9SAndroid Build Coastguard Worker 4656*da0073e9SAndroid Build Coastguard Worker Note: 4657*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 4658*da0073e9SAndroid Build Coastguard Worker """ 4659*da0073e9SAndroid Build Coastguard Worker # DeprecationWarning is ignored by default 4660*da0073e9SAndroid Build Coastguard Worker warnings.warn( 4661*da0073e9SAndroid Build Coastguard Worker "`nn.functional.upsample_nearest` is deprecated. " 4662*da0073e9SAndroid Build Coastguard Worker "Use `nn.functional.interpolate` instead.", 4663*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 4664*da0073e9SAndroid Build Coastguard Worker ) 4665*da0073e9SAndroid Build Coastguard Worker return interpolate(input, size, scale_factor, mode="nearest") 4666*da0073e9SAndroid Build Coastguard Worker 4667*da0073e9SAndroid Build Coastguard Worker 4668*da0073e9SAndroid Build Coastguard Workerif upsample_nearest.__doc__: 4669*da0073e9SAndroid Build Coastguard Worker upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) 4670*da0073e9SAndroid Build Coastguard Worker 4671*da0073e9SAndroid Build Coastguard Worker 4672*da0073e9SAndroid Build Coastguard Worker@_overload 4673*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear( # noqa: F811 4674*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4675*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4676*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4677*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4678*da0073e9SAndroid Build Coastguard Worker pass 4679*da0073e9SAndroid Build Coastguard Worker 4680*da0073e9SAndroid Build Coastguard Worker 4681*da0073e9SAndroid Build Coastguard Worker@_overload 4682*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear( # noqa: F811 4683*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4684*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4685*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[float] = None, 4686*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4687*da0073e9SAndroid Build Coastguard Worker pass 4688*da0073e9SAndroid Build Coastguard Worker 4689*da0073e9SAndroid Build Coastguard Worker 4690*da0073e9SAndroid Build Coastguard Worker@_overload 4691*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear( # noqa: F811 4692*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4693*da0073e9SAndroid Build Coastguard Worker size: Optional[int] = None, 4694*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[List[float]] = None, 4695*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4696*da0073e9SAndroid Build Coastguard Worker pass 4697*da0073e9SAndroid Build Coastguard Worker 4698*da0073e9SAndroid Build Coastguard Worker 4699*da0073e9SAndroid Build Coastguard Worker@_overload 4700*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear( # noqa: F811 4701*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4702*da0073e9SAndroid Build Coastguard Worker size: Optional[List[int]] = None, 4703*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[List[float]] = None, 4704*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4705*da0073e9SAndroid Build Coastguard Worker pass 4706*da0073e9SAndroid Build Coastguard Worker 4707*da0073e9SAndroid Build Coastguard Worker 4708*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 4709*da0073e9SAndroid Build Coastguard Worker r"""Upsamples the input, using bilinear upsampling. 4710*da0073e9SAndroid Build Coastguard Worker 4711*da0073e9SAndroid Build Coastguard Worker .. warning:: 4712*da0073e9SAndroid Build Coastguard Worker This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4713*da0073e9SAndroid Build Coastguard Worker This is equivalent with 4714*da0073e9SAndroid Build Coastguard Worker ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. 4715*da0073e9SAndroid Build Coastguard Worker 4716*da0073e9SAndroid Build Coastguard Worker Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo 4717*da0073e9SAndroid Build Coastguard Worker volumetric (5 dimensional) inputs. 4718*da0073e9SAndroid Build Coastguard Worker 4719*da0073e9SAndroid Build Coastguard Worker Args: 4720*da0073e9SAndroid Build Coastguard Worker input (Tensor): input 4721*da0073e9SAndroid Build Coastguard Worker size (int or Tuple[int, int]): output spatial size. 4722*da0073e9SAndroid Build Coastguard Worker scale_factor (int or Tuple[int, int]): multiplier for spatial size 4723*da0073e9SAndroid Build Coastguard Worker 4724*da0073e9SAndroid Build Coastguard Worker Note: 4725*da0073e9SAndroid Build Coastguard Worker {backward_reproducibility_note} 4726*da0073e9SAndroid Build Coastguard Worker """ 4727*da0073e9SAndroid Build Coastguard Worker # DeprecationWarning is ignored by default 4728*da0073e9SAndroid Build Coastguard Worker warnings.warn( 4729*da0073e9SAndroid Build Coastguard Worker "`nn.functional.upsample_bilinear` is deprecated. " 4730*da0073e9SAndroid Build Coastguard Worker "Use `nn.functional.interpolate` instead.", 4731*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 4732*da0073e9SAndroid Build Coastguard Worker ) 4733*da0073e9SAndroid Build Coastguard Worker return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) 4734*da0073e9SAndroid Build Coastguard Worker 4735*da0073e9SAndroid Build Coastguard Worker 4736*da0073e9SAndroid Build Coastguard Workerif upsample_bilinear.__doc__: 4737*da0073e9SAndroid Build Coastguard Worker upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format( 4738*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes 4739*da0073e9SAndroid Build Coastguard Worker ) 4740*da0073e9SAndroid Build Coastguard Worker 4741*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_INTERPOLATION_MODES = { 4742*da0073e9SAndroid Build Coastguard Worker "bilinear": 0, 4743*da0073e9SAndroid Build Coastguard Worker "nearest": 1, 4744*da0073e9SAndroid Build Coastguard Worker "bicubic": 2, 4745*da0073e9SAndroid Build Coastguard Worker} 4746*da0073e9SAndroid Build Coastguard Worker 4747*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_PADDING_MODES = { 4748*da0073e9SAndroid Build Coastguard Worker "zeros": 0, 4749*da0073e9SAndroid Build Coastguard Worker "border": 1, 4750*da0073e9SAndroid Build Coastguard Worker "reflection": 2, 4751*da0073e9SAndroid Build Coastguard Worker} 4752*da0073e9SAndroid Build Coastguard Worker 4753*da0073e9SAndroid Build Coastguard Worker 4754*da0073e9SAndroid Build Coastguard Workerdef grid_sample( 4755*da0073e9SAndroid Build Coastguard Worker input: Tensor, 4756*da0073e9SAndroid Build Coastguard Worker grid: Tensor, 4757*da0073e9SAndroid Build Coastguard Worker mode: str = "bilinear", 4758*da0073e9SAndroid Build Coastguard Worker padding_mode: str = "zeros", 4759*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4760*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4761*da0073e9SAndroid Build Coastguard Worker r"""Compute grid sample. 4762*da0073e9SAndroid Build Coastguard Worker 4763*da0073e9SAndroid Build Coastguard Worker Given an :attr:`input` and a flow-field :attr:`grid`, computes the 4764*da0073e9SAndroid Build Coastguard Worker ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. 4765*da0073e9SAndroid Build Coastguard Worker 4766*da0073e9SAndroid Build Coastguard Worker Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are 4767*da0073e9SAndroid Build Coastguard Worker supported. 4768*da0073e9SAndroid Build Coastguard Worker 4769*da0073e9SAndroid Build Coastguard Worker In the spatial (4-D) case, for :attr:`input` with shape 4770*da0073e9SAndroid Build Coastguard Worker :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape 4771*da0073e9SAndroid Build Coastguard Worker :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape 4772*da0073e9SAndroid Build Coastguard Worker :math:`(N, C, H_\text{out}, W_\text{out})`. 4773*da0073e9SAndroid Build Coastguard Worker 4774*da0073e9SAndroid Build Coastguard Worker For each output location ``output[n, :, h, w]``, the size-2 vector 4775*da0073e9SAndroid Build Coastguard Worker ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, 4776*da0073e9SAndroid Build Coastguard Worker which are used to interpolate the output value ``output[n, :, h, w]``. 4777*da0073e9SAndroid Build Coastguard Worker In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the 4778*da0073e9SAndroid Build Coastguard Worker ``x``, ``y``, ``z`` pixel locations for interpolating 4779*da0073e9SAndroid Build Coastguard Worker ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or 4780*da0073e9SAndroid Build Coastguard Worker ``bilinear`` interpolation method to sample the input pixels. 4781*da0073e9SAndroid Build Coastguard Worker 4782*da0073e9SAndroid Build Coastguard Worker :attr:`grid` specifies the sampling pixel locations normalized by the 4783*da0073e9SAndroid Build Coastguard Worker :attr:`input` spatial dimensions. Therefore, it should have most values in 4784*da0073e9SAndroid Build Coastguard Worker the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the 4785*da0073e9SAndroid Build Coastguard Worker left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the 4786*da0073e9SAndroid Build Coastguard Worker right-bottom pixel of :attr:`input`. 4787*da0073e9SAndroid Build Coastguard Worker 4788*da0073e9SAndroid Build Coastguard Worker If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding 4789*da0073e9SAndroid Build Coastguard Worker outputs are handled as defined by :attr:`padding_mode`. Options are 4790*da0073e9SAndroid Build Coastguard Worker 4791*da0073e9SAndroid Build Coastguard Worker * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, 4792*da0073e9SAndroid Build Coastguard Worker * ``padding_mode="border"``: use border values for out-of-bound grid locations, 4793*da0073e9SAndroid Build Coastguard Worker * ``padding_mode="reflection"``: use values at locations reflected by 4794*da0073e9SAndroid Build Coastguard Worker the border for out-of-bound grid locations. For location far away 4795*da0073e9SAndroid Build Coastguard Worker from the border, it will keep being reflected until becoming in bound, 4796*da0073e9SAndroid Build Coastguard Worker e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` 4797*da0073e9SAndroid Build Coastguard Worker and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes 4798*da0073e9SAndroid Build Coastguard Worker ``x'' = -0.5``. 4799*da0073e9SAndroid Build Coastguard Worker 4800*da0073e9SAndroid Build Coastguard Worker Note: 4801*da0073e9SAndroid Build Coastguard Worker This function is often used in conjunction with :func:`affine_grid` 4802*da0073e9SAndroid Build Coastguard Worker to build `Spatial Transformer Networks`_ . 4803*da0073e9SAndroid Build Coastguard Worker 4804*da0073e9SAndroid Build Coastguard Worker Note: 4805*da0073e9SAndroid Build Coastguard Worker When using the CUDA backend, this operation may induce nondeterministic 4806*da0073e9SAndroid Build Coastguard Worker behaviour in its backward pass that is not easily switched off. 4807*da0073e9SAndroid Build Coastguard Worker Please see the notes on :doc:`/notes/randomness` for background. 4808*da0073e9SAndroid Build Coastguard Worker 4809*da0073e9SAndroid Build Coastguard Worker Note: 4810*da0073e9SAndroid Build Coastguard Worker NaN values in :attr:`grid` would be interpreted as ``-1``. 4811*da0073e9SAndroid Build Coastguard Worker 4812*da0073e9SAndroid Build Coastguard Worker Args: 4813*da0073e9SAndroid Build Coastguard Worker input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case) 4814*da0073e9SAndroid Build Coastguard Worker or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case) 4815*da0073e9SAndroid Build Coastguard Worker grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) 4816*da0073e9SAndroid Build Coastguard Worker or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) 4817*da0073e9SAndroid Build Coastguard Worker mode (str): interpolation mode to calculate output values 4818*da0073e9SAndroid Build Coastguard Worker ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` 4819*da0073e9SAndroid Build Coastguard Worker Note: ``mode='bicubic'`` supports only 4-D input. 4820*da0073e9SAndroid Build Coastguard Worker When ``mode='bilinear'`` and the input is 5-D, the interpolation mode 4821*da0073e9SAndroid Build Coastguard Worker used internally will actually be trilinear. However, when the input is 4-D, 4822*da0073e9SAndroid Build Coastguard Worker the interpolation mode will legitimately be bilinear. 4823*da0073e9SAndroid Build Coastguard Worker padding_mode (str): padding mode for outside grid values 4824*da0073e9SAndroid Build Coastguard Worker ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` 4825*da0073e9SAndroid Build Coastguard Worker align_corners (bool, optional): Geometrically, we consider the pixels of the 4826*da0073e9SAndroid Build Coastguard Worker input as squares rather than points. 4827*da0073e9SAndroid Build Coastguard Worker If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring 4828*da0073e9SAndroid Build Coastguard Worker to the center points of the input's corner pixels. If set to ``False``, they 4829*da0073e9SAndroid Build Coastguard Worker are instead considered as referring to the corner points of the input's corner 4830*da0073e9SAndroid Build Coastguard Worker pixels, making the sampling more resolution agnostic. 4831*da0073e9SAndroid Build Coastguard Worker This option parallels the ``align_corners`` option in 4832*da0073e9SAndroid Build Coastguard Worker :func:`interpolate`, and so whichever option is used here 4833*da0073e9SAndroid Build Coastguard Worker should also be used there to resize the input image before grid sampling. 4834*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 4835*da0073e9SAndroid Build Coastguard Worker 4836*da0073e9SAndroid Build Coastguard Worker Returns: 4837*da0073e9SAndroid Build Coastguard Worker output (Tensor): output Tensor 4838*da0073e9SAndroid Build Coastguard Worker 4839*da0073e9SAndroid Build Coastguard Worker .. _`Spatial Transformer Networks`: 4840*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1506.02025 4841*da0073e9SAndroid Build Coastguard Worker 4842*da0073e9SAndroid Build Coastguard Worker .. warning:: 4843*da0073e9SAndroid Build Coastguard Worker When ``align_corners = True``, the grid positions depend on the pixel 4844*da0073e9SAndroid Build Coastguard Worker size relative to the input image size, and so the locations sampled by 4845*da0073e9SAndroid Build Coastguard Worker :func:`grid_sample` will differ for the same input given at different 4846*da0073e9SAndroid Build Coastguard Worker resolutions (that is, after being upsampled or downsampled). 4847*da0073e9SAndroid Build Coastguard Worker The default behavior up to version 1.2.0 was ``align_corners = True``. 4848*da0073e9SAndroid Build Coastguard Worker Since then, the default behavior has been changed to ``align_corners = False``, 4849*da0073e9SAndroid Build Coastguard Worker in order to bring it in line with the default for :func:`interpolate`. 4850*da0073e9SAndroid Build Coastguard Worker 4851*da0073e9SAndroid Build Coastguard Worker .. note:: 4852*da0073e9SAndroid Build Coastguard Worker ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. 4853*da0073e9SAndroid Build Coastguard Worker The constant :math:`\alpha` might be different from packages to packages. 4854*da0073e9SAndroid Build Coastguard Worker For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. 4855*da0073e9SAndroid Build Coastguard Worker This algorithm may "overshoot" the range of values it's interpolating. 4856*da0073e9SAndroid Build Coastguard Worker For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. 4857*da0073e9SAndroid Build Coastguard Worker Clamp the results with :func:`torch.clamp` to ensure they are within the valid range. 4858*da0073e9SAndroid Build Coastguard Worker .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation 4859*da0073e9SAndroid Build Coastguard Worker .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 4860*da0073e9SAndroid Build Coastguard Worker .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 4861*da0073e9SAndroid Build Coastguard Worker """ 4862*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, grid): 4863*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 4864*da0073e9SAndroid Build Coastguard Worker grid_sample, 4865*da0073e9SAndroid Build Coastguard Worker (input, grid), 4866*da0073e9SAndroid Build Coastguard Worker input, 4867*da0073e9SAndroid Build Coastguard Worker grid, 4868*da0073e9SAndroid Build Coastguard Worker mode=mode, 4869*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 4870*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 4871*da0073e9SAndroid Build Coastguard Worker ) 4872*da0073e9SAndroid Build Coastguard Worker if mode != "bilinear" and mode != "nearest" and mode != "bicubic": 4873*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4874*da0073e9SAndroid Build Coastguard Worker f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" 4875*da0073e9SAndroid Build Coastguard Worker ) 4876*da0073e9SAndroid Build Coastguard Worker if ( 4877*da0073e9SAndroid Build Coastguard Worker padding_mode != "zeros" 4878*da0073e9SAndroid Build Coastguard Worker and padding_mode != "border" 4879*da0073e9SAndroid Build Coastguard Worker and padding_mode != "reflection" 4880*da0073e9SAndroid Build Coastguard Worker ): 4881*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4882*da0073e9SAndroid Build Coastguard Worker "nn.functional.grid_sample(): expected padding_mode " 4883*da0073e9SAndroid Build Coastguard Worker "to be 'zeros', 'border', or 'reflection', " 4884*da0073e9SAndroid Build Coastguard Worker f"but got: '{padding_mode}'" 4885*da0073e9SAndroid Build Coastguard Worker ) 4886*da0073e9SAndroid Build Coastguard Worker 4887*da0073e9SAndroid Build Coastguard Worker if mode == "bilinear": 4888*da0073e9SAndroid Build Coastguard Worker mode_enum = 0 4889*da0073e9SAndroid Build Coastguard Worker elif mode == "nearest": 4890*da0073e9SAndroid Build Coastguard Worker mode_enum = 1 4891*da0073e9SAndroid Build Coastguard Worker else: # mode == 'bicubic' 4892*da0073e9SAndroid Build Coastguard Worker mode_enum = 2 4893*da0073e9SAndroid Build Coastguard Worker 4894*da0073e9SAndroid Build Coastguard Worker if padding_mode == "zeros": 4895*da0073e9SAndroid Build Coastguard Worker padding_mode_enum = 0 4896*da0073e9SAndroid Build Coastguard Worker elif padding_mode == "border": 4897*da0073e9SAndroid Build Coastguard Worker padding_mode_enum = 1 4898*da0073e9SAndroid Build Coastguard Worker else: # padding_mode == 'reflection' 4899*da0073e9SAndroid Build Coastguard Worker padding_mode_enum = 2 4900*da0073e9SAndroid Build Coastguard Worker 4901*da0073e9SAndroid Build Coastguard Worker if align_corners is None: 4902*da0073e9SAndroid Build Coastguard Worker warnings.warn( 4903*da0073e9SAndroid Build Coastguard Worker "Default grid_sample and affine_grid behavior has changed " 4904*da0073e9SAndroid Build Coastguard Worker "to align_corners=False since 1.3.0. Please specify " 4905*da0073e9SAndroid Build Coastguard Worker "align_corners=True if the old behavior is desired. " 4906*da0073e9SAndroid Build Coastguard Worker "See the documentation of grid_sample for details." 4907*da0073e9SAndroid Build Coastguard Worker ) 4908*da0073e9SAndroid Build Coastguard Worker align_corners = False 4909*da0073e9SAndroid Build Coastguard Worker 4910*da0073e9SAndroid Build Coastguard Worker return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) 4911*da0073e9SAndroid Build Coastguard Worker 4912*da0073e9SAndroid Build Coastguard Worker 4913*da0073e9SAndroid Build Coastguard Workerdef affine_grid( 4914*da0073e9SAndroid Build Coastguard Worker theta: Tensor, 4915*da0073e9SAndroid Build Coastguard Worker size: List[int], 4916*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[bool] = None, 4917*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 4918*da0073e9SAndroid Build Coastguard Worker r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. 4919*da0073e9SAndroid Build Coastguard Worker 4920*da0073e9SAndroid Build Coastguard Worker .. note:: 4921*da0073e9SAndroid Build Coastguard Worker This function is often used in conjunction with :func:`grid_sample` 4922*da0073e9SAndroid Build Coastguard Worker to build `Spatial Transformer Networks`_ . 4923*da0073e9SAndroid Build Coastguard Worker 4924*da0073e9SAndroid Build Coastguard Worker Args: 4925*da0073e9SAndroid Build Coastguard Worker theta (Tensor): input batch of affine matrices with shape 4926*da0073e9SAndroid Build Coastguard Worker (:math:`N \times 2 \times 3`) for 2D or 4927*da0073e9SAndroid Build Coastguard Worker (:math:`N \times 3 \times 4`) for 3D 4928*da0073e9SAndroid Build Coastguard Worker size (torch.Size): the target output image size. 4929*da0073e9SAndroid Build Coastguard Worker (:math:`N \times C \times H \times W` for 2D or 4930*da0073e9SAndroid Build Coastguard Worker :math:`N \times C \times D \times H \times W` for 3D) 4931*da0073e9SAndroid Build Coastguard Worker Example: torch.Size((32, 3, 24, 24)) 4932*da0073e9SAndroid Build Coastguard Worker align_corners (bool, optional): if ``True``, consider ``-1`` and ``1`` 4933*da0073e9SAndroid Build Coastguard Worker to refer to the centers of the corner pixels rather than the image corners. 4934*da0073e9SAndroid Build Coastguard Worker Refer to :func:`grid_sample` for a more complete description. 4935*da0073e9SAndroid Build Coastguard Worker A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` 4936*da0073e9SAndroid Build Coastguard Worker with the same setting for this option. 4937*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 4938*da0073e9SAndroid Build Coastguard Worker 4939*da0073e9SAndroid Build Coastguard Worker Returns: 4940*da0073e9SAndroid Build Coastguard Worker output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) 4941*da0073e9SAndroid Build Coastguard Worker 4942*da0073e9SAndroid Build Coastguard Worker .. _`Spatial Transformer Networks`: 4943*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1506.02025 4944*da0073e9SAndroid Build Coastguard Worker 4945*da0073e9SAndroid Build Coastguard Worker .. warning:: 4946*da0073e9SAndroid Build Coastguard Worker When ``align_corners = True``, the grid positions depend on the pixel 4947*da0073e9SAndroid Build Coastguard Worker size relative to the input image size, and so the locations sampled by 4948*da0073e9SAndroid Build Coastguard Worker :func:`grid_sample` will differ for the same input given at different 4949*da0073e9SAndroid Build Coastguard Worker resolutions (that is, after being upsampled or downsampled). 4950*da0073e9SAndroid Build Coastguard Worker The default behavior up to version 1.2.0 was ``align_corners = True``. 4951*da0073e9SAndroid Build Coastguard Worker Since then, the default behavior has been changed to ``align_corners = False``, 4952*da0073e9SAndroid Build Coastguard Worker in order to bring it in line with the default for :func:`interpolate`. 4953*da0073e9SAndroid Build Coastguard Worker .. warning:: 4954*da0073e9SAndroid Build Coastguard Worker When ``align_corners = True``, 2D affine transforms on 1D data and 4955*da0073e9SAndroid Build Coastguard Worker 3D affine transforms on 2D data (that is, when one of the spatial 4956*da0073e9SAndroid Build Coastguard Worker dimensions has unit size) are ill-defined, and not an intended use case. 4957*da0073e9SAndroid Build Coastguard Worker This is not a problem when ``align_corners = False``. 4958*da0073e9SAndroid Build Coastguard Worker Up to version 1.2.0, all grid points along a unit dimension were 4959*da0073e9SAndroid Build Coastguard Worker considered arbitrarily to be at ``-1``. 4960*da0073e9SAndroid Build Coastguard Worker From version 1.3.0, under ``align_corners = True`` all grid points 4961*da0073e9SAndroid Build Coastguard Worker along a unit dimension are considered to be at ``0`` 4962*da0073e9SAndroid Build Coastguard Worker (the center of the input image). 4963*da0073e9SAndroid Build Coastguard Worker """ 4964*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(theta): 4965*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 4966*da0073e9SAndroid Build Coastguard Worker affine_grid, (theta,), theta, size, align_corners=align_corners 4967*da0073e9SAndroid Build Coastguard Worker ) 4968*da0073e9SAndroid Build Coastguard Worker if align_corners is None: 4969*da0073e9SAndroid Build Coastguard Worker warnings.warn( 4970*da0073e9SAndroid Build Coastguard Worker "Default grid_sample and affine_grid behavior has changed " 4971*da0073e9SAndroid Build Coastguard Worker "to align_corners=False since 1.3.0. Please specify " 4972*da0073e9SAndroid Build Coastguard Worker "align_corners=True if the old behavior is desired. " 4973*da0073e9SAndroid Build Coastguard Worker "See the documentation of grid_sample for details." 4974*da0073e9SAndroid Build Coastguard Worker ) 4975*da0073e9SAndroid Build Coastguard Worker align_corners = False 4976*da0073e9SAndroid Build Coastguard Worker 4977*da0073e9SAndroid Build Coastguard Worker # enforce floating point dtype on theta 4978*da0073e9SAndroid Build Coastguard Worker if not theta.is_floating_point(): 4979*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4980*da0073e9SAndroid Build Coastguard Worker f"Expected theta to have floating point type, but got {theta.dtype}" 4981*da0073e9SAndroid Build Coastguard Worker ) 4982*da0073e9SAndroid Build Coastguard Worker # check that shapes and sizes match 4983*da0073e9SAndroid Build Coastguard Worker if len(size) == 4: 4984*da0073e9SAndroid Build Coastguard Worker if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: 4985*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4986*da0073e9SAndroid Build Coastguard Worker f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}." 4987*da0073e9SAndroid Build Coastguard Worker ) 4988*da0073e9SAndroid Build Coastguard Worker spatial_size = size[-2:] # spatial dimension sizes 4989*da0073e9SAndroid Build Coastguard Worker elif len(size) == 5: 4990*da0073e9SAndroid Build Coastguard Worker if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: 4991*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4992*da0073e9SAndroid Build Coastguard Worker f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}." 4993*da0073e9SAndroid Build Coastguard Worker ) 4994*da0073e9SAndroid Build Coastguard Worker spatial_size = size[-3:] # spatial dimension sizes 4995*da0073e9SAndroid Build Coastguard Worker else: 4996*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 4997*da0073e9SAndroid Build Coastguard Worker "affine_grid only supports 4D and 5D sizes, " 4998*da0073e9SAndroid Build Coastguard Worker "for 2D and 3D affine transforms, respectively. " 4999*da0073e9SAndroid Build Coastguard Worker f"Got size {size}." 5000*da0073e9SAndroid Build Coastguard Worker ) 5001*da0073e9SAndroid Build Coastguard Worker # check for empty span 5002*da0073e9SAndroid Build Coastguard Worker if align_corners and min(spatial_size) == 1: 5003*da0073e9SAndroid Build Coastguard Worker warnings.warn( 5004*da0073e9SAndroid Build Coastguard Worker "Since version 1.3.0, affine_grid behavior has changed " 5005*da0073e9SAndroid Build Coastguard Worker "for unit-size grids when align_corners=True. " 5006*da0073e9SAndroid Build Coastguard Worker "This is not an intended use case of affine_grid. " 5007*da0073e9SAndroid Build Coastguard Worker "See the documentation of affine_grid for details." 5008*da0073e9SAndroid Build Coastguard Worker ) 5009*da0073e9SAndroid Build Coastguard Worker elif min(size) <= 0: 5010*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Expected non-zero, positive output size. Got {size}") 5011*da0073e9SAndroid Build Coastguard Worker 5012*da0073e9SAndroid Build Coastguard Worker return torch.affine_grid_generator(theta, size, align_corners) 5013*da0073e9SAndroid Build Coastguard Worker 5014*da0073e9SAndroid Build Coastguard Worker 5015*da0073e9SAndroid Build Coastguard Workerdef pad( 5016*da0073e9SAndroid Build Coastguard Worker input: Tensor, 5017*da0073e9SAndroid Build Coastguard Worker pad: List[int], 5018*da0073e9SAndroid Build Coastguard Worker mode: str = "constant", 5019*da0073e9SAndroid Build Coastguard Worker value: Optional[float] = None, 5020*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5021*da0073e9SAndroid Build Coastguard Worker r""" 5022*da0073e9SAndroid Build Coastguard Worker pad(input, pad, mode="constant", value=None) -> Tensor 5023*da0073e9SAndroid Build Coastguard Worker 5024*da0073e9SAndroid Build Coastguard Worker Pads tensor. 5025*da0073e9SAndroid Build Coastguard Worker 5026*da0073e9SAndroid Build Coastguard Worker Padding size: 5027*da0073e9SAndroid Build Coastguard Worker The padding size by which to pad some dimensions of :attr:`input` 5028*da0073e9SAndroid Build Coastguard Worker are described starting from the last dimension and moving forward. 5029*da0073e9SAndroid Build Coastguard Worker :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions 5030*da0073e9SAndroid Build Coastguard Worker of ``input`` will be padded. 5031*da0073e9SAndroid Build Coastguard Worker For example, to pad only the last dimension of the input tensor, then 5032*da0073e9SAndroid Build Coastguard Worker :attr:`pad` has the form 5033*da0073e9SAndroid Build Coastguard Worker :math:`(\text{padding\_left}, \text{padding\_right})`; 5034*da0073e9SAndroid Build Coastguard Worker to pad the last 2 dimensions of the input tensor, then use 5035*da0073e9SAndroid Build Coastguard Worker :math:`(\text{padding\_left}, \text{padding\_right},` 5036*da0073e9SAndroid Build Coastguard Worker :math:`\text{padding\_top}, \text{padding\_bottom})`; 5037*da0073e9SAndroid Build Coastguard Worker to pad the last 3 dimensions, use 5038*da0073e9SAndroid Build Coastguard Worker :math:`(\text{padding\_left}, \text{padding\_right},` 5039*da0073e9SAndroid Build Coastguard Worker :math:`\text{padding\_top}, \text{padding\_bottom}` 5040*da0073e9SAndroid Build Coastguard Worker :math:`\text{padding\_front}, \text{padding\_back})`. 5041*da0073e9SAndroid Build Coastguard Worker 5042*da0073e9SAndroid Build Coastguard Worker Padding mode: 5043*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, 5044*da0073e9SAndroid Build Coastguard Worker :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` 5045*da0073e9SAndroid Build Coastguard Worker for concrete examples on how each of the padding modes works. Constant 5046*da0073e9SAndroid Build Coastguard Worker padding is implemented for arbitrary dimensions. Circular, replicate and 5047*da0073e9SAndroid Build Coastguard Worker reflection padding are implemented for padding the last 3 dimensions of a 5048*da0073e9SAndroid Build Coastguard Worker 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, 5049*da0073e9SAndroid Build Coastguard Worker or the last dimension of a 2D or 3D input tensor. 5050*da0073e9SAndroid Build Coastguard Worker 5051*da0073e9SAndroid Build Coastguard Worker Note: 5052*da0073e9SAndroid Build Coastguard Worker When using the CUDA backend, this operation may induce nondeterministic 5053*da0073e9SAndroid Build Coastguard Worker behaviour in its backward pass that is not easily switched off. 5054*da0073e9SAndroid Build Coastguard Worker Please see the notes on :doc:`/notes/randomness` for background. 5055*da0073e9SAndroid Build Coastguard Worker 5056*da0073e9SAndroid Build Coastguard Worker Args: 5057*da0073e9SAndroid Build Coastguard Worker input (Tensor): N-dimensional tensor 5058*da0073e9SAndroid Build Coastguard Worker pad (tuple): m-elements tuple, where 5059*da0073e9SAndroid Build Coastguard Worker :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. 5060*da0073e9SAndroid Build Coastguard Worker mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. 5061*da0073e9SAndroid Build Coastguard Worker Default: ``'constant'`` 5062*da0073e9SAndroid Build Coastguard Worker value: fill value for ``'constant'`` padding. Default: ``0`` 5063*da0073e9SAndroid Build Coastguard Worker 5064*da0073e9SAndroid Build Coastguard Worker Examples:: 5065*da0073e9SAndroid Build Coastguard Worker 5066*da0073e9SAndroid Build Coastguard Worker >>> t4d = torch.empty(3, 3, 4, 2) 5067*da0073e9SAndroid Build Coastguard Worker >>> p1d = (1, 1) # pad last dim by 1 on each side 5068*da0073e9SAndroid Build Coastguard Worker >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding 5069*da0073e9SAndroid Build Coastguard Worker >>> print(out.size()) 5070*da0073e9SAndroid Build Coastguard Worker torch.Size([3, 3, 4, 4]) 5071*da0073e9SAndroid Build Coastguard Worker >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) 5072*da0073e9SAndroid Build Coastguard Worker >>> out = F.pad(t4d, p2d, "constant", 0) 5073*da0073e9SAndroid Build Coastguard Worker >>> print(out.size()) 5074*da0073e9SAndroid Build Coastguard Worker torch.Size([3, 3, 8, 4]) 5075*da0073e9SAndroid Build Coastguard Worker >>> t4d = torch.empty(3, 3, 4, 2) 5076*da0073e9SAndroid Build Coastguard Worker >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) 5077*da0073e9SAndroid Build Coastguard Worker >>> out = F.pad(t4d, p3d, "constant", 0) 5078*da0073e9SAndroid Build Coastguard Worker >>> print(out.size()) 5079*da0073e9SAndroid Build Coastguard Worker torch.Size([3, 9, 7, 3]) 5080*da0073e9SAndroid Build Coastguard Worker """ 5081*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 5082*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5083*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value 5084*da0073e9SAndroid Build Coastguard Worker ) 5085*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 5086*da0073e9SAndroid Build Coastguard Worker if torch.are_deterministic_algorithms_enabled() and ( 5087*da0073e9SAndroid Build Coastguard Worker input.is_cuda or input.is_xpu 5088*da0073e9SAndroid Build Coastguard Worker ): 5089*da0073e9SAndroid Build Coastguard Worker if mode == "replicate": 5090*da0073e9SAndroid Build Coastguard Worker # Use slow decomp whose backward will be in terms of index_put. 5091*da0073e9SAndroid Build Coastguard Worker # importlib is required because the import cannot be top level 5092*da0073e9SAndroid Build Coastguard Worker # (cycle) and cannot be nested (TS doesn't support) 5093*da0073e9SAndroid Build Coastguard Worker return importlib.import_module( 5094*da0073e9SAndroid Build Coastguard Worker "torch._decomp.decompositions" 5095*da0073e9SAndroid Build Coastguard Worker )._replication_pad(input, pad) 5096*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.pad(input, pad, mode, value) 5097*da0073e9SAndroid Build Coastguard Worker 5098*da0073e9SAndroid Build Coastguard Worker 5099*da0073e9SAndroid Build Coastguard Worker# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 5100*da0073e9SAndroid Build Coastguard Workerpad.__module__ = "torch.nn.functional" 5101*da0073e9SAndroid Build Coastguard Worker 5102*da0073e9SAndroid Build Coastguard Worker# distance 5103*da0073e9SAndroid Build Coastguard Worker 5104*da0073e9SAndroid Build Coastguard Worker 5105*da0073e9SAndroid Build Coastguard Workerpairwise_distance = _add_docstr( 5106*da0073e9SAndroid Build Coastguard Worker torch.pairwise_distance, 5107*da0073e9SAndroid Build Coastguard Worker r""" 5108*da0073e9SAndroid Build Coastguard Workerpairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor 5109*da0073e9SAndroid Build Coastguard Worker 5110*da0073e9SAndroid Build Coastguard WorkerSee :class:`torch.nn.PairwiseDistance` for details 5111*da0073e9SAndroid Build Coastguard Worker""", 5112*da0073e9SAndroid Build Coastguard Worker) 5113*da0073e9SAndroid Build Coastguard Worker 5114*da0073e9SAndroid Build Coastguard Worker 5115*da0073e9SAndroid Build Coastguard Workerpdist = _add_docstr( 5116*da0073e9SAndroid Build Coastguard Worker torch.pdist, 5117*da0073e9SAndroid Build Coastguard Worker r""" 5118*da0073e9SAndroid Build Coastguard Workerpdist(input, p=2) -> Tensor 5119*da0073e9SAndroid Build Coastguard Worker 5120*da0073e9SAndroid Build Coastguard WorkerComputes the p-norm distance between every pair of row vectors in the input. 5121*da0073e9SAndroid Build Coastguard WorkerThis is identical to the upper triangular portion, excluding the diagonal, of 5122*da0073e9SAndroid Build Coastguard Worker`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster 5123*da0073e9SAndroid Build Coastguard Workerif the rows are contiguous. 5124*da0073e9SAndroid Build Coastguard Worker 5125*da0073e9SAndroid Build Coastguard WorkerIf input has shape :math:`N \times M` then the output will have shape 5126*da0073e9SAndroid Build Coastguard Worker:math:`\frac{1}{2} N (N - 1)`. 5127*da0073e9SAndroid Build Coastguard Worker 5128*da0073e9SAndroid Build Coastguard WorkerThis function is equivalent to ``scipy.spatial.distance.pdist(input, 5129*da0073e9SAndroid Build Coastguard Worker'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is 5130*da0073e9SAndroid Build Coastguard Workerequivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``. 5131*da0073e9SAndroid Build Coastguard WorkerWhen :math:`p = \infty`, the closest scipy function is 5132*da0073e9SAndroid Build Coastguard Worker``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``. 5133*da0073e9SAndroid Build Coastguard Worker 5134*da0073e9SAndroid Build Coastguard WorkerArgs: 5135*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape :math:`N \times M`. 5136*da0073e9SAndroid Build Coastguard Worker p: p value for the p-norm distance to calculate between each vector pair 5137*da0073e9SAndroid Build Coastguard Worker :math:`\in [0, \infty]`. 5138*da0073e9SAndroid Build Coastguard Worker""", 5139*da0073e9SAndroid Build Coastguard Worker) 5140*da0073e9SAndroid Build Coastguard Worker 5141*da0073e9SAndroid Build Coastguard Worker 5142*da0073e9SAndroid Build Coastguard Workercosine_similarity = _add_docstr( 5143*da0073e9SAndroid Build Coastguard Worker torch.cosine_similarity, 5144*da0073e9SAndroid Build Coastguard Worker r""" 5145*da0073e9SAndroid Build Coastguard Workercosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor 5146*da0073e9SAndroid Build Coastguard Worker 5147*da0073e9SAndroid Build Coastguard WorkerReturns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable 5148*da0073e9SAndroid Build Coastguard Workerto a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is 5149*da0073e9SAndroid Build Coastguard Workersqueezed (see :func:`torch.squeeze`), resulting in the 5150*da0073e9SAndroid Build Coastguard Workeroutput tensor having 1 fewer dimension. 5151*da0073e9SAndroid Build Coastguard Worker 5152*da0073e9SAndroid Build Coastguard Worker.. math :: 5153*da0073e9SAndroid Build Coastguard Worker \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)} 5154*da0073e9SAndroid Build Coastguard Worker 5155*da0073e9SAndroid Build Coastguard WorkerSupports :ref:`type promotion <type-promotion-doc>`. 5156*da0073e9SAndroid Build Coastguard Worker 5157*da0073e9SAndroid Build Coastguard WorkerArgs: 5158*da0073e9SAndroid Build Coastguard Worker x1 (Tensor): First input. 5159*da0073e9SAndroid Build Coastguard Worker x2 (Tensor): Second input. 5160*da0073e9SAndroid Build Coastguard Worker dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 5161*da0073e9SAndroid Build Coastguard Worker eps (float, optional): Small value to avoid division by zero. 5162*da0073e9SAndroid Build Coastguard Worker Default: 1e-8 5163*da0073e9SAndroid Build Coastguard Worker 5164*da0073e9SAndroid Build Coastguard WorkerExample:: 5165*da0073e9SAndroid Build Coastguard Worker 5166*da0073e9SAndroid Build Coastguard Worker >>> input1 = torch.randn(100, 128) 5167*da0073e9SAndroid Build Coastguard Worker >>> input2 = torch.randn(100, 128) 5168*da0073e9SAndroid Build Coastguard Worker >>> output = F.cosine_similarity(input1, input2) 5169*da0073e9SAndroid Build Coastguard Worker >>> print(output) 5170*da0073e9SAndroid Build Coastguard Worker""", 5171*da0073e9SAndroid Build Coastguard Worker) 5172*da0073e9SAndroid Build Coastguard Worker 5173*da0073e9SAndroid Build Coastguard Worker 5174*da0073e9SAndroid Build Coastguard Workerone_hot = _add_docstr( 5175*da0073e9SAndroid Build Coastguard Worker torch._C._nn.one_hot, 5176*da0073e9SAndroid Build Coastguard Worker r""" 5177*da0073e9SAndroid Build Coastguard Workerone_hot(tensor, num_classes=-1) -> LongTensor 5178*da0073e9SAndroid Build Coastguard Worker 5179*da0073e9SAndroid Build Coastguard WorkerTakes LongTensor with index values of shape ``(*)`` and returns a tensor 5180*da0073e9SAndroid Build Coastguard Workerof shape ``(*, num_classes)`` that have zeros everywhere except where the 5181*da0073e9SAndroid Build Coastguard Workerindex of last dimension matches the corresponding value of the input tensor, 5182*da0073e9SAndroid Build Coastguard Workerin which case it will be 1. 5183*da0073e9SAndroid Build Coastguard Worker 5184*da0073e9SAndroid Build Coastguard WorkerSee also `One-hot on Wikipedia`_ . 5185*da0073e9SAndroid Build Coastguard Worker 5186*da0073e9SAndroid Build Coastguard Worker.. _One-hot on Wikipedia: 5187*da0073e9SAndroid Build Coastguard Worker https://en.wikipedia.org/wiki/One-hot 5188*da0073e9SAndroid Build Coastguard Worker 5189*da0073e9SAndroid Build Coastguard WorkerArguments: 5190*da0073e9SAndroid Build Coastguard Worker tensor (LongTensor): class values of any shape. 5191*da0073e9SAndroid Build Coastguard Worker num_classes (int): Total number of classes. If set to -1, the number 5192*da0073e9SAndroid Build Coastguard Worker of classes will be inferred as one greater than the largest class 5193*da0073e9SAndroid Build Coastguard Worker value in the input tensor. 5194*da0073e9SAndroid Build Coastguard Worker 5195*da0073e9SAndroid Build Coastguard WorkerReturns: 5196*da0073e9SAndroid Build Coastguard Worker LongTensor that has one more dimension with 1 values at the 5197*da0073e9SAndroid Build Coastguard Worker index of last dimension indicated by the input, and 0 everywhere 5198*da0073e9SAndroid Build Coastguard Worker else. 5199*da0073e9SAndroid Build Coastguard Worker 5200*da0073e9SAndroid Build Coastguard WorkerExamples: 5201*da0073e9SAndroid Build Coastguard Worker >>> F.one_hot(torch.arange(0, 5) % 3) 5202*da0073e9SAndroid Build Coastguard Worker tensor([[1, 0, 0], 5203*da0073e9SAndroid Build Coastguard Worker [0, 1, 0], 5204*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 5205*da0073e9SAndroid Build Coastguard Worker [1, 0, 0], 5206*da0073e9SAndroid Build Coastguard Worker [0, 1, 0]]) 5207*da0073e9SAndroid Build Coastguard Worker >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) 5208*da0073e9SAndroid Build Coastguard Worker tensor([[1, 0, 0, 0, 0], 5209*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0, 0], 5210*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 0, 0], 5211*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0], 5212*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0, 0]]) 5213*da0073e9SAndroid Build Coastguard Worker >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) 5214*da0073e9SAndroid Build Coastguard Worker tensor([[[1, 0, 0], 5215*da0073e9SAndroid Build Coastguard Worker [0, 1, 0]], 5216*da0073e9SAndroid Build Coastguard Worker [[0, 0, 1], 5217*da0073e9SAndroid Build Coastguard Worker [1, 0, 0]], 5218*da0073e9SAndroid Build Coastguard Worker [[0, 1, 0], 5219*da0073e9SAndroid Build Coastguard Worker [0, 0, 1]]]) 5220*da0073e9SAndroid Build Coastguard Worker""", 5221*da0073e9SAndroid Build Coastguard Worker) 5222*da0073e9SAndroid Build Coastguard Worker 5223*da0073e9SAndroid Build Coastguard Worker 5224*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_loss( 5225*da0073e9SAndroid Build Coastguard Worker anchor: Tensor, 5226*da0073e9SAndroid Build Coastguard Worker positive: Tensor, 5227*da0073e9SAndroid Build Coastguard Worker negative: Tensor, 5228*da0073e9SAndroid Build Coastguard Worker margin: float = 1.0, 5229*da0073e9SAndroid Build Coastguard Worker p: float = 2, 5230*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-6, 5231*da0073e9SAndroid Build Coastguard Worker swap: bool = False, 5232*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = None, 5233*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = None, 5234*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 5235*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5236*da0073e9SAndroid Build Coastguard Worker r"""Compute the triplet loss between given input tensors and a margin greater than 0. 5237*da0073e9SAndroid Build Coastguard Worker 5238*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.TripletMarginLoss` for details. 5239*da0073e9SAndroid Build Coastguard Worker """ 5240*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(anchor, positive, negative): 5241*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5242*da0073e9SAndroid Build Coastguard Worker triplet_margin_loss, 5243*da0073e9SAndroid Build Coastguard Worker (anchor, positive, negative), 5244*da0073e9SAndroid Build Coastguard Worker anchor, 5245*da0073e9SAndroid Build Coastguard Worker positive, 5246*da0073e9SAndroid Build Coastguard Worker negative, 5247*da0073e9SAndroid Build Coastguard Worker margin=margin, 5248*da0073e9SAndroid Build Coastguard Worker p=p, 5249*da0073e9SAndroid Build Coastguard Worker eps=eps, 5250*da0073e9SAndroid Build Coastguard Worker swap=swap, 5251*da0073e9SAndroid Build Coastguard Worker size_average=size_average, 5252*da0073e9SAndroid Build Coastguard Worker reduce=reduce, 5253*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 5254*da0073e9SAndroid Build Coastguard Worker ) 5255*da0073e9SAndroid Build Coastguard Worker if size_average is not None or reduce is not None: 5256*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 5257*da0073e9SAndroid Build Coastguard Worker else: 5258*da0073e9SAndroid Build Coastguard Worker reduction_enum = _Reduction.get_enum(reduction) 5259*da0073e9SAndroid Build Coastguard Worker if margin <= 0: 5260*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"margin must be greater than 0, got {margin}") 5261*da0073e9SAndroid Build Coastguard Worker return torch.triplet_margin_loss( 5262*da0073e9SAndroid Build Coastguard Worker anchor, positive, negative, margin, p, eps, swap, reduction_enum 5263*da0073e9SAndroid Build Coastguard Worker ) 5264*da0073e9SAndroid Build Coastguard Worker 5265*da0073e9SAndroid Build Coastguard Worker 5266*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_with_distance_loss( 5267*da0073e9SAndroid Build Coastguard Worker anchor: Tensor, 5268*da0073e9SAndroid Build Coastguard Worker positive: Tensor, 5269*da0073e9SAndroid Build Coastguard Worker negative: Tensor, 5270*da0073e9SAndroid Build Coastguard Worker *, 5271*da0073e9SAndroid Build Coastguard Worker distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, 5272*da0073e9SAndroid Build Coastguard Worker margin: float = 1.0, 5273*da0073e9SAndroid Build Coastguard Worker swap: bool = False, 5274*da0073e9SAndroid Build Coastguard Worker reduction: str = "mean", 5275*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5276*da0073e9SAndroid Build Coastguard Worker r"""Compute the triplet margin loss for input tensors using a custom distance function. 5277*da0073e9SAndroid Build Coastguard Worker 5278*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. 5279*da0073e9SAndroid Build Coastguard Worker """ 5280*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 5281*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 5282*da0073e9SAndroid Build Coastguard Worker "F.triplet_margin_with_distance_loss does not support JIT scripting: " 5283*da0073e9SAndroid Build Coastguard Worker "functions requiring Callables cannot be scripted." 5284*da0073e9SAndroid Build Coastguard Worker ) 5285*da0073e9SAndroid Build Coastguard Worker 5286*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(anchor, positive, negative): 5287*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5288*da0073e9SAndroid Build Coastguard Worker triplet_margin_with_distance_loss, 5289*da0073e9SAndroid Build Coastguard Worker (anchor, positive, negative), 5290*da0073e9SAndroid Build Coastguard Worker anchor, 5291*da0073e9SAndroid Build Coastguard Worker positive, 5292*da0073e9SAndroid Build Coastguard Worker negative, 5293*da0073e9SAndroid Build Coastguard Worker distance_function=distance_function, 5294*da0073e9SAndroid Build Coastguard Worker margin=margin, 5295*da0073e9SAndroid Build Coastguard Worker swap=swap, 5296*da0073e9SAndroid Build Coastguard Worker reduction=reduction, 5297*da0073e9SAndroid Build Coastguard Worker ) 5298*da0073e9SAndroid Build Coastguard Worker 5299*da0073e9SAndroid Build Coastguard Worker # Check validity of reduction mode 5300*da0073e9SAndroid Build Coastguard Worker if reduction not in ("mean", "sum", "none"): 5301*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"{reduction} is not a valid value for reduction") 5302*da0073e9SAndroid Build Coastguard Worker 5303*da0073e9SAndroid Build Coastguard Worker # Check validity of margin 5304*da0073e9SAndroid Build Coastguard Worker if margin <= 0: 5305*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"margin must be greater than 0, got {margin}") 5306*da0073e9SAndroid Build Coastguard Worker 5307*da0073e9SAndroid Build Coastguard Worker # Check dimensions 5308*da0073e9SAndroid Build Coastguard Worker a_dim = anchor.ndim 5309*da0073e9SAndroid Build Coastguard Worker p_dim = positive.ndim 5310*da0073e9SAndroid Build Coastguard Worker n_dim = negative.ndim 5311*da0073e9SAndroid Build Coastguard Worker if not (a_dim == p_dim and p_dim == n_dim): 5312*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 5313*da0073e9SAndroid Build Coastguard Worker f"The anchor, positive, and negative tensors are expected to have " 5314*da0073e9SAndroid Build Coastguard Worker f"the same number of dimensions, but got: anchor {a_dim}D, " 5315*da0073e9SAndroid Build Coastguard Worker f"positive {p_dim}D, and negative {n_dim}D inputs" 5316*da0073e9SAndroid Build Coastguard Worker ) 5317*da0073e9SAndroid Build Coastguard Worker 5318*da0073e9SAndroid Build Coastguard Worker # Calculate loss 5319*da0073e9SAndroid Build Coastguard Worker if distance_function is None: 5320*da0073e9SAndroid Build Coastguard Worker distance_function = torch.pairwise_distance 5321*da0073e9SAndroid Build Coastguard Worker 5322*da0073e9SAndroid Build Coastguard Worker dist_pos = distance_function(anchor, positive) 5323*da0073e9SAndroid Build Coastguard Worker dist_neg = distance_function(anchor, negative) 5324*da0073e9SAndroid Build Coastguard Worker # The distance swap is described in the paper "Learning shallow 5325*da0073e9SAndroid Build Coastguard Worker # convolutional feature descriptors with triplet losses" by V. Balntas, E. 5326*da0073e9SAndroid Build Coastguard Worker # Riba et al. If True, and if the positive example is closer to the 5327*da0073e9SAndroid Build Coastguard Worker # negative example than the anchor is, swaps the positive example and the 5328*da0073e9SAndroid Build Coastguard Worker # anchor in the loss computation. 5329*da0073e9SAndroid Build Coastguard Worker if swap: 5330*da0073e9SAndroid Build Coastguard Worker dist_swap = distance_function(positive, negative) 5331*da0073e9SAndroid Build Coastguard Worker dist_neg = torch.minimum(dist_neg, dist_swap) 5332*da0073e9SAndroid Build Coastguard Worker loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) 5333*da0073e9SAndroid Build Coastguard Worker 5334*da0073e9SAndroid Build Coastguard Worker # Apply reduction 5335*da0073e9SAndroid Build Coastguard Worker if reduction == "sum": 5336*da0073e9SAndroid Build Coastguard Worker return torch.sum(loss) 5337*da0073e9SAndroid Build Coastguard Worker elif reduction == "mean": 5338*da0073e9SAndroid Build Coastguard Worker return torch.mean(loss) 5339*da0073e9SAndroid Build Coastguard Worker else: # reduction == "none" 5340*da0073e9SAndroid Build Coastguard Worker return loss 5341*da0073e9SAndroid Build Coastguard Worker 5342*da0073e9SAndroid Build Coastguard Worker 5343*da0073e9SAndroid Build Coastguard Workerdef normalize( 5344*da0073e9SAndroid Build Coastguard Worker input: Tensor, 5345*da0073e9SAndroid Build Coastguard Worker p: float = 2.0, 5346*da0073e9SAndroid Build Coastguard Worker dim: int = 1, 5347*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-12, 5348*da0073e9SAndroid Build Coastguard Worker out: Optional[Tensor] = None, 5349*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5350*da0073e9SAndroid Build Coastguard Worker r"""Perform :math:`L_p` normalization of inputs over specified dimension. 5351*da0073e9SAndroid Build Coastguard Worker 5352*da0073e9SAndroid Build Coastguard Worker For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each 5353*da0073e9SAndroid Build Coastguard Worker :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as 5354*da0073e9SAndroid Build Coastguard Worker 5355*da0073e9SAndroid Build Coastguard Worker .. math:: 5356*da0073e9SAndroid Build Coastguard Worker v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. 5357*da0073e9SAndroid Build Coastguard Worker 5358*da0073e9SAndroid Build Coastguard Worker With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. 5359*da0073e9SAndroid Build Coastguard Worker 5360*da0073e9SAndroid Build Coastguard Worker Args: 5361*da0073e9SAndroid Build Coastguard Worker input: input tensor of any shape 5362*da0073e9SAndroid Build Coastguard Worker p (float): the exponent value in the norm formulation. Default: 2 5363*da0073e9SAndroid Build Coastguard Worker dim (int or tuple of ints): the dimension to reduce. Default: 1 5364*da0073e9SAndroid Build Coastguard Worker eps (float): small value to avoid division by zero. Default: 1e-12 5365*da0073e9SAndroid Build Coastguard Worker out (Tensor, optional): the output tensor. If :attr:`out` is used, this 5366*da0073e9SAndroid Build Coastguard Worker operation won't be differentiable. 5367*da0073e9SAndroid Build Coastguard Worker """ 5368*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(input, out): 5369*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5370*da0073e9SAndroid Build Coastguard Worker normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out 5371*da0073e9SAndroid Build Coastguard Worker ) 5372*da0073e9SAndroid Build Coastguard Worker if out is None: 5373*da0073e9SAndroid Build Coastguard Worker denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) 5374*da0073e9SAndroid Build Coastguard Worker return input / denom 5375*da0073e9SAndroid Build Coastguard Worker else: 5376*da0073e9SAndroid Build Coastguard Worker denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) 5377*da0073e9SAndroid Build Coastguard Worker return torch.div(input, denom, out=out) 5378*da0073e9SAndroid Build Coastguard Worker 5379*da0073e9SAndroid Build Coastguard Worker 5380*da0073e9SAndroid Build Coastguard Workerdef assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: 5381*da0073e9SAndroid Build Coastguard Worker assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) 5382*da0073e9SAndroid Build Coastguard Worker 5383*da0073e9SAndroid Build Coastguard Worker 5384*da0073e9SAndroid Build Coastguard Workerdef unfold( 5385*da0073e9SAndroid Build Coastguard Worker input: Tensor, 5386*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 5387*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList2[int] = 1, 5388*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList2[int] = 0, 5389*da0073e9SAndroid Build Coastguard Worker stride: BroadcastingList2[int] = 1, 5390*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5391*da0073e9SAndroid Build Coastguard Worker r"""Extract sliding local blocks from a batched input tensor. 5392*da0073e9SAndroid Build Coastguard Worker 5393*da0073e9SAndroid Build Coastguard Worker .. warning:: 5394*da0073e9SAndroid Build Coastguard Worker Currently, only 4-D input tensors (batched image-like tensors) are 5395*da0073e9SAndroid Build Coastguard Worker supported. 5396*da0073e9SAndroid Build Coastguard Worker 5397*da0073e9SAndroid Build Coastguard Worker .. warning:: 5398*da0073e9SAndroid Build Coastguard Worker 5399*da0073e9SAndroid Build Coastguard Worker More than one element of the unfolded tensor may refer to a single 5400*da0073e9SAndroid Build Coastguard Worker memory location. As a result, in-place operations (especially ones that 5401*da0073e9SAndroid Build Coastguard Worker are vectorized) may result in incorrect behavior. If you need to write 5402*da0073e9SAndroid Build Coastguard Worker to the tensor, please clone it first. 5403*da0073e9SAndroid Build Coastguard Worker 5404*da0073e9SAndroid Build Coastguard Worker 5405*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.Unfold` for details 5406*da0073e9SAndroid Build Coastguard Worker """ 5407*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 5408*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5409*da0073e9SAndroid Build Coastguard Worker unfold, 5410*da0073e9SAndroid Build Coastguard Worker (input,), 5411*da0073e9SAndroid Build Coastguard Worker input, 5412*da0073e9SAndroid Build Coastguard Worker kernel_size, 5413*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 5414*da0073e9SAndroid Build Coastguard Worker padding=padding, 5415*da0073e9SAndroid Build Coastguard Worker stride=stride, 5416*da0073e9SAndroid Build Coastguard Worker ) 5417*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.im2col( 5418*da0073e9SAndroid Build Coastguard Worker input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) 5419*da0073e9SAndroid Build Coastguard Worker ) 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker 5422*da0073e9SAndroid Build Coastguard Workerdef fold( 5423*da0073e9SAndroid Build Coastguard Worker input: Tensor, 5424*da0073e9SAndroid Build Coastguard Worker output_size: BroadcastingList2[int], 5425*da0073e9SAndroid Build Coastguard Worker kernel_size: BroadcastingList2[int], 5426*da0073e9SAndroid Build Coastguard Worker dilation: BroadcastingList2[int] = 1, 5427*da0073e9SAndroid Build Coastguard Worker padding: BroadcastingList2[int] = 0, 5428*da0073e9SAndroid Build Coastguard Worker stride: BroadcastingList2[int] = 1, 5429*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 5430*da0073e9SAndroid Build Coastguard Worker r"""Combine an array of sliding local blocks into a large containing tensor. 5431*da0073e9SAndroid Build Coastguard Worker 5432*da0073e9SAndroid Build Coastguard Worker .. warning:: 5433*da0073e9SAndroid Build Coastguard Worker Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.Fold` for details 5436*da0073e9SAndroid Build Coastguard Worker """ 5437*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 5438*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5439*da0073e9SAndroid Build Coastguard Worker fold, 5440*da0073e9SAndroid Build Coastguard Worker (input,), 5441*da0073e9SAndroid Build Coastguard Worker input, 5442*da0073e9SAndroid Build Coastguard Worker output_size, 5443*da0073e9SAndroid Build Coastguard Worker kernel_size, 5444*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 5445*da0073e9SAndroid Build Coastguard Worker padding=padding, 5446*da0073e9SAndroid Build Coastguard Worker stride=stride, 5447*da0073e9SAndroid Build Coastguard Worker ) 5448*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.col2im( 5449*da0073e9SAndroid Build Coastguard Worker input, 5450*da0073e9SAndroid Build Coastguard Worker _pair(output_size), 5451*da0073e9SAndroid Build Coastguard Worker _pair(kernel_size), 5452*da0073e9SAndroid Build Coastguard Worker _pair(dilation), 5453*da0073e9SAndroid Build Coastguard Worker _pair(padding), 5454*da0073e9SAndroid Build Coastguard Worker _pair(stride), 5455*da0073e9SAndroid Build Coastguard Worker ) 5456*da0073e9SAndroid Build Coastguard Worker 5457*da0073e9SAndroid Build Coastguard Worker 5458*da0073e9SAndroid Build Coastguard Worker# 5459*da0073e9SAndroid Build Coastguard Worker# multihead attention 5460*da0073e9SAndroid Build Coastguard Worker# 5461*da0073e9SAndroid Build Coastguard Worker 5462*da0073e9SAndroid Build Coastguard Worker 5463*da0073e9SAndroid Build Coastguard Workerdef _in_projection_packed( 5464*da0073e9SAndroid Build Coastguard Worker q: Tensor, 5465*da0073e9SAndroid Build Coastguard Worker k: Tensor, 5466*da0073e9SAndroid Build Coastguard Worker v: Tensor, 5467*da0073e9SAndroid Build Coastguard Worker w: Tensor, 5468*da0073e9SAndroid Build Coastguard Worker b: Optional[Tensor] = None, 5469*da0073e9SAndroid Build Coastguard Worker) -> List[Tensor]: 5470*da0073e9SAndroid Build Coastguard Worker r"""Perform the in-projection step of the attention operation, using packed weights. 5471*da0073e9SAndroid Build Coastguard Worker 5472*da0073e9SAndroid Build Coastguard Worker Output is a triple containing projection tensors for query, key and value. 5473*da0073e9SAndroid Build Coastguard Worker 5474*da0073e9SAndroid Build Coastguard Worker Args: 5475*da0073e9SAndroid Build Coastguard Worker q, k, v: query, key and value tensors to be projected. For self-attention, 5476*da0073e9SAndroid Build Coastguard Worker these are typically the same tensor; for encoder-decoder attention, 5477*da0073e9SAndroid Build Coastguard Worker k and v are typically the same tensor. (We take advantage of these 5478*da0073e9SAndroid Build Coastguard Worker identities for performance if they are present.) Regardless, q, k and v 5479*da0073e9SAndroid Build Coastguard Worker must share a common embedding dimension; otherwise their shapes may vary. 5480*da0073e9SAndroid Build Coastguard Worker w: projection weights for q, k and v, packed into a single tensor. Weights 5481*da0073e9SAndroid Build Coastguard Worker are packed along dimension 0, in q, k, v order. 5482*da0073e9SAndroid Build Coastguard Worker b: optional projection biases for q, k and v, packed into a single tensor 5483*da0073e9SAndroid Build Coastguard Worker in q, k, v order. 5484*da0073e9SAndroid Build Coastguard Worker 5485*da0073e9SAndroid Build Coastguard Worker Shape: 5486*da0073e9SAndroid Build Coastguard Worker Inputs: 5487*da0073e9SAndroid Build Coastguard Worker - q: :math:`(..., E)` where E is the embedding dimension 5488*da0073e9SAndroid Build Coastguard Worker - k: :math:`(..., E)` where E is the embedding dimension 5489*da0073e9SAndroid Build Coastguard Worker - v: :math:`(..., E)` where E is the embedding dimension 5490*da0073e9SAndroid Build Coastguard Worker - w: :math:`(E * 3, E)` where E is the embedding dimension 5491*da0073e9SAndroid Build Coastguard Worker - b: :math:`E * 3` where E is the embedding dimension 5492*da0073e9SAndroid Build Coastguard Worker 5493*da0073e9SAndroid Build Coastguard Worker Output: 5494*da0073e9SAndroid Build Coastguard Worker - in output list :math:`[q', k', v']`, each output tensor will have the 5495*da0073e9SAndroid Build Coastguard Worker same shape as the corresponding input tensor. 5496*da0073e9SAndroid Build Coastguard Worker """ 5497*da0073e9SAndroid Build Coastguard Worker E = q.size(-1) 5498*da0073e9SAndroid Build Coastguard Worker if k is v: 5499*da0073e9SAndroid Build Coastguard Worker if q is k: 5500*da0073e9SAndroid Build Coastguard Worker # self-attention 5501*da0073e9SAndroid Build Coastguard Worker proj = linear(q, w, b) 5502*da0073e9SAndroid Build Coastguard Worker # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() 5503*da0073e9SAndroid Build Coastguard Worker proj = ( 5504*da0073e9SAndroid Build Coastguard Worker proj.unflatten(-1, (3, E)) 5505*da0073e9SAndroid Build Coastguard Worker .unsqueeze(0) 5506*da0073e9SAndroid Build Coastguard Worker .transpose(0, -2) 5507*da0073e9SAndroid Build Coastguard Worker .squeeze(-2) 5508*da0073e9SAndroid Build Coastguard Worker .contiguous() 5509*da0073e9SAndroid Build Coastguard Worker ) 5510*da0073e9SAndroid Build Coastguard Worker return proj[0], proj[1], proj[2] 5511*da0073e9SAndroid Build Coastguard Worker else: 5512*da0073e9SAndroid Build Coastguard Worker # encoder-decoder attention 5513*da0073e9SAndroid Build Coastguard Worker w_q, w_kv = w.split([E, E * 2]) 5514*da0073e9SAndroid Build Coastguard Worker if b is None: 5515*da0073e9SAndroid Build Coastguard Worker b_q = b_kv = None 5516*da0073e9SAndroid Build Coastguard Worker else: 5517*da0073e9SAndroid Build Coastguard Worker b_q, b_kv = b.split([E, E * 2]) 5518*da0073e9SAndroid Build Coastguard Worker q_proj = linear(q, w_q, b_q) 5519*da0073e9SAndroid Build Coastguard Worker kv_proj = linear(k, w_kv, b_kv) 5520*da0073e9SAndroid Build Coastguard Worker # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() 5521*da0073e9SAndroid Build Coastguard Worker kv_proj = ( 5522*da0073e9SAndroid Build Coastguard Worker kv_proj.unflatten(-1, (2, E)) 5523*da0073e9SAndroid Build Coastguard Worker .unsqueeze(0) 5524*da0073e9SAndroid Build Coastguard Worker .transpose(0, -2) 5525*da0073e9SAndroid Build Coastguard Worker .squeeze(-2) 5526*da0073e9SAndroid Build Coastguard Worker .contiguous() 5527*da0073e9SAndroid Build Coastguard Worker ) 5528*da0073e9SAndroid Build Coastguard Worker return (q_proj, kv_proj[0], kv_proj[1]) 5529*da0073e9SAndroid Build Coastguard Worker else: 5530*da0073e9SAndroid Build Coastguard Worker w_q, w_k, w_v = w.chunk(3) 5531*da0073e9SAndroid Build Coastguard Worker if b is None: 5532*da0073e9SAndroid Build Coastguard Worker b_q = b_k = b_v = None 5533*da0073e9SAndroid Build Coastguard Worker else: 5534*da0073e9SAndroid Build Coastguard Worker b_q, b_k, b_v = b.chunk(3) 5535*da0073e9SAndroid Build Coastguard Worker return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 5536*da0073e9SAndroid Build Coastguard Worker 5537*da0073e9SAndroid Build Coastguard Worker 5538*da0073e9SAndroid Build Coastguard Workerdef _in_projection( 5539*da0073e9SAndroid Build Coastguard Worker q: Tensor, 5540*da0073e9SAndroid Build Coastguard Worker k: Tensor, 5541*da0073e9SAndroid Build Coastguard Worker v: Tensor, 5542*da0073e9SAndroid Build Coastguard Worker w_q: Tensor, 5543*da0073e9SAndroid Build Coastguard Worker w_k: Tensor, 5544*da0073e9SAndroid Build Coastguard Worker w_v: Tensor, 5545*da0073e9SAndroid Build Coastguard Worker b_q: Optional[Tensor] = None, 5546*da0073e9SAndroid Build Coastguard Worker b_k: Optional[Tensor] = None, 5547*da0073e9SAndroid Build Coastguard Worker b_v: Optional[Tensor] = None, 5548*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor, Tensor]: 5549*da0073e9SAndroid Build Coastguard Worker r"""Perform the in-projection step of the attention operation. 5550*da0073e9SAndroid Build Coastguard Worker 5551*da0073e9SAndroid Build Coastguard Worker This is simply a triple of linear projections, 5552*da0073e9SAndroid Build Coastguard Worker with shape constraints on the weights which 5553*da0073e9SAndroid Build Coastguard Worker ensure embedding dimension uniformity in the projected outputs. 5554*da0073e9SAndroid Build Coastguard Worker Output is a triple containing projection tensors for query, key and value. 5555*da0073e9SAndroid Build Coastguard Worker 5556*da0073e9SAndroid Build Coastguard Worker Args: 5557*da0073e9SAndroid Build Coastguard Worker q, k, v: query, key and value tensors to be projected. 5558*da0073e9SAndroid Build Coastguard Worker w_q, w_k, w_v: weights for q, k and v, respectively. 5559*da0073e9SAndroid Build Coastguard Worker b_q, b_k, b_v: optional biases for q, k and v, respectively. 5560*da0073e9SAndroid Build Coastguard Worker 5561*da0073e9SAndroid Build Coastguard Worker Shape: 5562*da0073e9SAndroid Build Coastguard Worker Inputs: 5563*da0073e9SAndroid Build Coastguard Worker - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any 5564*da0073e9SAndroid Build Coastguard Worker number of leading dimensions. 5565*da0073e9SAndroid Build Coastguard Worker - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any 5566*da0073e9SAndroid Build Coastguard Worker number of leading dimensions. 5567*da0073e9SAndroid Build Coastguard Worker - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any 5568*da0073e9SAndroid Build Coastguard Worker number of leading dimensions. 5569*da0073e9SAndroid Build Coastguard Worker - w_q: :math:`(Eq, Eq)` 5570*da0073e9SAndroid Build Coastguard Worker - w_k: :math:`(Eq, Ek)` 5571*da0073e9SAndroid Build Coastguard Worker - w_v: :math:`(Eq, Ev)` 5572*da0073e9SAndroid Build Coastguard Worker - b_q: :math:`(Eq)` 5573*da0073e9SAndroid Build Coastguard Worker - b_k: :math:`(Eq)` 5574*da0073e9SAndroid Build Coastguard Worker - b_v: :math:`(Eq)` 5575*da0073e9SAndroid Build Coastguard Worker 5576*da0073e9SAndroid Build Coastguard Worker Output: in output triple :math:`(q', k', v')`, 5577*da0073e9SAndroid Build Coastguard Worker - q': :math:`[Qdims..., Eq]` 5578*da0073e9SAndroid Build Coastguard Worker - k': :math:`[Kdims..., Eq]` 5579*da0073e9SAndroid Build Coastguard Worker - v': :math:`[Vdims..., Eq]` 5580*da0073e9SAndroid Build Coastguard Worker 5581*da0073e9SAndroid Build Coastguard Worker """ 5582*da0073e9SAndroid Build Coastguard Worker Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) 5583*da0073e9SAndroid Build Coastguard Worker assert w_q.shape == ( 5584*da0073e9SAndroid Build Coastguard Worker Eq, 5585*da0073e9SAndroid Build Coastguard Worker Eq, 5586*da0073e9SAndroid Build Coastguard Worker ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" 5587*da0073e9SAndroid Build Coastguard Worker assert w_k.shape == ( 5588*da0073e9SAndroid Build Coastguard Worker Eq, 5589*da0073e9SAndroid Build Coastguard Worker Ek, 5590*da0073e9SAndroid Build Coastguard Worker ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" 5591*da0073e9SAndroid Build Coastguard Worker assert w_v.shape == ( 5592*da0073e9SAndroid Build Coastguard Worker Eq, 5593*da0073e9SAndroid Build Coastguard Worker Ev, 5594*da0073e9SAndroid Build Coastguard Worker ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" 5595*da0073e9SAndroid Build Coastguard Worker assert b_q is None or b_q.shape == ( 5596*da0073e9SAndroid Build Coastguard Worker Eq, 5597*da0073e9SAndroid Build Coastguard Worker ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" 5598*da0073e9SAndroid Build Coastguard Worker assert b_k is None or b_k.shape == ( 5599*da0073e9SAndroid Build Coastguard Worker Eq, 5600*da0073e9SAndroid Build Coastguard Worker ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" 5601*da0073e9SAndroid Build Coastguard Worker assert b_v is None or b_v.shape == ( 5602*da0073e9SAndroid Build Coastguard Worker Eq, 5603*da0073e9SAndroid Build Coastguard Worker ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" 5604*da0073e9SAndroid Build Coastguard Worker return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 5605*da0073e9SAndroid Build Coastguard Worker 5606*da0073e9SAndroid Build Coastguard Worker 5607*da0073e9SAndroid Build Coastguard Workerscaled_dot_product_attention = _add_docstr( 5608*da0073e9SAndroid Build Coastguard Worker torch._C._nn.scaled_dot_product_attention, 5609*da0073e9SAndroid Build Coastguard Worker r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, 5610*da0073e9SAndroid Build Coastguard Worker is_causal=False, scale=None, enable_gqa=False) -> Tensor: 5611*da0073e9SAndroid Build Coastguard Worker 5612*da0073e9SAndroid Build Coastguard Worker Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, 5613*da0073e9SAndroid Build Coastguard Worker and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be 5614*da0073e9SAndroid Build Coastguard Worker specified as a keyword argument. 5615*da0073e9SAndroid Build Coastguard Worker 5616*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 5617*da0073e9SAndroid Build Coastguard Worker 5618*da0073e9SAndroid Build Coastguard Worker # Efficient implementation equivalent to the following: 5619*da0073e9SAndroid Build Coastguard Worker def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, 5620*da0073e9SAndroid Build Coastguard Worker is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: 5621*da0073e9SAndroid Build Coastguard Worker L, S = query.size(-2), key.size(-2) 5622*da0073e9SAndroid Build Coastguard Worker scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 5623*da0073e9SAndroid Build Coastguard Worker attn_bias = torch.zeros(L, S, dtype=query.dtype) 5624*da0073e9SAndroid Build Coastguard Worker if is_causal: 5625*da0073e9SAndroid Build Coastguard Worker assert attn_mask is None 5626*da0073e9SAndroid Build Coastguard Worker temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) 5627*da0073e9SAndroid Build Coastguard Worker attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 5628*da0073e9SAndroid Build Coastguard Worker attn_bias.to(query.dtype) 5629*da0073e9SAndroid Build Coastguard Worker 5630*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 5631*da0073e9SAndroid Build Coastguard Worker if attn_mask.dtype == torch.bool: 5632*da0073e9SAndroid Build Coastguard Worker attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 5633*da0073e9SAndroid Build Coastguard Worker else: 5634*da0073e9SAndroid Build Coastguard Worker attn_bias += attn_mask 5635*da0073e9SAndroid Build Coastguard Worker 5636*da0073e9SAndroid Build Coastguard Worker if enable_gqa: 5637*da0073e9SAndroid Build Coastguard Worker key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) 5638*da0073e9SAndroid Build Coastguard Worker value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) 5639*da0073e9SAndroid Build Coastguard Worker 5640*da0073e9SAndroid Build Coastguard Worker attn_weight = query @ key.transpose(-2, -1) * scale_factor 5641*da0073e9SAndroid Build Coastguard Worker attn_weight += attn_bias 5642*da0073e9SAndroid Build Coastguard Worker attn_weight = torch.softmax(attn_weight, dim=-1) 5643*da0073e9SAndroid Build Coastguard Worker attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 5644*da0073e9SAndroid Build Coastguard Worker return attn_weight @ value 5645*da0073e9SAndroid Build Coastguard Worker 5646*da0073e9SAndroid Build Coastguard Worker .. warning:: 5647*da0073e9SAndroid Build Coastguard Worker This function is beta and subject to change. 5648*da0073e9SAndroid Build Coastguard Worker 5649*da0073e9SAndroid Build Coastguard Worker .. warning:: 5650*da0073e9SAndroid Build Coastguard Worker This function always applies dropout according to the specified ``dropout_p`` argument. 5651*da0073e9SAndroid Build Coastguard Worker To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module 5652*da0073e9SAndroid Build Coastguard Worker that makes the function call is not in training mode. 5653*da0073e9SAndroid Build Coastguard Worker 5654*da0073e9SAndroid Build Coastguard Worker For example: 5655*da0073e9SAndroid Build Coastguard Worker 5656*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 5657*da0073e9SAndroid Build Coastguard Worker 5658*da0073e9SAndroid Build Coastguard Worker class MyModel(nn.Module): 5659*da0073e9SAndroid Build Coastguard Worker def __init__(self, p=0.5): 5660*da0073e9SAndroid Build Coastguard Worker super().__init__() 5661*da0073e9SAndroid Build Coastguard Worker self.p = p 5662*da0073e9SAndroid Build Coastguard Worker 5663*da0073e9SAndroid Build Coastguard Worker def forward(self, ...): 5664*da0073e9SAndroid Build Coastguard Worker return F.scaled_dot_product_attention(..., 5665*da0073e9SAndroid Build Coastguard Worker dropout_p=(self.p if self.training else 0.0)) 5666*da0073e9SAndroid Build Coastguard Worker 5667*da0073e9SAndroid Build Coastguard Worker Note: 5668*da0073e9SAndroid Build Coastguard Worker 5669*da0073e9SAndroid Build Coastguard Worker There are currently three supported implementations of scaled dot product attention: 5670*da0073e9SAndroid Build Coastguard Worker 5671*da0073e9SAndroid Build Coastguard Worker - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ 5672*da0073e9SAndroid Build Coastguard Worker - `Memory-Efficient Attention`_ 5673*da0073e9SAndroid Build Coastguard Worker - A PyTorch implementation defined in C++ matching the above formulation 5674*da0073e9SAndroid Build Coastguard Worker 5675*da0073e9SAndroid Build Coastguard Worker The function may call optimized kernels for improved performance when using the CUDA backend. 5676*da0073e9SAndroid Build Coastguard Worker For all other backends, the PyTorch implementation will be used. 5677*da0073e9SAndroid Build Coastguard Worker 5678*da0073e9SAndroid Build Coastguard Worker All implementations are enabled by default. Scaled dot product attention attempts to automatically select the 5679*da0073e9SAndroid Build Coastguard Worker most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation 5680*da0073e9SAndroid Build Coastguard Worker is used, the following functions are provided for enabling and disabling implementations. 5681*da0073e9SAndroid Build Coastguard Worker The context manager is the preferred mechanism: 5682*da0073e9SAndroid Build Coastguard Worker 5683*da0073e9SAndroid Build Coastguard Worker - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. 5684*da0073e9SAndroid Build Coastguard Worker - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. 5685*da0073e9SAndroid Build Coastguard Worker - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention. 5686*da0073e9SAndroid Build Coastguard Worker - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation. 5687*da0073e9SAndroid Build Coastguard Worker 5688*da0073e9SAndroid Build Coastguard Worker Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, 5689*da0073e9SAndroid Build Coastguard Worker disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. 5690*da0073e9SAndroid Build Coastguard Worker In the event that a fused implementation is not available, a warning will be raised with the 5691*da0073e9SAndroid Build Coastguard Worker reasons why the fused implementation cannot run. 5692*da0073e9SAndroid Build Coastguard Worker 5693*da0073e9SAndroid Build Coastguard Worker Due to the nature of fusing floating point operations, the output of this function may be different 5694*da0073e9SAndroid Build Coastguard Worker depending on what backend kernel is chosen. 5695*da0073e9SAndroid Build Coastguard Worker The c++ implementation supports torch.float64 and can be used when higher precision is required. 5696*da0073e9SAndroid Build Coastguard Worker For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. 5697*da0073e9SAndroid Build Coastguard Worker For more information please see :doc:`/notes/numerical_accuracy` 5698*da0073e9SAndroid Build Coastguard Worker 5699*da0073e9SAndroid Build Coastguard Worker Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention 5700*da0073e9SAndroid Build Coastguard Worker and math kernel on CUDA tensor, and does not support Nested tensor. 5701*da0073e9SAndroid Build Coastguard Worker Constraints for GQA: 5702*da0073e9SAndroid Build Coastguard Worker 5703*da0073e9SAndroid Build Coastguard Worker - number_of_heads_query % number_of_heads_key_value == 0 and, 5704*da0073e9SAndroid Build Coastguard Worker - number_of_heads_key == number_of_heads_value 5705*da0073e9SAndroid Build Coastguard Worker 5706*da0073e9SAndroid Build Coastguard Worker Note: 5707*da0073e9SAndroid Build Coastguard Worker 5708*da0073e9SAndroid Build Coastguard Worker {cudnn_reproducibility_note} 5709*da0073e9SAndroid Build Coastguard Worker """.format( 5710*da0073e9SAndroid Build Coastguard Worker **reproducibility_notes 5711*da0073e9SAndroid Build Coastguard Worker ) 5712*da0073e9SAndroid Build Coastguard Worker + r""" 5713*da0073e9SAndroid Build Coastguard Worker Args: 5714*da0073e9SAndroid Build Coastguard Worker query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. 5715*da0073e9SAndroid Build Coastguard Worker key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. 5716*da0073e9SAndroid Build Coastguard Worker value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. 5717*da0073e9SAndroid Build Coastguard Worker attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, 5718*da0073e9SAndroid Build Coastguard Worker which is :math:`(N,..., L, S)`. Two types of masks are supported. 5719*da0073e9SAndroid Build Coastguard Worker A boolean mask where a value of True indicates that the element *should* take part in attention. 5720*da0073e9SAndroid Build Coastguard Worker A float mask of the same type as query, key, value that is added to the attention score. 5721*da0073e9SAndroid Build Coastguard Worker dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied 5722*da0073e9SAndroid Build Coastguard Worker is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a 5723*da0073e9SAndroid Build Coastguard Worker square matrix. The attention masking has the form of the upper left causal bias due to the alignment 5724*da0073e9SAndroid Build Coastguard Worker (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. 5725*da0073e9SAndroid Build Coastguard Worker An error is thrown if both attn_mask and is_causal are set. 5726*da0073e9SAndroid Build Coastguard Worker scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set 5727*da0073e9SAndroid Build Coastguard Worker to :math:`\frac{1}{\sqrt{E}}`. 5728*da0073e9SAndroid Build Coastguard Worker enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. 5729*da0073e9SAndroid Build Coastguard Worker 5730*da0073e9SAndroid Build Coastguard Worker Returns: 5731*da0073e9SAndroid Build Coastguard Worker output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. 5732*da0073e9SAndroid Build Coastguard Worker 5733*da0073e9SAndroid Build Coastguard Worker Shape legend: 5734*da0073e9SAndroid Build Coastguard Worker - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` 5735*da0073e9SAndroid Build Coastguard Worker - :math:`S: \text{Source sequence length}` 5736*da0073e9SAndroid Build Coastguard Worker - :math:`L: \text{Target sequence length}` 5737*da0073e9SAndroid Build Coastguard Worker - :math:`E: \text{Embedding dimension of the query and key}` 5738*da0073e9SAndroid Build Coastguard Worker - :math:`Ev: \text{Embedding dimension of the value}` 5739*da0073e9SAndroid Build Coastguard Worker - :math:`Hq: \text{Number of heads of query}` 5740*da0073e9SAndroid Build Coastguard Worker - :math:`H: \text{Number of heads of key and value}` 5741*da0073e9SAndroid Build Coastguard Worker 5742*da0073e9SAndroid Build Coastguard Worker Examples: 5743*da0073e9SAndroid Build Coastguard Worker 5744*da0073e9SAndroid Build Coastguard Worker >>> # Optionally use the context manager to ensure one of the fused kernels is run 5745*da0073e9SAndroid Build Coastguard Worker >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5746*da0073e9SAndroid Build Coastguard Worker >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5747*da0073e9SAndroid Build Coastguard Worker >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5748*da0073e9SAndroid Build Coastguard Worker >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 5749*da0073e9SAndroid Build Coastguard Worker >>> F.scaled_dot_product_attention(query,key,value) 5750*da0073e9SAndroid Build Coastguard Worker 5751*da0073e9SAndroid Build Coastguard Worker 5752*da0073e9SAndroid Build Coastguard Worker >>> # Sample for GQA for llama3 5753*da0073e9SAndroid Build Coastguard Worker >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") 5754*da0073e9SAndroid Build Coastguard Worker >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5755*da0073e9SAndroid Build Coastguard Worker >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5756*da0073e9SAndroid Build Coastguard Worker >>> with sdpa_kernel(backends=[SDPBackend.MATH]): 5757*da0073e9SAndroid Build Coastguard Worker >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True) 5758*da0073e9SAndroid Build Coastguard Worker 5759*da0073e9SAndroid Build Coastguard Worker 5760*da0073e9SAndroid Build Coastguard Worker .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: 5761*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/2307.08691 5762*da0073e9SAndroid Build Coastguard Worker .. _Memory-Efficient Attention: 5763*da0073e9SAndroid Build Coastguard Worker https://github.com/facebookresearch/xformers 5764*da0073e9SAndroid Build Coastguard Worker .. _Grouped-Query Attention: 5765*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/pdf/2305.13245 5766*da0073e9SAndroid Build Coastguard Worker """, 5767*da0073e9SAndroid Build Coastguard Worker) 5768*da0073e9SAndroid Build Coastguard Worker 5769*da0073e9SAndroid Build Coastguard Worker 5770*da0073e9SAndroid Build Coastguard Workerdef _mha_shape_check( 5771*da0073e9SAndroid Build Coastguard Worker query: Tensor, 5772*da0073e9SAndroid Build Coastguard Worker key: Tensor, 5773*da0073e9SAndroid Build Coastguard Worker value: Tensor, 5774*da0073e9SAndroid Build Coastguard Worker key_padding_mask: Optional[Tensor], 5775*da0073e9SAndroid Build Coastguard Worker attn_mask: Optional[Tensor], 5776*da0073e9SAndroid Build Coastguard Worker num_heads: int, 5777*da0073e9SAndroid Build Coastguard Worker): 5778*da0073e9SAndroid Build Coastguard Worker # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` 5779*da0073e9SAndroid Build Coastguard Worker # and returns if the input is batched or not. 5780*da0073e9SAndroid Build Coastguard Worker # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. 5781*da0073e9SAndroid Build Coastguard Worker 5782*da0073e9SAndroid Build Coastguard Worker # Shape check. 5783*da0073e9SAndroid Build Coastguard Worker if query.dim() == 3: 5784*da0073e9SAndroid Build Coastguard Worker # Batched Inputs 5785*da0073e9SAndroid Build Coastguard Worker is_batched = True 5786*da0073e9SAndroid Build Coastguard Worker assert key.dim() == 3 and value.dim() == 3, ( 5787*da0073e9SAndroid Build Coastguard Worker "For batched (3-D) `query`, expected `key` and `value` to be 3-D" 5788*da0073e9SAndroid Build Coastguard Worker f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" 5789*da0073e9SAndroid Build Coastguard Worker ) 5790*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 5791*da0073e9SAndroid Build Coastguard Worker assert key_padding_mask.dim() == 2, ( 5792*da0073e9SAndroid Build Coastguard Worker "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" 5793*da0073e9SAndroid Build Coastguard Worker f" but found {key_padding_mask.dim()}-D tensor instead" 5794*da0073e9SAndroid Build Coastguard Worker ) 5795*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 5796*da0073e9SAndroid Build Coastguard Worker assert attn_mask.dim() in (2, 3), ( 5797*da0073e9SAndroid Build Coastguard Worker "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" 5798*da0073e9SAndroid Build Coastguard Worker f" but found {attn_mask.dim()}-D tensor instead" 5799*da0073e9SAndroid Build Coastguard Worker ) 5800*da0073e9SAndroid Build Coastguard Worker elif query.dim() == 2: 5801*da0073e9SAndroid Build Coastguard Worker # Unbatched Inputs 5802*da0073e9SAndroid Build Coastguard Worker is_batched = False 5803*da0073e9SAndroid Build Coastguard Worker assert key.dim() == 2 and value.dim() == 2, ( 5804*da0073e9SAndroid Build Coastguard Worker "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" 5805*da0073e9SAndroid Build Coastguard Worker f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" 5806*da0073e9SAndroid Build Coastguard Worker ) 5807*da0073e9SAndroid Build Coastguard Worker 5808*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 5809*da0073e9SAndroid Build Coastguard Worker assert key_padding_mask.dim() == 1, ( 5810*da0073e9SAndroid Build Coastguard Worker "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" 5811*da0073e9SAndroid Build Coastguard Worker f" but found {key_padding_mask.dim()}-D tensor instead" 5812*da0073e9SAndroid Build Coastguard Worker ) 5813*da0073e9SAndroid Build Coastguard Worker 5814*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 5815*da0073e9SAndroid Build Coastguard Worker assert attn_mask.dim() in (2, 3), ( 5816*da0073e9SAndroid Build Coastguard Worker "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" 5817*da0073e9SAndroid Build Coastguard Worker f" but found {attn_mask.dim()}-D tensor instead" 5818*da0073e9SAndroid Build Coastguard Worker ) 5819*da0073e9SAndroid Build Coastguard Worker if attn_mask.dim() == 3: 5820*da0073e9SAndroid Build Coastguard Worker expected_shape = (num_heads, query.shape[0], key.shape[0]) 5821*da0073e9SAndroid Build Coastguard Worker assert ( 5822*da0073e9SAndroid Build Coastguard Worker attn_mask.shape == expected_shape 5823*da0073e9SAndroid Build Coastguard Worker ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" 5824*da0073e9SAndroid Build Coastguard Worker else: 5825*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 5826*da0073e9SAndroid Build Coastguard Worker f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" 5827*da0073e9SAndroid Build Coastguard Worker ) 5828*da0073e9SAndroid Build Coastguard Worker 5829*da0073e9SAndroid Build Coastguard Worker return is_batched 5830*da0073e9SAndroid Build Coastguard Worker 5831*da0073e9SAndroid Build Coastguard Worker 5832*da0073e9SAndroid Build Coastguard Workerdef _canonical_mask( 5833*da0073e9SAndroid Build Coastguard Worker mask: Optional[Tensor], 5834*da0073e9SAndroid Build Coastguard Worker mask_name: str, 5835*da0073e9SAndroid Build Coastguard Worker other_type: Optional[DType], 5836*da0073e9SAndroid Build Coastguard Worker other_name: str, 5837*da0073e9SAndroid Build Coastguard Worker target_type: DType, 5838*da0073e9SAndroid Build Coastguard Worker check_other: bool = True, 5839*da0073e9SAndroid Build Coastguard Worker) -> Optional[Tensor]: 5840*da0073e9SAndroid Build Coastguard Worker if mask is not None: 5841*da0073e9SAndroid Build Coastguard Worker _mask_dtype = mask.dtype 5842*da0073e9SAndroid Build Coastguard Worker _mask_is_float = torch.is_floating_point(mask) 5843*da0073e9SAndroid Build Coastguard Worker if _mask_dtype != torch.bool and not _mask_is_float: 5844*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 5845*da0073e9SAndroid Build Coastguard Worker f"only bool and floating types of {mask_name} are supported" 5846*da0073e9SAndroid Build Coastguard Worker ) 5847*da0073e9SAndroid Build Coastguard Worker if check_other and other_type is not None: 5848*da0073e9SAndroid Build Coastguard Worker if _mask_dtype != other_type: 5849*da0073e9SAndroid Build Coastguard Worker warnings.warn( 5850*da0073e9SAndroid Build Coastguard Worker f"Support for mismatched {mask_name} and {other_name} " 5851*da0073e9SAndroid Build Coastguard Worker "is deprecated. Use same type for both instead." 5852*da0073e9SAndroid Build Coastguard Worker ) 5853*da0073e9SAndroid Build Coastguard Worker if not _mask_is_float: 5854*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( 5855*da0073e9SAndroid Build Coastguard Worker mask, float("-inf") 5856*da0073e9SAndroid Build Coastguard Worker ) 5857*da0073e9SAndroid Build Coastguard Worker return mask 5858*da0073e9SAndroid Build Coastguard Worker 5859*da0073e9SAndroid Build Coastguard Worker 5860*da0073e9SAndroid Build Coastguard Workerdef _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: 5861*da0073e9SAndroid Build Coastguard Worker if input is None: 5862*da0073e9SAndroid Build Coastguard Worker return None 5863*da0073e9SAndroid Build Coastguard Worker elif isinstance(input, torch.Tensor): 5864*da0073e9SAndroid Build Coastguard Worker return input.dtype 5865*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") 5866*da0073e9SAndroid Build Coastguard Worker 5867*da0073e9SAndroid Build Coastguard Worker 5868*da0073e9SAndroid Build Coastguard Workerdef multi_head_attention_forward( 5869*da0073e9SAndroid Build Coastguard Worker query: Tensor, 5870*da0073e9SAndroid Build Coastguard Worker key: Tensor, 5871*da0073e9SAndroid Build Coastguard Worker value: Tensor, 5872*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check: int, 5873*da0073e9SAndroid Build Coastguard Worker num_heads: int, 5874*da0073e9SAndroid Build Coastguard Worker in_proj_weight: Optional[Tensor], 5875*da0073e9SAndroid Build Coastguard Worker in_proj_bias: Optional[Tensor], 5876*da0073e9SAndroid Build Coastguard Worker bias_k: Optional[Tensor], 5877*da0073e9SAndroid Build Coastguard Worker bias_v: Optional[Tensor], 5878*da0073e9SAndroid Build Coastguard Worker add_zero_attn: bool, 5879*da0073e9SAndroid Build Coastguard Worker dropout_p: float, 5880*da0073e9SAndroid Build Coastguard Worker out_proj_weight: Tensor, 5881*da0073e9SAndroid Build Coastguard Worker out_proj_bias: Optional[Tensor], 5882*da0073e9SAndroid Build Coastguard Worker training: bool = True, 5883*da0073e9SAndroid Build Coastguard Worker key_padding_mask: Optional[Tensor] = None, 5884*da0073e9SAndroid Build Coastguard Worker need_weights: bool = True, 5885*da0073e9SAndroid Build Coastguard Worker attn_mask: Optional[Tensor] = None, 5886*da0073e9SAndroid Build Coastguard Worker use_separate_proj_weight: bool = False, 5887*da0073e9SAndroid Build Coastguard Worker q_proj_weight: Optional[Tensor] = None, 5888*da0073e9SAndroid Build Coastguard Worker k_proj_weight: Optional[Tensor] = None, 5889*da0073e9SAndroid Build Coastguard Worker v_proj_weight: Optional[Tensor] = None, 5890*da0073e9SAndroid Build Coastguard Worker static_k: Optional[Tensor] = None, 5891*da0073e9SAndroid Build Coastguard Worker static_v: Optional[Tensor] = None, 5892*da0073e9SAndroid Build Coastguard Worker average_attn_weights: bool = True, 5893*da0073e9SAndroid Build Coastguard Worker is_causal: bool = False, 5894*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Optional[Tensor]]: 5895*da0073e9SAndroid Build Coastguard Worker r"""Forward method for MultiHeadAttention. 5896*da0073e9SAndroid Build Coastguard Worker 5897*da0073e9SAndroid Build Coastguard Worker See :class:`torch.nn.MultiheadAttention` for details. 5898*da0073e9SAndroid Build Coastguard Worker 5899*da0073e9SAndroid Build Coastguard Worker Args: 5900*da0073e9SAndroid Build Coastguard Worker query, key, value: map a query and a set of key-value pairs to an output. 5901*da0073e9SAndroid Build Coastguard Worker See "Attention Is All You Need" for more details. 5902*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check: total dimension of the model. 5903*da0073e9SAndroid Build Coastguard Worker num_heads: parallel attention heads. 5904*da0073e9SAndroid Build Coastguard Worker in_proj_weight, in_proj_bias: input projection weight and bias. 5905*da0073e9SAndroid Build Coastguard Worker bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 5906*da0073e9SAndroid Build Coastguard Worker add_zero_attn: add a new batch of zeros to the key and 5907*da0073e9SAndroid Build Coastguard Worker value sequences at dim=1. 5908*da0073e9SAndroid Build Coastguard Worker dropout_p: probability of an element to be zeroed. 5909*da0073e9SAndroid Build Coastguard Worker out_proj_weight, out_proj_bias: the output projection weight and bias. 5910*da0073e9SAndroid Build Coastguard Worker training: apply dropout if is ``True``. 5911*da0073e9SAndroid Build Coastguard Worker key_padding_mask: if provided, specified padding elements in the key will 5912*da0073e9SAndroid Build Coastguard Worker be ignored by the attention. This is an binary mask. When the value is True, 5913*da0073e9SAndroid Build Coastguard Worker the corresponding value on the attention layer will be filled with -inf. 5914*da0073e9SAndroid Build Coastguard Worker need_weights: output attn_output_weights. 5915*da0073e9SAndroid Build Coastguard Worker Default: `True` 5916*da0073e9SAndroid Build Coastguard Worker Note: `needs_weight` defaults to `True`, but should be set to `False` 5917*da0073e9SAndroid Build Coastguard Worker For best performance when attention weights are not needed. 5918*da0073e9SAndroid Build Coastguard Worker *Setting needs_weights to `True` 5919*da0073e9SAndroid Build Coastguard Worker leads to a significant performance degradation.* 5920*da0073e9SAndroid Build Coastguard Worker attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 5921*da0073e9SAndroid Build Coastguard Worker the batches while a 3D mask allows to specify a different mask for the entries of each batch. 5922*da0073e9SAndroid Build Coastguard Worker is_causal: If specified, applies a causal mask as attention mask, and ignores 5923*da0073e9SAndroid Build Coastguard Worker attn_mask for computing scaled dot product attention. 5924*da0073e9SAndroid Build Coastguard Worker Default: ``False``. 5925*da0073e9SAndroid Build Coastguard Worker .. warning:: 5926*da0073e9SAndroid Build Coastguard Worker is_causal is provides a hint that the attn_mask is the 5927*da0073e9SAndroid Build Coastguard Worker causal mask.Providing incorrect hints can result in 5928*da0073e9SAndroid Build Coastguard Worker incorrect execution, including forward and backward 5929*da0073e9SAndroid Build Coastguard Worker compatibility. 5930*da0073e9SAndroid Build Coastguard Worker use_separate_proj_weight: the function accept the proj. weights for query, key, 5931*da0073e9SAndroid Build Coastguard Worker and value in different forms. If false, in_proj_weight will be used, which is 5932*da0073e9SAndroid Build Coastguard Worker a combination of q_proj_weight, k_proj_weight, v_proj_weight. 5933*da0073e9SAndroid Build Coastguard Worker q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 5934*da0073e9SAndroid Build Coastguard Worker static_k, static_v: static key and value used for attention operators. 5935*da0073e9SAndroid Build Coastguard Worker average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. 5936*da0073e9SAndroid Build Coastguard Worker Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect 5937*da0073e9SAndroid Build Coastguard Worker when ``need_weights=True.``. Default: True 5938*da0073e9SAndroid Build Coastguard Worker 5939*da0073e9SAndroid Build Coastguard Worker 5940*da0073e9SAndroid Build Coastguard Worker Shape: 5941*da0073e9SAndroid Build Coastguard Worker Inputs: 5942*da0073e9SAndroid Build Coastguard Worker - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 5943*da0073e9SAndroid Build Coastguard Worker the embedding dimension. 5944*da0073e9SAndroid Build Coastguard Worker - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 5945*da0073e9SAndroid Build Coastguard Worker the embedding dimension. 5946*da0073e9SAndroid Build Coastguard Worker - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 5947*da0073e9SAndroid Build Coastguard Worker the embedding dimension. 5948*da0073e9SAndroid Build Coastguard Worker - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. 5949*da0073e9SAndroid Build Coastguard Worker If a FloatTensor is provided, it will be directly added to the value. 5950*da0073e9SAndroid Build Coastguard Worker If a BoolTensor is provided, the positions with the 5951*da0073e9SAndroid Build Coastguard Worker value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 5952*da0073e9SAndroid Build Coastguard Worker - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 5953*da0073e9SAndroid Build Coastguard Worker 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 5954*da0073e9SAndroid Build Coastguard Worker S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked 5955*da0073e9SAndroid Build Coastguard Worker positions. If a BoolTensor is provided, positions with ``True`` 5956*da0073e9SAndroid Build Coastguard Worker are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 5957*da0073e9SAndroid Build Coastguard Worker is provided, it will be added to the attention weight. 5958*da0073e9SAndroid Build Coastguard Worker - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 5959*da0073e9SAndroid Build Coastguard Worker N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 5960*da0073e9SAndroid Build Coastguard Worker - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 5961*da0073e9SAndroid Build Coastguard Worker N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 5962*da0073e9SAndroid Build Coastguard Worker 5963*da0073e9SAndroid Build Coastguard Worker Outputs: 5964*da0073e9SAndroid Build Coastguard Worker - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 5965*da0073e9SAndroid Build Coastguard Worker E is the embedding dimension. 5966*da0073e9SAndroid Build Coastguard Worker - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns 5967*da0073e9SAndroid Build Coastguard Worker attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 5968*da0073e9SAndroid Build Coastguard Worker :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 5969*da0073e9SAndroid Build Coastguard Worker :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 5970*da0073e9SAndroid Build Coastguard Worker head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. 5971*da0073e9SAndroid Build Coastguard Worker """ 5972*da0073e9SAndroid Build Coastguard Worker tens_ops = ( 5973*da0073e9SAndroid Build Coastguard Worker query, 5974*da0073e9SAndroid Build Coastguard Worker key, 5975*da0073e9SAndroid Build Coastguard Worker value, 5976*da0073e9SAndroid Build Coastguard Worker in_proj_weight, 5977*da0073e9SAndroid Build Coastguard Worker in_proj_bias, 5978*da0073e9SAndroid Build Coastguard Worker bias_k, 5979*da0073e9SAndroid Build Coastguard Worker bias_v, 5980*da0073e9SAndroid Build Coastguard Worker out_proj_weight, 5981*da0073e9SAndroid Build Coastguard Worker out_proj_bias, 5982*da0073e9SAndroid Build Coastguard Worker ) 5983*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tens_ops): 5984*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 5985*da0073e9SAndroid Build Coastguard Worker multi_head_attention_forward, 5986*da0073e9SAndroid Build Coastguard Worker tens_ops, 5987*da0073e9SAndroid Build Coastguard Worker query, 5988*da0073e9SAndroid Build Coastguard Worker key, 5989*da0073e9SAndroid Build Coastguard Worker value, 5990*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check, 5991*da0073e9SAndroid Build Coastguard Worker num_heads, 5992*da0073e9SAndroid Build Coastguard Worker in_proj_weight, 5993*da0073e9SAndroid Build Coastguard Worker in_proj_bias, 5994*da0073e9SAndroid Build Coastguard Worker bias_k, 5995*da0073e9SAndroid Build Coastguard Worker bias_v, 5996*da0073e9SAndroid Build Coastguard Worker add_zero_attn, 5997*da0073e9SAndroid Build Coastguard Worker dropout_p, 5998*da0073e9SAndroid Build Coastguard Worker out_proj_weight, 5999*da0073e9SAndroid Build Coastguard Worker out_proj_bias, 6000*da0073e9SAndroid Build Coastguard Worker training=training, 6001*da0073e9SAndroid Build Coastguard Worker key_padding_mask=key_padding_mask, 6002*da0073e9SAndroid Build Coastguard Worker need_weights=need_weights, 6003*da0073e9SAndroid Build Coastguard Worker attn_mask=attn_mask, 6004*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, 6005*da0073e9SAndroid Build Coastguard Worker use_separate_proj_weight=use_separate_proj_weight, 6006*da0073e9SAndroid Build Coastguard Worker q_proj_weight=q_proj_weight, 6007*da0073e9SAndroid Build Coastguard Worker k_proj_weight=k_proj_weight, 6008*da0073e9SAndroid Build Coastguard Worker v_proj_weight=v_proj_weight, 6009*da0073e9SAndroid Build Coastguard Worker static_k=static_k, 6010*da0073e9SAndroid Build Coastguard Worker static_v=static_v, 6011*da0073e9SAndroid Build Coastguard Worker average_attn_weights=average_attn_weights, 6012*da0073e9SAndroid Build Coastguard Worker ) 6013*da0073e9SAndroid Build Coastguard Worker 6014*da0073e9SAndroid Build Coastguard Worker is_batched = _mha_shape_check( 6015*da0073e9SAndroid Build Coastguard Worker query, key, value, key_padding_mask, attn_mask, num_heads 6016*da0073e9SAndroid Build Coastguard Worker ) 6017*da0073e9SAndroid Build Coastguard Worker 6018*da0073e9SAndroid Build Coastguard Worker # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input 6019*da0073e9SAndroid Build Coastguard Worker # is batched, run the computation and before returning squeeze the 6020*da0073e9SAndroid Build Coastguard Worker # batch dimension so that the output doesn't carry this temporary batch dimension. 6021*da0073e9SAndroid Build Coastguard Worker if not is_batched: 6022*da0073e9SAndroid Build Coastguard Worker # unsqueeze if the input is unbatched 6023*da0073e9SAndroid Build Coastguard Worker query = query.unsqueeze(1) 6024*da0073e9SAndroid Build Coastguard Worker key = key.unsqueeze(1) 6025*da0073e9SAndroid Build Coastguard Worker value = value.unsqueeze(1) 6026*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 6027*da0073e9SAndroid Build Coastguard Worker key_padding_mask = key_padding_mask.unsqueeze(0) 6028*da0073e9SAndroid Build Coastguard Worker 6029*da0073e9SAndroid Build Coastguard Worker # set up shape vars 6030*da0073e9SAndroid Build Coastguard Worker tgt_len, bsz, embed_dim = query.shape 6031*da0073e9SAndroid Build Coastguard Worker src_len, _, _ = key.shape 6032*da0073e9SAndroid Build Coastguard Worker 6033*da0073e9SAndroid Build Coastguard Worker key_padding_mask = _canonical_mask( 6034*da0073e9SAndroid Build Coastguard Worker mask=key_padding_mask, 6035*da0073e9SAndroid Build Coastguard Worker mask_name="key_padding_mask", 6036*da0073e9SAndroid Build Coastguard Worker other_type=_none_or_dtype(attn_mask), 6037*da0073e9SAndroid Build Coastguard Worker other_name="attn_mask", 6038*da0073e9SAndroid Build Coastguard Worker target_type=query.dtype, 6039*da0073e9SAndroid Build Coastguard Worker ) 6040*da0073e9SAndroid Build Coastguard Worker 6041*da0073e9SAndroid Build Coastguard Worker if is_causal and attn_mask is None: 6042*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 6043*da0073e9SAndroid Build Coastguard Worker "Need attn_mask if specifying the is_causal hint. " 6044*da0073e9SAndroid Build Coastguard Worker "You may use the Transformer module method " 6045*da0073e9SAndroid Build Coastguard Worker "`generate_square_subsequent_mask` to create this mask." 6046*da0073e9SAndroid Build Coastguard Worker ) 6047*da0073e9SAndroid Build Coastguard Worker 6048*da0073e9SAndroid Build Coastguard Worker if is_causal and key_padding_mask is None and not need_weights: 6049*da0073e9SAndroid Build Coastguard Worker # when we have a kpm or need weights, we need attn_mask 6050*da0073e9SAndroid Build Coastguard Worker # Otherwise, we use the is_causal hint go as is_causal 6051*da0073e9SAndroid Build Coastguard Worker # indicator to SDPA. 6052*da0073e9SAndroid Build Coastguard Worker attn_mask = None 6053*da0073e9SAndroid Build Coastguard Worker else: 6054*da0073e9SAndroid Build Coastguard Worker attn_mask = _canonical_mask( 6055*da0073e9SAndroid Build Coastguard Worker mask=attn_mask, 6056*da0073e9SAndroid Build Coastguard Worker mask_name="attn_mask", 6057*da0073e9SAndroid Build Coastguard Worker other_type=None, 6058*da0073e9SAndroid Build Coastguard Worker other_name="", 6059*da0073e9SAndroid Build Coastguard Worker target_type=query.dtype, 6060*da0073e9SAndroid Build Coastguard Worker check_other=False, 6061*da0073e9SAndroid Build Coastguard Worker ) 6062*da0073e9SAndroid Build Coastguard Worker 6063*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 6064*da0073e9SAndroid Build Coastguard Worker # We have the attn_mask, and use that to merge kpm into it. 6065*da0073e9SAndroid Build Coastguard Worker # Turn off use of is_causal hint, as the merged mask is no 6066*da0073e9SAndroid Build Coastguard Worker # longer causal. 6067*da0073e9SAndroid Build Coastguard Worker is_causal = False 6068*da0073e9SAndroid Build Coastguard Worker 6069*da0073e9SAndroid Build Coastguard Worker assert ( 6070*da0073e9SAndroid Build Coastguard Worker embed_dim == embed_dim_to_check 6071*da0073e9SAndroid Build Coastguard Worker ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 6072*da0073e9SAndroid Build Coastguard Worker if isinstance(embed_dim, torch.Tensor): 6073*da0073e9SAndroid Build Coastguard Worker # embed_dim can be a tensor when JIT tracing 6074*da0073e9SAndroid Build Coastguard Worker head_dim = embed_dim.div(num_heads, rounding_mode="trunc") 6075*da0073e9SAndroid Build Coastguard Worker else: 6076*da0073e9SAndroid Build Coastguard Worker head_dim = embed_dim // num_heads 6077*da0073e9SAndroid Build Coastguard Worker assert ( 6078*da0073e9SAndroid Build Coastguard Worker head_dim * num_heads == embed_dim 6079*da0073e9SAndroid Build Coastguard Worker ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 6080*da0073e9SAndroid Build Coastguard Worker if use_separate_proj_weight: 6081*da0073e9SAndroid Build Coastguard Worker # allow MHA to have different embedding dimensions when separate projection weights are used 6082*da0073e9SAndroid Build Coastguard Worker assert ( 6083*da0073e9SAndroid Build Coastguard Worker key.shape[:2] == value.shape[:2] 6084*da0073e9SAndroid Build Coastguard Worker ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 6085*da0073e9SAndroid Build Coastguard Worker else: 6086*da0073e9SAndroid Build Coastguard Worker assert ( 6087*da0073e9SAndroid Build Coastguard Worker key.shape == value.shape 6088*da0073e9SAndroid Build Coastguard Worker ), f"key shape {key.shape} does not match value shape {value.shape}" 6089*da0073e9SAndroid Build Coastguard Worker 6090*da0073e9SAndroid Build Coastguard Worker # 6091*da0073e9SAndroid Build Coastguard Worker # compute in-projection 6092*da0073e9SAndroid Build Coastguard Worker # 6093*da0073e9SAndroid Build Coastguard Worker if not use_separate_proj_weight: 6094*da0073e9SAndroid Build Coastguard Worker assert ( 6095*da0073e9SAndroid Build Coastguard Worker in_proj_weight is not None 6096*da0073e9SAndroid Build Coastguard Worker ), "use_separate_proj_weight is False but in_proj_weight is None" 6097*da0073e9SAndroid Build Coastguard Worker q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) 6098*da0073e9SAndroid Build Coastguard Worker else: 6099*da0073e9SAndroid Build Coastguard Worker assert ( 6100*da0073e9SAndroid Build Coastguard Worker q_proj_weight is not None 6101*da0073e9SAndroid Build Coastguard Worker ), "use_separate_proj_weight is True but q_proj_weight is None" 6102*da0073e9SAndroid Build Coastguard Worker assert ( 6103*da0073e9SAndroid Build Coastguard Worker k_proj_weight is not None 6104*da0073e9SAndroid Build Coastguard Worker ), "use_separate_proj_weight is True but k_proj_weight is None" 6105*da0073e9SAndroid Build Coastguard Worker assert ( 6106*da0073e9SAndroid Build Coastguard Worker v_proj_weight is not None 6107*da0073e9SAndroid Build Coastguard Worker ), "use_separate_proj_weight is True but v_proj_weight is None" 6108*da0073e9SAndroid Build Coastguard Worker if in_proj_bias is None: 6109*da0073e9SAndroid Build Coastguard Worker b_q = b_k = b_v = None 6110*da0073e9SAndroid Build Coastguard Worker else: 6111*da0073e9SAndroid Build Coastguard Worker b_q, b_k, b_v = in_proj_bias.chunk(3) 6112*da0073e9SAndroid Build Coastguard Worker q, k, v = _in_projection( 6113*da0073e9SAndroid Build Coastguard Worker query, 6114*da0073e9SAndroid Build Coastguard Worker key, 6115*da0073e9SAndroid Build Coastguard Worker value, 6116*da0073e9SAndroid Build Coastguard Worker q_proj_weight, 6117*da0073e9SAndroid Build Coastguard Worker k_proj_weight, 6118*da0073e9SAndroid Build Coastguard Worker v_proj_weight, 6119*da0073e9SAndroid Build Coastguard Worker b_q, 6120*da0073e9SAndroid Build Coastguard Worker b_k, 6121*da0073e9SAndroid Build Coastguard Worker b_v, 6122*da0073e9SAndroid Build Coastguard Worker ) 6123*da0073e9SAndroid Build Coastguard Worker 6124*da0073e9SAndroid Build Coastguard Worker # prep attention mask 6125*da0073e9SAndroid Build Coastguard Worker 6126*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 6127*da0073e9SAndroid Build Coastguard Worker # ensure attn_mask's dim is 3 6128*da0073e9SAndroid Build Coastguard Worker if attn_mask.dim() == 2: 6129*da0073e9SAndroid Build Coastguard Worker correct_2d_size = (tgt_len, src_len) 6130*da0073e9SAndroid Build Coastguard Worker if attn_mask.shape != correct_2d_size: 6131*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 6132*da0073e9SAndroid Build Coastguard Worker f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." 6133*da0073e9SAndroid Build Coastguard Worker ) 6134*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask.unsqueeze(0) 6135*da0073e9SAndroid Build Coastguard Worker elif attn_mask.dim() == 3: 6136*da0073e9SAndroid Build Coastguard Worker correct_3d_size = (bsz * num_heads, tgt_len, src_len) 6137*da0073e9SAndroid Build Coastguard Worker if attn_mask.shape != correct_3d_size: 6138*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 6139*da0073e9SAndroid Build Coastguard Worker f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." 6140*da0073e9SAndroid Build Coastguard Worker ) 6141*da0073e9SAndroid Build Coastguard Worker else: 6142*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 6143*da0073e9SAndroid Build Coastguard Worker f"attn_mask's dimension {attn_mask.dim()} is not supported" 6144*da0073e9SAndroid Build Coastguard Worker ) 6145*da0073e9SAndroid Build Coastguard Worker 6146*da0073e9SAndroid Build Coastguard Worker # add bias along batch dimension (currently second) 6147*da0073e9SAndroid Build Coastguard Worker if bias_k is not None and bias_v is not None: 6148*da0073e9SAndroid Build Coastguard Worker assert static_k is None, "bias cannot be added to static key." 6149*da0073e9SAndroid Build Coastguard Worker assert static_v is None, "bias cannot be added to static value." 6150*da0073e9SAndroid Build Coastguard Worker k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 6151*da0073e9SAndroid Build Coastguard Worker v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 6152*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 6153*da0073e9SAndroid Build Coastguard Worker attn_mask = pad(attn_mask, (0, 1)) 6154*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 6155*da0073e9SAndroid Build Coastguard Worker key_padding_mask = pad(key_padding_mask, (0, 1)) 6156*da0073e9SAndroid Build Coastguard Worker else: 6157*da0073e9SAndroid Build Coastguard Worker assert bias_k is None 6158*da0073e9SAndroid Build Coastguard Worker assert bias_v is None 6159*da0073e9SAndroid Build Coastguard Worker 6160*da0073e9SAndroid Build Coastguard Worker # 6161*da0073e9SAndroid Build Coastguard Worker # reshape q, k, v for multihead attention and make them batch first 6162*da0073e9SAndroid Build Coastguard Worker # 6163*da0073e9SAndroid Build Coastguard Worker q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 6164*da0073e9SAndroid Build Coastguard Worker if static_k is None: 6165*da0073e9SAndroid Build Coastguard Worker k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 6166*da0073e9SAndroid Build Coastguard Worker else: 6167*da0073e9SAndroid Build Coastguard Worker # TODO finish disentangling control flow so we don't do in-projections when statics are passed 6168*da0073e9SAndroid Build Coastguard Worker assert ( 6169*da0073e9SAndroid Build Coastguard Worker static_k.size(0) == bsz * num_heads 6170*da0073e9SAndroid Build Coastguard Worker ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 6171*da0073e9SAndroid Build Coastguard Worker assert ( 6172*da0073e9SAndroid Build Coastguard Worker static_k.size(2) == head_dim 6173*da0073e9SAndroid Build Coastguard Worker ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 6174*da0073e9SAndroid Build Coastguard Worker k = static_k 6175*da0073e9SAndroid Build Coastguard Worker if static_v is None: 6176*da0073e9SAndroid Build Coastguard Worker v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 6177*da0073e9SAndroid Build Coastguard Worker else: 6178*da0073e9SAndroid Build Coastguard Worker # TODO finish disentangling control flow so we don't do in-projections when statics are passed 6179*da0073e9SAndroid Build Coastguard Worker assert ( 6180*da0073e9SAndroid Build Coastguard Worker static_v.size(0) == bsz * num_heads 6181*da0073e9SAndroid Build Coastguard Worker ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 6182*da0073e9SAndroid Build Coastguard Worker assert ( 6183*da0073e9SAndroid Build Coastguard Worker static_v.size(2) == head_dim 6184*da0073e9SAndroid Build Coastguard Worker ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 6185*da0073e9SAndroid Build Coastguard Worker v = static_v 6186*da0073e9SAndroid Build Coastguard Worker 6187*da0073e9SAndroid Build Coastguard Worker # add zero attention along batch dimension (now first) 6188*da0073e9SAndroid Build Coastguard Worker if add_zero_attn: 6189*da0073e9SAndroid Build Coastguard Worker zero_attn_shape = (bsz * num_heads, 1, head_dim) 6190*da0073e9SAndroid Build Coastguard Worker k = torch.cat( 6191*da0073e9SAndroid Build Coastguard Worker [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 6192*da0073e9SAndroid Build Coastguard Worker ) 6193*da0073e9SAndroid Build Coastguard Worker v = torch.cat( 6194*da0073e9SAndroid Build Coastguard Worker [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 6195*da0073e9SAndroid Build Coastguard Worker ) 6196*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 6197*da0073e9SAndroid Build Coastguard Worker attn_mask = pad(attn_mask, (0, 1)) 6198*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 6199*da0073e9SAndroid Build Coastguard Worker key_padding_mask = pad(key_padding_mask, (0, 1)) 6200*da0073e9SAndroid Build Coastguard Worker 6201*da0073e9SAndroid Build Coastguard Worker # update source sequence length after adjustments 6202*da0073e9SAndroid Build Coastguard Worker src_len = k.size(1) 6203*da0073e9SAndroid Build Coastguard Worker 6204*da0073e9SAndroid Build Coastguard Worker # merge key padding and attention masks 6205*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 6206*da0073e9SAndroid Build Coastguard Worker assert key_padding_mask.shape == ( 6207*da0073e9SAndroid Build Coastguard Worker bsz, 6208*da0073e9SAndroid Build Coastguard Worker src_len, 6209*da0073e9SAndroid Build Coastguard Worker ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 6210*da0073e9SAndroid Build Coastguard Worker key_padding_mask = ( 6211*da0073e9SAndroid Build Coastguard Worker key_padding_mask.view(bsz, 1, 1, src_len) 6212*da0073e9SAndroid Build Coastguard Worker .expand(-1, num_heads, -1, -1) 6213*da0073e9SAndroid Build Coastguard Worker .reshape(bsz * num_heads, 1, src_len) 6214*da0073e9SAndroid Build Coastguard Worker ) 6215*da0073e9SAndroid Build Coastguard Worker if attn_mask is None: 6216*da0073e9SAndroid Build Coastguard Worker attn_mask = key_padding_mask 6217*da0073e9SAndroid Build Coastguard Worker else: 6218*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask + key_padding_mask 6219*da0073e9SAndroid Build Coastguard Worker 6220*da0073e9SAndroid Build Coastguard Worker # adjust dropout probability 6221*da0073e9SAndroid Build Coastguard Worker if not training: 6222*da0073e9SAndroid Build Coastguard Worker dropout_p = 0.0 6223*da0073e9SAndroid Build Coastguard Worker 6224*da0073e9SAndroid Build Coastguard Worker # 6225*da0073e9SAndroid Build Coastguard Worker # (deep breath) calculate attention and out projection 6226*da0073e9SAndroid Build Coastguard Worker # 6227*da0073e9SAndroid Build Coastguard Worker 6228*da0073e9SAndroid Build Coastguard Worker if need_weights: 6229*da0073e9SAndroid Build Coastguard Worker B, Nt, E = q.shape 6230*da0073e9SAndroid Build Coastguard Worker q_scaled = q * math.sqrt(1.0 / float(E)) 6231*da0073e9SAndroid Build Coastguard Worker 6232*da0073e9SAndroid Build Coastguard Worker assert not ( 6233*da0073e9SAndroid Build Coastguard Worker is_causal and attn_mask is None 6234*da0073e9SAndroid Build Coastguard Worker ), "FIXME: is_causal not implemented for need_weights" 6235*da0073e9SAndroid Build Coastguard Worker 6236*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 6237*da0073e9SAndroid Build Coastguard Worker attn_output_weights = torch.baddbmm( 6238*da0073e9SAndroid Build Coastguard Worker attn_mask, q_scaled, k.transpose(-2, -1) 6239*da0073e9SAndroid Build Coastguard Worker ) 6240*da0073e9SAndroid Build Coastguard Worker else: 6241*da0073e9SAndroid Build Coastguard Worker attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) 6242*da0073e9SAndroid Build Coastguard Worker attn_output_weights = softmax(attn_output_weights, dim=-1) 6243*da0073e9SAndroid Build Coastguard Worker if dropout_p > 0.0: 6244*da0073e9SAndroid Build Coastguard Worker attn_output_weights = dropout(attn_output_weights, p=dropout_p) 6245*da0073e9SAndroid Build Coastguard Worker 6246*da0073e9SAndroid Build Coastguard Worker attn_output = torch.bmm(attn_output_weights, v) 6247*da0073e9SAndroid Build Coastguard Worker 6248*da0073e9SAndroid Build Coastguard Worker attn_output = ( 6249*da0073e9SAndroid Build Coastguard Worker attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) 6250*da0073e9SAndroid Build Coastguard Worker ) 6251*da0073e9SAndroid Build Coastguard Worker attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 6252*da0073e9SAndroid Build Coastguard Worker attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 6253*da0073e9SAndroid Build Coastguard Worker 6254*da0073e9SAndroid Build Coastguard Worker # optionally average attention weights over heads 6255*da0073e9SAndroid Build Coastguard Worker attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 6256*da0073e9SAndroid Build Coastguard Worker if average_attn_weights: 6257*da0073e9SAndroid Build Coastguard Worker attn_output_weights = attn_output_weights.mean(dim=1) 6258*da0073e9SAndroid Build Coastguard Worker 6259*da0073e9SAndroid Build Coastguard Worker if not is_batched: 6260*da0073e9SAndroid Build Coastguard Worker # squeeze the output if input was unbatched 6261*da0073e9SAndroid Build Coastguard Worker attn_output = attn_output.squeeze(1) 6262*da0073e9SAndroid Build Coastguard Worker attn_output_weights = attn_output_weights.squeeze(0) 6263*da0073e9SAndroid Build Coastguard Worker return attn_output, attn_output_weights 6264*da0073e9SAndroid Build Coastguard Worker else: 6265*da0073e9SAndroid Build Coastguard Worker # attn_mask can be either (L,S) or (N*num_heads, L, S) 6266*da0073e9SAndroid Build Coastguard Worker # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) 6267*da0073e9SAndroid Build Coastguard Worker # in order to match the input for SDPA of (N, num_heads, L, S) 6268*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 6269*da0073e9SAndroid Build Coastguard Worker if attn_mask.size(0) == 1 and attn_mask.dim() == 3: 6270*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask.unsqueeze(0) 6271*da0073e9SAndroid Build Coastguard Worker else: 6272*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) 6273*da0073e9SAndroid Build Coastguard Worker 6274*da0073e9SAndroid Build Coastguard Worker q = q.view(bsz, num_heads, tgt_len, head_dim) 6275*da0073e9SAndroid Build Coastguard Worker k = k.view(bsz, num_heads, src_len, head_dim) 6276*da0073e9SAndroid Build Coastguard Worker v = v.view(bsz, num_heads, src_len, head_dim) 6277*da0073e9SAndroid Build Coastguard Worker 6278*da0073e9SAndroid Build Coastguard Worker attn_output = scaled_dot_product_attention( 6279*da0073e9SAndroid Build Coastguard Worker q, k, v, attn_mask, dropout_p, is_causal 6280*da0073e9SAndroid Build Coastguard Worker ) 6281*da0073e9SAndroid Build Coastguard Worker attn_output = ( 6282*da0073e9SAndroid Build Coastguard Worker attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) 6283*da0073e9SAndroid Build Coastguard Worker ) 6284*da0073e9SAndroid Build Coastguard Worker 6285*da0073e9SAndroid Build Coastguard Worker attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 6286*da0073e9SAndroid Build Coastguard Worker attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 6287*da0073e9SAndroid Build Coastguard Worker if not is_batched: 6288*da0073e9SAndroid Build Coastguard Worker # squeeze the output if input was unbatched 6289*da0073e9SAndroid Build Coastguard Worker attn_output = attn_output.squeeze(1) 6290*da0073e9SAndroid Build Coastguard Worker return attn_output, None 6291