1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport platform 5*da0073e9SAndroid Build Coastguard Workerimport uuid 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Workerimport weakref 8*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 9*da0073e9SAndroid Build Coastguard Workerfrom typing import * # noqa: F403 10*da0073e9SAndroid Build Coastguard Workerimport enum 11*da0073e9SAndroid Build Coastguard Workerfrom weakref import ReferenceType 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback 15*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch._aot_autograd.functional_utils import is_fun 16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode 18*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker__all__ = [ 21*da0073e9SAndroid Build Coastguard Worker "checkpoint", 22*da0073e9SAndroid Build Coastguard Worker "checkpoint_sequential", 23*da0073e9SAndroid Build Coastguard Worker "CheckpointError", 24*da0073e9SAndroid Build Coastguard Worker "CheckpointFunction", 25*da0073e9SAndroid Build Coastguard Worker "check_backward_validity", 26*da0073e9SAndroid Build Coastguard Worker "detach_variable", 27*da0073e9SAndroid Build Coastguard Worker "get_device_states", 28*da0073e9SAndroid Build Coastguard Worker "set_device_states", 29*da0073e9SAndroid Build Coastguard Worker "noop_context_fn", 30*da0073e9SAndroid Build Coastguard Worker "set_checkpoint_early_stop", 31*da0073e9SAndroid Build Coastguard Worker "DefaultDeviceType", 32*da0073e9SAndroid Build Coastguard Worker "set_checkpoint_debug_enabled", 33*da0073e9SAndroid Build Coastguard Worker "CheckpointPolicy", 34*da0073e9SAndroid Build Coastguard Worker "SelectiveCheckpointContext", 35*da0073e9SAndroid Build Coastguard Worker "create_selective_checkpoint_contexts", 36*da0073e9SAndroid Build Coastguard Worker "SAC_IGNORED_OPS", 37*da0073e9SAndroid Build Coastguard Worker] 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker_DEFAULT_DETERMINISM_MODE = "default" 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker_checkpoint_debug_enabled: Optional[bool] = None 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 45*da0073e9SAndroid Build Coastguard Workerdef set_checkpoint_debug_enabled(enabled: Optional[bool]): 46*da0073e9SAndroid Build Coastguard Worker """ 47*da0073e9SAndroid Build Coastguard Worker Context manager that sets whether checkpoint should print additional debug 48*da0073e9SAndroid Build Coastguard Worker information when running. See the ``debug`` flag for 49*da0073e9SAndroid Build Coastguard Worker :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that 50*da0073e9SAndroid Build Coastguard Worker when set, this context manager overrides the value of ``debug`` passed to 51*da0073e9SAndroid Build Coastguard Worker checkpoint. To defer to the local setting, pass ``None`` to this context. 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker Args: 54*da0073e9SAndroid Build Coastguard Worker enabled (bool): Whether checkpoint should print debug information. 55*da0073e9SAndroid Build Coastguard Worker Default is 'None'. 56*da0073e9SAndroid Build Coastguard Worker """ 57*da0073e9SAndroid Build Coastguard Worker global _checkpoint_debug_enabled 58*da0073e9SAndroid Build Coastguard Worker try: 59*da0073e9SAndroid Build Coastguard Worker prev = _checkpoint_debug_enabled 60*da0073e9SAndroid Build Coastguard Worker _checkpoint_debug_enabled = enabled 61*da0073e9SAndroid Build Coastguard Worker yield 62*da0073e9SAndroid Build Coastguard Worker finally: 63*da0073e9SAndroid Build Coastguard Worker _checkpoint_debug_enabled = prev 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workerdef detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: 67*da0073e9SAndroid Build Coastguard Worker if isinstance(inputs, tuple): 68*da0073e9SAndroid Build Coastguard Worker out = [] 69*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 70*da0073e9SAndroid Build Coastguard Worker if not isinstance(inp, torch.Tensor): 71*da0073e9SAndroid Build Coastguard Worker out.append(inp) 72*da0073e9SAndroid Build Coastguard Worker continue 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker x = inp.detach() 75*da0073e9SAndroid Build Coastguard Worker x.requires_grad = inp.requires_grad 76*da0073e9SAndroid Build Coastguard Worker out.append(x) 77*da0073e9SAndroid Build Coastguard Worker return tuple(out) 78*da0073e9SAndroid Build Coastguard Worker else: 79*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 80*da0073e9SAndroid Build Coastguard Worker "Only tuple of tensors is supported. Got Unsupported input type: ", 81*da0073e9SAndroid Build Coastguard Worker type(inputs).__name__, 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef check_backward_validity(inputs: Iterable[Any]) -> None: 86*da0073e9SAndroid Build Coastguard Worker if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): 87*da0073e9SAndroid Build Coastguard Worker warnings.warn( 88*da0073e9SAndroid Build Coastguard Worker "None of the inputs have requires_grad=True. Gradients will be None" 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerdef _get_device_module(device="cuda"): 93*da0073e9SAndroid Build Coastguard Worker if device == "meta": 94*da0073e9SAndroid Build Coastguard Worker return torch.device("meta") 95*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, device) 96*da0073e9SAndroid Build Coastguard Worker return device_module 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerclass DefaultDeviceType: 100*da0073e9SAndroid Build Coastguard Worker r""" 101*da0073e9SAndroid Build Coastguard Worker A class that manages the default device type for checkpointing. 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker If no non-CPU tensors are present, the default device type will 104*da0073e9SAndroid Build Coastguard Worker be used. The default value is 'cuda'. The device type is used in 105*da0073e9SAndroid Build Coastguard Worker the checkpointing process when determining which device states 106*da0073e9SAndroid Build Coastguard Worker to save and restore for recomputation. 107*da0073e9SAndroid Build Coastguard Worker """ 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker _default_device_type = "cuda" 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker @staticmethod 112*da0073e9SAndroid Build Coastguard Worker def set_device_type(device: str = "cuda"): 113*da0073e9SAndroid Build Coastguard Worker """ 114*da0073e9SAndroid Build Coastguard Worker Set the default device type for checkpointing. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker Args: 117*da0073e9SAndroid Build Coastguard Worker device (str): The device type to be set as default. Default is 'cuda'. 118*da0073e9SAndroid Build Coastguard Worker """ 119*da0073e9SAndroid Build Coastguard Worker DefaultDeviceType._default_device_type = device 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker @staticmethod 122*da0073e9SAndroid Build Coastguard Worker def get_device_type() -> str: 123*da0073e9SAndroid Build Coastguard Worker """ 124*da0073e9SAndroid Build Coastguard Worker Get the current default device type for checkpointing. 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker Returns: 127*da0073e9SAndroid Build Coastguard Worker str: The current default device type. 128*da0073e9SAndroid Build Coastguard Worker """ 129*da0073e9SAndroid Build Coastguard Worker return DefaultDeviceType._default_device_type 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Workerdef _infer_device_type(*args): 133*da0073e9SAndroid Build Coastguard Worker device_types = [] 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def add_device_types(arg): 136*da0073e9SAndroid Build Coastguard Worker nonlocal device_types 137*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": 138*da0073e9SAndroid Build Coastguard Worker device_types.append(arg.device.type) 139*da0073e9SAndroid Build Coastguard Worker tree_map(add_device_types, args) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker device_types_set = set(device_types) 142*da0073e9SAndroid Build Coastguard Worker if len(device_types_set) > 1: 143*da0073e9SAndroid Build Coastguard Worker warnings.warn( 144*da0073e9SAndroid Build Coastguard Worker "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " 145*da0073e9SAndroid Build Coastguard Worker "Device state will only be saved for devices of a single device type, and the remaining " 146*da0073e9SAndroid Build Coastguard Worker "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " 147*da0073e9SAndroid Build Coastguard Worker "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " 148*da0073e9SAndroid Build Coastguard Worker "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" 149*da0073e9SAndroid Build Coastguard Worker f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}" 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker if len(device_types) == 0: 152*da0073e9SAndroid Build Coastguard Worker return DefaultDeviceType.get_device_type() 153*da0073e9SAndroid Build Coastguard Worker elif "cuda" in device_types_set: 154*da0073e9SAndroid Build Coastguard Worker return "cuda" 155*da0073e9SAndroid Build Coastguard Worker else: 156*da0073e9SAndroid Build Coastguard Worker return device_types[0] 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker# We can't know if the run_fn will internally move some args to different devices, 160*da0073e9SAndroid Build Coastguard Worker# which would require logic to preserve rng states for those devices as well. 161*da0073e9SAndroid Build Coastguard Worker# We could paranoically stash and restore ALL the rng states for all visible devices, 162*da0073e9SAndroid Build Coastguard Worker# but that seems very wasteful for most cases. Compromise: Stash the RNG state for 163*da0073e9SAndroid Build Coastguard Worker# the device of all Tensor args. 164*da0073e9SAndroid Build Coastguard Worker# 165*da0073e9SAndroid Build Coastguard Worker# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? 166*da0073e9SAndroid Build Coastguard Workerdef get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: 167*da0073e9SAndroid Build Coastguard Worker # This will not error out if "arg" is a CPU tensor or a non-tensor type because 168*da0073e9SAndroid Build Coastguard Worker # the conditionals short-circuit. 169*da0073e9SAndroid Build Coastguard Worker fwd_device_ids = [] 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker def add_device_ids(arg): 172*da0073e9SAndroid Build Coastguard Worker nonlocal fwd_device_ids 173*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: 174*da0073e9SAndroid Build Coastguard Worker fwd_device_ids.append(arg.get_device()) 175*da0073e9SAndroid Build Coastguard Worker tree_map(add_device_ids, args) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker fwd_device_states = [] 178*da0073e9SAndroid Build Coastguard Worker device_module = _get_device_module(_infer_device_type(*args)) 179*da0073e9SAndroid Build Coastguard Worker for device_id in fwd_device_ids: 180*da0073e9SAndroid Build Coastguard Worker with device_module.device(device_id): 181*da0073e9SAndroid Build Coastguard Worker fwd_device_states.append(device_module.get_rng_state()) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker return fwd_device_ids, fwd_device_states 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef set_device_states(devices, states, *, device_type=None) -> None: 187*da0073e9SAndroid Build Coastguard Worker """Sets random number generator states for the specified devices. 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker Args: 190*da0073e9SAndroid Build Coastguard Worker devices: Device ids to set states for. 191*da0073e9SAndroid Build Coastguard Worker states: States to set. 192*da0073e9SAndroid Build Coastguard Worker device_type: ``device_type`` of the devices to set states for. Default 193*da0073e9SAndroid Build Coastguard Worker is the device returned by a call to ``DefaultDeviceType.get_device_type()``, 194*da0073e9SAndroid Build Coastguard Worker which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. 195*da0073e9SAndroid Build Coastguard Worker """ 196*da0073e9SAndroid Build Coastguard Worker if device_type is None: 197*da0073e9SAndroid Build Coastguard Worker device_type = DefaultDeviceType.get_device_type() 198*da0073e9SAndroid Build Coastguard Worker if device_type == "meta": 199*da0073e9SAndroid Build Coastguard Worker return 200*da0073e9SAndroid Build Coastguard Worker device_module = _get_device_module(device_type) 201*da0073e9SAndroid Build Coastguard Worker for device, state in zip(devices, states): 202*da0073e9SAndroid Build Coastguard Worker with device_module.device(device): 203*da0073e9SAndroid Build Coastguard Worker device_module.set_rng_state(state) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerdef _get_autocast_kwargs(device_type="cuda"): 207*da0073e9SAndroid Build Coastguard Worker if torch.amp.is_autocast_available(device_type): 208*da0073e9SAndroid Build Coastguard Worker device_autocast_kwargs = { 209*da0073e9SAndroid Build Coastguard Worker "enabled": torch.is_autocast_enabled(device_type), 210*da0073e9SAndroid Build Coastguard Worker "dtype": torch.get_autocast_dtype(device_type), 211*da0073e9SAndroid Build Coastguard Worker "cache_enabled": torch.is_autocast_cache_enabled(), 212*da0073e9SAndroid Build Coastguard Worker } 213*da0073e9SAndroid Build Coastguard Worker else: 214*da0073e9SAndroid Build Coastguard Worker device_autocast_kwargs = None 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker cpu_autocast_kwargs = { 217*da0073e9SAndroid Build Coastguard Worker "enabled": torch.is_autocast_enabled('cpu'), 218*da0073e9SAndroid Build Coastguard Worker "dtype": torch.get_autocast_dtype('cpu'), 219*da0073e9SAndroid Build Coastguard Worker "cache_enabled": torch.is_autocast_cache_enabled(), 220*da0073e9SAndroid Build Coastguard Worker } 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker return device_autocast_kwargs, cpu_autocast_kwargs 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Workerclass CheckpointFunction(torch.autograd.Function): 226*da0073e9SAndroid Build Coastguard Worker @staticmethod 227*da0073e9SAndroid Build Coastguard Worker def forward(ctx, run_function, preserve_rng_state, *args): 228*da0073e9SAndroid Build Coastguard Worker check_backward_validity(args) 229*da0073e9SAndroid Build Coastguard Worker ctx.run_function = run_function 230*da0073e9SAndroid Build Coastguard Worker ctx.preserve_rng_state = preserve_rng_state 231*da0073e9SAndroid Build Coastguard Worker # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 232*da0073e9SAndroid Build Coastguard Worker ctx.device_type = _infer_device_type(*args) 233*da0073e9SAndroid Build Coastguard Worker ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( 234*da0073e9SAndroid Build Coastguard Worker ctx.device_type 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker if preserve_rng_state: 237*da0073e9SAndroid Build Coastguard Worker ctx.fwd_cpu_state = torch.get_rng_state() 238*da0073e9SAndroid Build Coastguard Worker # Don't eagerly initialize the cuda context by accident. 239*da0073e9SAndroid Build Coastguard Worker # (If the user intends that the context is initialized later, within their 240*da0073e9SAndroid Build Coastguard Worker # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 241*da0073e9SAndroid Build Coastguard Worker # we have no way to anticipate this will happen before we run the function.) 242*da0073e9SAndroid Build Coastguard Worker ctx.had_device_in_fwd = False 243*da0073e9SAndroid Build Coastguard Worker device_module = _get_device_module(ctx.device_type) 244*da0073e9SAndroid Build Coastguard Worker if getattr(device_module, "_initialized", False): 245*da0073e9SAndroid Build Coastguard Worker ctx.had_device_in_fwd = True 246*da0073e9SAndroid Build Coastguard Worker ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker # Save non-tensor inputs in ctx, keep a placeholder None for tensors 249*da0073e9SAndroid Build Coastguard Worker # to be filled out during the backward. 250*da0073e9SAndroid Build Coastguard Worker ctx.inputs = [] 251*da0073e9SAndroid Build Coastguard Worker ctx.tensor_indices = [] 252*da0073e9SAndroid Build Coastguard Worker tensor_inputs = [] 253*da0073e9SAndroid Build Coastguard Worker for i, arg in enumerate(args): 254*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(arg): 255*da0073e9SAndroid Build Coastguard Worker tensor_inputs.append(arg) 256*da0073e9SAndroid Build Coastguard Worker ctx.tensor_indices.append(i) 257*da0073e9SAndroid Build Coastguard Worker ctx.inputs.append(None) 258*da0073e9SAndroid Build Coastguard Worker else: 259*da0073e9SAndroid Build Coastguard Worker ctx.inputs.append(arg) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(*tensor_inputs) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 264*da0073e9SAndroid Build Coastguard Worker outputs = run_function(*args) 265*da0073e9SAndroid Build Coastguard Worker return outputs 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @staticmethod 268*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *args): 269*da0073e9SAndroid Build Coastguard Worker if not torch.autograd._is_checkpoint_valid(): 270*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 271*da0073e9SAndroid Build Coastguard Worker "When use_reentrant=True, torch.utils.checkpoint is incompatible" 272*da0073e9SAndroid Build Coastguard Worker " with .grad() or passing an `inputs` parameter to .backward()." 273*da0073e9SAndroid Build Coastguard Worker " To resolve this error, you can either set use_reentrant=False," 274*da0073e9SAndroid Build Coastguard Worker " or call .backward() without passing the `inputs` argument." 275*da0073e9SAndroid Build Coastguard Worker ) 276*da0073e9SAndroid Build Coastguard Worker # Copy the list to avoid modifying original list. 277*da0073e9SAndroid Build Coastguard Worker inputs = list(ctx.inputs) 278*da0073e9SAndroid Build Coastguard Worker tensor_indices = ctx.tensor_indices 279*da0073e9SAndroid Build Coastguard Worker tensors = ctx.saved_tensors 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker # Fill in inputs with appropriate saved tensors. 282*da0073e9SAndroid Build Coastguard Worker for i, idx in enumerate(tensor_indices): 283*da0073e9SAndroid Build Coastguard Worker inputs[idx] = tensors[i] 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker # Stash the surrounding rng state, and mimic the state that was 286*da0073e9SAndroid Build Coastguard Worker # present at this time during forward. Restore the surrounding state 287*da0073e9SAndroid Build Coastguard Worker # when we're done. 288*da0073e9SAndroid Build Coastguard Worker rng_devices = [] 289*da0073e9SAndroid Build Coastguard Worker if ctx.preserve_rng_state and ctx.had_device_in_fwd: 290*da0073e9SAndroid Build Coastguard Worker rng_devices = ctx.fwd_devices 291*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng( 292*da0073e9SAndroid Build Coastguard Worker devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type 293*da0073e9SAndroid Build Coastguard Worker ): 294*da0073e9SAndroid Build Coastguard Worker if ctx.preserve_rng_state: 295*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(ctx.fwd_cpu_state) 296*da0073e9SAndroid Build Coastguard Worker if ctx.had_device_in_fwd: 297*da0073e9SAndroid Build Coastguard Worker set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) 298*da0073e9SAndroid Build Coastguard Worker detached_inputs = detach_variable(tuple(inputs)) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker device_autocast_ctx = torch.amp.autocast( 301*da0073e9SAndroid Build Coastguard Worker device_type=ctx.device_type, **ctx.device_autocast_kwargs 302*da0073e9SAndroid Build Coastguard Worker ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() 303*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] 304*da0073e9SAndroid Build Coastguard Worker outputs = ctx.run_function(*detached_inputs) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker if isinstance(outputs, torch.Tensor): 307*da0073e9SAndroid Build Coastguard Worker outputs = (outputs,) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker # run backward() with only tensor that requires grad 310*da0073e9SAndroid Build Coastguard Worker outputs_with_grad = [] 311*da0073e9SAndroid Build Coastguard Worker args_with_grad = [] 312*da0073e9SAndroid Build Coastguard Worker for i in range(len(outputs)): 313*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: 314*da0073e9SAndroid Build Coastguard Worker outputs_with_grad.append(outputs[i]) 315*da0073e9SAndroid Build Coastguard Worker args_with_grad.append(args[i]) 316*da0073e9SAndroid Build Coastguard Worker if len(outputs_with_grad) == 0: 317*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 318*da0073e9SAndroid Build Coastguard Worker "none of output has requires_grad=True," 319*da0073e9SAndroid Build Coastguard Worker " this checkpoint() is not necessary" 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(outputs_with_grad, args_with_grad) 322*da0073e9SAndroid Build Coastguard Worker grads = tuple( 323*da0073e9SAndroid Build Coastguard Worker inp.grad if isinstance(inp, torch.Tensor) else None 324*da0073e9SAndroid Build Coastguard Worker for inp in detached_inputs 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker return (None, None) + grads 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Workerdef noop_context_fn(): 331*da0073e9SAndroid Build Coastguard Worker return contextlib.nullcontext(), contextlib.nullcontext() 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker# TorchDynamo does not step inside utils.checkpoint function. The flow 334*da0073e9SAndroid Build Coastguard Worker# looks likes this 335*da0073e9SAndroid Build Coastguard Worker# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by 336*da0073e9SAndroid Build Coastguard Worker# speculatively checking if the forward function is safe to trace. 337*da0073e9SAndroid Build Coastguard Worker# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher 338*da0073e9SAndroid Build Coastguard Worker# order op. As a result, TorchDynamo does not look inside utils.checkpoint. 339*da0073e9SAndroid Build Coastguard Worker# 3) If not, then TorchDynamo falls back to eager by performing a graph 340*da0073e9SAndroid Build Coastguard Worker# break. And here, the following disable wrapper ensures that 341*da0073e9SAndroid Build Coastguard Worker# TorchDynamo does not trigger again on the frames created by 342*da0073e9SAndroid Build Coastguard Worker# utils.checkpoint innards. 343*da0073e9SAndroid Build Coastguard Worker@torch._disable_dynamo 344*da0073e9SAndroid Build Coastguard Workerdef checkpoint( 345*da0073e9SAndroid Build Coastguard Worker function, 346*da0073e9SAndroid Build Coastguard Worker *args, 347*da0073e9SAndroid Build Coastguard Worker use_reentrant: Optional[bool] = None, 348*da0073e9SAndroid Build Coastguard Worker context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, 349*da0073e9SAndroid Build Coastguard Worker determinism_check: str = _DEFAULT_DETERMINISM_MODE, 350*da0073e9SAndroid Build Coastguard Worker debug: bool = False, 351*da0073e9SAndroid Build Coastguard Worker **kwargs 352*da0073e9SAndroid Build Coastguard Worker): 353*da0073e9SAndroid Build Coastguard Worker r"""Checkpoint a model or part of the model. 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker Activation checkpointing is a technique that trades compute for memory. 356*da0073e9SAndroid Build Coastguard Worker Instead of keeping tensors needed for backward alive until they are used in 357*da0073e9SAndroid Build Coastguard Worker gradient computation during backward, forward computation in checkpointed 358*da0073e9SAndroid Build Coastguard Worker regions omits saving tensors for backward and recomputes them during the 359*da0073e9SAndroid Build Coastguard Worker backward pass. Activation checkpointing can be applied to any part of a 360*da0073e9SAndroid Build Coastguard Worker model. 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker There are currently two checkpointing implementations available, determined 363*da0073e9SAndroid Build Coastguard Worker by the :attr:`use_reentrant` parameter. It is recommended that you use 364*da0073e9SAndroid Build Coastguard Worker ``use_reentrant=False``. Please refer the note below for a discussion of 365*da0073e9SAndroid Build Coastguard Worker their differences. 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker .. warning:: 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker If the :attr:`function` invocation during the backward pass differs 370*da0073e9SAndroid Build Coastguard Worker from the forward pass, e.g., due to a global variable, the checkpointed 371*da0073e9SAndroid Build Coastguard Worker version may not be equivalent, potentially causing an 372*da0073e9SAndroid Build Coastguard Worker error being raised or leading to silently incorrect gradients. 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker .. warning:: 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker The ``use_reentrant`` parameter should be passed explicitly. In version 377*da0073e9SAndroid Build Coastguard Worker 2.4 we will raise an exception if ``use_reentrant`` is not passed. 378*da0073e9SAndroid Build Coastguard Worker If you are using the ``use_reentrant=True`` variant, please refer to the 379*da0073e9SAndroid Build Coastguard Worker note below for important considerations and potential limitations. 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker .. note:: 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker The reentrant variant of checkpoint (``use_reentrant=True``) and 384*da0073e9SAndroid Build Coastguard Worker the non-reentrant variant of checkpoint (``use_reentrant=False``) 385*da0073e9SAndroid Build Coastguard Worker differ in the following ways: 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker * Non-reentrant checkpoint stops recomputation as soon as all needed 388*da0073e9SAndroid Build Coastguard Worker intermediate activations have been recomputed. This feature is enabled 389*da0073e9SAndroid Build Coastguard Worker by default, but can be disabled with :func:`set_checkpoint_early_stop`. 390*da0073e9SAndroid Build Coastguard Worker Reentrant checkpoint always recomputes :attr:`function` in its 391*da0073e9SAndroid Build Coastguard Worker entirety during the backward pass. 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker * The reentrant variant does not record the autograd graph during the 394*da0073e9SAndroid Build Coastguard Worker forward pass, as it runs with the forward pass under 395*da0073e9SAndroid Build Coastguard Worker :func:`torch.no_grad`. The non-reentrant version does record the 396*da0073e9SAndroid Build Coastguard Worker autograd graph, allowing one to perform backward on the graph within 397*da0073e9SAndroid Build Coastguard Worker checkpointed regions. 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker * The reentrant checkpoint only supports the 400*da0073e9SAndroid Build Coastguard Worker :func:`torch.autograd.backward` API for the backward pass without its 401*da0073e9SAndroid Build Coastguard Worker `inputs` argument, while the non-reentrant version supports all ways 402*da0073e9SAndroid Build Coastguard Worker of performing the backward pass. 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker * At least one input and output must have ``requires_grad=True`` for the 405*da0073e9SAndroid Build Coastguard Worker reentrant variant. If this condition is unmet, the checkpointed part 406*da0073e9SAndroid Build Coastguard Worker of the model will not have gradients. The non-reentrant version does 407*da0073e9SAndroid Build Coastguard Worker not have this requirement. 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker * The reentrant version does not consider tensors in nested structures 410*da0073e9SAndroid Build Coastguard Worker (e.g., custom objects, lists, dicts, etc) as participating in 411*da0073e9SAndroid Build Coastguard Worker autograd, while the non-reentrant version does. 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker * The reentrant checkpoint does not support checkpointed regions with 414*da0073e9SAndroid Build Coastguard Worker detached tensors from the computational graph, whereas the 415*da0073e9SAndroid Build Coastguard Worker non-reentrant version does. For the reentrant variant, if the 416*da0073e9SAndroid Build Coastguard Worker checkpointed segment contains tensors detached using ``detach()`` or 417*da0073e9SAndroid Build Coastguard Worker with :func:`torch.no_grad`, the backward pass will raise an error. 418*da0073e9SAndroid Build Coastguard Worker This is because ``checkpoint`` makes all the outputs require gradients 419*da0073e9SAndroid Build Coastguard Worker and this causes issues when a tensor is defined to have no gradient in 420*da0073e9SAndroid Build Coastguard Worker the model. To avoid this, detach the tensors outside of the 421*da0073e9SAndroid Build Coastguard Worker ``checkpoint`` function. 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker Args: 424*da0073e9SAndroid Build Coastguard Worker function: describes what to run in the forward pass of the model or 425*da0073e9SAndroid Build Coastguard Worker part of the model. It should also know how to handle the inputs 426*da0073e9SAndroid Build Coastguard Worker passed as the tuple. For example, in LSTM, if user passes 427*da0073e9SAndroid Build Coastguard Worker ``(activation, hidden)``, :attr:`function` should correctly use the 428*da0073e9SAndroid Build Coastguard Worker first input as ``activation`` and the second input as ``hidden`` 429*da0073e9SAndroid Build Coastguard Worker preserve_rng_state(bool, optional): Omit stashing and restoring 430*da0073e9SAndroid Build Coastguard Worker the RNG state during each checkpoint. Note that under torch.compile, 431*da0073e9SAndroid Build Coastguard Worker this flag doesn't take effect and we always preserve RNG state. 432*da0073e9SAndroid Build Coastguard Worker Default: ``True`` 433*da0073e9SAndroid Build Coastguard Worker use_reentrant(bool): 434*da0073e9SAndroid Build Coastguard Worker specify whether to use the activation checkpoint variant that 435*da0073e9SAndroid Build Coastguard Worker requires reentrant autograd. This parameter should be passed 436*da0073e9SAndroid Build Coastguard Worker explicitly. In version 2.5 we will raise an exception if 437*da0073e9SAndroid Build Coastguard Worker ``use_reentrant`` is not passed. If ``use_reentrant=False``, 438*da0073e9SAndroid Build Coastguard Worker ``checkpoint`` will use an implementation that does not require 439*da0073e9SAndroid Build Coastguard Worker reentrant autograd. This allows ``checkpoint`` to support additional 440*da0073e9SAndroid Build Coastguard Worker functionality, such as working as expected with 441*da0073e9SAndroid Build Coastguard Worker ``torch.autograd.grad`` and support for keyword arguments input into 442*da0073e9SAndroid Build Coastguard Worker the checkpointed function. 443*da0073e9SAndroid Build Coastguard Worker context_fn(Callable, optional): A callable returning a tuple of two 444*da0073e9SAndroid Build Coastguard Worker context managers. The function and its recomputation will be run 445*da0073e9SAndroid Build Coastguard Worker under the first and second context managers respectively. 446*da0073e9SAndroid Build Coastguard Worker This argument is only supported if ``use_reentrant=False``. 447*da0073e9SAndroid Build Coastguard Worker determinism_check(str, optional): A string specifying the determinism 448*da0073e9SAndroid Build Coastguard Worker check to perform. By default it is set to ``"default"`` which 449*da0073e9SAndroid Build Coastguard Worker compares the shapes, dtypes, and devices of the recomputed tensors 450*da0073e9SAndroid Build Coastguard Worker against those the saved tensors. To turn off this check, specify 451*da0073e9SAndroid Build Coastguard Worker ``"none"``. Currently these are the only two supported values. 452*da0073e9SAndroid Build Coastguard Worker Please open an issue if you would like to see more determinism 453*da0073e9SAndroid Build Coastguard Worker checks. This argument is only supported if ``use_reentrant=False``, 454*da0073e9SAndroid Build Coastguard Worker if ``use_reentrant=True``, the determinism check is always disabled. 455*da0073e9SAndroid Build Coastguard Worker debug(bool, optional): If ``True``, error messages will also include 456*da0073e9SAndroid Build Coastguard Worker a trace of the operators ran during the original forward computation 457*da0073e9SAndroid Build Coastguard Worker as well as the recomputation. This argument is only supported if 458*da0073e9SAndroid Build Coastguard Worker ``use_reentrant=False``. 459*da0073e9SAndroid Build Coastguard Worker args: tuple containing inputs to the :attr:`function` 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker Returns: 462*da0073e9SAndroid Build Coastguard Worker Output of running :attr:`function` on :attr:`*args` 463*da0073e9SAndroid Build Coastguard Worker """ 464*da0073e9SAndroid Build Coastguard Worker if use_reentrant is None: 465*da0073e9SAndroid Build Coastguard Worker warnings.warn( 466*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint: the use_reentrant parameter should be " 467*da0073e9SAndroid Build Coastguard Worker "passed explicitly. In version 2.5 we will raise an exception " 468*da0073e9SAndroid Build Coastguard Worker "if use_reentrant is not passed. use_reentrant=False is " 469*da0073e9SAndroid Build Coastguard Worker "recommended, but if you need to preserve the current default " 470*da0073e9SAndroid Build Coastguard Worker "behavior, you can pass use_reentrant=True. Refer to docs for more " 471*da0073e9SAndroid Build Coastguard Worker "details on the differences between the two variants.", 472*da0073e9SAndroid Build Coastguard Worker stacklevel=2 473*da0073e9SAndroid Build Coastguard Worker ) 474*da0073e9SAndroid Build Coastguard Worker use_reentrant = True 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker # Hack to mix *args with **kwargs in a python 2.7-compliant way 477*da0073e9SAndroid Build Coastguard Worker preserve = kwargs.pop("preserve_rng_state", True) 478*da0073e9SAndroid Build Coastguard Worker if kwargs and use_reentrant: 479*da0073e9SAndroid Build Coastguard Worker raise ValueError( 480*da0073e9SAndroid Build Coastguard Worker "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 481*da0073e9SAndroid Build Coastguard Worker ) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker if use_reentrant: 484*da0073e9SAndroid Build Coastguard Worker if context_fn is not noop_context_fn or debug is not False: 485*da0073e9SAndroid Build Coastguard Worker raise ValueError( 486*da0073e9SAndroid Build Coastguard Worker "Passing `context_fn` or `debug` is only supported when " 487*da0073e9SAndroid Build Coastguard Worker "use_reentrant=False." 488*da0073e9SAndroid Build Coastguard Worker ) 489*da0073e9SAndroid Build Coastguard Worker return CheckpointFunction.apply(function, preserve, *args) 490*da0073e9SAndroid Build Coastguard Worker else: 491*da0073e9SAndroid Build Coastguard Worker gen = _checkpoint_without_reentrant_generator( 492*da0073e9SAndroid Build Coastguard Worker function, preserve, context_fn, determinism_check, debug, *args, **kwargs 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker # Runs pre-forward logic 495*da0073e9SAndroid Build Coastguard Worker next(gen) 496*da0073e9SAndroid Build Coastguard Worker ret = function(*args, **kwargs) 497*da0073e9SAndroid Build Coastguard Worker # Runs post-forward logic 498*da0073e9SAndroid Build Coastguard Worker try: 499*da0073e9SAndroid Build Coastguard Worker next(gen) 500*da0073e9SAndroid Build Coastguard Worker except StopIteration: 501*da0073e9SAndroid Build Coastguard Worker return ret 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Workerdef checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): 505*da0073e9SAndroid Build Coastguard Worker r"""Checkpoint a sequential model to save memory. 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker Sequential models execute a list of modules/functions in order 508*da0073e9SAndroid Build Coastguard Worker (sequentially). Therefore, we can divide such a model in various segments 509*da0073e9SAndroid Build Coastguard Worker and checkpoint each segment. All segments except the last will not store 510*da0073e9SAndroid Build Coastguard Worker the intermediate activations. The inputs of each checkpointed segment will 511*da0073e9SAndroid Build Coastguard Worker be saved for re-running the segment in the backward pass. 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker .. warning:: 514*da0073e9SAndroid Build Coastguard Worker The ``use_reentrant`` parameter should be passed explicitly. In version 515*da0073e9SAndroid Build Coastguard Worker 2.4 we will raise an exception if ``use_reentrant`` is not passed. 516*da0073e9SAndroid Build Coastguard Worker If you are using the ``use_reentrant=True` variant, please see 517*da0073e9SAndroid Build Coastguard Worker :func:`~torch.utils.checkpoint.checkpoint` for 518*da0073e9SAndroid Build Coastguard Worker the important considerations and limitations of this variant. It is 519*da0073e9SAndroid Build Coastguard Worker recommended that you use ``use_reentrant=False``. 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker .. warning: 522*da0073e9SAndroid Build Coastguard Worker Since PyTorch 1.4, it allows only one Tensor as the input and 523*da0073e9SAndroid Build Coastguard Worker intermediate outputs, just like :class:`torch.nn.Sequential`. 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker Args: 526*da0073e9SAndroid Build Coastguard Worker functions: A :class:`torch.nn.Sequential` or the list of modules or 527*da0073e9SAndroid Build Coastguard Worker functions (comprising the model) to run sequentially. 528*da0073e9SAndroid Build Coastguard Worker segments: Number of chunks to create in the model 529*da0073e9SAndroid Build Coastguard Worker input: A Tensor that is input to :attr:`functions` 530*da0073e9SAndroid Build Coastguard Worker preserve_rng_state(bool, optional): Omit stashing and restoring 531*da0073e9SAndroid Build Coastguard Worker the RNG state during each checkpoint. 532*da0073e9SAndroid Build Coastguard Worker Default: ``True`` 533*da0073e9SAndroid Build Coastguard Worker use_reentrant(bool): 534*da0073e9SAndroid Build Coastguard Worker specify whether to use the activation checkpoint variant that 535*da0073e9SAndroid Build Coastguard Worker requires reentrant autograd. This parameter should be passed 536*da0073e9SAndroid Build Coastguard Worker explicitly. In version 2.5 we will raise an exception if 537*da0073e9SAndroid Build Coastguard Worker ``use_reentrant`` is not passed. If ``use_reentrant=False``, 538*da0073e9SAndroid Build Coastguard Worker ``checkpoint`` will use an implementation that does not require 539*da0073e9SAndroid Build Coastguard Worker reentrant autograd. This allows ``checkpoint`` to support additional 540*da0073e9SAndroid Build Coastguard Worker functionality, such as working as expected with 541*da0073e9SAndroid Build Coastguard Worker ``torch.autograd.grad`` and support for keyword arguments input into 542*da0073e9SAndroid Build Coastguard Worker the checkpointed function. 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker Returns: 545*da0073e9SAndroid Build Coastguard Worker Output of running :attr:`functions` sequentially on :attr:`*inputs` 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker Example: 548*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("stub") 549*da0073e9SAndroid Build Coastguard Worker >>> model = nn.Sequential(...) 550*da0073e9SAndroid Build Coastguard Worker >>> input_var = checkpoint_sequential(model, chunks, input_var) 551*da0073e9SAndroid Build Coastguard Worker """ 552*da0073e9SAndroid Build Coastguard Worker if use_reentrant is None: 553*da0073e9SAndroid Build Coastguard Worker warnings.warn( 554*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " 555*da0073e9SAndroid Build Coastguard Worker "parameter should be passed explicitly. " 556*da0073e9SAndroid Build Coastguard Worker "In version 2.5 we will raise an exception if use_reentrant " 557*da0073e9SAndroid Build Coastguard Worker "is not passed. use_reentrant=False is " 558*da0073e9SAndroid Build Coastguard Worker "recommended, but if you need to preserve the current default " 559*da0073e9SAndroid Build Coastguard Worker "behavior, you can pass use_reentrant=True. Refer to docs for more " 560*da0073e9SAndroid Build Coastguard Worker "details on the differences between the two variants." 561*da0073e9SAndroid Build Coastguard Worker ) 562*da0073e9SAndroid Build Coastguard Worker use_reentrant = True 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker # Hack for keyword-only parameter in a python 2.7-compliant way 565*da0073e9SAndroid Build Coastguard Worker preserve = kwargs.pop("preserve_rng_state", True) 566*da0073e9SAndroid Build Coastguard Worker if kwargs: 567*da0073e9SAndroid Build Coastguard Worker raise ValueError( 568*da0073e9SAndroid Build Coastguard Worker "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 569*da0073e9SAndroid Build Coastguard Worker ) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def run_function(start, end, functions): 572*da0073e9SAndroid Build Coastguard Worker def forward(input): 573*da0073e9SAndroid Build Coastguard Worker for j in range(start, end + 1): 574*da0073e9SAndroid Build Coastguard Worker input = functions[j](input) 575*da0073e9SAndroid Build Coastguard Worker return input 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker return forward 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker if isinstance(functions, torch.nn.Sequential): 580*da0073e9SAndroid Build Coastguard Worker functions = list(functions.children()) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker segment_size = len(functions) // segments 583*da0073e9SAndroid Build Coastguard Worker # the last chunk has to be non-volatile 584*da0073e9SAndroid Build Coastguard Worker end = -1 585*da0073e9SAndroid Build Coastguard Worker for start in range(0, segment_size * (segments - 1), segment_size): 586*da0073e9SAndroid Build Coastguard Worker end = start + segment_size - 1 587*da0073e9SAndroid Build Coastguard Worker input = checkpoint( 588*da0073e9SAndroid Build Coastguard Worker run_function(start, end, functions), 589*da0073e9SAndroid Build Coastguard Worker input, 590*da0073e9SAndroid Build Coastguard Worker use_reentrant=use_reentrant, 591*da0073e9SAndroid Build Coastguard Worker preserve_rng_state=preserve, 592*da0073e9SAndroid Build Coastguard Worker ) 593*da0073e9SAndroid Build Coastguard Worker return run_function(end + 1, len(functions) - 1, functions)(input) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Workerdef _internal_assert(cond): 597*da0073e9SAndroid Build Coastguard Worker if not cond: 598*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 599*da0073e9SAndroid Build Coastguard Worker "Something went unexpectedly wrong in activation checkpoint. " 600*da0073e9SAndroid Build Coastguard Worker "Please report this bug by filing an issue to PyTorch." 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker# NOTE [ Nestable Checkpoint ] 605*da0073e9SAndroid Build Coastguard Worker# 606*da0073e9SAndroid Build Coastguard Worker# The semantics of nested checkpoint can be defined by two basic rules. 607*da0073e9SAndroid Build Coastguard Worker# Following the two rules leads to an important implication that is central 608*da0073e9SAndroid Build Coastguard Worker# to motivating the design. 609*da0073e9SAndroid Build Coastguard Worker# 610*da0073e9SAndroid Build Coastguard Worker# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden 611*da0073e9SAndroid Build Coastguard Worker# from any outer layers of checkpoint. 612*da0073e9SAndroid Build Coastguard Worker# 613*da0073e9SAndroid Build Coastguard Worker# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its 614*da0073e9SAndroid Build Coastguard Worker# parent checkpoint. 615*da0073e9SAndroid Build Coastguard Worker# 616*da0073e9SAndroid Build Coastguard Worker# Implication: To recompute any given saved tensor, we need to recompute all of 617*da0073e9SAndroid Build Coastguard Worker# the checkpoints wrapping it. 618*da0073e9SAndroid Build Coastguard Worker# 619*da0073e9SAndroid Build Coastguard Worker# Why is this implied? To unpack a saved tensor X during backward we need to 620*da0073e9SAndroid Build Coastguard Worker# recompute the inner-most checkpoint (#1), and in order to recompute that 621*da0073e9SAndroid Build Coastguard Worker# checkpoint I need to have its inputs, which are managed by that checkpoint's 622*da0073e9SAndroid Build Coastguard Worker# parent (#2), which thus also needs to be recomputed first. Continue this line 623*da0073e9SAndroid Build Coastguard Worker# of reasoning and we realize that in order to unpack X, all checkpoints that 624*da0073e9SAndroid Build Coastguard Worker# were active at the time X was saved need to be recomputed. (unless we have 625*da0073e9SAndroid Build Coastguard Worker# already done so in that backward for some other saved tensor). 626*da0073e9SAndroid Build Coastguard Worker# 627*da0073e9SAndroid Build Coastguard Worker# In practice, we use a noop autograd Function to save inputs as saved tensors. 628*da0073e9SAndroid Build Coastguard Worker# During unpack calling ctx.saved_tensor triggers the parent checkpoint to 629*da0073e9SAndroid Build Coastguard Worker# recompute. 630*da0073e9SAndroid Build Coastguard Worker# 631*da0073e9SAndroid Build Coastguard Worker# Rule 3. We should start recomputation as if there are no checkpoints currently 632*da0073e9SAndroid Build Coastguard Worker# active. Checkpoints encountered during recomputation are still 633*da0073e9SAndroid Build Coastguard Worker# respected. 634*da0073e9SAndroid Build Coastguard Worker# 635*da0073e9SAndroid Build Coastguard Worker# When we start recomputation, we push the saved variable hook meant for 636*da0073e9SAndroid Build Coastguard Worker# recomputation on the stack. See examples in Rule 6 for more context. 637*da0073e9SAndroid Build Coastguard Worker# 638*da0073e9SAndroid Build Coastguard Worker# * * * * 639*da0073e9SAndroid Build Coastguard Worker# 640*da0073e9SAndroid Build Coastguard Worker# Beyond the basic semantics specific to nested checkpoint, we impose several 641*da0073e9SAndroid Build Coastguard Worker# more constraints that may apply to checkpointing in general. 642*da0073e9SAndroid Build Coastguard Worker# 643*da0073e9SAndroid Build Coastguard Worker# Rule 4. Lifetime of recomputed tensors 644*da0073e9SAndroid Build Coastguard Worker# 645*da0073e9SAndroid Build Coastguard Worker# Recomputed tensors are considered specific to particular invocations 646*da0073e9SAndroid Build Coastguard Worker# of backward and are always cleared immediately as they are unpacked 647*da0073e9SAndroid Build Coastguard Worker# Particularly, we require this to happen even if retain_graph=True. 648*da0073e9SAndroid Build Coastguard Worker# 649*da0073e9SAndroid Build Coastguard Worker# [ Implementation details of Rule 4 ] 650*da0073e9SAndroid Build Coastguard Worker# 651*da0073e9SAndroid Build Coastguard Worker# If we were okay with recomputed tensors staying alive after backward is run 652*da0073e9SAndroid Build Coastguard Worker# with retain_graph=True, we would store recomputed variables as the values of a 653*da0073e9SAndroid Build Coastguard Worker# WeakKeyDictionary and pack strong references to the keys, so that as we 654*da0073e9SAndroid Build Coastguard Worker# backward, those packed keys would be cleared as long as retain_graph=False. 655*da0073e9SAndroid Build Coastguard Worker# Clearing the packed key clears the corresponding entry in the WKD. 656*da0073e9SAndroid Build Coastguard Worker# 657*da0073e9SAndroid Build Coastguard Worker# If we wish recomputed variables to be immediately cleared as we unpack them in 658*da0073e9SAndroid Build Coastguard Worker# the retain_graph=True case, we cannot rely on the packed keys to be cleared by 659*da0073e9SAndroid Build Coastguard Worker# backward automatically. Instead of packing the strong reference to the key 660*da0073e9SAndroid Build Coastguard Worker# directly, we pack a container object, which we manually clear as we unpack. 661*da0073e9SAndroid Build Coastguard Worker# 662*da0073e9SAndroid Build Coastguard Worker# An important detail is that if a second backward happens, the second 663*da0073e9SAndroid Build Coastguard Worker# recomputation needs to reset the container with a newly created key. 664*da0073e9SAndroid Build Coastguard Worker# 665*da0073e9SAndroid Build Coastguard Worker# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we 666*da0073e9SAndroid Build Coastguard Worker# know we need. 667*da0073e9SAndroid Build Coastguard Worker# 668*da0073e9SAndroid Build Coastguard Worker# [ Implementation details of Rule 5 ] 669*da0073e9SAndroid Build Coastguard Worker# 670*da0073e9SAndroid Build Coastguard Worker# During recomputation, raise an exception if the number of recomputed tensors 671*da0073e9SAndroid Build Coastguard Worker# matches the number of tensors that we expected to recompute. We wrap the 672*da0073e9SAndroid Build Coastguard Worker# recomputation call with a try-catch to catch this specific exception. See 673*da0073e9SAndroid Build Coastguard Worker# Rule #6 below for some examples. 674*da0073e9SAndroid Build Coastguard Worker# 675*da0073e9SAndroid Build Coastguard Worker# Rule 6. We support doing backward inside checkpoint context 676*da0073e9SAndroid Build Coastguard Worker# 677*da0073e9SAndroid Build Coastguard Worker# [ retain_graph is True] 678*da0073e9SAndroid Build Coastguard Worker# 679*da0073e9SAndroid Build Coastguard Worker# def fn(x): 680*da0073e9SAndroid Build Coastguard Worker# y = x.sin() 681*da0073e9SAndroid Build Coastguard Worker# z = y.cos() 682*da0073e9SAndroid Build Coastguard Worker# gx, = torch.autograd.grad(z, x, retains_grad=True) 683*da0073e9SAndroid Build Coastguard Worker# return gx, z 684*da0073e9SAndroid Build Coastguard Worker# 685*da0073e9SAndroid Build Coastguard Worker# out = checkpoint(fn)(inp) 686*da0073e9SAndroid Build Coastguard Worker# out.backward() 687*da0073e9SAndroid Build Coastguard Worker# 688*da0073e9SAndroid Build Coastguard Worker# Because z is saved by cos while checkpoint is enabled, it would not be 689*da0073e9SAndroid Build Coastguard Worker# actually saved, and so the .grad() call inside must trigger a recomputation. 690*da0073e9SAndroid Build Coastguard Worker# 691*da0073e9SAndroid Build Coastguard Worker# During recomputation the "inner pack hook" has two responsibilities: 692*da0073e9SAndroid Build Coastguard Worker# 693*da0073e9SAndroid Build Coastguard Worker# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors 694*da0073e9SAndroid Build Coastguard Worker# 2) Pack the actual tensor (detached) so that one may perform backward on the 695*da0073e9SAndroid Build Coastguard Worker# recomputed graph. The tensors saved to this graph will live until the end 696*da0073e9SAndroid Build Coastguard Worker# of recomputation, or die earlier if someone performs backward with 697*da0073e9SAndroid Build Coastguard Worker# retain_graph=False. 698*da0073e9SAndroid Build Coastguard Worker# 699*da0073e9SAndroid Build Coastguard Worker# More generally performing backward on the recomputed graph occurs in the 700*da0073e9SAndroid Build Coastguard Worker# following cases: 701*da0073e9SAndroid Build Coastguard Worker# - If backward is performed inside forward, 702*da0073e9SAndroid Build Coastguard Worker# - During the original forward IF early-stop is disabled 703*da0073e9SAndroid Build Coastguard Worker# - During the original backward 704*da0073e9SAndroid Build Coastguard Worker# - If there are multiple .grad()/.backward() calls, we would perform backward 705*da0073e9SAndroid Build Coastguard Worker# on the recomputed graph even if early-stop is enabled (see the example below) 706*da0073e9SAndroid Build Coastguard Worker# 707*da0073e9SAndroid Build Coastguard Worker# [ retain_graph is False ] 708*da0073e9SAndroid Build Coastguard Worker# 709*da0073e9SAndroid Build Coastguard Worker# The example below shows what happens if during recomputation we find that some 710*da0073e9SAndroid Build Coastguard Worker# of the tensors we are trying to recompute have already been cleared. 711*da0073e9SAndroid Build Coastguard Worker# 712*da0073e9SAndroid Build Coastguard Worker# Spoiler: we don't do anything special, we just skip over them! 713*da0073e9SAndroid Build Coastguard Worker# 714*da0073e9SAndroid Build Coastguard Worker# def fn(x): 715*da0073e9SAndroid Build Coastguard Worker# y = x.sin() # (1) 716*da0073e9SAndroid Build Coastguard Worker# z = y.cos() # (2) 717*da0073e9SAndroid Build Coastguard Worker# gx, = torch.autograd.grad(z, x) # (3) 718*da0073e9SAndroid Build Coastguard Worker# return x.cos() * gx # (4) 719*da0073e9SAndroid Build Coastguard Worker# 720*da0073e9SAndroid Build Coastguard Worker# out = checkpoint(fn)(inp) 721*da0073e9SAndroid Build Coastguard Worker# out.backward() # (5) 722*da0073e9SAndroid Build Coastguard Worker# 723*da0073e9SAndroid Build Coastguard Worker# 1, 2. Don't save x and y since we are inside a checkpoint. 724*da0073e9SAndroid Build Coastguard Worker# 3. Trigger a recompute of fn since x and y weren't saved. 725*da0073e9SAndroid Build Coastguard Worker# And depending on whether early stop is enabled, either stop at (2) or 726*da0073e9SAndroid Build Coastguard Worker# continue running the function. 727*da0073e9SAndroid Build Coastguard Worker# Because we are running backward with retain_graph=False, we clear x and y's 728*da0073e9SAndroid Build Coastguard Worker# holders. 729*da0073e9SAndroid Build Coastguard Worker# 4. Don't save x since we are inside a checkpoint. 730*da0073e9SAndroid Build Coastguard Worker# 5. Calling backward triggers another recompute of fn. During recompute, we see 731*da0073e9SAndroid Build Coastguard Worker# that x and y have already been cleared in the original graph as indicated 732*da0073e9SAndroid Build Coastguard Worker# by holder=None. We skip over them. We still save x at (4) (since its holder 733*da0073e9SAndroid Build Coastguard Worker# is still alive.) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker_enable_checkpoint_early_stop = True 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 739*da0073e9SAndroid Build Coastguard Workerdef set_checkpoint_early_stop(enable: bool): 740*da0073e9SAndroid Build Coastguard Worker """Context manager that sets whether checkpoint should stop recomputation early. 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker By default, non-reentrant checkpoint stops recomputation as soon as it 743*da0073e9SAndroid Build Coastguard Worker has computed all needed Tensors. This context manager can be used to disable 744*da0073e9SAndroid Build Coastguard Worker that feature if it is problematic for your specific application. 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker This context manager only needs to be active when forward is run. It does 747*da0073e9SAndroid Build Coastguard Worker not need to be active during backward. 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker Example:: 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP(failing) 752*da0073e9SAndroid Build Coastguard Worker >>> message = "saved tensors default hooks are disabled" 753*da0073e9SAndroid Build Coastguard Worker >>> with set_checkpoint_early_stop(False): 754*da0073e9SAndroid Build Coastguard Worker ... # Any checkpoint under this context manager will respect this 755*da0073e9SAndroid Build Coastguard Worker ... # context manager, even if its backward is performed outside. 756*da0073e9SAndroid Build Coastguard Worker ... out = checkpoint(fn, inputs) 757*da0073e9SAndroid Build Coastguard Worker ... 758*da0073e9SAndroid Build Coastguard Worker >>> out.backward() 759*da0073e9SAndroid Build Coastguard Worker """ 760*da0073e9SAndroid Build Coastguard Worker global _enable_checkpoint_early_stop 761*da0073e9SAndroid Build Coastguard Worker try: 762*da0073e9SAndroid Build Coastguard Worker prev = _enable_checkpoint_early_stop 763*da0073e9SAndroid Build Coastguard Worker _enable_checkpoint_early_stop = enable 764*da0073e9SAndroid Build Coastguard Worker yield 765*da0073e9SAndroid Build Coastguard Worker finally: 766*da0073e9SAndroid Build Coastguard Worker _enable_checkpoint_early_stop = prev 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Workerclass _Handle: 770*da0073e9SAndroid Build Coastguard Worker pass 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Workerclass _Holder: 774*da0073e9SAndroid Build Coastguard Worker def __init__(self): 775*da0073e9SAndroid Build Coastguard Worker self.handles: Dict[int, Optional[_Handle]] = {} 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Workerclass _NoopSaveInputs(torch.autograd.Function): 779*da0073e9SAndroid Build Coastguard Worker @staticmethod 780*da0073e9SAndroid Build Coastguard Worker def forward(*args): 781*da0073e9SAndroid Build Coastguard Worker return torch.empty((0,)) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker @staticmethod 784*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: 785*da0073e9SAndroid Build Coastguard Worker # Only tensors can be saved with ctx.save_for_backward, everything else 786*da0073e9SAndroid Build Coastguard Worker # is captured by get_args, which is saved directly on ctx 787*da0073e9SAndroid Build Coastguard Worker tensor_indices, tensors = zip( 788*da0073e9SAndroid Build Coastguard Worker *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} 791*da0073e9SAndroid Build Coastguard Worker # args but with tensors replaced with None as placeholders 792*da0073e9SAndroid Build Coastguard Worker args = [None if isinstance(o, torch.Tensor) else o for o in inputs] 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker def get_args(saved_tensors): 795*da0073e9SAndroid Build Coastguard Worker # restore the placeholders with the original tensors grabbed from 796*da0073e9SAndroid Build Coastguard Worker # ctx.saved_tensors (which may be saved on a parent checkpoint if 797*da0073e9SAndroid Build Coastguard Worker # this checkpoint is nested, and that would trigger a recursive 798*da0073e9SAndroid Build Coastguard Worker # unpack!) 799*da0073e9SAndroid Build Coastguard Worker ret = [ 800*da0073e9SAndroid Build Coastguard Worker saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o 801*da0073e9SAndroid Build Coastguard Worker for i, o in enumerate(args) 802*da0073e9SAndroid Build Coastguard Worker ] 803*da0073e9SAndroid Build Coastguard Worker # grab the tail since we also saved the dummy to avoid having to explicitly 804*da0073e9SAndroid Build Coastguard Worker # handle the case where there are no tensor inputs 805*da0073e9SAndroid Build Coastguard Worker return ret[1:] 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker ctx.get_args = get_args 808*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(*tensors) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker @staticmethod 811*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *grad_outputs): 812*da0073e9SAndroid Build Coastguard Worker raise AssertionError("Did not expect to backward on this graph") 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Workerclass _CheckpointFrame: 816*da0073e9SAndroid Build Coastguard Worker def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): 817*da0073e9SAndroid Build Coastguard Worker self.recompute_fn = recompute_fn 818*da0073e9SAndroid Build Coastguard Worker self.input_saver = None 819*da0073e9SAndroid Build Coastguard Worker self.weak_holders: List[ReferenceType] = [] 820*da0073e9SAndroid Build Coastguard Worker # We store this as a weakkeydictionary so that in the case of a partial 821*da0073e9SAndroid Build Coastguard Worker # backward, the entries in the dict are cleared alongside the Holder 822*da0073e9SAndroid Build Coastguard Worker # which will be removed when the SavedVariable is cleared. 823*da0073e9SAndroid Build Coastguard Worker self.recomputed: DefaultDict[ 824*da0073e9SAndroid Build Coastguard Worker int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] 825*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(weakref.WeakKeyDictionary) 826*da0073e9SAndroid Build Coastguard Worker # We need both recomp_counter and recomputed since they can diverge 827*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 828*da0073e9SAndroid Build Coastguard Worker self.recomp_counter: DefaultDict[int, int] = defaultdict(int) 829*da0073e9SAndroid Build Coastguard Worker self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker # See Rule 5 832*da0073e9SAndroid Build Coastguard Worker self.early_stop = early_stop 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker # Debugging 835*da0073e9SAndroid Build Coastguard Worker self.metadata_fn = metadata_fn 836*da0073e9SAndroid Build Coastguard Worker self.unpack_error_cb = unpack_error_cb 837*da0073e9SAndroid Build Coastguard Worker self.x_metadatas = [] 838*da0073e9SAndroid Build Coastguard Worker self.forward_completed = False 839*da0073e9SAndroid Build Coastguard Worker self.ignore_saved_mismatch = False 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def check_recomputed_tensors_match(self, gid): 842*da0073e9SAndroid Build Coastguard Worker if self.ignore_saved_mismatch: 843*da0073e9SAndroid Build Coastguard Worker # TODO: we can probably make this check stricter by checking that 844*da0073e9SAndroid Build Coastguard Worker # the metadata of the first tensors still match. 845*da0073e9SAndroid Build Coastguard Worker return 846*da0073e9SAndroid Build Coastguard Worker # NOTE [ Error handling for checkpoint ] 847*da0073e9SAndroid Build Coastguard Worker # 848*da0073e9SAndroid Build Coastguard Worker # At a high level, we need to check that the tensors saved 849*da0073e9SAndroid Build Coastguard Worker # during original forward matches tensors saved during recompute 850*da0073e9SAndroid Build Coastguard Worker # This means handling 3 cases: 851*da0073e9SAndroid Build Coastguard Worker # 852*da0073e9SAndroid Build Coastguard Worker # 1. During recompute, more tensors were saved. 853*da0073e9SAndroid Build Coastguard Worker # 854*da0073e9SAndroid Build Coastguard Worker # Usually this is hidden due to the StopRecomputationError 855*da0073e9SAndroid Build Coastguard Worker # but if early stop is not enabled, or we would have errored 856*da0073e9SAndroid Build Coastguard Worker # anyway because there aren't enough weak_holders. But we 857*da0073e9SAndroid Build Coastguard Worker # do want to have a nice error. See the _recomputation_hook 858*da0073e9SAndroid Build Coastguard Worker # for details. 859*da0073e9SAndroid Build Coastguard Worker if not len(self.weak_holders) == self.recomp_counter[gid]: 860*da0073e9SAndroid Build Coastguard Worker # 2. During recompute, fewer tensors were saved 861*da0073e9SAndroid Build Coastguard Worker # 862*da0073e9SAndroid Build Coastguard Worker # We know that everytime we save something do original forward 863*da0073e9SAndroid Build Coastguard Worker # we append to weak_holder, and every time we save a tensor 864*da0073e9SAndroid Build Coastguard Worker # during recompute we increment recompute_counter. 865*da0073e9SAndroid Build Coastguard Worker raise CheckpointError( 866*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint: A different number of tensors was saved " 867*da0073e9SAndroid Build Coastguard Worker "during the original forward and recomputation.\n" 868*da0073e9SAndroid Build Coastguard Worker f"Number of tensors saved during forward: {len(self.weak_holders)}\n" 869*da0073e9SAndroid Build Coastguard Worker f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" 870*da0073e9SAndroid Build Coastguard Worker ) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker # 3. During recompute, the same tensors were saved, but they 873*da0073e9SAndroid Build Coastguard Worker # have different metadata 874*da0073e9SAndroid Build Coastguard Worker nb_meta_different = [] 875*da0073e9SAndroid Build Coastguard Worker for idx, weak_holder in enumerate(self.weak_holders): 876*da0073e9SAndroid Build Coastguard Worker holder = weak_holder() 877*da0073e9SAndroid Build Coastguard Worker if holder is None: 878*da0073e9SAndroid Build Coastguard Worker continue 879*da0073e9SAndroid Build Coastguard Worker # We've seen all holders since we iterate over them in order 880*da0073e9SAndroid Build Coastguard Worker # For every holder that is still alive now, it must've been 881*da0073e9SAndroid Build Coastguard Worker # alive when we saw it during recompute, therefore, the 882*da0073e9SAndroid Build Coastguard Worker # gid must be set. 883*da0073e9SAndroid Build Coastguard Worker _internal_assert(gid in holder.handles) 884*da0073e9SAndroid Build Coastguard Worker # We know this is the first unpack, so it couldn't have been set 885*da0073e9SAndroid Build Coastguard Worker # to None yet. 886*da0073e9SAndroid Build Coastguard Worker _internal_assert(holder.handles[gid] is not None) 887*da0073e9SAndroid Build Coastguard Worker # We always set these together in the recomputation hook 888*da0073e9SAndroid Build Coastguard Worker _internal_assert(holder.handles[gid] in self.recomputed[gid]) 889*da0073e9SAndroid Build Coastguard Worker # see pack hook, x_metadata is 1:1 with weak_holders. 890*da0073e9SAndroid Build Coastguard Worker x_meta = self.x_metadatas[idx] 891*da0073e9SAndroid Build Coastguard Worker recomputed_x = self.recomputed[gid][holder.handles[gid]] 892*da0073e9SAndroid Build Coastguard Worker if x_meta != self.metadata_fn(recomputed_x): 893*da0073e9SAndroid Build Coastguard Worker nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker if len(nb_meta_different) > 0: 896*da0073e9SAndroid Build Coastguard Worker mismatched_tensors = "" 897*da0073e9SAndroid Build Coastguard Worker for idx, x_meta, recomputed_meta in nb_meta_different: 898*da0073e9SAndroid Build Coastguard Worker mismatched_tensors += ( 899*da0073e9SAndroid Build Coastguard Worker f"tensor at position {idx}:\n" 900*da0073e9SAndroid Build Coastguard Worker f"saved metadata: {x_meta}\n" 901*da0073e9SAndroid Build Coastguard Worker f"recomputed metadata: {recomputed_meta}\n" 902*da0073e9SAndroid Build Coastguard Worker ) 903*da0073e9SAndroid Build Coastguard Worker raise CheckpointError( 904*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint: Recomputed values for the following tensors " 905*da0073e9SAndroid Build Coastguard Worker "have different metadata than during the forward pass.\n" 906*da0073e9SAndroid Build Coastguard Worker f"{mismatched_tensors}" 907*da0073e9SAndroid Build Coastguard Worker ) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker_checkpoint_error_template = """ \ 911*da0073e9SAndroid Build Coastguard WorkerAn error happened while unpacking tensors; dumping logs of latest computation 912*da0073e9SAndroid Build Coastguard Workerbecause you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. 913*da0073e9SAndroid Build Coastguard WorkerScroll all the way down for guidance on how to navigate these logs. 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 916*da0073e9SAndroid Build Coastguard Worker| 1. Stack traces of the operators that ran in the original forward | 917*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+ 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker{forward_traces} 920*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 921*da0073e9SAndroid Build Coastguard Worker| 2. Stack traces of the operators that ran during recomputation | 922*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+ 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker{recompute_traces} 925*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 926*da0073e9SAndroid Build Coastguard Worker| 3. Log of operators in the original forward and recomputation | 927*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+ 928*da0073e9SAndroid Build Coastguard Worker(Scroll up to correlate stack traces with each operation listed below. This 929*da0073e9SAndroid Build Coastguard Worker helps identify their source in the code.) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard WorkerIMPORTANT: Differences in "detach" calls between the original forward and the 932*da0073e9SAndroid Build Coastguard Worker recomputation are expected. They are introduced by the checkpointing 933*da0073e9SAndroid Build Coastguard Worker mechanism and can be ignored. 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard WorkerOperations executed during the original forward: 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker{forward_ops} 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard WorkerOperations executed during recomputation: 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker{recompute_ops} 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+ 944*da0073e9SAndroid Build Coastguard Worker ERROR: Detected non-determinism while running activation checkpointing 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker You are seeing this error because you passed `debug=True` to checkpoint and 947*da0073e9SAndroid Build Coastguard Worker tensors to be saved during the original forward and differ between those saved 948*da0073e9SAndroid Build Coastguard Worker during recomputation. This can happen if different operators were ran in the 949*da0073e9SAndroid Build Coastguard Worker original forward and in the recomputation. 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker To identify where the mismatch may be coming from, you can do the following: 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker 1) Compare the operators ran during original forward and recomputation to 954*da0073e9SAndroid Build Coastguard Worker see where they differ. These operators are printed above in the order they 955*da0073e9SAndroid Build Coastguard Worker were executed. 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker 2) Review the stack trace for each operator to locate its invocation source. 958*da0073e9SAndroid Build Coastguard Worker Each operator's stack trace is printed in their execution order. 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker Note that the logs can be quite long. Here's how they are structured: 961*da0073e9SAndroid Build Coastguard Worker (Tip: you can Ctrl-f for these headers) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker 1. Stack traces of the operators that ran in the original forward 964*da0073e9SAndroid Build Coastguard Worker 2. Stack traces of the operators that ran during recomputation 965*da0073e9SAndroid Build Coastguard Worker 3. Log of operators in the original forward and recomputation 966*da0073e9SAndroid Build Coastguard Worker 4. Error message <--- You are here 967*da0073e9SAndroid Build Coastguard Worker-------------------------------------------------------------------------------- 968*da0073e9SAndroid Build Coastguard Worker""" 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Workerclass CheckpointError(RuntimeError): 971*da0073e9SAndroid Build Coastguard Worker pass 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Workerdef _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: 975*da0073e9SAndroid Build Coastguard Worker # This function returns the context_fn and error_cb to be used by the 976*da0073e9SAndroid Build Coastguard Worker # checkpointing mechanism. error_cb is invoked when an error is detected 977*da0073e9SAndroid Build Coastguard Worker # during unpack. 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker # record_context_cpp is not support on non-linux non-x86_64 platforms 980*da0073e9SAndroid Build Coastguard Worker cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker class CaptureLogs: 983*da0073e9SAndroid Build Coastguard Worker def __init__(self): 984*da0073e9SAndroid Build Coastguard Worker self.logs = None 985*da0073e9SAndroid Build Coastguard Worker self.tbs = None 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker def get_context_manager(self): 988*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 989*da0073e9SAndroid Build Coastguard Worker def logging_mode(): 990*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(), \ 991*da0073e9SAndroid Build Coastguard Worker capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: 992*da0073e9SAndroid Build Coastguard Worker self.logs, self.tbs = logs_and_tb 993*da0073e9SAndroid Build Coastguard Worker yield logs_and_tb 994*da0073e9SAndroid Build Coastguard Worker return logging_mode() 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker capture_logs_fwd = CaptureLogs() 997*da0073e9SAndroid Build Coastguard Worker capture_logs_recompute = CaptureLogs() 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def unpack_error_cb(e: CheckpointError): 1000*da0073e9SAndroid Build Coastguard Worker def get_str_tb(label, capture_logs): 1001*da0073e9SAndroid Build Coastguard Worker out = "" 1002*da0073e9SAndroid Build Coastguard Worker total_len = len(capture_logs.logs) 1003*da0073e9SAndroid Build Coastguard Worker for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): 1004*da0073e9SAndroid Build Coastguard Worker out += f"{log} ({i + 1} of {total_len} in {label})\n\n" 1005*da0073e9SAndroid Build Coastguard Worker found_torch_dispatch = False 1006*da0073e9SAndroid Build Coastguard Worker for line in tb: 1007*da0073e9SAndroid Build Coastguard Worker # Start printing stack trace only after __torch_dispatch__ is found 1008*da0073e9SAndroid Build Coastguard Worker is_torch_dispatch = line['name'] == '__torch_dispatch__' 1009*da0073e9SAndroid Build Coastguard Worker if not found_torch_dispatch and not is_torch_dispatch: 1010*da0073e9SAndroid Build Coastguard Worker continue 1011*da0073e9SAndroid Build Coastguard Worker elif is_torch_dispatch: 1012*da0073e9SAndroid Build Coastguard Worker found_torch_dispatch = True 1013*da0073e9SAndroid Build Coastguard Worker continue 1014*da0073e9SAndroid Build Coastguard Worker out += f"{line['filename']}:{line['line']}:{line['name']}\n" 1015*da0073e9SAndroid Build Coastguard Worker out += "\n\n" 1016*da0073e9SAndroid Build Coastguard Worker return out 1017*da0073e9SAndroid Build Coastguard Worker assert capture_logs_fwd.logs is not None 1018*da0073e9SAndroid Build Coastguard Worker assert capture_logs_recompute.logs is not None 1019*da0073e9SAndroid Build Coastguard Worker raise CheckpointError( 1020*da0073e9SAndroid Build Coastguard Worker _checkpoint_error_template.format( 1021*da0073e9SAndroid Build Coastguard Worker forward_traces=get_str_tb("original", capture_logs_fwd), 1022*da0073e9SAndroid Build Coastguard Worker recompute_traces=get_str_tb("recompute", capture_logs_recompute), 1023*da0073e9SAndroid Build Coastguard Worker forward_ops="\n".join(capture_logs_fwd.logs), 1024*da0073e9SAndroid Build Coastguard Worker recompute_ops="\n".join(capture_logs_recompute.logs) 1025*da0073e9SAndroid Build Coastguard Worker ) 1026*da0073e9SAndroid Build Coastguard Worker ) from e 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker def context_fn(): 1029*da0073e9SAndroid Build Coastguard Worker return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker return context_fn, unpack_error_cb 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Workerdef _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: 1034*da0073e9SAndroid Build Coastguard Worker # These properties are fast to check, easy to understand 1035*da0073e9SAndroid Build Coastguard Worker return { 1036*da0073e9SAndroid Build Coastguard Worker "shape": x.shape, 1037*da0073e9SAndroid Build Coastguard Worker "dtype": x.dtype, 1038*da0073e9SAndroid Build Coastguard Worker "device": x.device 1039*da0073e9SAndroid Build Coastguard Worker } 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { 1042*da0073e9SAndroid Build Coastguard Worker _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, 1043*da0073e9SAndroid Build Coastguard Worker "none": lambda _: None, 1044*da0073e9SAndroid Build Coastguard Worker} 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker# See Rule 5 1047*da0073e9SAndroid Build Coastguard Workerclass _StopRecomputationError(Exception): 1048*da0073e9SAndroid Build Coastguard Worker pass 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Workerclass _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): 1052*da0073e9SAndroid Build Coastguard Worker def __init__(self, target_frame_ref: ReferenceType, gid: int): 1053*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 1054*da0073e9SAndroid Build Coastguard Worker x = x.detach() if x.requires_grad else x 1055*da0073e9SAndroid Build Coastguard Worker target_frame = target_frame_ref() 1056*da0073e9SAndroid Build Coastguard Worker assert target_frame is not None # appease mypy 1057*da0073e9SAndroid Build Coastguard Worker recomp_idx = target_frame.recomp_counter[gid] 1058*da0073e9SAndroid Build Coastguard Worker target_frame.recomp_counter[gid] += 1 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker if recomp_idx >= len(target_frame.weak_holders): 1061*da0073e9SAndroid Build Coastguard Worker assert not target_frame.early_stop 1062*da0073e9SAndroid Build Coastguard Worker if not target_frame.forward_completed: 1063*da0073e9SAndroid Build Coastguard Worker # We run into this case when early stop is not enabled and do 1064*da0073e9SAndroid Build Coastguard Worker # grad within checkpoint. 1065*da0073e9SAndroid Build Coastguard Worker # We need to set this flag, so we don't error out later when 1066*da0073e9SAndroid Build Coastguard Worker # we check if the number of tensors saved during forward and 1067*da0073e9SAndroid Build Coastguard Worker # recomputation match. 1068*da0073e9SAndroid Build Coastguard Worker target_frame.ignore_saved_mismatch = True 1069*da0073e9SAndroid Build Coastguard Worker return x 1070*da0073e9SAndroid Build Coastguard Worker raise CheckpointError( 1071*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint: trying to save more tensors during " 1072*da0073e9SAndroid Build Coastguard Worker "recomputation than during the original forward pass." 1073*da0073e9SAndroid Build Coastguard Worker ) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker holder = target_frame.weak_holders[recomp_idx]() 1076*da0073e9SAndroid Build Coastguard Worker 1077*da0073e9SAndroid Build Coastguard Worker # This holder may have been cleared because someone may have called 1078*da0073e9SAndroid Build Coastguard Worker # backward within forward. If so, we don't need to save. 1079*da0073e9SAndroid Build Coastguard Worker if holder is not None: 1080*da0073e9SAndroid Build Coastguard Worker _internal_assert(holder.handles.get(gid, None) is None) 1081*da0073e9SAndroid Build Coastguard Worker holder.handles[gid] = _Handle() 1082*da0073e9SAndroid Build Coastguard Worker target_frame.recomputed[gid][holder.handles[gid]] = x 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker if target_frame.early_stop and target_frame.recomp_counter[gid] == len( 1085*da0073e9SAndroid Build Coastguard Worker target_frame.weak_holders 1086*da0073e9SAndroid Build Coastguard Worker ): 1087*da0073e9SAndroid Build Coastguard Worker raise _StopRecomputationError 1088*da0073e9SAndroid Build Coastguard Worker # See Rule 6: [ retain_graph is True ] above 1089*da0073e9SAndroid Build Coastguard Worker return x 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker def unpack_hook(x): 1092*da0073e9SAndroid Build Coastguard Worker # See Rule 6: [ retain_graph is True ] above for an example of when 1093*da0073e9SAndroid Build Coastguard Worker # the graph created during recomputation could be backwarded. 1094*da0073e9SAndroid Build Coastguard Worker return x 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_hook, unpack_hook) 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Workerclass _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): 1100*da0073e9SAndroid Build Coastguard Worker def __init__(self, frame): 1101*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 1102*da0073e9SAndroid Build Coastguard Worker # See Rule 4 above 1103*da0073e9SAndroid Build Coastguard Worker holder = _Holder() 1104*da0073e9SAndroid Build Coastguard Worker frame.weak_holders.append(weakref.ref(holder)) 1105*da0073e9SAndroid Build Coastguard Worker # Save metadata to detect non-determinism 1106*da0073e9SAndroid Build Coastguard Worker if frame.metadata_fn is not None: 1107*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1108*da0073e9SAndroid Build Coastguard Worker frame.x_metadatas.append(frame.metadata_fn(x)) 1109*da0073e9SAndroid Build Coastguard Worker return holder 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker def unpack_hook(holder): 1112*da0073e9SAndroid Build Coastguard Worker gid = torch._C._current_graph_task_id() 1113*da0073e9SAndroid Build Coastguard Worker if gid == -1: 1114*da0073e9SAndroid Build Coastguard Worker # generate a temporary id if we trigger unpack outside of a backward call 1115*da0073e9SAndroid Build Coastguard Worker gid = int(uuid.uuid4()) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker if not frame.is_recomputed[gid]: 1118*da0073e9SAndroid Build Coastguard Worker ctx = frame.input_saver.grad_fn 1119*da0073e9SAndroid Build Coastguard Worker args = ctx.get_args(ctx.saved_tensors) 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker try: 1122*da0073e9SAndroid Build Coastguard Worker with _recomputation_hook( 1123*da0073e9SAndroid Build Coastguard Worker weakref.ref(frame), gid 1124*da0073e9SAndroid Build Coastguard Worker ), torch.autograd.enable_grad(): 1125*da0073e9SAndroid Build Coastguard Worker frame.recompute_fn(*args) 1126*da0073e9SAndroid Build Coastguard Worker except _StopRecomputationError: 1127*da0073e9SAndroid Build Coastguard Worker pass 1128*da0073e9SAndroid Build Coastguard Worker frame.is_recomputed[gid] = True 1129*da0073e9SAndroid Build Coastguard Worker frame.check_recomputed_tensors_match(gid) 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker _internal_assert(gid in holder.handles) 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Worker if holder.handles[gid] is None: 1134*da0073e9SAndroid Build Coastguard Worker raise CheckpointError( 1135*da0073e9SAndroid Build Coastguard Worker "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " 1136*da0073e9SAndroid Build Coastguard Worker "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " 1137*da0073e9SAndroid Build Coastguard Worker "so only once. Otherwise please open an issue with details on your use case." 1138*da0073e9SAndroid Build Coastguard Worker ) 1139*da0073e9SAndroid Build Coastguard Worker _internal_assert(holder.handles[gid] in frame.recomputed[gid]) 1140*da0073e9SAndroid Build Coastguard Worker ret = frame.recomputed[gid][holder.handles[gid]] 1141*da0073e9SAndroid Build Coastguard Worker holder.handles[gid] = None 1142*da0073e9SAndroid Build Coastguard Worker return ret 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker if frame.unpack_error_cb is not None: 1145*da0073e9SAndroid Build Coastguard Worker def unpack_hook_with_error_cb(holder): 1146*da0073e9SAndroid Build Coastguard Worker try: 1147*da0073e9SAndroid Build Coastguard Worker return unpack_hook(holder) 1148*da0073e9SAndroid Build Coastguard Worker except CheckpointError as e: 1149*da0073e9SAndroid Build Coastguard Worker frame.unpack_error_cb(e) 1150*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_hook, unpack_hook_with_error_cb) 1151*da0073e9SAndroid Build Coastguard Worker else: 1152*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_hook, unpack_hook) 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Workerdef _is_compiling(func, args, kwargs): 1156*da0073e9SAndroid Build Coastguard Worker # Check if we are under AOTAutograd tracing 1157*da0073e9SAndroid Build Coastguard Worker # There should probably be a better way to do this... 1158*da0073e9SAndroid Build Coastguard Worker # TODO: unify _is_compiling across all compile stacks 1159*da0073e9SAndroid Build Coastguard Worker for arg in args: 1160*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and is_fun(arg): 1161*da0073e9SAndroid Build Coastguard Worker return True 1162*da0073e9SAndroid Build Coastguard Worker return False 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Workerclass _VersionWrapper: 1166*da0073e9SAndroid Build Coastguard Worker # Check that cached tensors are not mutated. 1167*da0073e9SAndroid Build Coastguard Worker def __init__(self, val): 1168*da0073e9SAndroid Build Coastguard Worker self.val: Union[torch.Tensor, Any] = val 1169*da0073e9SAndroid Build Coastguard Worker self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker def get_val(self, allow_cache_entry_mutation): 1172*da0073e9SAndroid Build Coastguard Worker if self.version is not None and not allow_cache_entry_mutation: 1173*da0073e9SAndroid Build Coastguard Worker if self.val._version != self.version: 1174*da0073e9SAndroid Build Coastguard Worker # Can we give user a stack trace of where the mutation happened? 1175*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1176*da0073e9SAndroid Build Coastguard Worker "Tensor cached during selective activation checkpoint has been mutated" 1177*da0073e9SAndroid Build Coastguard Worker ) 1178*da0073e9SAndroid Build Coastguard Worker return self.val 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Workerdef _maybe_detach(x, any_ret_has_alias_info): 1182*da0073e9SAndroid Build Coastguard Worker # We detach for two separate reasons: 1183*da0073e9SAndroid Build Coastguard Worker # - For view ops, we need to ensure that when the tensor is returned from 1184*da0073e9SAndroid Build Coastguard Worker # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr 1185*da0073e9SAndroid Build Coastguard Worker # - Avoid reference cycles 1186*da0073e9SAndroid Build Coastguard Worker # For case 1, it is not enough to check whether x has differentiable dtype 1187*da0073e9SAndroid Build Coastguard Worker # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. 1188*da0073e9SAndroid Build Coastguard Worker # when the tensor is a view. 1189*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): 1190*da0073e9SAndroid Build Coastguard Worker with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): 1191*da0073e9SAndroid Build Coastguard Worker # Ensure that view performed beneath autograd properly propagates 1192*da0073e9SAndroid Build Coastguard Worker # version counter. TODO: Use reentrant_dispatch instead of 1193*da0073e9SAndroid Build Coastguard Worker # manually manipulating dispatch keys. Using reentrant_dispatch 1194*da0073e9SAndroid Build Coastguard Worker # would respect inference_mode, though that is not relevant for 1195*da0073e9SAndroid Build Coastguard Worker # this case. 1196*da0073e9SAndroid Build Coastguard Worker x = x.detach() 1197*da0073e9SAndroid Build Coastguard Worker return x 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Workerclass SelectiveCheckpointContext: 1201*da0073e9SAndroid Build Coastguard Worker """ 1202*da0073e9SAndroid Build Coastguard Worker Context passed to policy function during selective checkpointing. 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker This class is used to pass relevant metadata to the policy function during 1205*da0073e9SAndroid Build Coastguard Worker selective checkpointing. The metadata includes whether the current invocation 1206*da0073e9SAndroid Build Coastguard Worker of the policy function is during recomputation or not. 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker Example: 1209*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP(stub) 1210*da0073e9SAndroid Build Coastguard Worker >>> 1211*da0073e9SAndroid Build Coastguard Worker >>> def policy_fn(ctx, op, *args, **kwargs): 1212*da0073e9SAndroid Build Coastguard Worker >>> print(ctx.is_recompute) 1213*da0073e9SAndroid Build Coastguard Worker >>> 1214*da0073e9SAndroid Build Coastguard Worker >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 1215*da0073e9SAndroid Build Coastguard Worker >>> 1216*da0073e9SAndroid Build Coastguard Worker >>> out = torch.utils.checkpoint.checkpoint( 1217*da0073e9SAndroid Build Coastguard Worker >>> fn, x, y, 1218*da0073e9SAndroid Build Coastguard Worker >>> use_reentrant=False, 1219*da0073e9SAndroid Build Coastguard Worker >>> context_fn=context_fn, 1220*da0073e9SAndroid Build Coastguard Worker >>> ) 1221*da0073e9SAndroid Build Coastguard Worker """ 1222*da0073e9SAndroid Build Coastguard Worker def __init__(self, *, is_recompute): 1223*da0073e9SAndroid Build Coastguard Worker self.is_recompute = is_recompute 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Workerclass CheckpointPolicy(enum.Enum): 1227*da0073e9SAndroid Build Coastguard Worker """ 1228*da0073e9SAndroid Build Coastguard Worker Enum for specifying the policy for checkpointing during backpropagation. 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker The following policies are supported: 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward 1233*da0073e9SAndroid Build Coastguard Worker pass and will not be recomputed during the backward pass 1234*da0073e9SAndroid Build Coastguard Worker - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the 1235*da0073e9SAndroid Build Coastguard Worker forward pass and will be recomputed during the backward pass 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden 1238*da0073e9SAndroid Build Coastguard Worker by other subsystems like `torch.compile`. 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker .. note:: 1241*da0073e9SAndroid Build Coastguard Worker A policy function that always returns ``PREFER_RECOMPUTE`` is 1242*da0073e9SAndroid Build Coastguard Worker equivalent to vanilla checkpointing. 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker A policy function that returns ``PREFER_SAVE`` every op is 1245*da0073e9SAndroid Build Coastguard Worker NOT equivalent to not using checkpointing. Using such a policy would 1246*da0073e9SAndroid Build Coastguard Worker save additional tensors not limited to ones that are actually needed for 1247*da0073e9SAndroid Build Coastguard Worker gradient computation. 1248*da0073e9SAndroid Build Coastguard Worker """ 1249*da0073e9SAndroid Build Coastguard Worker MUST_SAVE = 0 1250*da0073e9SAndroid Build Coastguard Worker PREFER_SAVE = 1 1251*da0073e9SAndroid Build Coastguard Worker MUST_RECOMPUTE = 2 1252*da0073e9SAndroid Build Coastguard Worker PREFER_RECOMPUTE = 3 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Workerdef _policy_from_bool(b): 1256*da0073e9SAndroid Build Coastguard Worker # For backward compatability 1257*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard WorkerSAC_IGNORED_OPS = { 1261*da0073e9SAndroid Build Coastguard Worker # AC inserts different number of detach during forward and recompute. 1262*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.detach.default, 1263*da0073e9SAndroid Build Coastguard Worker # AC's determinism check invokes additional metadata ops during forward. 1264*da0073e9SAndroid Build Coastguard Worker # With subclasses involved, these metadata ops become dispatchable, this 1265*da0073e9SAndroid Build Coastguard Worker # can result in incorrectness if these ops are selected cached. 1266*da0073e9SAndroid Build Coastguard Worker torch.ops.prim.device.default, 1267*da0073e9SAndroid Build Coastguard Worker} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Workerclass _CachingTorchDispatchMode(TorchDispatchMode): 1271*da0073e9SAndroid Build Coastguard Worker # Used together with _CachedTorchDispatchMode to implement SAC. 1272*da0073e9SAndroid Build Coastguard Worker def __init__(self, policy_fn, storage): 1273*da0073e9SAndroid Build Coastguard Worker self.policy_fn = policy_fn 1274*da0073e9SAndroid Build Coastguard Worker self.storage = storage 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1277*da0073e9SAndroid Build Coastguard Worker if func in SAC_IGNORED_OPS: 1278*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1281*da0073e9SAndroid Build Coastguard Worker policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), 1282*da0073e9SAndroid Build Coastguard Worker func, *args, **kwargs) 1283*da0073e9SAndroid Build Coastguard Worker if isinstance(policy, bool): 1284*da0073e9SAndroid Build Coastguard Worker policy = _policy_from_bool(policy) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker is_compiling = _is_compiling(func, args, kwargs) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker if is_compiling: 1289*da0073e9SAndroid Build Coastguard Worker # Overwrite each node's "recompute" tag to add in the user annotation. 1290*da0073e9SAndroid Build Coastguard Worker fx_traceback.current_meta["recompute"] = policy 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker out = func(*args, **kwargs) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 1297*da0073e9SAndroid Build Coastguard Worker self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) 1298*da0073e9SAndroid Build Coastguard Worker return out 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Workerclass _CachedTorchDispatchMode(TorchDispatchMode): 1301*da0073e9SAndroid Build Coastguard Worker # Used together with _CachedTorchDispatchMode to implement SAC. 1302*da0073e9SAndroid Build Coastguard Worker def __init__(self, policy_fn, storage, allow_cache_entry_mutation): 1303*da0073e9SAndroid Build Coastguard Worker self.policy_fn = policy_fn 1304*da0073e9SAndroid Build Coastguard Worker self.storage = storage 1305*da0073e9SAndroid Build Coastguard Worker self.allow_cache_entry_mutation = allow_cache_entry_mutation 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1308*da0073e9SAndroid Build Coastguard Worker if func in SAC_IGNORED_OPS: 1309*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1312*da0073e9SAndroid Build Coastguard Worker policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), 1313*da0073e9SAndroid Build Coastguard Worker func, *args, **kwargs) 1314*da0073e9SAndroid Build Coastguard Worker if isinstance(policy, bool): 1315*da0073e9SAndroid Build Coastguard Worker policy = _policy_from_bool(policy) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker is_compiling = _is_compiling(func, args, kwargs) 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 1320*da0073e9SAndroid Build Coastguard Worker storage = self.storage.get(func) 1321*da0073e9SAndroid Build Coastguard Worker if storage is None: 1322*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"{func} encountered during backward, but not found in storage") 1323*da0073e9SAndroid Build Coastguard Worker if len(storage) == 0: 1324*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1325*da0073e9SAndroid Build Coastguard Worker "Trying to backward an extra time. You are only allowed to backward once " 1326*da0073e9SAndroid Build Coastguard Worker "on any region computed under selective activation checkpoint." 1327*da0073e9SAndroid Build Coastguard Worker ) 1328*da0073e9SAndroid Build Coastguard Worker out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) 1329*da0073e9SAndroid Build Coastguard Worker else: 1330*da0073e9SAndroid Build Coastguard Worker out = func(*args, **kwargs) 1331*da0073e9SAndroid Build Coastguard Worker return out 1332*da0073e9SAndroid Build Coastguard Worker 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Workerdef create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 1335*da0073e9SAndroid Build Coastguard Worker """ 1336*da0073e9SAndroid Build Coastguard Worker Helper to avoid recomputing certain ops during activation checkpointing. 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker Use this with `torch.utils.checkpoint.checkpoint` to control which 1339*da0073e9SAndroid Build Coastguard Worker operations are recomputed during the backward pass. 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker Args: 1342*da0073e9SAndroid Build Coastguard Worker policy_fn_or_list (Callable or List): 1343*da0073e9SAndroid Build Coastguard Worker - If a policy function is provided, it should accept a 1344*da0073e9SAndroid Build Coastguard Worker :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and 1345*da0073e9SAndroid Build Coastguard Worker kwargs to the op, and return a :class:`CheckpointPolicy` enum value 1346*da0073e9SAndroid Build Coastguard Worker indicating whether the execution of the op should be recomputed or not. 1347*da0073e9SAndroid Build Coastguard Worker - If a list of operations is provided, it is equivalent to a policy 1348*da0073e9SAndroid Build Coastguard Worker returning `CheckpointPolicy.MUST_SAVE` for the specified 1349*da0073e9SAndroid Build Coastguard Worker operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other 1350*da0073e9SAndroid Build Coastguard Worker operations. 1351*da0073e9SAndroid Build Coastguard Worker allow_cache_entry_mutation (bool, optional): By default, an error is 1352*da0073e9SAndroid Build Coastguard Worker raised if any tensors cached by selective activation checkpoint are 1353*da0073e9SAndroid Build Coastguard Worker mutated in order to ensure correctness. If set to `True`, this check 1354*da0073e9SAndroid Build Coastguard Worker is disabled. 1355*da0073e9SAndroid Build Coastguard Worker Returns: 1356*da0073e9SAndroid Build Coastguard Worker A tuple of two context managers. 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker Example: 1359*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(LINUX) 1360*da0073e9SAndroid Build Coastguard Worker >>> import functools 1361*da0073e9SAndroid Build Coastguard Worker >>> 1362*da0073e9SAndroid Build Coastguard Worker >>> x = torch.rand(10, 10, requires_grad=True) 1363*da0073e9SAndroid Build Coastguard Worker >>> y = torch.rand(10, 10, requires_grad=True) 1364*da0073e9SAndroid Build Coastguard Worker >>> 1365*da0073e9SAndroid Build Coastguard Worker >>> ops_to_save = [ 1366*da0073e9SAndroid Build Coastguard Worker >>> torch.ops.aten.mm.default, 1367*da0073e9SAndroid Build Coastguard Worker >>> ] 1368*da0073e9SAndroid Build Coastguard Worker >>> 1369*da0073e9SAndroid Build Coastguard Worker >>> def policy_fn(ctx, op, *args, **kwargs): 1370*da0073e9SAndroid Build Coastguard Worker >>> if op in ops_to_save: 1371*da0073e9SAndroid Build Coastguard Worker >>> return CheckpointPolicy.MUST_SAVE 1372*da0073e9SAndroid Build Coastguard Worker >>> else: 1373*da0073e9SAndroid Build Coastguard Worker >>> return CheckpointPolicy.PREFER_RECOMPUTE 1374*da0073e9SAndroid Build Coastguard Worker >>> 1375*da0073e9SAndroid Build Coastguard Worker >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 1376*da0073e9SAndroid Build Coastguard Worker >>> 1377*da0073e9SAndroid Build Coastguard Worker >>> # or equivalently 1378*da0073e9SAndroid Build Coastguard Worker >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) 1379*da0073e9SAndroid Build Coastguard Worker >>> 1380*da0073e9SAndroid Build Coastguard Worker >>> def fn(x, y): 1381*da0073e9SAndroid Build Coastguard Worker >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 1382*da0073e9SAndroid Build Coastguard Worker >>> 1383*da0073e9SAndroid Build Coastguard Worker >>> out = torch.utils.checkpoint.checkpoint( 1384*da0073e9SAndroid Build Coastguard Worker >>> fn, x, y, 1385*da0073e9SAndroid Build Coastguard Worker >>> use_reentrant=False, 1386*da0073e9SAndroid Build Coastguard Worker >>> context_fn=context_fn, 1387*da0073e9SAndroid Build Coastguard Worker >>> ) 1388*da0073e9SAndroid Build Coastguard Worker """ 1389*da0073e9SAndroid Build Coastguard Worker # NB: If grad_mode is disabled, checkpoint would not run forward under 1390*da0073e9SAndroid Build Coastguard Worker # context_fn anyway, so proceed as usual. 1391*da0073e9SAndroid Build Coastguard Worker if isinstance(policy_fn_or_list, list): 1392*da0073e9SAndroid Build Coastguard Worker for op in policy_fn_or_list: 1393*da0073e9SAndroid Build Coastguard Worker if not isinstance(op, torch._ops.OpOverload): 1394*da0073e9SAndroid Build Coastguard Worker _extra_msg = ( 1395*da0073e9SAndroid Build Coastguard Worker "Please update the OpOverloadPacket to a specific OpOverload." 1396*da0073e9SAndroid Build Coastguard Worker "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." 1397*da0073e9SAndroid Build Coastguard Worker ) if isinstance(op, torch._ops.OpOverloadPacket) else "" 1398*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1399*da0073e9SAndroid Build Coastguard Worker f"Expected op in `op_list` to be an OpOverload but got: {op} " 1400*da0073e9SAndroid Build Coastguard Worker f"of type {type(op)}. {_extra_msg}" 1401*da0073e9SAndroid Build Coastguard Worker ) 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 1404*da0073e9SAndroid Build Coastguard Worker if op in policy_fn_or_list: 1405*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 1406*da0073e9SAndroid Build Coastguard Worker else: 1407*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 1408*da0073e9SAndroid Build Coastguard Worker elif callable(policy_fn_or_list): 1409*da0073e9SAndroid Build Coastguard Worker policy_fn = policy_fn_or_list 1410*da0073e9SAndroid Build Coastguard Worker else: 1411*da0073e9SAndroid Build Coastguard Worker raise TypeError("policy_fn_or_list must be either a function or a list of ops.") 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker storage: Dict[Any, List[Any]] = defaultdict(list) 1414*da0073e9SAndroid Build Coastguard Worker return ( 1415*da0073e9SAndroid Build Coastguard Worker _CachingTorchDispatchMode(policy_fn, storage), 1416*da0073e9SAndroid Build Coastguard Worker _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), 1417*da0073e9SAndroid Build Coastguard Worker ) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker# NB: this helper wraps fn before calling checkpoint_impl. kwargs and 1420*da0073e9SAndroid Build Coastguard Worker# saving/restoring of global state is handled here. 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Workerdef _checkpoint_without_reentrant_generator( 1423*da0073e9SAndroid Build Coastguard Worker fn, 1424*da0073e9SAndroid Build Coastguard Worker preserve_rng_state=True, 1425*da0073e9SAndroid Build Coastguard Worker context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, 1426*da0073e9SAndroid Build Coastguard Worker determinism_check: str = _DEFAULT_DETERMINISM_MODE, 1427*da0073e9SAndroid Build Coastguard Worker debug: bool = False, 1428*da0073e9SAndroid Build Coastguard Worker *args, 1429*da0073e9SAndroid Build Coastguard Worker **kwargs 1430*da0073e9SAndroid Build Coastguard Worker): 1431*da0073e9SAndroid Build Coastguard Worker """Checkpointing without reentrant autograd. 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker Args: 1434*da0073e9SAndroid Build Coastguard Worker function: describes what to run in the forward pass of the model or 1435*da0073e9SAndroid Build Coastguard Worker part of the model. It should also know how to handle the inputs 1436*da0073e9SAndroid Build Coastguard Worker passed as the tuple. For example, in LSTM, if user passes 1437*da0073e9SAndroid Build Coastguard Worker ``(activation, hidden)``, :attr:`function` should correctly use the 1438*da0073e9SAndroid Build Coastguard Worker first input as ``activation`` and the second input as ``hidden`` 1439*da0073e9SAndroid Build Coastguard Worker preserve_rng_state(bool, optional): Omit stashing and restoring 1440*da0073e9SAndroid Build Coastguard Worker the RNG state during each checkpoint. 1441*da0073e9SAndroid Build Coastguard Worker Default: ``True`` 1442*da0073e9SAndroid Build Coastguard Worker context_fn(Callable, optional): A callable returning a tuple of two 1443*da0073e9SAndroid Build Coastguard Worker context managers. The function and its recomputation will be run 1444*da0073e9SAndroid Build Coastguard Worker under the first and second context managers respectively. 1445*da0073e9SAndroid Build Coastguard Worker determinism_check(str, optional): A string specifying the determinism 1446*da0073e9SAndroid Build Coastguard Worker check to perform. By default it is set to ``"default"`` which 1447*da0073e9SAndroid Build Coastguard Worker compares the shapes, dtypes, and devices of the recomputed tensors 1448*da0073e9SAndroid Build Coastguard Worker against those the saved tensors. To turn off this check, specify 1449*da0073e9SAndroid Build Coastguard Worker ``"none"``. Currently these are the only two supported values. 1450*da0073e9SAndroid Build Coastguard Worker Please open an issue if you would like to see more determinism 1451*da0073e9SAndroid Build Coastguard Worker checks. 1452*da0073e9SAndroid Build Coastguard Worker debug(bool, optional): If ``True``, error messages will also include 1453*da0073e9SAndroid Build Coastguard Worker a trace of the operators ran during the original forward computation 1454*da0073e9SAndroid Build Coastguard Worker as well as the recomputation. 1455*da0073e9SAndroid Build Coastguard Worker *args: Arguments to pass in to the given ``function``. 1456*da0073e9SAndroid Build Coastguard Worker **kwargs: Keyword arguments to pass into the given ``function``. 1457*da0073e9SAndroid Build Coastguard Worker """ 1458*da0073e9SAndroid Build Coastguard Worker unpack_error_cb = None 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: 1461*da0073e9SAndroid Build Coastguard Worker if context_fn != noop_context_fn: 1462*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1463*da0073e9SAndroid Build Coastguard Worker "debug=True is incompatible with non-default context_fn" 1464*da0073e9SAndroid Build Coastguard Worker ) 1465*da0073e9SAndroid Build Coastguard Worker context_fn, unpack_error_cb = _get_debug_context_and_cb() 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker if determinism_check in _allowed_determinism_checks_to_fns: 1468*da0073e9SAndroid Build Coastguard Worker metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] 1469*da0073e9SAndroid Build Coastguard Worker else: 1470*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1471*da0073e9SAndroid Build Coastguard Worker f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " 1472*da0073e9SAndroid Build Coastguard Worker f"but got {determinism_check}" 1473*da0073e9SAndroid Build Coastguard Worker ) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker device_type = _infer_device_type(*args) 1476*da0073e9SAndroid Build Coastguard Worker device_module = _get_device_module(device_type) 1477*da0073e9SAndroid Build Coastguard Worker forward_context, recompute_context = context_fn() 1478*da0073e9SAndroid Build Coastguard Worker if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: 1479*da0073e9SAndroid Build Coastguard Worker assert ( 1480*da0073e9SAndroid Build Coastguard Worker isinstance(forward_context, TorchDispatchMode) and 1481*da0073e9SAndroid Build Coastguard Worker isinstance(recompute_context, TorchDispatchMode) 1482*da0073e9SAndroid Build Coastguard Worker ), \ 1483*da0073e9SAndroid Build Coastguard Worker "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ 1484*da0073e9SAndroid Build Coastguard Worker "must generate a tuple of two `TorchDispatchMode`s." 1485*da0073e9SAndroid Build Coastguard Worker # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 1486*da0073e9SAndroid Build Coastguard Worker device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker if preserve_rng_state: 1489*da0073e9SAndroid Build Coastguard Worker fwd_cpu_state = torch.get_rng_state() 1490*da0073e9SAndroid Build Coastguard Worker # Don't eagerly initialize the cuda context by accident. 1491*da0073e9SAndroid Build Coastguard Worker # (If the user intends that the context is initialized later, within their 1492*da0073e9SAndroid Build Coastguard Worker # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 1493*da0073e9SAndroid Build Coastguard Worker # we have no way to anticipate this will happen before we run the function. 1494*da0073e9SAndroid Build Coastguard Worker # If they do so, we raise an error.) 1495*da0073e9SAndroid Build Coastguard Worker had_device_in_fwd = False 1496*da0073e9SAndroid Build Coastguard Worker if getattr(device_module, "_initialized", False): 1497*da0073e9SAndroid Build Coastguard Worker had_device_in_fwd = True 1498*da0073e9SAndroid Build Coastguard Worker fwd_devices, fwd_device_states = get_device_states(*args) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker def recompute_fn(*inputs): 1501*da0073e9SAndroid Build Coastguard Worker kwargs, *args = inputs 1502*da0073e9SAndroid Build Coastguard Worker # This will be called later during recomputation. This wrapping enables 1503*da0073e9SAndroid Build Coastguard Worker # the necessary global state to be captured. 1504*da0073e9SAndroid Build Coastguard Worker rng_devices = [] 1505*da0073e9SAndroid Build Coastguard Worker if preserve_rng_state and had_device_in_fwd: 1506*da0073e9SAndroid Build Coastguard Worker rng_devices = fwd_devices 1507*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng( 1508*da0073e9SAndroid Build Coastguard Worker devices=rng_devices, enabled=preserve_rng_state, device_type=device_type 1509*da0073e9SAndroid Build Coastguard Worker ): 1510*da0073e9SAndroid Build Coastguard Worker if preserve_rng_state: 1511*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(fwd_cpu_state) 1512*da0073e9SAndroid Build Coastguard Worker if had_device_in_fwd: 1513*da0073e9SAndroid Build Coastguard Worker set_device_states(fwd_devices, fwd_device_states, device_type=device_type) 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker device_autocast_ctx = torch.amp.autocast( 1516*da0073e9SAndroid Build Coastguard Worker device_type=device_type, **device_autocast_kwargs 1517*da0073e9SAndroid Build Coastguard Worker ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() 1518*da0073e9SAndroid Build Coastguard Worker with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] 1519*da0073e9SAndroid Build Coastguard Worker fn(*args, **kwargs) 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker new_frame = _CheckpointFrame( 1522*da0073e9SAndroid Build Coastguard Worker recompute_fn, 1523*da0073e9SAndroid Build Coastguard Worker _enable_checkpoint_early_stop, 1524*da0073e9SAndroid Build Coastguard Worker unpack_error_cb, 1525*da0073e9SAndroid Build Coastguard Worker metadata_fn 1526*da0073e9SAndroid Build Coastguard Worker ) 1527*da0073e9SAndroid Build Coastguard Worker dummy = torch.empty((0,), requires_grad=True) 1528*da0073e9SAndroid Build Coastguard Worker new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker # When ambient grad_mode is False 1531*da0073e9SAndroid Build Coastguard Worker if new_frame.input_saver.grad_fn is None: 1532*da0073e9SAndroid Build Coastguard Worker yield 1533*da0073e9SAndroid Build Coastguard Worker return 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker with _checkpoint_hook(new_frame), forward_context: 1536*da0073e9SAndroid Build Coastguard Worker yield 1537*da0073e9SAndroid Build Coastguard Worker new_frame.forward_completed = True 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker if getattr(device_module, "_initialized", False) and \ 1540*da0073e9SAndroid Build Coastguard Worker preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] 1541*da0073e9SAndroid Build Coastguard Worker # Device was not initialized before running the forward, so we didn't 1542*da0073e9SAndroid Build Coastguard Worker # stash the device state. 1543*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1544*da0073e9SAndroid Build Coastguard Worker "PyTorch's device state was initialized in the forward pass " 1545*da0073e9SAndroid Build Coastguard Worker "of a Checkpoint, which is not allowed. Please open an issue " 1546*da0073e9SAndroid Build Coastguard Worker "if you need this feature." 1547*da0073e9SAndroid Build Coastguard Worker ) 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker return 1550