xref: /aosp_15_r20/external/pytorch/torch/autograd/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import _vmap
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom . import forward_ad as fwAD
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# Utility functions
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef _as_tuple_nocheck(x):
16*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, tuple):
17*da0073e9SAndroid Build Coastguard Worker        return x
18*da0073e9SAndroid Build Coastguard Worker    elif isinstance(x, list):
19*da0073e9SAndroid Build Coastguard Worker        return tuple(x)
20*da0073e9SAndroid Build Coastguard Worker    else:
21*da0073e9SAndroid Build Coastguard Worker        return (x,)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(inp, arg_name=None, fn_name=None):
25*da0073e9SAndroid Build Coastguard Worker    # Ensures that inp is a tuple of Tensors
26*da0073e9SAndroid Build Coastguard Worker    # Returns whether or not the original inp was a tuple and the tupled version of the input
27*da0073e9SAndroid Build Coastguard Worker    if arg_name is None and fn_name is None:
28*da0073e9SAndroid Build Coastguard Worker        return _as_tuple_nocheck(inp)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    is_inp_tuple = True
31*da0073e9SAndroid Build Coastguard Worker    if not isinstance(inp, tuple):
32*da0073e9SAndroid Build Coastguard Worker        inp = (inp,)
33*da0073e9SAndroid Build Coastguard Worker        is_inp_tuple = False
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    for i, el in enumerate(inp):
36*da0073e9SAndroid Build Coastguard Worker        if not isinstance(el, torch.Tensor):
37*da0073e9SAndroid Build Coastguard Worker            if is_inp_tuple:
38*da0073e9SAndroid Build Coastguard Worker                raise TypeError(
39*da0073e9SAndroid Build Coastguard Worker                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
40*da0073e9SAndroid Build Coastguard Worker                    f" value at index {i} has type {type(el)}."
41*da0073e9SAndroid Build Coastguard Worker                )
42*da0073e9SAndroid Build Coastguard Worker            else:
43*da0073e9SAndroid Build Coastguard Worker                raise TypeError(
44*da0073e9SAndroid Build Coastguard Worker                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
45*da0073e9SAndroid Build Coastguard Worker                    f" given {arg_name} has type {type(el)}."
46*da0073e9SAndroid Build Coastguard Worker                )
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    return is_inp_tuple, inp
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerdef _tuple_postprocess(res, to_unpack):
52*da0073e9SAndroid Build Coastguard Worker    # Unpacks a potentially nested tuple of Tensors
53*da0073e9SAndroid Build Coastguard Worker    # to_unpack should be a single boolean or a tuple of two booleans.
54*da0073e9SAndroid Build Coastguard Worker    # It is used to:
55*da0073e9SAndroid Build Coastguard Worker    # - invert _as_tuple when res should match the inp given to _as_tuple
56*da0073e9SAndroid Build Coastguard Worker    # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
57*da0073e9SAndroid Build Coastguard Worker    if isinstance(to_unpack, tuple):
58*da0073e9SAndroid Build Coastguard Worker        assert len(to_unpack) == 2
59*da0073e9SAndroid Build Coastguard Worker        if not to_unpack[1]:
60*da0073e9SAndroid Build Coastguard Worker            res = tuple(el[0] for el in res)
61*da0073e9SAndroid Build Coastguard Worker        if not to_unpack[0]:
62*da0073e9SAndroid Build Coastguard Worker            res = res[0]
63*da0073e9SAndroid Build Coastguard Worker    else:
64*da0073e9SAndroid Build Coastguard Worker        if not to_unpack:
65*da0073e9SAndroid Build Coastguard Worker            res = res[0]
66*da0073e9SAndroid Build Coastguard Worker    return res
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef _grad_preprocess(inputs, create_graph, need_graph):
70*da0073e9SAndroid Build Coastguard Worker    # Preprocess the inputs to make sure they require gradient
71*da0073e9SAndroid Build Coastguard Worker    # inputs is a tuple of Tensors to preprocess
72*da0073e9SAndroid Build Coastguard Worker    # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
73*da0073e9SAndroid Build Coastguard Worker    # need_graph specifies if we internally want gradients to flow back to the Tensors in res
74*da0073e9SAndroid Build Coastguard Worker    # Note that we *always* create a new Tensor object to be able to see the difference between
75*da0073e9SAndroid Build Coastguard Worker    # inputs given as arguments and the same Tensors automatically captured by the user function.
76*da0073e9SAndroid Build Coastguard Worker    # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
77*da0073e9SAndroid Build Coastguard Worker    res = []
78*da0073e9SAndroid Build Coastguard Worker    for inp in inputs:
79*da0073e9SAndroid Build Coastguard Worker        if create_graph and inp.requires_grad:
80*da0073e9SAndroid Build Coastguard Worker            # Create at least a new Tensor object in a differentiable way
81*da0073e9SAndroid Build Coastguard Worker            if not inp.is_sparse:
82*da0073e9SAndroid Build Coastguard Worker                # Use .view_as() to get a shallow copy
83*da0073e9SAndroid Build Coastguard Worker                res.append(inp.view_as(inp))
84*da0073e9SAndroid Build Coastguard Worker            else:
85*da0073e9SAndroid Build Coastguard Worker                # We cannot use view for sparse Tensors so we clone
86*da0073e9SAndroid Build Coastguard Worker                res.append(inp.clone())
87*da0073e9SAndroid Build Coastguard Worker        else:
88*da0073e9SAndroid Build Coastguard Worker            res.append(inp.detach().requires_grad_(need_graph))
89*da0073e9SAndroid Build Coastguard Worker    return tuple(res)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerdef _grad_postprocess(inputs, create_graph):
93*da0073e9SAndroid Build Coastguard Worker    # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
94*da0073e9SAndroid Build Coastguard Worker    # request it.
95*da0073e9SAndroid Build Coastguard Worker    if isinstance(inputs[0], torch.Tensor):
96*da0073e9SAndroid Build Coastguard Worker        if not create_graph:
97*da0073e9SAndroid Build Coastguard Worker            return tuple(inp.detach() for inp in inputs)
98*da0073e9SAndroid Build Coastguard Worker        else:
99*da0073e9SAndroid Build Coastguard Worker            return inputs
100*da0073e9SAndroid Build Coastguard Worker    else:
101*da0073e9SAndroid Build Coastguard Worker        return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef _validate_v(v, other, is_other_tuple):
105*da0073e9SAndroid Build Coastguard Worker    # This assumes that other is the correct shape, and v should match
106*da0073e9SAndroid Build Coastguard Worker    # Both are assumed to be tuples of Tensors
107*da0073e9SAndroid Build Coastguard Worker    if len(other) != len(v):
108*da0073e9SAndroid Build Coastguard Worker        if is_other_tuple:
109*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
110*da0073e9SAndroid Build Coastguard Worker                f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
111*da0073e9SAndroid Build Coastguard Worker            )
112*da0073e9SAndroid Build Coastguard Worker        else:
113*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("The given v should contain a single Tensor.")
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    for idx, (el_v, el_other) in enumerate(zip(v, other)):
116*da0073e9SAndroid Build Coastguard Worker        if el_v.size() != el_other.size():
117*da0073e9SAndroid Build Coastguard Worker            prepend = ""
118*da0073e9SAndroid Build Coastguard Worker            if is_other_tuple:
119*da0073e9SAndroid Build Coastguard Worker                prepend = f"Entry {idx} in "
120*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
121*da0073e9SAndroid Build Coastguard Worker                f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
122*da0073e9SAndroid Build Coastguard Worker            )
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Workerdef _check_requires_grad(inputs, input_type, strict):
126*da0073e9SAndroid Build Coastguard Worker    # Used to make all the necessary checks to raise nice errors in strict mode.
127*da0073e9SAndroid Build Coastguard Worker    if not strict:
128*da0073e9SAndroid Build Coastguard Worker        return
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
131*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Invalid input_type to _check_requires_grad")
132*da0073e9SAndroid Build Coastguard Worker    for i, inp in enumerate(inputs):
133*da0073e9SAndroid Build Coastguard Worker        if inp is None:
134*da0073e9SAndroid Build Coastguard Worker            # This can only be reached for grad_inputs.
135*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
136*da0073e9SAndroid Build Coastguard Worker                f"The output of the user-provided function is independent of input {i}."
137*da0073e9SAndroid Build Coastguard Worker                " This is not allowed in strict mode."
138*da0073e9SAndroid Build Coastguard Worker            )
139*da0073e9SAndroid Build Coastguard Worker        if not inp.requires_grad:
140*da0073e9SAndroid Build Coastguard Worker            if input_type == "hessian":
141*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
142*da0073e9SAndroid Build Coastguard Worker                    f"The hessian of the user-provided function with respect to input {i}"
143*da0073e9SAndroid Build Coastguard Worker                    " is independent of the input. This is not allowed in strict mode."
144*da0073e9SAndroid Build Coastguard Worker                    " You should ensure that your function is thrice differentiable and that"
145*da0073e9SAndroid Build Coastguard Worker                    " the hessian depends on the inputs."
146*da0073e9SAndroid Build Coastguard Worker                )
147*da0073e9SAndroid Build Coastguard Worker            elif input_type == "jacobian":
148*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
149*da0073e9SAndroid Build Coastguard Worker                    "While computing the hessian, found that the jacobian of the user-provided"
150*da0073e9SAndroid Build Coastguard Worker                    f" function with respect to input {i} is independent of the input. This is not"
151*da0073e9SAndroid Build Coastguard Worker                    " allowed in strict mode. You should ensure that your function is twice"
152*da0073e9SAndroid Build Coastguard Worker                    " differentiable and that the jacobian depends on the inputs (this would be"
153*da0073e9SAndroid Build Coastguard Worker                    " violated by a linear function for example)."
154*da0073e9SAndroid Build Coastguard Worker                )
155*da0073e9SAndroid Build Coastguard Worker            elif input_type == "grad_inputs":
156*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
157*da0073e9SAndroid Build Coastguard Worker                    f"The gradient with respect to input {i} is independent of the inputs of the"
158*da0073e9SAndroid Build Coastguard Worker                    " user-provided function. This is not allowed in strict mode."
159*da0073e9SAndroid Build Coastguard Worker                )
160*da0073e9SAndroid Build Coastguard Worker            else:
161*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
162*da0073e9SAndroid Build Coastguard Worker                    f"Output {i} of the user-provided function does not require gradients."
163*da0073e9SAndroid Build Coastguard Worker                    " The outputs must be computed in a differentiable manner from the input"
164*da0073e9SAndroid Build Coastguard Worker                    " when running in strict mode."
165*da0073e9SAndroid Build Coastguard Worker                )
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerdef _autograd_grad(
169*da0073e9SAndroid Build Coastguard Worker    outputs,
170*da0073e9SAndroid Build Coastguard Worker    inputs,
171*da0073e9SAndroid Build Coastguard Worker    grad_outputs=None,
172*da0073e9SAndroid Build Coastguard Worker    create_graph=False,
173*da0073e9SAndroid Build Coastguard Worker    retain_graph=None,
174*da0073e9SAndroid Build Coastguard Worker    is_grads_batched=False,
175*da0073e9SAndroid Build Coastguard Worker):
176*da0073e9SAndroid Build Coastguard Worker    # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
177*da0073e9SAndroid Build Coastguard Worker    # This has the extra constraint that inputs has to be a tuple
178*da0073e9SAndroid Build Coastguard Worker    assert isinstance(outputs, tuple)
179*da0073e9SAndroid Build Coastguard Worker    if grad_outputs is None:
180*da0073e9SAndroid Build Coastguard Worker        grad_outputs = (None,) * len(outputs)
181*da0073e9SAndroid Build Coastguard Worker    assert isinstance(grad_outputs, tuple)
182*da0073e9SAndroid Build Coastguard Worker    assert len(outputs) == len(grad_outputs)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker    new_outputs: Tuple[torch.Tensor, ...] = ()
185*da0073e9SAndroid Build Coastguard Worker    new_grad_outputs: Tuple[torch.Tensor, ...] = ()
186*da0073e9SAndroid Build Coastguard Worker    for out, grad_out in zip(outputs, grad_outputs):
187*da0073e9SAndroid Build Coastguard Worker        if out is not None and out.requires_grad:
188*da0073e9SAndroid Build Coastguard Worker            new_outputs += (out,)
189*da0073e9SAndroid Build Coastguard Worker            new_grad_outputs += (grad_out,)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    if len(new_outputs) == 0:
192*da0073e9SAndroid Build Coastguard Worker        # No differentiable output, we don't need to call the autograd engine
193*da0073e9SAndroid Build Coastguard Worker        return (None,) * len(inputs)
194*da0073e9SAndroid Build Coastguard Worker    else:
195*da0073e9SAndroid Build Coastguard Worker        return torch.autograd.grad(
196*da0073e9SAndroid Build Coastguard Worker            new_outputs,
197*da0073e9SAndroid Build Coastguard Worker            inputs,
198*da0073e9SAndroid Build Coastguard Worker            new_grad_outputs,
199*da0073e9SAndroid Build Coastguard Worker            allow_unused=True,
200*da0073e9SAndroid Build Coastguard Worker            create_graph=create_graph,
201*da0073e9SAndroid Build Coastguard Worker            retain_graph=retain_graph,
202*da0073e9SAndroid Build Coastguard Worker            is_grads_batched=is_grads_batched,
203*da0073e9SAndroid Build Coastguard Worker        )
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerdef _fill_in_zeros(grads, refs, strict, create_graph, stage):
207*da0073e9SAndroid Build Coastguard Worker    # Used to detect None in the grads and depending on the flags, either replace them
208*da0073e9SAndroid Build Coastguard Worker    # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
209*da0073e9SAndroid Build Coastguard Worker    # strict and create graph allow us to detect when it is appropriate to raise an error
210*da0073e9SAndroid Build Coastguard Worker    # stage gives us information of which backward call we consider to give good error message
211*da0073e9SAndroid Build Coastguard Worker    if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
212*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    res: Tuple[torch.Tensor, ...] = ()
215*da0073e9SAndroid Build Coastguard Worker    for i, grads_i in enumerate(grads):
216*da0073e9SAndroid Build Coastguard Worker        if grads_i is None:
217*da0073e9SAndroid Build Coastguard Worker            if strict:
218*da0073e9SAndroid Build Coastguard Worker                if stage == "back":
219*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
220*da0073e9SAndroid Build Coastguard Worker                        "The output of the user-provided function is independent of "
221*da0073e9SAndroid Build Coastguard Worker                        f"input {i}. This is not allowed in strict mode."
222*da0073e9SAndroid Build Coastguard Worker                    )
223*da0073e9SAndroid Build Coastguard Worker                elif stage == "back_trick":
224*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
225*da0073e9SAndroid Build Coastguard Worker                        f"The gradient with respect to the input is independent of entry {i}"
226*da0073e9SAndroid Build Coastguard Worker                        " in the grad_outputs when using the double backward trick to compute"
227*da0073e9SAndroid Build Coastguard Worker                        " forward mode gradients. This is not allowed in strict mode."
228*da0073e9SAndroid Build Coastguard Worker                    )
229*da0073e9SAndroid Build Coastguard Worker                elif stage == "double_back":
230*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
231*da0073e9SAndroid Build Coastguard Worker                        "The jacobian of the user-provided function is independent of "
232*da0073e9SAndroid Build Coastguard Worker                        f"input {i}. This is not allowed in strict mode."
233*da0073e9SAndroid Build Coastguard Worker                    )
234*da0073e9SAndroid Build Coastguard Worker                else:
235*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
236*da0073e9SAndroid Build Coastguard Worker                        "The hessian of the user-provided function is independent of "
237*da0073e9SAndroid Build Coastguard Worker                        f"entry {i} in the grad_jacobian. This is not allowed in strict "
238*da0073e9SAndroid Build Coastguard Worker                        "mode as it prevents from using the double backward trick to "
239*da0073e9SAndroid Build Coastguard Worker                        "replace forward mode AD."
240*da0073e9SAndroid Build Coastguard Worker                    )
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker            grads_i = torch.zeros_like(refs[i])
243*da0073e9SAndroid Build Coastguard Worker        else:
244*da0073e9SAndroid Build Coastguard Worker            if strict and create_graph and not grads_i.requires_grad:
245*da0073e9SAndroid Build Coastguard Worker                if "double" not in stage:
246*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
247*da0073e9SAndroid Build Coastguard Worker                        "The jacobian of the user-provided function is independent of "
248*da0073e9SAndroid Build Coastguard Worker                        f"input {i}. This is not allowed in strict mode when create_graph=True."
249*da0073e9SAndroid Build Coastguard Worker                    )
250*da0073e9SAndroid Build Coastguard Worker                else:
251*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
252*da0073e9SAndroid Build Coastguard Worker                        "The hessian of the user-provided function is independent of "
253*da0073e9SAndroid Build Coastguard Worker                        f"input {i}. This is not allowed in strict mode when create_graph=True."
254*da0073e9SAndroid Build Coastguard Worker                    )
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        res += (grads_i,)
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    return res
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker# Public API
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Workerdef vjp(func, inputs, v=None, create_graph=False, strict=False):
265*da0073e9SAndroid Build Coastguard Worker    r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    Args:
268*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
269*da0073e9SAndroid Build Coastguard Worker            a tuple of Tensors or a Tensor.
270*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
271*da0073e9SAndroid Build Coastguard Worker        v (tuple of Tensors or Tensor): The vector for which the vector
272*da0073e9SAndroid Build Coastguard Worker            Jacobian product is computed.  Must be the same size as the output
273*da0073e9SAndroid Build Coastguard Worker            of ``func``. This argument is optional when the output of ``func``
274*da0073e9SAndroid Build Coastguard Worker            contains a single element and (if it is not provided) will be set
275*da0073e9SAndroid Build Coastguard Worker            as a Tensor containing a single ``1``.
276*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, both the output and result
277*da0073e9SAndroid Build Coastguard Worker            will be computed in a differentiable way. Note that when ``strict``
278*da0073e9SAndroid Build Coastguard Worker            is ``False``, the result can not require gradients or be
279*da0073e9SAndroid Build Coastguard Worker            disconnected from the inputs.  Defaults to ``False``.
280*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we
281*da0073e9SAndroid Build Coastguard Worker            detect that there exists an input such that all the outputs are
282*da0073e9SAndroid Build Coastguard Worker            independent of it. If ``False``, we return a Tensor of zeros as the
283*da0073e9SAndroid Build Coastguard Worker            vjp for said inputs, which is the expected mathematical value.
284*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    Returns:
287*da0073e9SAndroid Build Coastguard Worker        output (tuple): tuple with:
288*da0073e9SAndroid Build Coastguard Worker            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker            vjp (tuple of Tensors or Tensor): result of the dot product with
291*da0073e9SAndroid Build Coastguard Worker            the same shape as the inputs.
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker    Example:
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
296*da0073e9SAndroid Build Coastguard Worker        >>> def exp_reducer(x):
297*da0073e9SAndroid Build Coastguard Worker        ...     return x.exp().sum(dim=1)
298*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(4, 4)
299*da0073e9SAndroid Build Coastguard Worker        >>> v = torch.ones(4)
300*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
301*da0073e9SAndroid Build Coastguard Worker        >>> vjp(exp_reducer, inputs, v)
302*da0073e9SAndroid Build Coastguard Worker        (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
303*da0073e9SAndroid Build Coastguard Worker         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
304*da0073e9SAndroid Build Coastguard Worker                [2.1288, 1.0652, 1.5483, 2.5035],
305*da0073e9SAndroid Build Coastguard Worker                [2.2046, 1.1292, 1.1432, 1.3059],
306*da0073e9SAndroid Build Coastguard Worker                [1.3225, 1.6652, 1.7753, 2.0152]]))
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        >>> vjp(exp_reducer, inputs, v, create_graph=True)
309*da0073e9SAndroid Build Coastguard Worker        (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
310*da0073e9SAndroid Build Coastguard Worker         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
311*da0073e9SAndroid Build Coastguard Worker                [2.1288, 1.0652, 1.5483, 2.5035],
312*da0073e9SAndroid Build Coastguard Worker                [2.2046, 1.1292, 1.1432, 1.3059],
313*da0073e9SAndroid Build Coastguard Worker                [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        >>> def adder(x, y):
316*da0073e9SAndroid Build Coastguard Worker        ...     return 2 * x + 3 * y
317*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
318*da0073e9SAndroid Build Coastguard Worker        >>> v = torch.ones(2)
319*da0073e9SAndroid Build Coastguard Worker        >>> vjp(adder, inputs, v)
320*da0073e9SAndroid Build Coastguard Worker        (tensor([2.4225, 2.3340]),
321*da0073e9SAndroid Build Coastguard Worker         (tensor([2., 2.]), tensor([3., 3.])))
322*da0073e9SAndroid Build Coastguard Worker    """
323*da0073e9SAndroid Build Coastguard Worker    with torch.enable_grad():
324*da0073e9SAndroid Build Coastguard Worker        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
325*da0073e9SAndroid Build Coastguard Worker        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        outputs = func(*inputs)
328*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = _as_tuple(
329*da0073e9SAndroid Build Coastguard Worker            outputs, "outputs of the user-provided function", "vjp"
330*da0073e9SAndroid Build Coastguard Worker        )
331*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(outputs, "outputs", strict=strict)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        if v is not None:
334*da0073e9SAndroid Build Coastguard Worker            _, v = _as_tuple(v, "v", "vjp")
335*da0073e9SAndroid Build Coastguard Worker            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
336*da0073e9SAndroid Build Coastguard Worker            _validate_v(v, outputs, is_outputs_tuple)
337*da0073e9SAndroid Build Coastguard Worker        else:
338*da0073e9SAndroid Build Coastguard Worker            if len(outputs) != 1 or outputs[0].nelement() != 1:
339*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
340*da0073e9SAndroid Build Coastguard Worker                    "The vector v can only be None if the "
341*da0073e9SAndroid Build Coastguard Worker                    "user-provided function returns "
342*da0073e9SAndroid Build Coastguard Worker                    "a single Tensor with a single element."
343*da0073e9SAndroid Build Coastguard Worker                )
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    enable_grad = True if create_graph else torch.is_grad_enabled()
346*da0073e9SAndroid Build Coastguard Worker    with torch.set_grad_enabled(enable_grad):
347*da0073e9SAndroid Build Coastguard Worker        grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
348*da0073e9SAndroid Build Coastguard Worker        vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    # Cleanup objects and return them to the user
351*da0073e9SAndroid Build Coastguard Worker    outputs = _grad_postprocess(outputs, create_graph)
352*da0073e9SAndroid Build Coastguard Worker    vjp = _grad_postprocess(vjp, create_graph)
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
355*da0073e9SAndroid Build Coastguard Worker        vjp, is_inputs_tuple
356*da0073e9SAndroid Build Coastguard Worker    )
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerdef jvp(func, inputs, v=None, create_graph=False, strict=False):
360*da0073e9SAndroid Build Coastguard Worker    r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    Args:
363*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
364*da0073e9SAndroid Build Coastguard Worker            a tuple of Tensors or a Tensor.
365*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
366*da0073e9SAndroid Build Coastguard Worker        v (tuple of Tensors or Tensor): The vector for which the Jacobian
367*da0073e9SAndroid Build Coastguard Worker            vector product is computed. Must be the same size as the input of
368*da0073e9SAndroid Build Coastguard Worker            ``func``. This argument is optional when the input to ``func``
369*da0073e9SAndroid Build Coastguard Worker            contains a single element and (if it is not provided) will be set
370*da0073e9SAndroid Build Coastguard Worker            as a Tensor containing a single ``1``.
371*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, both the output and result
372*da0073e9SAndroid Build Coastguard Worker            will be computed in a differentiable way. Note that when ``strict``
373*da0073e9SAndroid Build Coastguard Worker            is ``False``, the result can not require gradients or be
374*da0073e9SAndroid Build Coastguard Worker            disconnected from the inputs.  Defaults to ``False``.
375*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we
376*da0073e9SAndroid Build Coastguard Worker            detect that there exists an input such that all the outputs are
377*da0073e9SAndroid Build Coastguard Worker            independent of it. If ``False``, we return a Tensor of zeros as the
378*da0073e9SAndroid Build Coastguard Worker            jvp for said inputs, which is the expected mathematical value.
379*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    Returns:
382*da0073e9SAndroid Build Coastguard Worker        output (tuple): tuple with:
383*da0073e9SAndroid Build Coastguard Worker            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker            jvp (tuple of Tensors or Tensor): result of the dot product with
386*da0073e9SAndroid Build Coastguard Worker            the same shape as the output.
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    Note:
389*da0073e9SAndroid Build Coastguard Worker        ``autograd.functional.jvp`` computes the jvp by using the backward of
390*da0073e9SAndroid Build Coastguard Worker        the backward (sometimes called the double backwards trick). This is not
391*da0073e9SAndroid Build Coastguard Worker        the most performant way of computing the jvp. Please consider using
392*da0073e9SAndroid Build Coastguard Worker        :func:`torch.func.jvp` or the
393*da0073e9SAndroid Build Coastguard Worker        :ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    Example:
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
398*da0073e9SAndroid Build Coastguard Worker        >>> def exp_reducer(x):
399*da0073e9SAndroid Build Coastguard Worker        ...     return x.exp().sum(dim=1)
400*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(4, 4)
401*da0073e9SAndroid Build Coastguard Worker        >>> v = torch.ones(4, 4)
402*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
403*da0073e9SAndroid Build Coastguard Worker        >>> jvp(exp_reducer, inputs, v)
404*da0073e9SAndroid Build Coastguard Worker        (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
405*da0073e9SAndroid Build Coastguard Worker         tensor([6.3090, 4.6742, 7.9114, 8.2106]))
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        >>> jvp(exp_reducer, inputs, v, create_graph=True)
408*da0073e9SAndroid Build Coastguard Worker        (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
409*da0073e9SAndroid Build Coastguard Worker         tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        >>> def adder(x, y):
412*da0073e9SAndroid Build Coastguard Worker        ...     return 2 * x + 3 * y
413*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
414*da0073e9SAndroid Build Coastguard Worker        >>> v = (torch.ones(2), torch.ones(2))
415*da0073e9SAndroid Build Coastguard Worker        >>> jvp(adder, inputs, v)
416*da0073e9SAndroid Build Coastguard Worker        (tensor([2.2399, 2.5005]),
417*da0073e9SAndroid Build Coastguard Worker         tensor([5., 5.]))
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker    """
420*da0073e9SAndroid Build Coastguard Worker    with torch.enable_grad():
421*da0073e9SAndroid Build Coastguard Worker        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
422*da0073e9SAndroid Build Coastguard Worker        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        if v is not None:
425*da0073e9SAndroid Build Coastguard Worker            _, v = _as_tuple(v, "v", "jvp")
426*da0073e9SAndroid Build Coastguard Worker            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
427*da0073e9SAndroid Build Coastguard Worker            _validate_v(v, inputs, is_inputs_tuple)
428*da0073e9SAndroid Build Coastguard Worker        else:
429*da0073e9SAndroid Build Coastguard Worker            if len(inputs) != 1 or inputs[0].nelement() != 1:
430*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
431*da0073e9SAndroid Build Coastguard Worker                    "The vector v can only be None if the input to "
432*da0073e9SAndroid Build Coastguard Worker                    "the user-provided function is a single Tensor "
433*da0073e9SAndroid Build Coastguard Worker                    "with a single element."
434*da0073e9SAndroid Build Coastguard Worker                )
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        outputs = func(*inputs)
437*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = _as_tuple(
438*da0073e9SAndroid Build Coastguard Worker            outputs, "outputs of the user-provided function", "jvp"
439*da0073e9SAndroid Build Coastguard Worker        )
440*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(outputs, "outputs", strict=strict)
441*da0073e9SAndroid Build Coastguard Worker        # The backward is linear so the value of grad_outputs is not important as
442*da0073e9SAndroid Build Coastguard Worker        # it won't appear in the double backward graph. We only need to ensure that
443*da0073e9SAndroid Build Coastguard Worker        # it does not contain inf or nan.
444*da0073e9SAndroid Build Coastguard Worker        grad_outputs = tuple(
445*da0073e9SAndroid Build Coastguard Worker            torch.zeros_like(out, requires_grad=True) for out in outputs
446*da0073e9SAndroid Build Coastguard Worker        )
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
449*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    if create_graph:
452*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
453*da0073e9SAndroid Build Coastguard Worker            grad_res = _autograd_grad(
454*da0073e9SAndroid Build Coastguard Worker                grad_inputs, grad_outputs, v, create_graph=create_graph
455*da0073e9SAndroid Build Coastguard Worker            )
456*da0073e9SAndroid Build Coastguard Worker            jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
457*da0073e9SAndroid Build Coastguard Worker    else:
458*da0073e9SAndroid Build Coastguard Worker        grad_res = _autograd_grad(
459*da0073e9SAndroid Build Coastguard Worker            grad_inputs, grad_outputs, v, create_graph=create_graph
460*da0073e9SAndroid Build Coastguard Worker        )
461*da0073e9SAndroid Build Coastguard Worker        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    # Cleanup objects and return them to the user
464*da0073e9SAndroid Build Coastguard Worker    outputs = _grad_postprocess(outputs, create_graph)
465*da0073e9SAndroid Build Coastguard Worker    jvp = _grad_postprocess(jvp, create_graph)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
468*da0073e9SAndroid Build Coastguard Worker        jvp, is_outputs_tuple
469*da0073e9SAndroid Build Coastguard Worker    )
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Workerdef _construct_standard_basis_for(
473*da0073e9SAndroid Build Coastguard Worker    tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
474*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
475*da0073e9SAndroid Build Coastguard Worker    # This function:
476*da0073e9SAndroid Build Coastguard Worker    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
477*da0073e9SAndroid Build Coastguard Worker    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
478*da0073e9SAndroid Build Coastguard Worker    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
479*da0073e9SAndroid Build Coastguard Worker    #   device as the tensor
480*da0073e9SAndroid Build Coastguard Worker    #
481*da0073e9SAndroid Build Coastguard Worker    # For example, with tensor_numels = [1, 2, 1], this function returns:
482*da0073e9SAndroid Build Coastguard Worker    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
483*da0073e9SAndroid Build Coastguard Worker    #           [0],             [1, 0],              [0],
484*da0073e9SAndroid Build Coastguard Worker    #           [0],             [0, 1],              [0],
485*da0073e9SAndroid Build Coastguard Worker    #           [0]])  ,         [0, 0]])  ,          [1]])  )
486*da0073e9SAndroid Build Coastguard Worker    #
487*da0073e9SAndroid Build Coastguard Worker    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
488*da0073e9SAndroid Build Coastguard Worker    # Precondition: tensors always has at least one element.
489*da0073e9SAndroid Build Coastguard Worker    #
490*da0073e9SAndroid Build Coastguard Worker    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
491*da0073e9SAndroid Build Coastguard Worker    # for context behind this function. All the pre-conditions are guarded for
492*da0073e9SAndroid Build Coastguard Worker    # in torch.autograd.functional.jacobian.
493*da0073e9SAndroid Build Coastguard Worker    assert len(tensors) == len(tensor_numels)
494*da0073e9SAndroid Build Coastguard Worker    assert len(tensors) > 0
495*da0073e9SAndroid Build Coastguard Worker    total_numel = sum(tensor_numels)
496*da0073e9SAndroid Build Coastguard Worker    chunks = tuple(
497*da0073e9SAndroid Build Coastguard Worker        tensor.new_zeros(total_numel, tensor_numel)
498*da0073e9SAndroid Build Coastguard Worker        for tensor, tensor_numel in zip(tensors, tensor_numels)
499*da0073e9SAndroid Build Coastguard Worker    )
500*da0073e9SAndroid Build Coastguard Worker    diag_start_idx = 0
501*da0073e9SAndroid Build Coastguard Worker    for chunk, numel in zip(chunks, tensor_numels):
502*da0073e9SAndroid Build Coastguard Worker        chunk.diagonal(diag_start_idx).fill_(1)
503*da0073e9SAndroid Build Coastguard Worker        diag_start_idx -= numel
504*da0073e9SAndroid Build Coastguard Worker    return chunks
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Workerdef _jacfwd(func, inputs, strict=False, vectorize=False):
508*da0073e9SAndroid Build Coastguard Worker    if strict:
509*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
510*da0073e9SAndroid Build Coastguard Worker            "torch.autograd.functional.jacobian: `strict=True` "
511*da0073e9SAndroid Build Coastguard Worker            'and `strategy="forward-mode"` are not supported together (yet). '
512*da0073e9SAndroid Build Coastguard Worker            "Please either set `strict=False` or "
513*da0073e9SAndroid Build Coastguard Worker            '`strategy="reverse-mode"`.'
514*da0073e9SAndroid Build Coastguard Worker        )
515*da0073e9SAndroid Build Coastguard Worker    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
516*da0073e9SAndroid Build Coastguard Worker    output_info = []
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    if vectorize:
519*da0073e9SAndroid Build Coastguard Worker        # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
520*da0073e9SAndroid Build Coastguard Worker        input_numels = tuple(input.numel() for input in inputs)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        # Step 1: Prepare tangents
523*da0073e9SAndroid Build Coastguard Worker        tangents = _construct_standard_basis_for(inputs, input_numels)
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        # Step 2: Compute vmap over computation with dual tensors
526*da0073e9SAndroid Build Coastguard Worker        def jvp(tangents):
527*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
528*da0073e9SAndroid Build Coastguard Worker                dual_inputs = tuple(
529*da0073e9SAndroid Build Coastguard Worker                    fwAD.make_dual(input, tangent.view_as(input))
530*da0073e9SAndroid Build Coastguard Worker                    for input, tangent in zip(inputs, tangents)
531*da0073e9SAndroid Build Coastguard Worker                )
532*da0073e9SAndroid Build Coastguard Worker                _is_outputs_tuple, dual_outputs = _as_tuple(
533*da0073e9SAndroid Build Coastguard Worker                    func(*dual_inputs), "outputs"
534*da0073e9SAndroid Build Coastguard Worker                )
535*da0073e9SAndroid Build Coastguard Worker                output_info.append(_is_outputs_tuple)
536*da0073e9SAndroid Build Coastguard Worker                jv = []
537*da0073e9SAndroid Build Coastguard Worker                primal_outs = []
538*da0073e9SAndroid Build Coastguard Worker                for dual_out in dual_outputs:
539*da0073e9SAndroid Build Coastguard Worker                    primal, tangent = fwAD.unpack_dual(dual_out)
540*da0073e9SAndroid Build Coastguard Worker                    primal_outs.append(primal)
541*da0073e9SAndroid Build Coastguard Worker                    if tangent is not None:
542*da0073e9SAndroid Build Coastguard Worker                        jv.append(tangent)
543*da0073e9SAndroid Build Coastguard Worker                    else:
544*da0073e9SAndroid Build Coastguard Worker                        jv.append(torch.zeros_like(primal))
545*da0073e9SAndroid Build Coastguard Worker                output_info.append(primal_outs)
546*da0073e9SAndroid Build Coastguard Worker                return tuple(jv)
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        outputs_before_split = _vmap(jvp)(tangents)
549*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = output_info
550*da0073e9SAndroid Build Coastguard Worker        # Step 3: for each of the output tangents, split along dim 0
551*da0073e9SAndroid Build Coastguard Worker        jacobian_input_output = []
552*da0073e9SAndroid Build Coastguard Worker        for jac_output_i, output_i in zip(outputs_before_split, outputs):
553*da0073e9SAndroid Build Coastguard Worker            jacobian_output_i_output = []
554*da0073e9SAndroid Build Coastguard Worker            for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
555*da0073e9SAndroid Build Coastguard Worker                # We need to transpose the Jacobian because in forward AD, the
556*da0073e9SAndroid Build Coastguard Worker                # batch dimension represents that of the inputs
557*da0073e9SAndroid Build Coastguard Worker                jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
558*da0073e9SAndroid Build Coastguard Worker                    (*output_i.shape, *input_j.shape)
559*da0073e9SAndroid Build Coastguard Worker                )  # noqa: C409
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker                jacobian_output_i_output.append(jacobian_input_i_output_j)
562*da0073e9SAndroid Build Coastguard Worker            jacobian_input_output.append(jacobian_output_i_output)
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        # Omit [Step 4] because everything is already transposed w/ forward AD
565*da0073e9SAndroid Build Coastguard Worker        return _tuple_postprocess(
566*da0073e9SAndroid Build Coastguard Worker            jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
567*da0073e9SAndroid Build Coastguard Worker        )
568*da0073e9SAndroid Build Coastguard Worker    else:
569*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
570*da0073e9SAndroid Build Coastguard Worker            "Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
571*da0073e9SAndroid Build Coastguard Worker            "only implemented for `vectorize=True`."
572*da0073e9SAndroid Build Coastguard Worker        )
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Workerdef jacobian(
576*da0073e9SAndroid Build Coastguard Worker    func,
577*da0073e9SAndroid Build Coastguard Worker    inputs,
578*da0073e9SAndroid Build Coastguard Worker    create_graph=False,
579*da0073e9SAndroid Build Coastguard Worker    strict=False,
580*da0073e9SAndroid Build Coastguard Worker    vectorize=False,
581*da0073e9SAndroid Build Coastguard Worker    strategy="reverse-mode",
582*da0073e9SAndroid Build Coastguard Worker):
583*da0073e9SAndroid Build Coastguard Worker    r"""Compute the Jacobian of a given function.
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    Args:
586*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
587*da0073e9SAndroid Build Coastguard Worker            a tuple of Tensors or a Tensor.
588*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
589*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, the Jacobian will be
590*da0073e9SAndroid Build Coastguard Worker            computed in a differentiable manner. Note that when ``strict`` is
591*da0073e9SAndroid Build Coastguard Worker            ``False``, the result can not require gradients or be disconnected
592*da0073e9SAndroid Build Coastguard Worker            from the inputs.  Defaults to ``False``.
593*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we
594*da0073e9SAndroid Build Coastguard Worker            detect that there exists an input such that all the outputs are
595*da0073e9SAndroid Build Coastguard Worker            independent of it. If ``False``, we return a Tensor of zeros as the
596*da0073e9SAndroid Build Coastguard Worker            jacobian for said inputs, which is the expected mathematical value.
597*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
598*da0073e9SAndroid Build Coastguard Worker        vectorize (bool, optional): This feature is experimental.
599*da0073e9SAndroid Build Coastguard Worker            Please consider using :func:`torch.func.jacrev` or
600*da0073e9SAndroid Build Coastguard Worker            :func:`torch.func.jacfwd` instead if you are looking for something
601*da0073e9SAndroid Build Coastguard Worker            less experimental and more performant.
602*da0073e9SAndroid Build Coastguard Worker            When computing the jacobian, usually we invoke
603*da0073e9SAndroid Build Coastguard Worker            ``autograd.grad`` once per row of the jacobian. If this flag is
604*da0073e9SAndroid Build Coastguard Worker            ``True``, we perform only a single ``autograd.grad`` call with
605*da0073e9SAndroid Build Coastguard Worker            ``batched_grad=True`` which uses the vmap prototype feature.
606*da0073e9SAndroid Build Coastguard Worker            Though this should lead to performance improvements in many cases,
607*da0073e9SAndroid Build Coastguard Worker            because this feature is still experimental, there may be performance
608*da0073e9SAndroid Build Coastguard Worker            cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
609*da0073e9SAndroid Build Coastguard Worker            more information.
610*da0073e9SAndroid Build Coastguard Worker        strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
611*da0073e9SAndroid Build Coastguard Worker            determine whether the Jacobian will be computed with forward or reverse
612*da0073e9SAndroid Build Coastguard Worker            mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
613*da0073e9SAndroid Build Coastguard Worker            Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
614*da0073e9SAndroid Build Coastguard Worker            inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
615*da0073e9SAndroid Build Coastguard Worker            prefer to use ``"reverse-mode"``.
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    Returns:
618*da0073e9SAndroid Build Coastguard Worker        Jacobian (Tensor or nested tuple of Tensors): if there is a single
619*da0073e9SAndroid Build Coastguard Worker        input and output, this will be a single Tensor containing the
620*da0073e9SAndroid Build Coastguard Worker        Jacobian for the linearized inputs and output. If one of the two is
621*da0073e9SAndroid Build Coastguard Worker        a tuple, then the Jacobian will be a tuple of Tensors. If both of
622*da0073e9SAndroid Build Coastguard Worker        them are tuples, then the Jacobian will be a tuple of tuple of
623*da0073e9SAndroid Build Coastguard Worker        Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
624*da0073e9SAndroid Build Coastguard Worker        ``i``\th output and ``j``\th input and will have as size the
625*da0073e9SAndroid Build Coastguard Worker        concatenation of the sizes of the corresponding output and the
626*da0073e9SAndroid Build Coastguard Worker        corresponding input and will have same dtype and device as the
627*da0073e9SAndroid Build Coastguard Worker        corresponding input. If strategy is ``forward-mode``, the dtype will be
628*da0073e9SAndroid Build Coastguard Worker        that of the output; otherwise, the input.
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    Example:
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
633*da0073e9SAndroid Build Coastguard Worker        >>> def exp_reducer(x):
634*da0073e9SAndroid Build Coastguard Worker        ...     return x.exp().sum(dim=1)
635*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(2, 2)
636*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
637*da0073e9SAndroid Build Coastguard Worker        >>> jacobian(exp_reducer, inputs)
638*da0073e9SAndroid Build Coastguard Worker        tensor([[[1.4917, 2.4352],
639*da0073e9SAndroid Build Coastguard Worker                 [0.0000, 0.0000]],
640*da0073e9SAndroid Build Coastguard Worker                [[0.0000, 0.0000],
641*da0073e9SAndroid Build Coastguard Worker                 [2.4369, 2.3799]]])
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        >>> jacobian(exp_reducer, inputs, create_graph=True)
644*da0073e9SAndroid Build Coastguard Worker        tensor([[[1.4917, 2.4352],
645*da0073e9SAndroid Build Coastguard Worker                 [0.0000, 0.0000]],
646*da0073e9SAndroid Build Coastguard Worker                [[0.0000, 0.0000],
647*da0073e9SAndroid Build Coastguard Worker                 [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        >>> def exp_adder(x, y):
650*da0073e9SAndroid Build Coastguard Worker        ...     return 2 * x.exp() + 3 * y
651*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
652*da0073e9SAndroid Build Coastguard Worker        >>> jacobian(exp_adder, inputs)
653*da0073e9SAndroid Build Coastguard Worker        (tensor([[2.8052, 0.0000],
654*da0073e9SAndroid Build Coastguard Worker                [0.0000, 3.3963]]),
655*da0073e9SAndroid Build Coastguard Worker         tensor([[3., 0.],
656*da0073e9SAndroid Build Coastguard Worker                 [0., 3.]]))
657*da0073e9SAndroid Build Coastguard Worker    """
658*da0073e9SAndroid Build Coastguard Worker    assert strategy in ("forward-mode", "reverse-mode"), (
659*da0073e9SAndroid Build Coastguard Worker        'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
660*da0073e9SAndroid Build Coastguard Worker        'function has more outputs than inputs, "forward-mode" tends to be more performant. '
661*da0073e9SAndroid Build Coastguard Worker        'Otherwise, prefer to use "reverse-mode".'
662*da0073e9SAndroid Build Coastguard Worker    )
663*da0073e9SAndroid Build Coastguard Worker    if strategy == "forward-mode":
664*da0073e9SAndroid Build Coastguard Worker        if create_graph:
665*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
666*da0073e9SAndroid Build Coastguard Worker                "torch.autograd.functional.jacobian: `create_graph=True` "
667*da0073e9SAndroid Build Coastguard Worker                'and `strategy="forward-mode"` are not supported together (yet). '
668*da0073e9SAndroid Build Coastguard Worker                "Please either set `create_graph=False` or "
669*da0073e9SAndroid Build Coastguard Worker                '`strategy="reverse-mode"`.'
670*da0073e9SAndroid Build Coastguard Worker            )
671*da0073e9SAndroid Build Coastguard Worker        return _jacfwd(func, inputs, strict, vectorize)
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker    with torch.enable_grad():
674*da0073e9SAndroid Build Coastguard Worker        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
675*da0073e9SAndroid Build Coastguard Worker        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        outputs = func(*inputs)
678*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = _as_tuple(
679*da0073e9SAndroid Build Coastguard Worker            outputs, "outputs of the user-provided function", "jacobian"
680*da0073e9SAndroid Build Coastguard Worker        )
681*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(outputs, "outputs", strict=strict)
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        if vectorize:
684*da0073e9SAndroid Build Coastguard Worker            if strict:
685*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
686*da0073e9SAndroid Build Coastguard Worker                    "torch.autograd.functional.jacobian: `strict=True` "
687*da0073e9SAndroid Build Coastguard Worker                    "and `vectorized=True` are not supported together. "
688*da0073e9SAndroid Build Coastguard Worker                    "Please either set `strict=False` or "
689*da0073e9SAndroid Build Coastguard Worker                    "`vectorize=False`."
690*da0073e9SAndroid Build Coastguard Worker                )
691*da0073e9SAndroid Build Coastguard Worker            # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
692*da0073e9SAndroid Build Coastguard Worker            #
693*da0073e9SAndroid Build Coastguard Worker            # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
694*da0073e9SAndroid Build Coastguard Worker            # It turns out we can compute the jacobian of this function with a single
695*da0073e9SAndroid Build Coastguard Worker            # call to autograd.grad by using vmap over the correct grad_outputs.
696*da0073e9SAndroid Build Coastguard Worker            #
697*da0073e9SAndroid Build Coastguard Worker            # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
698*da0073e9SAndroid Build Coastguard Worker            # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
699*da0073e9SAndroid Build Coastguard Worker            #
700*da0073e9SAndroid Build Coastguard Worker            # To get the first row of the jacobian, we call
701*da0073e9SAndroid Build Coastguard Worker            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
702*da0073e9SAndroid Build Coastguard Worker            # To get the 2nd row of the jacobian, we call
703*da0073e9SAndroid Build Coastguard Worker            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
704*da0073e9SAndroid Build Coastguard Worker            # and so on.
705*da0073e9SAndroid Build Coastguard Worker            #
706*da0073e9SAndroid Build Coastguard Worker            # Using vmap, we can vectorize all 4 of these computations into one by
707*da0073e9SAndroid Build Coastguard Worker            # passing the standard basis for R^4 as the grad_output.
708*da0073e9SAndroid Build Coastguard Worker            # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
709*da0073e9SAndroid Build Coastguard Worker            #
710*da0073e9SAndroid Build Coastguard Worker            # Now, how do we compute the jacobian *without stacking the output*?
711*da0073e9SAndroid Build Coastguard Worker            # We can just split the standard basis across the outputs. So to
712*da0073e9SAndroid Build Coastguard Worker            # compute the jacobian of f(x), we'd use
713*da0073e9SAndroid Build Coastguard Worker            # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
714*da0073e9SAndroid Build Coastguard Worker            # The grad_outputs looks like the following:
715*da0073e9SAndroid Build Coastguard Worker            # ( torch.tensor([[1, 0, 0],
716*da0073e9SAndroid Build Coastguard Worker            #                 [0, 1, 0],
717*da0073e9SAndroid Build Coastguard Worker            #                 [0, 0, 1],
718*da0073e9SAndroid Build Coastguard Worker            #                 [0, 0, 0]]),
719*da0073e9SAndroid Build Coastguard Worker            #   torch.tensor([[0],
720*da0073e9SAndroid Build Coastguard Worker            #                 [0],
721*da0073e9SAndroid Build Coastguard Worker            #                 [0],
722*da0073e9SAndroid Build Coastguard Worker            #                 [1]]) )
723*da0073e9SAndroid Build Coastguard Worker            #
724*da0073e9SAndroid Build Coastguard Worker            # But we're not done yet!
725*da0073e9SAndroid Build Coastguard Worker            # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
726*da0073e9SAndroid Build Coastguard Worker            # returns a Tensor of shape [4, 3]. We have to remember to split the
727*da0073e9SAndroid Build Coastguard Worker            # jacobian of shape [4, 3] into two:
728*da0073e9SAndroid Build Coastguard Worker            # - one of shape [3, 3] for the first output
729*da0073e9SAndroid Build Coastguard Worker            # - one of shape [   3] for the second output
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker            # Step 1: Construct grad_outputs by splitting the standard basis
732*da0073e9SAndroid Build Coastguard Worker            output_numels = tuple(output.numel() for output in outputs)
733*da0073e9SAndroid Build Coastguard Worker            grad_outputs = _construct_standard_basis_for(outputs, output_numels)
734*da0073e9SAndroid Build Coastguard Worker            flat_outputs = tuple(output.reshape(-1) for output in outputs)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker            # Step 2: Call vmap + autograd.grad
737*da0073e9SAndroid Build Coastguard Worker            def vjp(grad_output):
738*da0073e9SAndroid Build Coastguard Worker                vj = list(
739*da0073e9SAndroid Build Coastguard Worker                    _autograd_grad(
740*da0073e9SAndroid Build Coastguard Worker                        flat_outputs,
741*da0073e9SAndroid Build Coastguard Worker                        inputs,
742*da0073e9SAndroid Build Coastguard Worker                        grad_output,
743*da0073e9SAndroid Build Coastguard Worker                        create_graph=create_graph,
744*da0073e9SAndroid Build Coastguard Worker                        is_grads_batched=True,
745*da0073e9SAndroid Build Coastguard Worker                    )
746*da0073e9SAndroid Build Coastguard Worker                )
747*da0073e9SAndroid Build Coastguard Worker                for el_idx, vj_el in enumerate(vj):
748*da0073e9SAndroid Build Coastguard Worker                    if vj_el is not None:
749*da0073e9SAndroid Build Coastguard Worker                        continue
750*da0073e9SAndroid Build Coastguard Worker                    vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
751*da0073e9SAndroid Build Coastguard Worker                        (sum(output_numels),) + inputs[el_idx].shape
752*da0073e9SAndroid Build Coastguard Worker                    )
753*da0073e9SAndroid Build Coastguard Worker                return tuple(vj)
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker            jacobians_of_flat_output = vjp(grad_outputs)
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            # Step 3: The returned jacobian is one big tensor per input. In this step,
758*da0073e9SAndroid Build Coastguard Worker            # we split each Tensor by output.
759*da0073e9SAndroid Build Coastguard Worker            jacobian_input_output = []
760*da0073e9SAndroid Build Coastguard Worker            for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
761*da0073e9SAndroid Build Coastguard Worker                jacobian_input_i_output = []
762*da0073e9SAndroid Build Coastguard Worker                for jac, output_j in zip(
763*da0073e9SAndroid Build Coastguard Worker                    jac_input_i.split(output_numels, dim=0), outputs
764*da0073e9SAndroid Build Coastguard Worker                ):
765*da0073e9SAndroid Build Coastguard Worker                    jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
766*da0073e9SAndroid Build Coastguard Worker                    jacobian_input_i_output.append(jacobian_input_i_output_j)
767*da0073e9SAndroid Build Coastguard Worker                jacobian_input_output.append(jacobian_input_i_output)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker            # Step 4: Right now, `jacobian` is a List[List[Tensor]].
770*da0073e9SAndroid Build Coastguard Worker            # The outer List corresponds to the number of inputs,
771*da0073e9SAndroid Build Coastguard Worker            # the inner List corresponds to the number of outputs.
772*da0073e9SAndroid Build Coastguard Worker            # We need to exchange the order of these and convert to tuples
773*da0073e9SAndroid Build Coastguard Worker            # before returning.
774*da0073e9SAndroid Build Coastguard Worker            jacobian_output_input = tuple(zip(*jacobian_input_output))
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker            jacobian_output_input = _grad_postprocess(
777*da0073e9SAndroid Build Coastguard Worker                jacobian_output_input, create_graph
778*da0073e9SAndroid Build Coastguard Worker            )
779*da0073e9SAndroid Build Coastguard Worker            return _tuple_postprocess(
780*da0073e9SAndroid Build Coastguard Worker                jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
781*da0073e9SAndroid Build Coastguard Worker            )
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        jacobian: Tuple[torch.Tensor, ...] = ()
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        for i, out in enumerate(outputs):
786*da0073e9SAndroid Build Coastguard Worker            # mypy complains that expression and variable have different types due to the empty list
787*da0073e9SAndroid Build Coastguard Worker            jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs)))  # type: ignore[assignment]
788*da0073e9SAndroid Build Coastguard Worker            for j in range(out.nelement()):
789*da0073e9SAndroid Build Coastguard Worker                vj = _autograd_grad(
790*da0073e9SAndroid Build Coastguard Worker                    (out.reshape(-1)[j],),
791*da0073e9SAndroid Build Coastguard Worker                    inputs,
792*da0073e9SAndroid Build Coastguard Worker                    retain_graph=True,
793*da0073e9SAndroid Build Coastguard Worker                    create_graph=create_graph,
794*da0073e9SAndroid Build Coastguard Worker                )
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker                for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
797*da0073e9SAndroid Build Coastguard Worker                    zip(jac_i, vj, inputs)
798*da0073e9SAndroid Build Coastguard Worker                ):
799*da0073e9SAndroid Build Coastguard Worker                    if vj_el is not None:
800*da0073e9SAndroid Build Coastguard Worker                        if strict and create_graph and not vj_el.requires_grad:
801*da0073e9SAndroid Build Coastguard Worker                            msg = (
802*da0073e9SAndroid Build Coastguard Worker                                "The jacobian of the user-provided function is "
803*da0073e9SAndroid Build Coastguard Worker                                f"independent of input {i}. This is not allowed in "
804*da0073e9SAndroid Build Coastguard Worker                                "strict mode when create_graph=True."
805*da0073e9SAndroid Build Coastguard Worker                            )
806*da0073e9SAndroid Build Coastguard Worker                            raise RuntimeError(msg)
807*da0073e9SAndroid Build Coastguard Worker                        jac_i_el.append(vj_el)
808*da0073e9SAndroid Build Coastguard Worker                    else:
809*da0073e9SAndroid Build Coastguard Worker                        if strict:
810*da0073e9SAndroid Build Coastguard Worker                            msg = (
811*da0073e9SAndroid Build Coastguard Worker                                f"Output {i} of the user-provided function is "
812*da0073e9SAndroid Build Coastguard Worker                                f"independent of input {el_idx}. This is not allowed in "
813*da0073e9SAndroid Build Coastguard Worker                                "strict mode."
814*da0073e9SAndroid Build Coastguard Worker                            )
815*da0073e9SAndroid Build Coastguard Worker                            raise RuntimeError(msg)
816*da0073e9SAndroid Build Coastguard Worker                        jac_i_el.append(torch.zeros_like(inp_el))
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker            jacobian += (
819*da0073e9SAndroid Build Coastguard Worker                tuple(
820*da0073e9SAndroid Build Coastguard Worker                    torch.stack(jac_i_el, dim=0).view(
821*da0073e9SAndroid Build Coastguard Worker                        out.size() + inputs[el_idx].size()  # type: ignore[operator]
822*da0073e9SAndroid Build Coastguard Worker                    )
823*da0073e9SAndroid Build Coastguard Worker                    for (el_idx, jac_i_el) in enumerate(jac_i)
824*da0073e9SAndroid Build Coastguard Worker                ),
825*da0073e9SAndroid Build Coastguard Worker            )
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker        jacobian = _grad_postprocess(jacobian, create_graph)
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker        return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Workerdef hessian(
833*da0073e9SAndroid Build Coastguard Worker    func,
834*da0073e9SAndroid Build Coastguard Worker    inputs,
835*da0073e9SAndroid Build Coastguard Worker    create_graph=False,
836*da0073e9SAndroid Build Coastguard Worker    strict=False,
837*da0073e9SAndroid Build Coastguard Worker    vectorize=False,
838*da0073e9SAndroid Build Coastguard Worker    outer_jacobian_strategy="reverse-mode",
839*da0073e9SAndroid Build Coastguard Worker):
840*da0073e9SAndroid Build Coastguard Worker    r"""Compute the Hessian of a given scalar function.
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker    Args:
843*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
844*da0073e9SAndroid Build Coastguard Worker            a Tensor with a single element.
845*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
846*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, the Hessian will be computed in
847*da0073e9SAndroid Build Coastguard Worker            a differentiable manner. Note that when ``strict`` is ``False``, the result can not
848*da0073e9SAndroid Build Coastguard Worker            require gradients or be disconnected from the inputs.
849*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
850*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
851*da0073e9SAndroid Build Coastguard Worker            such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
852*da0073e9SAndroid Build Coastguard Worker            hessian for said inputs, which is the expected mathematical value.
853*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
854*da0073e9SAndroid Build Coastguard Worker        vectorize (bool, optional): This feature is experimental.
855*da0073e9SAndroid Build Coastguard Worker            Please consider using :func:`torch.func.hessian`
856*da0073e9SAndroid Build Coastguard Worker            instead if you are looking for something less experimental and more performant.
857*da0073e9SAndroid Build Coastguard Worker            When computing the hessian, usually we invoke
858*da0073e9SAndroid Build Coastguard Worker            ``autograd.grad`` once per row of the hessian. If this flag is
859*da0073e9SAndroid Build Coastguard Worker            ``True``, we use the vmap prototype feature as the backend to
860*da0073e9SAndroid Build Coastguard Worker            vectorize calls to ``autograd.grad`` so we only invoke it once
861*da0073e9SAndroid Build Coastguard Worker            instead of once per row. This should lead to performance
862*da0073e9SAndroid Build Coastguard Worker            improvements in many use cases, however, due to this feature
863*da0073e9SAndroid Build Coastguard Worker            being incomplete, there may be performance cliffs. Please
864*da0073e9SAndroid Build Coastguard Worker            use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
865*da0073e9SAndroid Build Coastguard Worker            to show any performance warnings and file us issues if
866*da0073e9SAndroid Build Coastguard Worker            warnings exist for your use case. Defaults to ``False``.
867*da0073e9SAndroid Build Coastguard Worker        outer_jacobian_strategy (str, optional): The Hessian is computed by
868*da0073e9SAndroid Build Coastguard Worker            computing the Jacobian of a Jacobian. The inner Jacobian is always
869*da0073e9SAndroid Build Coastguard Worker            computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
870*da0073e9SAndroid Build Coastguard Worker            or ``"reverse-mode"`` determines whether the outer Jacobian will be
871*da0073e9SAndroid Build Coastguard Worker            computed with forward or reverse mode AD. Currently, computing the outer
872*da0073e9SAndroid Build Coastguard Worker            Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
873*da0073e9SAndroid Build Coastguard Worker            to ``"reverse-mode"``.
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker    Returns:
876*da0073e9SAndroid Build Coastguard Worker        Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
877*da0073e9SAndroid Build Coastguard Worker        this will be a single Tensor containing the Hessian for the input.
878*da0073e9SAndroid Build Coastguard Worker        If it is a tuple, then the Hessian will be a tuple of tuples where
879*da0073e9SAndroid Build Coastguard Worker        ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
880*da0073e9SAndroid Build Coastguard Worker        and ``j``\th input with size the sum of the size of the ``i``\th input plus
881*da0073e9SAndroid Build Coastguard Worker        the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
882*da0073e9SAndroid Build Coastguard Worker        dtype and device as the corresponding ``i``\th input.
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    Example:
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
887*da0073e9SAndroid Build Coastguard Worker        >>> def pow_reducer(x):
888*da0073e9SAndroid Build Coastguard Worker        ...     return x.pow(3).sum()
889*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(2, 2)
890*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
891*da0073e9SAndroid Build Coastguard Worker        >>> hessian(pow_reducer, inputs)
892*da0073e9SAndroid Build Coastguard Worker        tensor([[[[5.2265, 0.0000],
893*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 0.0000]],
894*da0073e9SAndroid Build Coastguard Worker                 [[0.0000, 4.8221],
895*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 0.0000]]],
896*da0073e9SAndroid Build Coastguard Worker                [[[0.0000, 0.0000],
897*da0073e9SAndroid Build Coastguard Worker                  [1.9456, 0.0000]],
898*da0073e9SAndroid Build Coastguard Worker                 [[0.0000, 0.0000],
899*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 3.2550]]]])
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        >>> hessian(pow_reducer, inputs, create_graph=True)
902*da0073e9SAndroid Build Coastguard Worker        tensor([[[[5.2265, 0.0000],
903*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 0.0000]],
904*da0073e9SAndroid Build Coastguard Worker                 [[0.0000, 4.8221],
905*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 0.0000]]],
906*da0073e9SAndroid Build Coastguard Worker                [[[0.0000, 0.0000],
907*da0073e9SAndroid Build Coastguard Worker                  [1.9456, 0.0000]],
908*da0073e9SAndroid Build Coastguard Worker                 [[0.0000, 0.0000],
909*da0073e9SAndroid Build Coastguard Worker                  [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        >>> def pow_adder_reducer(x, y):
913*da0073e9SAndroid Build Coastguard Worker        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
914*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
915*da0073e9SAndroid Build Coastguard Worker        >>> hessian(pow_adder_reducer, inputs)
916*da0073e9SAndroid Build Coastguard Worker        ((tensor([[4., 0.],
917*da0073e9SAndroid Build Coastguard Worker                  [0., 4.]]),
918*da0073e9SAndroid Build Coastguard Worker          tensor([[0., 0.],
919*da0073e9SAndroid Build Coastguard Worker                  [0., 0.]])),
920*da0073e9SAndroid Build Coastguard Worker         (tensor([[0., 0.],
921*da0073e9SAndroid Build Coastguard Worker                  [0., 0.]]),
922*da0073e9SAndroid Build Coastguard Worker          tensor([[6., 0.],
923*da0073e9SAndroid Build Coastguard Worker                  [0., 6.]])))
924*da0073e9SAndroid Build Coastguard Worker    """
925*da0073e9SAndroid Build Coastguard Worker    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
926*da0073e9SAndroid Build Coastguard Worker    assert outer_jacobian_strategy in (
927*da0073e9SAndroid Build Coastguard Worker        "forward-mode",
928*da0073e9SAndroid Build Coastguard Worker        "reverse-mode",
929*da0073e9SAndroid Build Coastguard Worker    ), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker    def ensure_single_output_function(*inp):
932*da0073e9SAndroid Build Coastguard Worker        out = func(*inp)
933*da0073e9SAndroid Build Coastguard Worker        is_out_tuple, t_out = _as_tuple(
934*da0073e9SAndroid Build Coastguard Worker            out, "outputs of the user-provided function", "hessian"
935*da0073e9SAndroid Build Coastguard Worker        )
936*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(t_out, "outputs", strict=strict)
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker        if is_out_tuple or not isinstance(out, torch.Tensor):
939*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
940*da0073e9SAndroid Build Coastguard Worker                "The function given to hessian should return a single Tensor"
941*da0073e9SAndroid Build Coastguard Worker            )
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker        if out.nelement() != 1:
944*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
945*da0073e9SAndroid Build Coastguard Worker                "The Tensor returned by the function given to hessian should contain a single element"
946*da0073e9SAndroid Build Coastguard Worker            )
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        return out.squeeze()
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    def jac_func(*inp):
951*da0073e9SAndroid Build Coastguard Worker        if outer_jacobian_strategy == "forward-mode":
952*da0073e9SAndroid Build Coastguard Worker            # _grad_preprocess requires create_graph=True and input to require_grad
953*da0073e9SAndroid Build Coastguard Worker            # or else the input will be detached
954*da0073e9SAndroid Build Coastguard Worker            inp = tuple(t.requires_grad_(True) for t in inp)
955*da0073e9SAndroid Build Coastguard Worker        jac = jacobian(ensure_single_output_function, inp, create_graph=True)
956*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(jac, "jacobian", strict=strict)
957*da0073e9SAndroid Build Coastguard Worker        return jac
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker    res = jacobian(
960*da0073e9SAndroid Build Coastguard Worker        jac_func,
961*da0073e9SAndroid Build Coastguard Worker        inputs,
962*da0073e9SAndroid Build Coastguard Worker        create_graph=create_graph,
963*da0073e9SAndroid Build Coastguard Worker        strict=strict,
964*da0073e9SAndroid Build Coastguard Worker        vectorize=vectorize,
965*da0073e9SAndroid Build Coastguard Worker        strategy=outer_jacobian_strategy,
966*da0073e9SAndroid Build Coastguard Worker    )
967*da0073e9SAndroid Build Coastguard Worker    return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Workerdef vhp(func, inputs, v=None, create_graph=False, strict=False):
971*da0073e9SAndroid Build Coastguard Worker    r"""Compute the dot product between vector ``v`` and Hessian of a  given scalar function at a specified point.
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    Args:
974*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
975*da0073e9SAndroid Build Coastguard Worker            a Tensor with a single element.
976*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
977*da0073e9SAndroid Build Coastguard Worker        v (tuple of Tensors or Tensor): The vector for which the vector Hessian
978*da0073e9SAndroid Build Coastguard Worker            product is computed. Must be the same size as the input of
979*da0073e9SAndroid Build Coastguard Worker            ``func``. This argument is optional when ``func``'s input contains
980*da0073e9SAndroid Build Coastguard Worker            a single element and (if it is not provided) will be set as a
981*da0073e9SAndroid Build Coastguard Worker            Tensor containing a single ``1``.
982*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, both the output and result
983*da0073e9SAndroid Build Coastguard Worker            will be computed in a differentiable way. Note that when ``strict``
984*da0073e9SAndroid Build Coastguard Worker            is ``False``, the result can not require gradients or be
985*da0073e9SAndroid Build Coastguard Worker            disconnected from the inputs.
986*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
987*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we
988*da0073e9SAndroid Build Coastguard Worker            detect that there exists an input such that all the outputs are
989*da0073e9SAndroid Build Coastguard Worker            independent of it. If ``False``, we return a Tensor of zeros as the
990*da0073e9SAndroid Build Coastguard Worker            vhp for said inputs, which is the expected mathematical value.
991*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker    Returns:
994*da0073e9SAndroid Build Coastguard Worker        output (tuple): tuple with:
995*da0073e9SAndroid Build Coastguard Worker            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker            vhp (tuple of Tensors or Tensor): result of the dot product with the
998*da0073e9SAndroid Build Coastguard Worker            same shape as the inputs.
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker    Example:
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1003*da0073e9SAndroid Build Coastguard Worker        >>> def pow_reducer(x):
1004*da0073e9SAndroid Build Coastguard Worker        ...     return x.pow(3).sum()
1005*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(2, 2)
1006*da0073e9SAndroid Build Coastguard Worker        >>> v = torch.ones(2, 2)
1007*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1008*da0073e9SAndroid Build Coastguard Worker        >>> vhp(pow_reducer, inputs, v)
1009*da0073e9SAndroid Build Coastguard Worker        (tensor(0.5591),
1010*da0073e9SAndroid Build Coastguard Worker         tensor([[1.0689, 1.2431],
1011*da0073e9SAndroid Build Coastguard Worker                 [3.0989, 4.4456]]))
1012*da0073e9SAndroid Build Coastguard Worker        >>> vhp(pow_reducer, inputs, v, create_graph=True)
1013*da0073e9SAndroid Build Coastguard Worker        (tensor(0.5591, grad_fn=<SumBackward0>),
1014*da0073e9SAndroid Build Coastguard Worker         tensor([[1.0689, 1.2431],
1015*da0073e9SAndroid Build Coastguard Worker                 [3.0989, 4.4456]], grad_fn=<MulBackward0>))
1016*da0073e9SAndroid Build Coastguard Worker        >>> def pow_adder_reducer(x, y):
1017*da0073e9SAndroid Build Coastguard Worker        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1018*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
1019*da0073e9SAndroid Build Coastguard Worker        >>> v = (torch.zeros(2), torch.ones(2))
1020*da0073e9SAndroid Build Coastguard Worker        >>> vhp(pow_adder_reducer, inputs, v)
1021*da0073e9SAndroid Build Coastguard Worker        (tensor(4.8053),
1022*da0073e9SAndroid Build Coastguard Worker         (tensor([0., 0.]),
1023*da0073e9SAndroid Build Coastguard Worker          tensor([6., 6.])))
1024*da0073e9SAndroid Build Coastguard Worker    """
1025*da0073e9SAndroid Build Coastguard Worker    with torch.enable_grad():
1026*da0073e9SAndroid Build Coastguard Worker        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
1027*da0073e9SAndroid Build Coastguard Worker        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        if v is not None:
1030*da0073e9SAndroid Build Coastguard Worker            _, v = _as_tuple(v, "v", "vhp")
1031*da0073e9SAndroid Build Coastguard Worker            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1032*da0073e9SAndroid Build Coastguard Worker            _validate_v(v, inputs, is_inputs_tuple)
1033*da0073e9SAndroid Build Coastguard Worker        else:
1034*da0073e9SAndroid Build Coastguard Worker            if len(inputs) != 1 or inputs[0].nelement() != 1:
1035*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1036*da0073e9SAndroid Build Coastguard Worker                    "The vector v can only be None if the input to the user-provided function "
1037*da0073e9SAndroid Build Coastguard Worker                    "is a single Tensor with a single element."
1038*da0073e9SAndroid Build Coastguard Worker                )
1039*da0073e9SAndroid Build Coastguard Worker        outputs = func(*inputs)
1040*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = _as_tuple(
1041*da0073e9SAndroid Build Coastguard Worker            outputs, "outputs of the user-provided function", "vhp"
1042*da0073e9SAndroid Build Coastguard Worker        )
1043*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(outputs, "outputs", strict=strict)
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1046*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1047*da0073e9SAndroid Build Coastguard Worker                "The function given to vhp should return a single Tensor"
1048*da0073e9SAndroid Build Coastguard Worker            )
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        if outputs[0].nelement() != 1:
1051*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1052*da0073e9SAndroid Build Coastguard Worker                "The Tensor returned by the function given to vhp should contain a single element"
1053*da0073e9SAndroid Build Coastguard Worker            )
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        jac = _autograd_grad(outputs, inputs, create_graph=True)
1056*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(jac, "jacobian", strict=strict)
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker    enable_grad = True if create_graph else torch.is_grad_enabled()
1059*da0073e9SAndroid Build Coastguard Worker    with torch.set_grad_enabled(enable_grad):
1060*da0073e9SAndroid Build Coastguard Worker        grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
1061*da0073e9SAndroid Build Coastguard Worker        vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker    outputs = _grad_postprocess(outputs, create_graph)
1064*da0073e9SAndroid Build Coastguard Worker    vhp = _grad_postprocess(vhp, create_graph)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1067*da0073e9SAndroid Build Coastguard Worker        vhp, is_inputs_tuple
1068*da0073e9SAndroid Build Coastguard Worker    )
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Workerdef hvp(func, inputs, v=None, create_graph=False, strict=False):
1072*da0073e9SAndroid Build Coastguard Worker    r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker    Args:
1075*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
1076*da0073e9SAndroid Build Coastguard Worker            a Tensor with a single element.
1077*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
1078*da0073e9SAndroid Build Coastguard Worker        v (tuple of Tensors or Tensor): The vector for which the Hessian vector
1079*da0073e9SAndroid Build Coastguard Worker            product is computed. Must be the same size as the input of
1080*da0073e9SAndroid Build Coastguard Worker            ``func``. This argument is optional when ``func``'s input contains
1081*da0073e9SAndroid Build Coastguard Worker            a single element and (if it is not provided) will be set as a
1082*da0073e9SAndroid Build Coastguard Worker            Tensor containing a single ``1``.
1083*da0073e9SAndroid Build Coastguard Worker        create_graph (bool, optional): If ``True``, both the output and result will be
1084*da0073e9SAndroid Build Coastguard Worker            computed in a differentiable way. Note that when ``strict`` is
1085*da0073e9SAndroid Build Coastguard Worker            ``False``, the result can not require gradients or be disconnected
1086*da0073e9SAndroid Build Coastguard Worker            from the inputs.  Defaults to ``False``.
1087*da0073e9SAndroid Build Coastguard Worker        strict (bool, optional): If ``True``, an error will be raised when we
1088*da0073e9SAndroid Build Coastguard Worker            detect that there exists an input such that all the outputs are
1089*da0073e9SAndroid Build Coastguard Worker            independent of it. If ``False``, we return a Tensor of zeros as the
1090*da0073e9SAndroid Build Coastguard Worker            hvp for said inputs, which is the expected mathematical value.
1091*da0073e9SAndroid Build Coastguard Worker            Defaults to ``False``.
1092*da0073e9SAndroid Build Coastguard Worker    Returns:
1093*da0073e9SAndroid Build Coastguard Worker        output (tuple): tuple with:
1094*da0073e9SAndroid Build Coastguard Worker            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker            hvp (tuple of Tensors or Tensor): result of the dot product with
1097*da0073e9SAndroid Build Coastguard Worker            the same shape as the inputs.
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker    Example:
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1102*da0073e9SAndroid Build Coastguard Worker        >>> def pow_reducer(x):
1103*da0073e9SAndroid Build Coastguard Worker        ...     return x.pow(3).sum()
1104*da0073e9SAndroid Build Coastguard Worker        >>> inputs = torch.rand(2, 2)
1105*da0073e9SAndroid Build Coastguard Worker        >>> v = torch.ones(2, 2)
1106*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1107*da0073e9SAndroid Build Coastguard Worker        >>> hvp(pow_reducer, inputs, v)
1108*da0073e9SAndroid Build Coastguard Worker        (tensor(0.1448),
1109*da0073e9SAndroid Build Coastguard Worker         tensor([[2.0239, 1.6456],
1110*da0073e9SAndroid Build Coastguard Worker                 [2.4988, 1.4310]]))
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker        >>> hvp(pow_reducer, inputs, v, create_graph=True)
1113*da0073e9SAndroid Build Coastguard Worker        (tensor(0.1448, grad_fn=<SumBackward0>),
1114*da0073e9SAndroid Build Coastguard Worker         tensor([[2.0239, 1.6456],
1115*da0073e9SAndroid Build Coastguard Worker                 [2.4988, 1.4310]], grad_fn=<MulBackward0>))
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        >>> def pow_adder_reducer(x, y):
1119*da0073e9SAndroid Build Coastguard Worker        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1120*da0073e9SAndroid Build Coastguard Worker        >>> inputs = (torch.rand(2), torch.rand(2))
1121*da0073e9SAndroid Build Coastguard Worker        >>> v = (torch.zeros(2), torch.ones(2))
1122*da0073e9SAndroid Build Coastguard Worker        >>> hvp(pow_adder_reducer, inputs, v)
1123*da0073e9SAndroid Build Coastguard Worker        (tensor(2.3030),
1124*da0073e9SAndroid Build Coastguard Worker         (tensor([0., 0.]),
1125*da0073e9SAndroid Build Coastguard Worker          tensor([6., 6.])))
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    Note:
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        This function is significantly slower than `vhp` due to backward mode AD constraints.
1130*da0073e9SAndroid Build Coastguard Worker        If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
1131*da0073e9SAndroid Build Coastguard Worker        know that your function satisfies this condition, you should use vhp instead that is
1132*da0073e9SAndroid Build Coastguard Worker        much faster with the current implementation.
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    """
1135*da0073e9SAndroid Build Coastguard Worker    with torch.enable_grad():
1136*da0073e9SAndroid Build Coastguard Worker        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
1137*da0073e9SAndroid Build Coastguard Worker        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker        if v is not None:
1140*da0073e9SAndroid Build Coastguard Worker            _, v = _as_tuple(v, "v", "hvp")
1141*da0073e9SAndroid Build Coastguard Worker            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1142*da0073e9SAndroid Build Coastguard Worker            _validate_v(v, inputs, is_inputs_tuple)
1143*da0073e9SAndroid Build Coastguard Worker        else:
1144*da0073e9SAndroid Build Coastguard Worker            if len(inputs) != 1 or inputs[0].nelement() != 1:
1145*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1146*da0073e9SAndroid Build Coastguard Worker                    "The vector v can only be None if the input to the user-provided function "
1147*da0073e9SAndroid Build Coastguard Worker                    "is a single Tensor with a single element."
1148*da0073e9SAndroid Build Coastguard Worker                )
1149*da0073e9SAndroid Build Coastguard Worker        outputs = func(*inputs)
1150*da0073e9SAndroid Build Coastguard Worker        is_outputs_tuple, outputs = _as_tuple(
1151*da0073e9SAndroid Build Coastguard Worker            outputs, "outputs of the user-provided function", "hvp"
1152*da0073e9SAndroid Build Coastguard Worker        )
1153*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(outputs, "outputs", strict=strict)
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1156*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1157*da0073e9SAndroid Build Coastguard Worker                "The function given to hvp should return a single Tensor"
1158*da0073e9SAndroid Build Coastguard Worker            )
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker        if outputs[0].nelement() != 1:
1161*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1162*da0073e9SAndroid Build Coastguard Worker                "The Tensor returned by the function given to hvp should contain a single element"
1163*da0073e9SAndroid Build Coastguard Worker            )
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        jac = _autograd_grad(outputs, inputs, create_graph=True)
1166*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(jac, "jacobian", strict=strict)
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker        grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker        double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
1171*da0073e9SAndroid Build Coastguard Worker        _check_requires_grad(jac, "hessian", strict=strict)
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker    enable_grad = True if create_graph else torch.is_grad_enabled()
1174*da0073e9SAndroid Build Coastguard Worker    with torch.set_grad_enabled(enable_grad):
1175*da0073e9SAndroid Build Coastguard Worker        grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
1176*da0073e9SAndroid Build Coastguard Worker        hvp = _fill_in_zeros(
1177*da0073e9SAndroid Build Coastguard Worker            grad_res, inputs, strict, create_graph, "double_back_trick"
1178*da0073e9SAndroid Build Coastguard Worker        )
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker    outputs = _grad_postprocess(outputs, create_graph)
1181*da0073e9SAndroid Build Coastguard Worker    hvp = _grad_postprocess(hvp, create_graph)
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1184*da0073e9SAndroid Build Coastguard Worker        hvp, is_inputs_tuple
1185*da0073e9SAndroid Build Coastguard Worker    )
1186