1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import _vmap 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom . import forward_ad as fwAD 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker# Utility functions 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef _as_tuple_nocheck(x): 16*da0073e9SAndroid Build Coastguard Worker if isinstance(x, tuple): 17*da0073e9SAndroid Build Coastguard Worker return x 18*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, list): 19*da0073e9SAndroid Build Coastguard Worker return tuple(x) 20*da0073e9SAndroid Build Coastguard Worker else: 21*da0073e9SAndroid Build Coastguard Worker return (x,) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(inp, arg_name=None, fn_name=None): 25*da0073e9SAndroid Build Coastguard Worker # Ensures that inp is a tuple of Tensors 26*da0073e9SAndroid Build Coastguard Worker # Returns whether or not the original inp was a tuple and the tupled version of the input 27*da0073e9SAndroid Build Coastguard Worker if arg_name is None and fn_name is None: 28*da0073e9SAndroid Build Coastguard Worker return _as_tuple_nocheck(inp) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker is_inp_tuple = True 31*da0073e9SAndroid Build Coastguard Worker if not isinstance(inp, tuple): 32*da0073e9SAndroid Build Coastguard Worker inp = (inp,) 33*da0073e9SAndroid Build Coastguard Worker is_inp_tuple = False 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker for i, el in enumerate(inp): 36*da0073e9SAndroid Build Coastguard Worker if not isinstance(el, torch.Tensor): 37*da0073e9SAndroid Build Coastguard Worker if is_inp_tuple: 38*da0073e9SAndroid Build Coastguard Worker raise TypeError( 39*da0073e9SAndroid Build Coastguard Worker f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" 40*da0073e9SAndroid Build Coastguard Worker f" value at index {i} has type {type(el)}." 41*da0073e9SAndroid Build Coastguard Worker ) 42*da0073e9SAndroid Build Coastguard Worker else: 43*da0073e9SAndroid Build Coastguard Worker raise TypeError( 44*da0073e9SAndroid Build Coastguard Worker f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" 45*da0073e9SAndroid Build Coastguard Worker f" given {arg_name} has type {type(el)}." 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker return is_inp_tuple, inp 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerdef _tuple_postprocess(res, to_unpack): 52*da0073e9SAndroid Build Coastguard Worker # Unpacks a potentially nested tuple of Tensors 53*da0073e9SAndroid Build Coastguard Worker # to_unpack should be a single boolean or a tuple of two booleans. 54*da0073e9SAndroid Build Coastguard Worker # It is used to: 55*da0073e9SAndroid Build Coastguard Worker # - invert _as_tuple when res should match the inp given to _as_tuple 56*da0073e9SAndroid Build Coastguard Worker # - optionally remove nesting of two tuples created by multiple calls to _as_tuple 57*da0073e9SAndroid Build Coastguard Worker if isinstance(to_unpack, tuple): 58*da0073e9SAndroid Build Coastguard Worker assert len(to_unpack) == 2 59*da0073e9SAndroid Build Coastguard Worker if not to_unpack[1]: 60*da0073e9SAndroid Build Coastguard Worker res = tuple(el[0] for el in res) 61*da0073e9SAndroid Build Coastguard Worker if not to_unpack[0]: 62*da0073e9SAndroid Build Coastguard Worker res = res[0] 63*da0073e9SAndroid Build Coastguard Worker else: 64*da0073e9SAndroid Build Coastguard Worker if not to_unpack: 65*da0073e9SAndroid Build Coastguard Worker res = res[0] 66*da0073e9SAndroid Build Coastguard Worker return res 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef _grad_preprocess(inputs, create_graph, need_graph): 70*da0073e9SAndroid Build Coastguard Worker # Preprocess the inputs to make sure they require gradient 71*da0073e9SAndroid Build Coastguard Worker # inputs is a tuple of Tensors to preprocess 72*da0073e9SAndroid Build Coastguard Worker # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs 73*da0073e9SAndroid Build Coastguard Worker # need_graph specifies if we internally want gradients to flow back to the Tensors in res 74*da0073e9SAndroid Build Coastguard Worker # Note that we *always* create a new Tensor object to be able to see the difference between 75*da0073e9SAndroid Build Coastguard Worker # inputs given as arguments and the same Tensors automatically captured by the user function. 76*da0073e9SAndroid Build Coastguard Worker # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 77*da0073e9SAndroid Build Coastguard Worker res = [] 78*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 79*da0073e9SAndroid Build Coastguard Worker if create_graph and inp.requires_grad: 80*da0073e9SAndroid Build Coastguard Worker # Create at least a new Tensor object in a differentiable way 81*da0073e9SAndroid Build Coastguard Worker if not inp.is_sparse: 82*da0073e9SAndroid Build Coastguard Worker # Use .view_as() to get a shallow copy 83*da0073e9SAndroid Build Coastguard Worker res.append(inp.view_as(inp)) 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker # We cannot use view for sparse Tensors so we clone 86*da0073e9SAndroid Build Coastguard Worker res.append(inp.clone()) 87*da0073e9SAndroid Build Coastguard Worker else: 88*da0073e9SAndroid Build Coastguard Worker res.append(inp.detach().requires_grad_(need_graph)) 89*da0073e9SAndroid Build Coastguard Worker return tuple(res) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerdef _grad_postprocess(inputs, create_graph): 93*da0073e9SAndroid Build Coastguard Worker # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not 94*da0073e9SAndroid Build Coastguard Worker # request it. 95*da0073e9SAndroid Build Coastguard Worker if isinstance(inputs[0], torch.Tensor): 96*da0073e9SAndroid Build Coastguard Worker if not create_graph: 97*da0073e9SAndroid Build Coastguard Worker return tuple(inp.detach() for inp in inputs) 98*da0073e9SAndroid Build Coastguard Worker else: 99*da0073e9SAndroid Build Coastguard Worker return inputs 100*da0073e9SAndroid Build Coastguard Worker else: 101*da0073e9SAndroid Build Coastguard Worker return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef _validate_v(v, other, is_other_tuple): 105*da0073e9SAndroid Build Coastguard Worker # This assumes that other is the correct shape, and v should match 106*da0073e9SAndroid Build Coastguard Worker # Both are assumed to be tuples of Tensors 107*da0073e9SAndroid Build Coastguard Worker if len(other) != len(v): 108*da0073e9SAndroid Build Coastguard Worker if is_other_tuple: 109*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 110*da0073e9SAndroid Build Coastguard Worker f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker else: 113*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("The given v should contain a single Tensor.") 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker for idx, (el_v, el_other) in enumerate(zip(v, other)): 116*da0073e9SAndroid Build Coastguard Worker if el_v.size() != el_other.size(): 117*da0073e9SAndroid Build Coastguard Worker prepend = "" 118*da0073e9SAndroid Build Coastguard Worker if is_other_tuple: 119*da0073e9SAndroid Build Coastguard Worker prepend = f"Entry {idx} in " 120*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 121*da0073e9SAndroid Build Coastguard Worker f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." 122*da0073e9SAndroid Build Coastguard Worker ) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Workerdef _check_requires_grad(inputs, input_type, strict): 126*da0073e9SAndroid Build Coastguard Worker # Used to make all the necessary checks to raise nice errors in strict mode. 127*da0073e9SAndroid Build Coastguard Worker if not strict: 128*da0073e9SAndroid Build Coastguard Worker return 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: 131*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Invalid input_type to _check_requires_grad") 132*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(inputs): 133*da0073e9SAndroid Build Coastguard Worker if inp is None: 134*da0073e9SAndroid Build Coastguard Worker # This can only be reached for grad_inputs. 135*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 136*da0073e9SAndroid Build Coastguard Worker f"The output of the user-provided function is independent of input {i}." 137*da0073e9SAndroid Build Coastguard Worker " This is not allowed in strict mode." 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker if not inp.requires_grad: 140*da0073e9SAndroid Build Coastguard Worker if input_type == "hessian": 141*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 142*da0073e9SAndroid Build Coastguard Worker f"The hessian of the user-provided function with respect to input {i}" 143*da0073e9SAndroid Build Coastguard Worker " is independent of the input. This is not allowed in strict mode." 144*da0073e9SAndroid Build Coastguard Worker " You should ensure that your function is thrice differentiable and that" 145*da0073e9SAndroid Build Coastguard Worker " the hessian depends on the inputs." 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker elif input_type == "jacobian": 148*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 149*da0073e9SAndroid Build Coastguard Worker "While computing the hessian, found that the jacobian of the user-provided" 150*da0073e9SAndroid Build Coastguard Worker f" function with respect to input {i} is independent of the input. This is not" 151*da0073e9SAndroid Build Coastguard Worker " allowed in strict mode. You should ensure that your function is twice" 152*da0073e9SAndroid Build Coastguard Worker " differentiable and that the jacobian depends on the inputs (this would be" 153*da0073e9SAndroid Build Coastguard Worker " violated by a linear function for example)." 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker elif input_type == "grad_inputs": 156*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 157*da0073e9SAndroid Build Coastguard Worker f"The gradient with respect to input {i} is independent of the inputs of the" 158*da0073e9SAndroid Build Coastguard Worker " user-provided function. This is not allowed in strict mode." 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker else: 161*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 162*da0073e9SAndroid Build Coastguard Worker f"Output {i} of the user-provided function does not require gradients." 163*da0073e9SAndroid Build Coastguard Worker " The outputs must be computed in a differentiable manner from the input" 164*da0073e9SAndroid Build Coastguard Worker " when running in strict mode." 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerdef _autograd_grad( 169*da0073e9SAndroid Build Coastguard Worker outputs, 170*da0073e9SAndroid Build Coastguard Worker inputs, 171*da0073e9SAndroid Build Coastguard Worker grad_outputs=None, 172*da0073e9SAndroid Build Coastguard Worker create_graph=False, 173*da0073e9SAndroid Build Coastguard Worker retain_graph=None, 174*da0073e9SAndroid Build Coastguard Worker is_grads_batched=False, 175*da0073e9SAndroid Build Coastguard Worker): 176*da0073e9SAndroid Build Coastguard Worker # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. 177*da0073e9SAndroid Build Coastguard Worker # This has the extra constraint that inputs has to be a tuple 178*da0073e9SAndroid Build Coastguard Worker assert isinstance(outputs, tuple) 179*da0073e9SAndroid Build Coastguard Worker if grad_outputs is None: 180*da0073e9SAndroid Build Coastguard Worker grad_outputs = (None,) * len(outputs) 181*da0073e9SAndroid Build Coastguard Worker assert isinstance(grad_outputs, tuple) 182*da0073e9SAndroid Build Coastguard Worker assert len(outputs) == len(grad_outputs) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker new_outputs: Tuple[torch.Tensor, ...] = () 185*da0073e9SAndroid Build Coastguard Worker new_grad_outputs: Tuple[torch.Tensor, ...] = () 186*da0073e9SAndroid Build Coastguard Worker for out, grad_out in zip(outputs, grad_outputs): 187*da0073e9SAndroid Build Coastguard Worker if out is not None and out.requires_grad: 188*da0073e9SAndroid Build Coastguard Worker new_outputs += (out,) 189*da0073e9SAndroid Build Coastguard Worker new_grad_outputs += (grad_out,) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker if len(new_outputs) == 0: 192*da0073e9SAndroid Build Coastguard Worker # No differentiable output, we don't need to call the autograd engine 193*da0073e9SAndroid Build Coastguard Worker return (None,) * len(inputs) 194*da0073e9SAndroid Build Coastguard Worker else: 195*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad( 196*da0073e9SAndroid Build Coastguard Worker new_outputs, 197*da0073e9SAndroid Build Coastguard Worker inputs, 198*da0073e9SAndroid Build Coastguard Worker new_grad_outputs, 199*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 200*da0073e9SAndroid Build Coastguard Worker create_graph=create_graph, 201*da0073e9SAndroid Build Coastguard Worker retain_graph=retain_graph, 202*da0073e9SAndroid Build Coastguard Worker is_grads_batched=is_grads_batched, 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerdef _fill_in_zeros(grads, refs, strict, create_graph, stage): 207*da0073e9SAndroid Build Coastguard Worker # Used to detect None in the grads and depending on the flags, either replace them 208*da0073e9SAndroid Build Coastguard Worker # with Tensors full of 0s of the appropriate size based on the refs or raise an error. 209*da0073e9SAndroid Build Coastguard Worker # strict and create graph allow us to detect when it is appropriate to raise an error 210*da0073e9SAndroid Build Coastguard Worker # stage gives us information of which backward call we consider to give good error message 211*da0073e9SAndroid Build Coastguard Worker if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: 212*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker res: Tuple[torch.Tensor, ...] = () 215*da0073e9SAndroid Build Coastguard Worker for i, grads_i in enumerate(grads): 216*da0073e9SAndroid Build Coastguard Worker if grads_i is None: 217*da0073e9SAndroid Build Coastguard Worker if strict: 218*da0073e9SAndroid Build Coastguard Worker if stage == "back": 219*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 220*da0073e9SAndroid Build Coastguard Worker "The output of the user-provided function is independent of " 221*da0073e9SAndroid Build Coastguard Worker f"input {i}. This is not allowed in strict mode." 222*da0073e9SAndroid Build Coastguard Worker ) 223*da0073e9SAndroid Build Coastguard Worker elif stage == "back_trick": 224*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 225*da0073e9SAndroid Build Coastguard Worker f"The gradient with respect to the input is independent of entry {i}" 226*da0073e9SAndroid Build Coastguard Worker " in the grad_outputs when using the double backward trick to compute" 227*da0073e9SAndroid Build Coastguard Worker " forward mode gradients. This is not allowed in strict mode." 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker elif stage == "double_back": 230*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 231*da0073e9SAndroid Build Coastguard Worker "The jacobian of the user-provided function is independent of " 232*da0073e9SAndroid Build Coastguard Worker f"input {i}. This is not allowed in strict mode." 233*da0073e9SAndroid Build Coastguard Worker ) 234*da0073e9SAndroid Build Coastguard Worker else: 235*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 236*da0073e9SAndroid Build Coastguard Worker "The hessian of the user-provided function is independent of " 237*da0073e9SAndroid Build Coastguard Worker f"entry {i} in the grad_jacobian. This is not allowed in strict " 238*da0073e9SAndroid Build Coastguard Worker "mode as it prevents from using the double backward trick to " 239*da0073e9SAndroid Build Coastguard Worker "replace forward mode AD." 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker grads_i = torch.zeros_like(refs[i]) 243*da0073e9SAndroid Build Coastguard Worker else: 244*da0073e9SAndroid Build Coastguard Worker if strict and create_graph and not grads_i.requires_grad: 245*da0073e9SAndroid Build Coastguard Worker if "double" not in stage: 246*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 247*da0073e9SAndroid Build Coastguard Worker "The jacobian of the user-provided function is independent of " 248*da0073e9SAndroid Build Coastguard Worker f"input {i}. This is not allowed in strict mode when create_graph=True." 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker else: 251*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 252*da0073e9SAndroid Build Coastguard Worker "The hessian of the user-provided function is independent of " 253*da0073e9SAndroid Build Coastguard Worker f"input {i}. This is not allowed in strict mode when create_graph=True." 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker res += (grads_i,) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker return res 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker# Public API 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Workerdef vjp(func, inputs, v=None, create_graph=False, strict=False): 265*da0073e9SAndroid Build Coastguard Worker r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker Args: 268*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 269*da0073e9SAndroid Build Coastguard Worker a tuple of Tensors or a Tensor. 270*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 271*da0073e9SAndroid Build Coastguard Worker v (tuple of Tensors or Tensor): The vector for which the vector 272*da0073e9SAndroid Build Coastguard Worker Jacobian product is computed. Must be the same size as the output 273*da0073e9SAndroid Build Coastguard Worker of ``func``. This argument is optional when the output of ``func`` 274*da0073e9SAndroid Build Coastguard Worker contains a single element and (if it is not provided) will be set 275*da0073e9SAndroid Build Coastguard Worker as a Tensor containing a single ``1``. 276*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, both the output and result 277*da0073e9SAndroid Build Coastguard Worker will be computed in a differentiable way. Note that when ``strict`` 278*da0073e9SAndroid Build Coastguard Worker is ``False``, the result can not require gradients or be 279*da0073e9SAndroid Build Coastguard Worker disconnected from the inputs. Defaults to ``False``. 280*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we 281*da0073e9SAndroid Build Coastguard Worker detect that there exists an input such that all the outputs are 282*da0073e9SAndroid Build Coastguard Worker independent of it. If ``False``, we return a Tensor of zeros as the 283*da0073e9SAndroid Build Coastguard Worker vjp for said inputs, which is the expected mathematical value. 284*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker Returns: 287*da0073e9SAndroid Build Coastguard Worker output (tuple): tuple with: 288*da0073e9SAndroid Build Coastguard Worker func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker vjp (tuple of Tensors or Tensor): result of the dot product with 291*da0073e9SAndroid Build Coastguard Worker the same shape as the inputs. 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker Example: 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 296*da0073e9SAndroid Build Coastguard Worker >>> def exp_reducer(x): 297*da0073e9SAndroid Build Coastguard Worker ... return x.exp().sum(dim=1) 298*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(4, 4) 299*da0073e9SAndroid Build Coastguard Worker >>> v = torch.ones(4) 300*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 301*da0073e9SAndroid Build Coastguard Worker >>> vjp(exp_reducer, inputs, v) 302*da0073e9SAndroid Build Coastguard Worker (tensor([5.7817, 7.2458, 5.7830, 6.7782]), 303*da0073e9SAndroid Build Coastguard Worker tensor([[1.4458, 1.3962, 1.3042, 1.6354], 304*da0073e9SAndroid Build Coastguard Worker [2.1288, 1.0652, 1.5483, 2.5035], 305*da0073e9SAndroid Build Coastguard Worker [2.2046, 1.1292, 1.1432, 1.3059], 306*da0073e9SAndroid Build Coastguard Worker [1.3225, 1.6652, 1.7753, 2.0152]])) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker >>> vjp(exp_reducer, inputs, v, create_graph=True) 309*da0073e9SAndroid Build Coastguard Worker (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>), 310*da0073e9SAndroid Build Coastguard Worker tensor([[1.4458, 1.3962, 1.3042, 1.6354], 311*da0073e9SAndroid Build Coastguard Worker [2.1288, 1.0652, 1.5483, 2.5035], 312*da0073e9SAndroid Build Coastguard Worker [2.2046, 1.1292, 1.1432, 1.3059], 313*da0073e9SAndroid Build Coastguard Worker [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>)) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker >>> def adder(x, y): 316*da0073e9SAndroid Build Coastguard Worker ... return 2 * x + 3 * y 317*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 318*da0073e9SAndroid Build Coastguard Worker >>> v = torch.ones(2) 319*da0073e9SAndroid Build Coastguard Worker >>> vjp(adder, inputs, v) 320*da0073e9SAndroid Build Coastguard Worker (tensor([2.4225, 2.3340]), 321*da0073e9SAndroid Build Coastguard Worker (tensor([2., 2.]), tensor([3., 3.]))) 322*da0073e9SAndroid Build Coastguard Worker """ 323*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 324*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") 325*da0073e9SAndroid Build Coastguard Worker inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker outputs = func(*inputs) 328*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = _as_tuple( 329*da0073e9SAndroid Build Coastguard Worker outputs, "outputs of the user-provided function", "vjp" 330*da0073e9SAndroid Build Coastguard Worker ) 331*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(outputs, "outputs", strict=strict) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker if v is not None: 334*da0073e9SAndroid Build Coastguard Worker _, v = _as_tuple(v, "v", "vjp") 335*da0073e9SAndroid Build Coastguard Worker v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 336*da0073e9SAndroid Build Coastguard Worker _validate_v(v, outputs, is_outputs_tuple) 337*da0073e9SAndroid Build Coastguard Worker else: 338*da0073e9SAndroid Build Coastguard Worker if len(outputs) != 1 or outputs[0].nelement() != 1: 339*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 340*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the " 341*da0073e9SAndroid Build Coastguard Worker "user-provided function returns " 342*da0073e9SAndroid Build Coastguard Worker "a single Tensor with a single element." 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker enable_grad = True if create_graph else torch.is_grad_enabled() 346*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(enable_grad): 347*da0073e9SAndroid Build Coastguard Worker grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) 348*da0073e9SAndroid Build Coastguard Worker vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # Cleanup objects and return them to the user 351*da0073e9SAndroid Build Coastguard Worker outputs = _grad_postprocess(outputs, create_graph) 352*da0073e9SAndroid Build Coastguard Worker vjp = _grad_postprocess(vjp, create_graph) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 355*da0073e9SAndroid Build Coastguard Worker vjp, is_inputs_tuple 356*da0073e9SAndroid Build Coastguard Worker ) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerdef jvp(func, inputs, v=None, create_graph=False, strict=False): 360*da0073e9SAndroid Build Coastguard Worker r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker Args: 363*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 364*da0073e9SAndroid Build Coastguard Worker a tuple of Tensors or a Tensor. 365*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 366*da0073e9SAndroid Build Coastguard Worker v (tuple of Tensors or Tensor): The vector for which the Jacobian 367*da0073e9SAndroid Build Coastguard Worker vector product is computed. Must be the same size as the input of 368*da0073e9SAndroid Build Coastguard Worker ``func``. This argument is optional when the input to ``func`` 369*da0073e9SAndroid Build Coastguard Worker contains a single element and (if it is not provided) will be set 370*da0073e9SAndroid Build Coastguard Worker as a Tensor containing a single ``1``. 371*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, both the output and result 372*da0073e9SAndroid Build Coastguard Worker will be computed in a differentiable way. Note that when ``strict`` 373*da0073e9SAndroid Build Coastguard Worker is ``False``, the result can not require gradients or be 374*da0073e9SAndroid Build Coastguard Worker disconnected from the inputs. Defaults to ``False``. 375*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we 376*da0073e9SAndroid Build Coastguard Worker detect that there exists an input such that all the outputs are 377*da0073e9SAndroid Build Coastguard Worker independent of it. If ``False``, we return a Tensor of zeros as the 378*da0073e9SAndroid Build Coastguard Worker jvp for said inputs, which is the expected mathematical value. 379*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker Returns: 382*da0073e9SAndroid Build Coastguard Worker output (tuple): tuple with: 383*da0073e9SAndroid Build Coastguard Worker func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker jvp (tuple of Tensors or Tensor): result of the dot product with 386*da0073e9SAndroid Build Coastguard Worker the same shape as the output. 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker Note: 389*da0073e9SAndroid Build Coastguard Worker ``autograd.functional.jvp`` computes the jvp by using the backward of 390*da0073e9SAndroid Build Coastguard Worker the backward (sometimes called the double backwards trick). This is not 391*da0073e9SAndroid Build Coastguard Worker the most performant way of computing the jvp. Please consider using 392*da0073e9SAndroid Build Coastguard Worker :func:`torch.func.jvp` or the 393*da0073e9SAndroid Build Coastguard Worker :ref:`low-level forward-mode AD API <forward-mode-ad>` instead. 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker Example: 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 398*da0073e9SAndroid Build Coastguard Worker >>> def exp_reducer(x): 399*da0073e9SAndroid Build Coastguard Worker ... return x.exp().sum(dim=1) 400*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(4, 4) 401*da0073e9SAndroid Build Coastguard Worker >>> v = torch.ones(4, 4) 402*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 403*da0073e9SAndroid Build Coastguard Worker >>> jvp(exp_reducer, inputs, v) 404*da0073e9SAndroid Build Coastguard Worker (tensor([6.3090, 4.6742, 7.9114, 8.2106]), 405*da0073e9SAndroid Build Coastguard Worker tensor([6.3090, 4.6742, 7.9114, 8.2106])) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker >>> jvp(exp_reducer, inputs, v, create_graph=True) 408*da0073e9SAndroid Build Coastguard Worker (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>), 409*da0073e9SAndroid Build Coastguard Worker tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>)) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker >>> def adder(x, y): 412*da0073e9SAndroid Build Coastguard Worker ... return 2 * x + 3 * y 413*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 414*da0073e9SAndroid Build Coastguard Worker >>> v = (torch.ones(2), torch.ones(2)) 415*da0073e9SAndroid Build Coastguard Worker >>> jvp(adder, inputs, v) 416*da0073e9SAndroid Build Coastguard Worker (tensor([2.2399, 2.5005]), 417*da0073e9SAndroid Build Coastguard Worker tensor([5., 5.])) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker """ 420*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 421*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") 422*da0073e9SAndroid Build Coastguard Worker inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if v is not None: 425*da0073e9SAndroid Build Coastguard Worker _, v = _as_tuple(v, "v", "jvp") 426*da0073e9SAndroid Build Coastguard Worker v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 427*da0073e9SAndroid Build Coastguard Worker _validate_v(v, inputs, is_inputs_tuple) 428*da0073e9SAndroid Build Coastguard Worker else: 429*da0073e9SAndroid Build Coastguard Worker if len(inputs) != 1 or inputs[0].nelement() != 1: 430*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 431*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the input to " 432*da0073e9SAndroid Build Coastguard Worker "the user-provided function is a single Tensor " 433*da0073e9SAndroid Build Coastguard Worker "with a single element." 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker outputs = func(*inputs) 437*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = _as_tuple( 438*da0073e9SAndroid Build Coastguard Worker outputs, "outputs of the user-provided function", "jvp" 439*da0073e9SAndroid Build Coastguard Worker ) 440*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(outputs, "outputs", strict=strict) 441*da0073e9SAndroid Build Coastguard Worker # The backward is linear so the value of grad_outputs is not important as 442*da0073e9SAndroid Build Coastguard Worker # it won't appear in the double backward graph. We only need to ensure that 443*da0073e9SAndroid Build Coastguard Worker # it does not contain inf or nan. 444*da0073e9SAndroid Build Coastguard Worker grad_outputs = tuple( 445*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(out, requires_grad=True) for out in outputs 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) 449*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker if create_graph: 452*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 453*da0073e9SAndroid Build Coastguard Worker grad_res = _autograd_grad( 454*da0073e9SAndroid Build Coastguard Worker grad_inputs, grad_outputs, v, create_graph=create_graph 455*da0073e9SAndroid Build Coastguard Worker ) 456*da0073e9SAndroid Build Coastguard Worker jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") 457*da0073e9SAndroid Build Coastguard Worker else: 458*da0073e9SAndroid Build Coastguard Worker grad_res = _autograd_grad( 459*da0073e9SAndroid Build Coastguard Worker grad_inputs, grad_outputs, v, create_graph=create_graph 460*da0073e9SAndroid Build Coastguard Worker ) 461*da0073e9SAndroid Build Coastguard Worker jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker # Cleanup objects and return them to the user 464*da0073e9SAndroid Build Coastguard Worker outputs = _grad_postprocess(outputs, create_graph) 465*da0073e9SAndroid Build Coastguard Worker jvp = _grad_postprocess(jvp, create_graph) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 468*da0073e9SAndroid Build Coastguard Worker jvp, is_outputs_tuple 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Workerdef _construct_standard_basis_for( 473*da0073e9SAndroid Build Coastguard Worker tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...] 474*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 475*da0073e9SAndroid Build Coastguard Worker # This function: 476*da0073e9SAndroid Build Coastguard Worker # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. 477*da0073e9SAndroid Build Coastguard Worker # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. 478*da0073e9SAndroid Build Coastguard Worker # - Each chunk corresponds to one tensor. The chunk has the same dtype and 479*da0073e9SAndroid Build Coastguard Worker # device as the tensor 480*da0073e9SAndroid Build Coastguard Worker # 481*da0073e9SAndroid Build Coastguard Worker # For example, with tensor_numels = [1, 2, 1], this function returns: 482*da0073e9SAndroid Build Coastguard Worker # ( tensor([[1], tensor([[0, 0], tensor([[0], 483*da0073e9SAndroid Build Coastguard Worker # [0], [1, 0], [0], 484*da0073e9SAndroid Build Coastguard Worker # [0], [0, 1], [0], 485*da0073e9SAndroid Build Coastguard Worker # [0]]) , [0, 0]]) , [1]]) ) 486*da0073e9SAndroid Build Coastguard Worker # 487*da0073e9SAndroid Build Coastguard Worker # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) 488*da0073e9SAndroid Build Coastguard Worker # Precondition: tensors always has at least one element. 489*da0073e9SAndroid Build Coastguard Worker # 490*da0073e9SAndroid Build Coastguard Worker # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] 491*da0073e9SAndroid Build Coastguard Worker # for context behind this function. All the pre-conditions are guarded for 492*da0073e9SAndroid Build Coastguard Worker # in torch.autograd.functional.jacobian. 493*da0073e9SAndroid Build Coastguard Worker assert len(tensors) == len(tensor_numels) 494*da0073e9SAndroid Build Coastguard Worker assert len(tensors) > 0 495*da0073e9SAndroid Build Coastguard Worker total_numel = sum(tensor_numels) 496*da0073e9SAndroid Build Coastguard Worker chunks = tuple( 497*da0073e9SAndroid Build Coastguard Worker tensor.new_zeros(total_numel, tensor_numel) 498*da0073e9SAndroid Build Coastguard Worker for tensor, tensor_numel in zip(tensors, tensor_numels) 499*da0073e9SAndroid Build Coastguard Worker ) 500*da0073e9SAndroid Build Coastguard Worker diag_start_idx = 0 501*da0073e9SAndroid Build Coastguard Worker for chunk, numel in zip(chunks, tensor_numels): 502*da0073e9SAndroid Build Coastguard Worker chunk.diagonal(diag_start_idx).fill_(1) 503*da0073e9SAndroid Build Coastguard Worker diag_start_idx -= numel 504*da0073e9SAndroid Build Coastguard Worker return chunks 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Workerdef _jacfwd(func, inputs, strict=False, vectorize=False): 508*da0073e9SAndroid Build Coastguard Worker if strict: 509*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 510*da0073e9SAndroid Build Coastguard Worker "torch.autograd.functional.jacobian: `strict=True` " 511*da0073e9SAndroid Build Coastguard Worker 'and `strategy="forward-mode"` are not supported together (yet). ' 512*da0073e9SAndroid Build Coastguard Worker "Please either set `strict=False` or " 513*da0073e9SAndroid Build Coastguard Worker '`strategy="reverse-mode"`.' 514*da0073e9SAndroid Build Coastguard Worker ) 515*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") 516*da0073e9SAndroid Build Coastguard Worker output_info = [] 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker if vectorize: 519*da0073e9SAndroid Build Coastguard Worker # See NOTE: [Computing jacobian with vmap and grad for multiple outputs] 520*da0073e9SAndroid Build Coastguard Worker input_numels = tuple(input.numel() for input in inputs) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # Step 1: Prepare tangents 523*da0073e9SAndroid Build Coastguard Worker tangents = _construct_standard_basis_for(inputs, input_numels) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker # Step 2: Compute vmap over computation with dual tensors 526*da0073e9SAndroid Build Coastguard Worker def jvp(tangents): 527*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 528*da0073e9SAndroid Build Coastguard Worker dual_inputs = tuple( 529*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(input, tangent.view_as(input)) 530*da0073e9SAndroid Build Coastguard Worker for input, tangent in zip(inputs, tangents) 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker _is_outputs_tuple, dual_outputs = _as_tuple( 533*da0073e9SAndroid Build Coastguard Worker func(*dual_inputs), "outputs" 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker output_info.append(_is_outputs_tuple) 536*da0073e9SAndroid Build Coastguard Worker jv = [] 537*da0073e9SAndroid Build Coastguard Worker primal_outs = [] 538*da0073e9SAndroid Build Coastguard Worker for dual_out in dual_outputs: 539*da0073e9SAndroid Build Coastguard Worker primal, tangent = fwAD.unpack_dual(dual_out) 540*da0073e9SAndroid Build Coastguard Worker primal_outs.append(primal) 541*da0073e9SAndroid Build Coastguard Worker if tangent is not None: 542*da0073e9SAndroid Build Coastguard Worker jv.append(tangent) 543*da0073e9SAndroid Build Coastguard Worker else: 544*da0073e9SAndroid Build Coastguard Worker jv.append(torch.zeros_like(primal)) 545*da0073e9SAndroid Build Coastguard Worker output_info.append(primal_outs) 546*da0073e9SAndroid Build Coastguard Worker return tuple(jv) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker outputs_before_split = _vmap(jvp)(tangents) 549*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = output_info 550*da0073e9SAndroid Build Coastguard Worker # Step 3: for each of the output tangents, split along dim 0 551*da0073e9SAndroid Build Coastguard Worker jacobian_input_output = [] 552*da0073e9SAndroid Build Coastguard Worker for jac_output_i, output_i in zip(outputs_before_split, outputs): 553*da0073e9SAndroid Build Coastguard Worker jacobian_output_i_output = [] 554*da0073e9SAndroid Build Coastguard Worker for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs): 555*da0073e9SAndroid Build Coastguard Worker # We need to transpose the Jacobian because in forward AD, the 556*da0073e9SAndroid Build Coastguard Worker # batch dimension represents that of the inputs 557*da0073e9SAndroid Build Coastguard Worker jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( 558*da0073e9SAndroid Build Coastguard Worker (*output_i.shape, *input_j.shape) 559*da0073e9SAndroid Build Coastguard Worker ) # noqa: C409 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker jacobian_output_i_output.append(jacobian_input_i_output_j) 562*da0073e9SAndroid Build Coastguard Worker jacobian_input_output.append(jacobian_output_i_output) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker # Omit [Step 4] because everything is already transposed w/ forward AD 565*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess( 566*da0073e9SAndroid Build Coastguard Worker jacobian_input_output, (is_outputs_tuple, is_inputs_tuple) 567*da0073e9SAndroid Build Coastguard Worker ) 568*da0073e9SAndroid Build Coastguard Worker else: 569*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 570*da0073e9SAndroid Build Coastguard Worker "Computing Jacobian using forward-AD or forward-over-reverse Hessian is" 571*da0073e9SAndroid Build Coastguard Worker "only implemented for `vectorize=True`." 572*da0073e9SAndroid Build Coastguard Worker ) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Workerdef jacobian( 576*da0073e9SAndroid Build Coastguard Worker func, 577*da0073e9SAndroid Build Coastguard Worker inputs, 578*da0073e9SAndroid Build Coastguard Worker create_graph=False, 579*da0073e9SAndroid Build Coastguard Worker strict=False, 580*da0073e9SAndroid Build Coastguard Worker vectorize=False, 581*da0073e9SAndroid Build Coastguard Worker strategy="reverse-mode", 582*da0073e9SAndroid Build Coastguard Worker): 583*da0073e9SAndroid Build Coastguard Worker r"""Compute the Jacobian of a given function. 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker Args: 586*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 587*da0073e9SAndroid Build Coastguard Worker a tuple of Tensors or a Tensor. 588*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 589*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, the Jacobian will be 590*da0073e9SAndroid Build Coastguard Worker computed in a differentiable manner. Note that when ``strict`` is 591*da0073e9SAndroid Build Coastguard Worker ``False``, the result can not require gradients or be disconnected 592*da0073e9SAndroid Build Coastguard Worker from the inputs. Defaults to ``False``. 593*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we 594*da0073e9SAndroid Build Coastguard Worker detect that there exists an input such that all the outputs are 595*da0073e9SAndroid Build Coastguard Worker independent of it. If ``False``, we return a Tensor of zeros as the 596*da0073e9SAndroid Build Coastguard Worker jacobian for said inputs, which is the expected mathematical value. 597*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 598*da0073e9SAndroid Build Coastguard Worker vectorize (bool, optional): This feature is experimental. 599*da0073e9SAndroid Build Coastguard Worker Please consider using :func:`torch.func.jacrev` or 600*da0073e9SAndroid Build Coastguard Worker :func:`torch.func.jacfwd` instead if you are looking for something 601*da0073e9SAndroid Build Coastguard Worker less experimental and more performant. 602*da0073e9SAndroid Build Coastguard Worker When computing the jacobian, usually we invoke 603*da0073e9SAndroid Build Coastguard Worker ``autograd.grad`` once per row of the jacobian. If this flag is 604*da0073e9SAndroid Build Coastguard Worker ``True``, we perform only a single ``autograd.grad`` call with 605*da0073e9SAndroid Build Coastguard Worker ``batched_grad=True`` which uses the vmap prototype feature. 606*da0073e9SAndroid Build Coastguard Worker Though this should lead to performance improvements in many cases, 607*da0073e9SAndroid Build Coastguard Worker because this feature is still experimental, there may be performance 608*da0073e9SAndroid Build Coastguard Worker cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for 609*da0073e9SAndroid Build Coastguard Worker more information. 610*da0073e9SAndroid Build Coastguard Worker strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to 611*da0073e9SAndroid Build Coastguard Worker determine whether the Jacobian will be computed with forward or reverse 612*da0073e9SAndroid Build Coastguard Worker mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. 613*da0073e9SAndroid Build Coastguard Worker Defaults to ``"reverse-mode"``. If ``func`` has more outputs than 614*da0073e9SAndroid Build Coastguard Worker inputs, ``"forward-mode"`` tends to be more performant. Otherwise, 615*da0073e9SAndroid Build Coastguard Worker prefer to use ``"reverse-mode"``. 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker Returns: 618*da0073e9SAndroid Build Coastguard Worker Jacobian (Tensor or nested tuple of Tensors): if there is a single 619*da0073e9SAndroid Build Coastguard Worker input and output, this will be a single Tensor containing the 620*da0073e9SAndroid Build Coastguard Worker Jacobian for the linearized inputs and output. If one of the two is 621*da0073e9SAndroid Build Coastguard Worker a tuple, then the Jacobian will be a tuple of Tensors. If both of 622*da0073e9SAndroid Build Coastguard Worker them are tuples, then the Jacobian will be a tuple of tuple of 623*da0073e9SAndroid Build Coastguard Worker Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the 624*da0073e9SAndroid Build Coastguard Worker ``i``\th output and ``j``\th input and will have as size the 625*da0073e9SAndroid Build Coastguard Worker concatenation of the sizes of the corresponding output and the 626*da0073e9SAndroid Build Coastguard Worker corresponding input and will have same dtype and device as the 627*da0073e9SAndroid Build Coastguard Worker corresponding input. If strategy is ``forward-mode``, the dtype will be 628*da0073e9SAndroid Build Coastguard Worker that of the output; otherwise, the input. 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker Example: 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 633*da0073e9SAndroid Build Coastguard Worker >>> def exp_reducer(x): 634*da0073e9SAndroid Build Coastguard Worker ... return x.exp().sum(dim=1) 635*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(2, 2) 636*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 637*da0073e9SAndroid Build Coastguard Worker >>> jacobian(exp_reducer, inputs) 638*da0073e9SAndroid Build Coastguard Worker tensor([[[1.4917, 2.4352], 639*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]], 640*da0073e9SAndroid Build Coastguard Worker [[0.0000, 0.0000], 641*da0073e9SAndroid Build Coastguard Worker [2.4369, 2.3799]]]) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker >>> jacobian(exp_reducer, inputs, create_graph=True) 644*da0073e9SAndroid Build Coastguard Worker tensor([[[1.4917, 2.4352], 645*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]], 646*da0073e9SAndroid Build Coastguard Worker [[0.0000, 0.0000], 647*da0073e9SAndroid Build Coastguard Worker [2.4369, 2.3799]]], grad_fn=<ViewBackward>) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker >>> def exp_adder(x, y): 650*da0073e9SAndroid Build Coastguard Worker ... return 2 * x.exp() + 3 * y 651*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 652*da0073e9SAndroid Build Coastguard Worker >>> jacobian(exp_adder, inputs) 653*da0073e9SAndroid Build Coastguard Worker (tensor([[2.8052, 0.0000], 654*da0073e9SAndroid Build Coastguard Worker [0.0000, 3.3963]]), 655*da0073e9SAndroid Build Coastguard Worker tensor([[3., 0.], 656*da0073e9SAndroid Build Coastguard Worker [0., 3.]])) 657*da0073e9SAndroid Build Coastguard Worker """ 658*da0073e9SAndroid Build Coastguard Worker assert strategy in ("forward-mode", "reverse-mode"), ( 659*da0073e9SAndroid Build Coastguard Worker 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' 660*da0073e9SAndroid Build Coastguard Worker 'function has more outputs than inputs, "forward-mode" tends to be more performant. ' 661*da0073e9SAndroid Build Coastguard Worker 'Otherwise, prefer to use "reverse-mode".' 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker if strategy == "forward-mode": 664*da0073e9SAndroid Build Coastguard Worker if create_graph: 665*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 666*da0073e9SAndroid Build Coastguard Worker "torch.autograd.functional.jacobian: `create_graph=True` " 667*da0073e9SAndroid Build Coastguard Worker 'and `strategy="forward-mode"` are not supported together (yet). ' 668*da0073e9SAndroid Build Coastguard Worker "Please either set `create_graph=False` or " 669*da0073e9SAndroid Build Coastguard Worker '`strategy="reverse-mode"`.' 670*da0073e9SAndroid Build Coastguard Worker ) 671*da0073e9SAndroid Build Coastguard Worker return _jacfwd(func, inputs, strict, vectorize) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 674*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") 675*da0073e9SAndroid Build Coastguard Worker inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker outputs = func(*inputs) 678*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = _as_tuple( 679*da0073e9SAndroid Build Coastguard Worker outputs, "outputs of the user-provided function", "jacobian" 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(outputs, "outputs", strict=strict) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker if vectorize: 684*da0073e9SAndroid Build Coastguard Worker if strict: 685*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 686*da0073e9SAndroid Build Coastguard Worker "torch.autograd.functional.jacobian: `strict=True` " 687*da0073e9SAndroid Build Coastguard Worker "and `vectorized=True` are not supported together. " 688*da0073e9SAndroid Build Coastguard Worker "Please either set `strict=False` or " 689*da0073e9SAndroid Build Coastguard Worker "`vectorize=False`." 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker # NOTE: [Computing jacobian with vmap and grad for multiple outputs] 692*da0073e9SAndroid Build Coastguard Worker # 693*da0073e9SAndroid Build Coastguard Worker # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). 694*da0073e9SAndroid Build Coastguard Worker # It turns out we can compute the jacobian of this function with a single 695*da0073e9SAndroid Build Coastguard Worker # call to autograd.grad by using vmap over the correct grad_outputs. 696*da0073e9SAndroid Build Coastguard Worker # 697*da0073e9SAndroid Build Coastguard Worker # Firstly, one way to compute the jacobian is to stack x**2 and x.sum() 698*da0073e9SAndroid Build Coastguard Worker # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) 699*da0073e9SAndroid Build Coastguard Worker # 700*da0073e9SAndroid Build Coastguard Worker # To get the first row of the jacobian, we call 701*da0073e9SAndroid Build Coastguard Worker # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) 702*da0073e9SAndroid Build Coastguard Worker # To get the 2nd row of the jacobian, we call 703*da0073e9SAndroid Build Coastguard Worker # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) 704*da0073e9SAndroid Build Coastguard Worker # and so on. 705*da0073e9SAndroid Build Coastguard Worker # 706*da0073e9SAndroid Build Coastguard Worker # Using vmap, we can vectorize all 4 of these computations into one by 707*da0073e9SAndroid Build Coastguard Worker # passing the standard basis for R^4 as the grad_output. 708*da0073e9SAndroid Build Coastguard Worker # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). 709*da0073e9SAndroid Build Coastguard Worker # 710*da0073e9SAndroid Build Coastguard Worker # Now, how do we compute the jacobian *without stacking the output*? 711*da0073e9SAndroid Build Coastguard Worker # We can just split the standard basis across the outputs. So to 712*da0073e9SAndroid Build Coastguard Worker # compute the jacobian of f(x), we'd use 713*da0073e9SAndroid Build Coastguard Worker # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) 714*da0073e9SAndroid Build Coastguard Worker # The grad_outputs looks like the following: 715*da0073e9SAndroid Build Coastguard Worker # ( torch.tensor([[1, 0, 0], 716*da0073e9SAndroid Build Coastguard Worker # [0, 1, 0], 717*da0073e9SAndroid Build Coastguard Worker # [0, 0, 1], 718*da0073e9SAndroid Build Coastguard Worker # [0, 0, 0]]), 719*da0073e9SAndroid Build Coastguard Worker # torch.tensor([[0], 720*da0073e9SAndroid Build Coastguard Worker # [0], 721*da0073e9SAndroid Build Coastguard Worker # [0], 722*da0073e9SAndroid Build Coastguard Worker # [1]]) ) 723*da0073e9SAndroid Build Coastguard Worker # 724*da0073e9SAndroid Build Coastguard Worker # But we're not done yet! 725*da0073e9SAndroid Build Coastguard Worker # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) 726*da0073e9SAndroid Build Coastguard Worker # returns a Tensor of shape [4, 3]. We have to remember to split the 727*da0073e9SAndroid Build Coastguard Worker # jacobian of shape [4, 3] into two: 728*da0073e9SAndroid Build Coastguard Worker # - one of shape [3, 3] for the first output 729*da0073e9SAndroid Build Coastguard Worker # - one of shape [ 3] for the second output 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker # Step 1: Construct grad_outputs by splitting the standard basis 732*da0073e9SAndroid Build Coastguard Worker output_numels = tuple(output.numel() for output in outputs) 733*da0073e9SAndroid Build Coastguard Worker grad_outputs = _construct_standard_basis_for(outputs, output_numels) 734*da0073e9SAndroid Build Coastguard Worker flat_outputs = tuple(output.reshape(-1) for output in outputs) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker # Step 2: Call vmap + autograd.grad 737*da0073e9SAndroid Build Coastguard Worker def vjp(grad_output): 738*da0073e9SAndroid Build Coastguard Worker vj = list( 739*da0073e9SAndroid Build Coastguard Worker _autograd_grad( 740*da0073e9SAndroid Build Coastguard Worker flat_outputs, 741*da0073e9SAndroid Build Coastguard Worker inputs, 742*da0073e9SAndroid Build Coastguard Worker grad_output, 743*da0073e9SAndroid Build Coastguard Worker create_graph=create_graph, 744*da0073e9SAndroid Build Coastguard Worker is_grads_batched=True, 745*da0073e9SAndroid Build Coastguard Worker ) 746*da0073e9SAndroid Build Coastguard Worker ) 747*da0073e9SAndroid Build Coastguard Worker for el_idx, vj_el in enumerate(vj): 748*da0073e9SAndroid Build Coastguard Worker if vj_el is not None: 749*da0073e9SAndroid Build Coastguard Worker continue 750*da0073e9SAndroid Build Coastguard Worker vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand( 751*da0073e9SAndroid Build Coastguard Worker (sum(output_numels),) + inputs[el_idx].shape 752*da0073e9SAndroid Build Coastguard Worker ) 753*da0073e9SAndroid Build Coastguard Worker return tuple(vj) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker jacobians_of_flat_output = vjp(grad_outputs) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker # Step 3: The returned jacobian is one big tensor per input. In this step, 758*da0073e9SAndroid Build Coastguard Worker # we split each Tensor by output. 759*da0073e9SAndroid Build Coastguard Worker jacobian_input_output = [] 760*da0073e9SAndroid Build Coastguard Worker for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs): 761*da0073e9SAndroid Build Coastguard Worker jacobian_input_i_output = [] 762*da0073e9SAndroid Build Coastguard Worker for jac, output_j in zip( 763*da0073e9SAndroid Build Coastguard Worker jac_input_i.split(output_numels, dim=0), outputs 764*da0073e9SAndroid Build Coastguard Worker ): 765*da0073e9SAndroid Build Coastguard Worker jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) 766*da0073e9SAndroid Build Coastguard Worker jacobian_input_i_output.append(jacobian_input_i_output_j) 767*da0073e9SAndroid Build Coastguard Worker jacobian_input_output.append(jacobian_input_i_output) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker # Step 4: Right now, `jacobian` is a List[List[Tensor]]. 770*da0073e9SAndroid Build Coastguard Worker # The outer List corresponds to the number of inputs, 771*da0073e9SAndroid Build Coastguard Worker # the inner List corresponds to the number of outputs. 772*da0073e9SAndroid Build Coastguard Worker # We need to exchange the order of these and convert to tuples 773*da0073e9SAndroid Build Coastguard Worker # before returning. 774*da0073e9SAndroid Build Coastguard Worker jacobian_output_input = tuple(zip(*jacobian_input_output)) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker jacobian_output_input = _grad_postprocess( 777*da0073e9SAndroid Build Coastguard Worker jacobian_output_input, create_graph 778*da0073e9SAndroid Build Coastguard Worker ) 779*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess( 780*da0073e9SAndroid Build Coastguard Worker jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) 781*da0073e9SAndroid Build Coastguard Worker ) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker jacobian: Tuple[torch.Tensor, ...] = () 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker for i, out in enumerate(outputs): 786*da0073e9SAndroid Build Coastguard Worker # mypy complains that expression and variable have different types due to the empty list 787*da0073e9SAndroid Build Coastguard Worker jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] 788*da0073e9SAndroid Build Coastguard Worker for j in range(out.nelement()): 789*da0073e9SAndroid Build Coastguard Worker vj = _autograd_grad( 790*da0073e9SAndroid Build Coastguard Worker (out.reshape(-1)[j],), 791*da0073e9SAndroid Build Coastguard Worker inputs, 792*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 793*da0073e9SAndroid Build Coastguard Worker create_graph=create_graph, 794*da0073e9SAndroid Build Coastguard Worker ) 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( 797*da0073e9SAndroid Build Coastguard Worker zip(jac_i, vj, inputs) 798*da0073e9SAndroid Build Coastguard Worker ): 799*da0073e9SAndroid Build Coastguard Worker if vj_el is not None: 800*da0073e9SAndroid Build Coastguard Worker if strict and create_graph and not vj_el.requires_grad: 801*da0073e9SAndroid Build Coastguard Worker msg = ( 802*da0073e9SAndroid Build Coastguard Worker "The jacobian of the user-provided function is " 803*da0073e9SAndroid Build Coastguard Worker f"independent of input {i}. This is not allowed in " 804*da0073e9SAndroid Build Coastguard Worker "strict mode when create_graph=True." 805*da0073e9SAndroid Build Coastguard Worker ) 806*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(msg) 807*da0073e9SAndroid Build Coastguard Worker jac_i_el.append(vj_el) 808*da0073e9SAndroid Build Coastguard Worker else: 809*da0073e9SAndroid Build Coastguard Worker if strict: 810*da0073e9SAndroid Build Coastguard Worker msg = ( 811*da0073e9SAndroid Build Coastguard Worker f"Output {i} of the user-provided function is " 812*da0073e9SAndroid Build Coastguard Worker f"independent of input {el_idx}. This is not allowed in " 813*da0073e9SAndroid Build Coastguard Worker "strict mode." 814*da0073e9SAndroid Build Coastguard Worker ) 815*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(msg) 816*da0073e9SAndroid Build Coastguard Worker jac_i_el.append(torch.zeros_like(inp_el)) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker jacobian += ( 819*da0073e9SAndroid Build Coastguard Worker tuple( 820*da0073e9SAndroid Build Coastguard Worker torch.stack(jac_i_el, dim=0).view( 821*da0073e9SAndroid Build Coastguard Worker out.size() + inputs[el_idx].size() # type: ignore[operator] 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker for (el_idx, jac_i_el) in enumerate(jac_i) 824*da0073e9SAndroid Build Coastguard Worker ), 825*da0073e9SAndroid Build Coastguard Worker ) 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker jacobian = _grad_postprocess(jacobian, create_graph) 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Workerdef hessian( 833*da0073e9SAndroid Build Coastguard Worker func, 834*da0073e9SAndroid Build Coastguard Worker inputs, 835*da0073e9SAndroid Build Coastguard Worker create_graph=False, 836*da0073e9SAndroid Build Coastguard Worker strict=False, 837*da0073e9SAndroid Build Coastguard Worker vectorize=False, 838*da0073e9SAndroid Build Coastguard Worker outer_jacobian_strategy="reverse-mode", 839*da0073e9SAndroid Build Coastguard Worker): 840*da0073e9SAndroid Build Coastguard Worker r"""Compute the Hessian of a given scalar function. 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker Args: 843*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 844*da0073e9SAndroid Build Coastguard Worker a Tensor with a single element. 845*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 846*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, the Hessian will be computed in 847*da0073e9SAndroid Build Coastguard Worker a differentiable manner. Note that when ``strict`` is ``False``, the result can not 848*da0073e9SAndroid Build Coastguard Worker require gradients or be disconnected from the inputs. 849*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 850*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input 851*da0073e9SAndroid Build Coastguard Worker such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the 852*da0073e9SAndroid Build Coastguard Worker hessian for said inputs, which is the expected mathematical value. 853*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 854*da0073e9SAndroid Build Coastguard Worker vectorize (bool, optional): This feature is experimental. 855*da0073e9SAndroid Build Coastguard Worker Please consider using :func:`torch.func.hessian` 856*da0073e9SAndroid Build Coastguard Worker instead if you are looking for something less experimental and more performant. 857*da0073e9SAndroid Build Coastguard Worker When computing the hessian, usually we invoke 858*da0073e9SAndroid Build Coastguard Worker ``autograd.grad`` once per row of the hessian. If this flag is 859*da0073e9SAndroid Build Coastguard Worker ``True``, we use the vmap prototype feature as the backend to 860*da0073e9SAndroid Build Coastguard Worker vectorize calls to ``autograd.grad`` so we only invoke it once 861*da0073e9SAndroid Build Coastguard Worker instead of once per row. This should lead to performance 862*da0073e9SAndroid Build Coastguard Worker improvements in many use cases, however, due to this feature 863*da0073e9SAndroid Build Coastguard Worker being incomplete, there may be performance cliffs. Please 864*da0073e9SAndroid Build Coastguard Worker use `torch._C._debug_only_display_vmap_fallback_warnings(True)` 865*da0073e9SAndroid Build Coastguard Worker to show any performance warnings and file us issues if 866*da0073e9SAndroid Build Coastguard Worker warnings exist for your use case. Defaults to ``False``. 867*da0073e9SAndroid Build Coastguard Worker outer_jacobian_strategy (str, optional): The Hessian is computed by 868*da0073e9SAndroid Build Coastguard Worker computing the Jacobian of a Jacobian. The inner Jacobian is always 869*da0073e9SAndroid Build Coastguard Worker computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` 870*da0073e9SAndroid Build Coastguard Worker or ``"reverse-mode"`` determines whether the outer Jacobian will be 871*da0073e9SAndroid Build Coastguard Worker computed with forward or reverse mode AD. Currently, computing the outer 872*da0073e9SAndroid Build Coastguard Worker Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults 873*da0073e9SAndroid Build Coastguard Worker to ``"reverse-mode"``. 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker Returns: 876*da0073e9SAndroid Build Coastguard Worker Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, 877*da0073e9SAndroid Build Coastguard Worker this will be a single Tensor containing the Hessian for the input. 878*da0073e9SAndroid Build Coastguard Worker If it is a tuple, then the Hessian will be a tuple of tuples where 879*da0073e9SAndroid Build Coastguard Worker ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input 880*da0073e9SAndroid Build Coastguard Worker and ``j``\th input with size the sum of the size of the ``i``\th input plus 881*da0073e9SAndroid Build Coastguard Worker the size of the ``j``\th input. ``Hessian[i][j]`` will have the same 882*da0073e9SAndroid Build Coastguard Worker dtype and device as the corresponding ``i``\th input. 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker Example: 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 887*da0073e9SAndroid Build Coastguard Worker >>> def pow_reducer(x): 888*da0073e9SAndroid Build Coastguard Worker ... return x.pow(3).sum() 889*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(2, 2) 890*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 891*da0073e9SAndroid Build Coastguard Worker >>> hessian(pow_reducer, inputs) 892*da0073e9SAndroid Build Coastguard Worker tensor([[[[5.2265, 0.0000], 893*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]], 894*da0073e9SAndroid Build Coastguard Worker [[0.0000, 4.8221], 895*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]]], 896*da0073e9SAndroid Build Coastguard Worker [[[0.0000, 0.0000], 897*da0073e9SAndroid Build Coastguard Worker [1.9456, 0.0000]], 898*da0073e9SAndroid Build Coastguard Worker [[0.0000, 0.0000], 899*da0073e9SAndroid Build Coastguard Worker [0.0000, 3.2550]]]]) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker >>> hessian(pow_reducer, inputs, create_graph=True) 902*da0073e9SAndroid Build Coastguard Worker tensor([[[[5.2265, 0.0000], 903*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]], 904*da0073e9SAndroid Build Coastguard Worker [[0.0000, 4.8221], 905*da0073e9SAndroid Build Coastguard Worker [0.0000, 0.0000]]], 906*da0073e9SAndroid Build Coastguard Worker [[[0.0000, 0.0000], 907*da0073e9SAndroid Build Coastguard Worker [1.9456, 0.0000]], 908*da0073e9SAndroid Build Coastguard Worker [[0.0000, 0.0000], 909*da0073e9SAndroid Build Coastguard Worker [0.0000, 3.2550]]]], grad_fn=<ViewBackward>) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker >>> def pow_adder_reducer(x, y): 913*da0073e9SAndroid Build Coastguard Worker ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 914*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 915*da0073e9SAndroid Build Coastguard Worker >>> hessian(pow_adder_reducer, inputs) 916*da0073e9SAndroid Build Coastguard Worker ((tensor([[4., 0.], 917*da0073e9SAndroid Build Coastguard Worker [0., 4.]]), 918*da0073e9SAndroid Build Coastguard Worker tensor([[0., 0.], 919*da0073e9SAndroid Build Coastguard Worker [0., 0.]])), 920*da0073e9SAndroid Build Coastguard Worker (tensor([[0., 0.], 921*da0073e9SAndroid Build Coastguard Worker [0., 0.]]), 922*da0073e9SAndroid Build Coastguard Worker tensor([[6., 0.], 923*da0073e9SAndroid Build Coastguard Worker [0., 6.]]))) 924*da0073e9SAndroid Build Coastguard Worker """ 925*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") 926*da0073e9SAndroid Build Coastguard Worker assert outer_jacobian_strategy in ( 927*da0073e9SAndroid Build Coastguard Worker "forward-mode", 928*da0073e9SAndroid Build Coastguard Worker "reverse-mode", 929*da0073e9SAndroid Build Coastguard Worker ), 'Expected strategy to be either "forward-mode" or "reverse-mode".' 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker def ensure_single_output_function(*inp): 932*da0073e9SAndroid Build Coastguard Worker out = func(*inp) 933*da0073e9SAndroid Build Coastguard Worker is_out_tuple, t_out = _as_tuple( 934*da0073e9SAndroid Build Coastguard Worker out, "outputs of the user-provided function", "hessian" 935*da0073e9SAndroid Build Coastguard Worker ) 936*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(t_out, "outputs", strict=strict) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker if is_out_tuple or not isinstance(out, torch.Tensor): 939*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 940*da0073e9SAndroid Build Coastguard Worker "The function given to hessian should return a single Tensor" 941*da0073e9SAndroid Build Coastguard Worker ) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker if out.nelement() != 1: 944*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 945*da0073e9SAndroid Build Coastguard Worker "The Tensor returned by the function given to hessian should contain a single element" 946*da0073e9SAndroid Build Coastguard Worker ) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker return out.squeeze() 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker def jac_func(*inp): 951*da0073e9SAndroid Build Coastguard Worker if outer_jacobian_strategy == "forward-mode": 952*da0073e9SAndroid Build Coastguard Worker # _grad_preprocess requires create_graph=True and input to require_grad 953*da0073e9SAndroid Build Coastguard Worker # or else the input will be detached 954*da0073e9SAndroid Build Coastguard Worker inp = tuple(t.requires_grad_(True) for t in inp) 955*da0073e9SAndroid Build Coastguard Worker jac = jacobian(ensure_single_output_function, inp, create_graph=True) 956*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(jac, "jacobian", strict=strict) 957*da0073e9SAndroid Build Coastguard Worker return jac 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker res = jacobian( 960*da0073e9SAndroid Build Coastguard Worker jac_func, 961*da0073e9SAndroid Build Coastguard Worker inputs, 962*da0073e9SAndroid Build Coastguard Worker create_graph=create_graph, 963*da0073e9SAndroid Build Coastguard Worker strict=strict, 964*da0073e9SAndroid Build Coastguard Worker vectorize=vectorize, 965*da0073e9SAndroid Build Coastguard Worker strategy=outer_jacobian_strategy, 966*da0073e9SAndroid Build Coastguard Worker ) 967*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Workerdef vhp(func, inputs, v=None, create_graph=False, strict=False): 971*da0073e9SAndroid Build Coastguard Worker r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point. 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker Args: 974*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 975*da0073e9SAndroid Build Coastguard Worker a Tensor with a single element. 976*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 977*da0073e9SAndroid Build Coastguard Worker v (tuple of Tensors or Tensor): The vector for which the vector Hessian 978*da0073e9SAndroid Build Coastguard Worker product is computed. Must be the same size as the input of 979*da0073e9SAndroid Build Coastguard Worker ``func``. This argument is optional when ``func``'s input contains 980*da0073e9SAndroid Build Coastguard Worker a single element and (if it is not provided) will be set as a 981*da0073e9SAndroid Build Coastguard Worker Tensor containing a single ``1``. 982*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, both the output and result 983*da0073e9SAndroid Build Coastguard Worker will be computed in a differentiable way. Note that when ``strict`` 984*da0073e9SAndroid Build Coastguard Worker is ``False``, the result can not require gradients or be 985*da0073e9SAndroid Build Coastguard Worker disconnected from the inputs. 986*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 987*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we 988*da0073e9SAndroid Build Coastguard Worker detect that there exists an input such that all the outputs are 989*da0073e9SAndroid Build Coastguard Worker independent of it. If ``False``, we return a Tensor of zeros as the 990*da0073e9SAndroid Build Coastguard Worker vhp for said inputs, which is the expected mathematical value. 991*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker Returns: 994*da0073e9SAndroid Build Coastguard Worker output (tuple): tuple with: 995*da0073e9SAndroid Build Coastguard Worker func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker vhp (tuple of Tensors or Tensor): result of the dot product with the 998*da0073e9SAndroid Build Coastguard Worker same shape as the inputs. 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker Example: 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 1003*da0073e9SAndroid Build Coastguard Worker >>> def pow_reducer(x): 1004*da0073e9SAndroid Build Coastguard Worker ... return x.pow(3).sum() 1005*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(2, 2) 1006*da0073e9SAndroid Build Coastguard Worker >>> v = torch.ones(2, 2) 1007*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 1008*da0073e9SAndroid Build Coastguard Worker >>> vhp(pow_reducer, inputs, v) 1009*da0073e9SAndroid Build Coastguard Worker (tensor(0.5591), 1010*da0073e9SAndroid Build Coastguard Worker tensor([[1.0689, 1.2431], 1011*da0073e9SAndroid Build Coastguard Worker [3.0989, 4.4456]])) 1012*da0073e9SAndroid Build Coastguard Worker >>> vhp(pow_reducer, inputs, v, create_graph=True) 1013*da0073e9SAndroid Build Coastguard Worker (tensor(0.5591, grad_fn=<SumBackward0>), 1014*da0073e9SAndroid Build Coastguard Worker tensor([[1.0689, 1.2431], 1015*da0073e9SAndroid Build Coastguard Worker [3.0989, 4.4456]], grad_fn=<MulBackward0>)) 1016*da0073e9SAndroid Build Coastguard Worker >>> def pow_adder_reducer(x, y): 1017*da0073e9SAndroid Build Coastguard Worker ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 1018*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 1019*da0073e9SAndroid Build Coastguard Worker >>> v = (torch.zeros(2), torch.ones(2)) 1020*da0073e9SAndroid Build Coastguard Worker >>> vhp(pow_adder_reducer, inputs, v) 1021*da0073e9SAndroid Build Coastguard Worker (tensor(4.8053), 1022*da0073e9SAndroid Build Coastguard Worker (tensor([0., 0.]), 1023*da0073e9SAndroid Build Coastguard Worker tensor([6., 6.]))) 1024*da0073e9SAndroid Build Coastguard Worker """ 1025*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 1026*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") 1027*da0073e9SAndroid Build Coastguard Worker inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker if v is not None: 1030*da0073e9SAndroid Build Coastguard Worker _, v = _as_tuple(v, "v", "vhp") 1031*da0073e9SAndroid Build Coastguard Worker v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 1032*da0073e9SAndroid Build Coastguard Worker _validate_v(v, inputs, is_inputs_tuple) 1033*da0073e9SAndroid Build Coastguard Worker else: 1034*da0073e9SAndroid Build Coastguard Worker if len(inputs) != 1 or inputs[0].nelement() != 1: 1035*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1036*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the input to the user-provided function " 1037*da0073e9SAndroid Build Coastguard Worker "is a single Tensor with a single element." 1038*da0073e9SAndroid Build Coastguard Worker ) 1039*da0073e9SAndroid Build Coastguard Worker outputs = func(*inputs) 1040*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = _as_tuple( 1041*da0073e9SAndroid Build Coastguard Worker outputs, "outputs of the user-provided function", "vhp" 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(outputs, "outputs", strict=strict) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): 1046*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1047*da0073e9SAndroid Build Coastguard Worker "The function given to vhp should return a single Tensor" 1048*da0073e9SAndroid Build Coastguard Worker ) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker if outputs[0].nelement() != 1: 1051*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1052*da0073e9SAndroid Build Coastguard Worker "The Tensor returned by the function given to vhp should contain a single element" 1053*da0073e9SAndroid Build Coastguard Worker ) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker jac = _autograd_grad(outputs, inputs, create_graph=True) 1056*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(jac, "jacobian", strict=strict) 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker enable_grad = True if create_graph else torch.is_grad_enabled() 1059*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(enable_grad): 1060*da0073e9SAndroid Build Coastguard Worker grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) 1061*da0073e9SAndroid Build Coastguard Worker vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker outputs = _grad_postprocess(outputs, create_graph) 1064*da0073e9SAndroid Build Coastguard Worker vhp = _grad_postprocess(vhp, create_graph) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 1067*da0073e9SAndroid Build Coastguard Worker vhp, is_inputs_tuple 1068*da0073e9SAndroid Build Coastguard Worker ) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Workerdef hvp(func, inputs, v=None, create_graph=False, strict=False): 1072*da0073e9SAndroid Build Coastguard Worker r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point. 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker Args: 1075*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 1076*da0073e9SAndroid Build Coastguard Worker a Tensor with a single element. 1077*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 1078*da0073e9SAndroid Build Coastguard Worker v (tuple of Tensors or Tensor): The vector for which the Hessian vector 1079*da0073e9SAndroid Build Coastguard Worker product is computed. Must be the same size as the input of 1080*da0073e9SAndroid Build Coastguard Worker ``func``. This argument is optional when ``func``'s input contains 1081*da0073e9SAndroid Build Coastguard Worker a single element and (if it is not provided) will be set as a 1082*da0073e9SAndroid Build Coastguard Worker Tensor containing a single ``1``. 1083*da0073e9SAndroid Build Coastguard Worker create_graph (bool, optional): If ``True``, both the output and result will be 1084*da0073e9SAndroid Build Coastguard Worker computed in a differentiable way. Note that when ``strict`` is 1085*da0073e9SAndroid Build Coastguard Worker ``False``, the result can not require gradients or be disconnected 1086*da0073e9SAndroid Build Coastguard Worker from the inputs. Defaults to ``False``. 1087*da0073e9SAndroid Build Coastguard Worker strict (bool, optional): If ``True``, an error will be raised when we 1088*da0073e9SAndroid Build Coastguard Worker detect that there exists an input such that all the outputs are 1089*da0073e9SAndroid Build Coastguard Worker independent of it. If ``False``, we return a Tensor of zeros as the 1090*da0073e9SAndroid Build Coastguard Worker hvp for said inputs, which is the expected mathematical value. 1091*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 1092*da0073e9SAndroid Build Coastguard Worker Returns: 1093*da0073e9SAndroid Build Coastguard Worker output (tuple): tuple with: 1094*da0073e9SAndroid Build Coastguard Worker func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker hvp (tuple of Tensors or Tensor): result of the dot product with 1097*da0073e9SAndroid Build Coastguard Worker the same shape as the inputs. 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker Example: 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 1102*da0073e9SAndroid Build Coastguard Worker >>> def pow_reducer(x): 1103*da0073e9SAndroid Build Coastguard Worker ... return x.pow(3).sum() 1104*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.rand(2, 2) 1105*da0073e9SAndroid Build Coastguard Worker >>> v = torch.ones(2, 2) 1106*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 1107*da0073e9SAndroid Build Coastguard Worker >>> hvp(pow_reducer, inputs, v) 1108*da0073e9SAndroid Build Coastguard Worker (tensor(0.1448), 1109*da0073e9SAndroid Build Coastguard Worker tensor([[2.0239, 1.6456], 1110*da0073e9SAndroid Build Coastguard Worker [2.4988, 1.4310]])) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker >>> hvp(pow_reducer, inputs, v, create_graph=True) 1113*da0073e9SAndroid Build Coastguard Worker (tensor(0.1448, grad_fn=<SumBackward0>), 1114*da0073e9SAndroid Build Coastguard Worker tensor([[2.0239, 1.6456], 1115*da0073e9SAndroid Build Coastguard Worker [2.4988, 1.4310]], grad_fn=<MulBackward0>)) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker >>> def pow_adder_reducer(x, y): 1119*da0073e9SAndroid Build Coastguard Worker ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 1120*da0073e9SAndroid Build Coastguard Worker >>> inputs = (torch.rand(2), torch.rand(2)) 1121*da0073e9SAndroid Build Coastguard Worker >>> v = (torch.zeros(2), torch.ones(2)) 1122*da0073e9SAndroid Build Coastguard Worker >>> hvp(pow_adder_reducer, inputs, v) 1123*da0073e9SAndroid Build Coastguard Worker (tensor(2.3030), 1124*da0073e9SAndroid Build Coastguard Worker (tensor([0., 0.]), 1125*da0073e9SAndroid Build Coastguard Worker tensor([6., 6.]))) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker Note: 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker This function is significantly slower than `vhp` due to backward mode AD constraints. 1130*da0073e9SAndroid Build Coastguard Worker If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you 1131*da0073e9SAndroid Build Coastguard Worker know that your function satisfies this condition, you should use vhp instead that is 1132*da0073e9SAndroid Build Coastguard Worker much faster with the current implementation. 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker """ 1135*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 1136*da0073e9SAndroid Build Coastguard Worker is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") 1137*da0073e9SAndroid Build Coastguard Worker inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker if v is not None: 1140*da0073e9SAndroid Build Coastguard Worker _, v = _as_tuple(v, "v", "hvp") 1141*da0073e9SAndroid Build Coastguard Worker v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 1142*da0073e9SAndroid Build Coastguard Worker _validate_v(v, inputs, is_inputs_tuple) 1143*da0073e9SAndroid Build Coastguard Worker else: 1144*da0073e9SAndroid Build Coastguard Worker if len(inputs) != 1 or inputs[0].nelement() != 1: 1145*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1146*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the input to the user-provided function " 1147*da0073e9SAndroid Build Coastguard Worker "is a single Tensor with a single element." 1148*da0073e9SAndroid Build Coastguard Worker ) 1149*da0073e9SAndroid Build Coastguard Worker outputs = func(*inputs) 1150*da0073e9SAndroid Build Coastguard Worker is_outputs_tuple, outputs = _as_tuple( 1151*da0073e9SAndroid Build Coastguard Worker outputs, "outputs of the user-provided function", "hvp" 1152*da0073e9SAndroid Build Coastguard Worker ) 1153*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(outputs, "outputs", strict=strict) 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): 1156*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1157*da0073e9SAndroid Build Coastguard Worker "The function given to hvp should return a single Tensor" 1158*da0073e9SAndroid Build Coastguard Worker ) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker if outputs[0].nelement() != 1: 1161*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1162*da0073e9SAndroid Build Coastguard Worker "The Tensor returned by the function given to hvp should contain a single element" 1163*da0073e9SAndroid Build Coastguard Worker ) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker jac = _autograd_grad(outputs, inputs, create_graph=True) 1166*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(jac, "jacobian", strict=strict) 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs) 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) 1171*da0073e9SAndroid Build Coastguard Worker _check_requires_grad(jac, "hessian", strict=strict) 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker enable_grad = True if create_graph else torch.is_grad_enabled() 1174*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(enable_grad): 1175*da0073e9SAndroid Build Coastguard Worker grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) 1176*da0073e9SAndroid Build Coastguard Worker hvp = _fill_in_zeros( 1177*da0073e9SAndroid Build Coastguard Worker grad_res, inputs, strict, create_graph, "double_back_trick" 1178*da0073e9SAndroid Build Coastguard Worker ) 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker outputs = _grad_postprocess(outputs, create_graph) 1181*da0073e9SAndroid Build Coastguard Worker hvp = _grad_postprocess(hvp, create_graph) 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 1184*da0073e9SAndroid Build Coastguard Worker hvp, is_inputs_tuple 1185*da0073e9SAndroid Build Coastguard Worker ) 1186