xref: /aosp_15_r20/external/pytorch/torch/utils/checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport platform
5*da0073e9SAndroid Build Coastguard Workerimport uuid
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Workerimport weakref
8*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
9*da0073e9SAndroid Build Coastguard Workerfrom typing import *  # noqa: F403
10*da0073e9SAndroid Build Coastguard Workerimport enum
11*da0073e9SAndroid Build Coastguard Workerfrom weakref import ReferenceType
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback
15*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch._aot_autograd.functional_utils import is_fun
16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
18*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker__all__ = [
21*da0073e9SAndroid Build Coastguard Worker    "checkpoint",
22*da0073e9SAndroid Build Coastguard Worker    "checkpoint_sequential",
23*da0073e9SAndroid Build Coastguard Worker    "CheckpointError",
24*da0073e9SAndroid Build Coastguard Worker    "CheckpointFunction",
25*da0073e9SAndroid Build Coastguard Worker    "check_backward_validity",
26*da0073e9SAndroid Build Coastguard Worker    "detach_variable",
27*da0073e9SAndroid Build Coastguard Worker    "get_device_states",
28*da0073e9SAndroid Build Coastguard Worker    "set_device_states",
29*da0073e9SAndroid Build Coastguard Worker    "noop_context_fn",
30*da0073e9SAndroid Build Coastguard Worker    "set_checkpoint_early_stop",
31*da0073e9SAndroid Build Coastguard Worker    "DefaultDeviceType",
32*da0073e9SAndroid Build Coastguard Worker    "set_checkpoint_debug_enabled",
33*da0073e9SAndroid Build Coastguard Worker    "CheckpointPolicy",
34*da0073e9SAndroid Build Coastguard Worker    "SelectiveCheckpointContext",
35*da0073e9SAndroid Build Coastguard Worker    "create_selective_checkpoint_contexts",
36*da0073e9SAndroid Build Coastguard Worker    "SAC_IGNORED_OPS",
37*da0073e9SAndroid Build Coastguard Worker]
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker_DEFAULT_DETERMINISM_MODE = "default"
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker_checkpoint_debug_enabled: Optional[bool] = None
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
45*da0073e9SAndroid Build Coastguard Workerdef set_checkpoint_debug_enabled(enabled: Optional[bool]):
46*da0073e9SAndroid Build Coastguard Worker    """
47*da0073e9SAndroid Build Coastguard Worker    Context manager that sets whether checkpoint should print additional debug
48*da0073e9SAndroid Build Coastguard Worker    information when running. See the ``debug`` flag for
49*da0073e9SAndroid Build Coastguard Worker    :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
50*da0073e9SAndroid Build Coastguard Worker    when set, this context manager overrides the value of ``debug`` passed to
51*da0073e9SAndroid Build Coastguard Worker    checkpoint. To defer to the local setting, pass ``None`` to this context.
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    Args:
54*da0073e9SAndroid Build Coastguard Worker        enabled (bool): Whether checkpoint should print debug information.
55*da0073e9SAndroid Build Coastguard Worker            Default is 'None'.
56*da0073e9SAndroid Build Coastguard Worker    """
57*da0073e9SAndroid Build Coastguard Worker    global _checkpoint_debug_enabled
58*da0073e9SAndroid Build Coastguard Worker    try:
59*da0073e9SAndroid Build Coastguard Worker        prev = _checkpoint_debug_enabled
60*da0073e9SAndroid Build Coastguard Worker        _checkpoint_debug_enabled = enabled
61*da0073e9SAndroid Build Coastguard Worker        yield
62*da0073e9SAndroid Build Coastguard Worker    finally:
63*da0073e9SAndroid Build Coastguard Worker        _checkpoint_debug_enabled = prev
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workerdef detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
67*da0073e9SAndroid Build Coastguard Worker    if isinstance(inputs, tuple):
68*da0073e9SAndroid Build Coastguard Worker        out = []
69*da0073e9SAndroid Build Coastguard Worker        for inp in inputs:
70*da0073e9SAndroid Build Coastguard Worker            if not isinstance(inp, torch.Tensor):
71*da0073e9SAndroid Build Coastguard Worker                out.append(inp)
72*da0073e9SAndroid Build Coastguard Worker                continue
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker            x = inp.detach()
75*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = inp.requires_grad
76*da0073e9SAndroid Build Coastguard Worker            out.append(x)
77*da0073e9SAndroid Build Coastguard Worker        return tuple(out)
78*da0073e9SAndroid Build Coastguard Worker    else:
79*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
80*da0073e9SAndroid Build Coastguard Worker            "Only tuple of tensors is supported. Got Unsupported input type: ",
81*da0073e9SAndroid Build Coastguard Worker            type(inputs).__name__,
82*da0073e9SAndroid Build Coastguard Worker        )
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef check_backward_validity(inputs: Iterable[Any]) -> None:
86*da0073e9SAndroid Build Coastguard Worker    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
87*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
88*da0073e9SAndroid Build Coastguard Worker            "None of the inputs have requires_grad=True. Gradients will be None"
89*da0073e9SAndroid Build Coastguard Worker        )
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerdef _get_device_module(device="cuda"):
93*da0073e9SAndroid Build Coastguard Worker    if device == "meta":
94*da0073e9SAndroid Build Coastguard Worker        return torch.device("meta")
95*da0073e9SAndroid Build Coastguard Worker    device_module = getattr(torch, device)
96*da0073e9SAndroid Build Coastguard Worker    return device_module
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerclass DefaultDeviceType:
100*da0073e9SAndroid Build Coastguard Worker    r"""
101*da0073e9SAndroid Build Coastguard Worker    A class that manages the default device type for checkpointing.
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    If no non-CPU tensors are present, the default device type will
104*da0073e9SAndroid Build Coastguard Worker    be used. The default value is 'cuda'. The device type is used in
105*da0073e9SAndroid Build Coastguard Worker    the checkpointing process when determining which device states
106*da0073e9SAndroid Build Coastguard Worker    to save and restore for recomputation.
107*da0073e9SAndroid Build Coastguard Worker    """
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    _default_device_type = "cuda"
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    @staticmethod
112*da0073e9SAndroid Build Coastguard Worker    def set_device_type(device: str = "cuda"):
113*da0073e9SAndroid Build Coastguard Worker        """
114*da0073e9SAndroid Build Coastguard Worker        Set the default device type for checkpointing.
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        Args:
117*da0073e9SAndroid Build Coastguard Worker            device (str): The device type to be set as default. Default is 'cuda'.
118*da0073e9SAndroid Build Coastguard Worker        """
119*da0073e9SAndroid Build Coastguard Worker        DefaultDeviceType._default_device_type = device
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    @staticmethod
122*da0073e9SAndroid Build Coastguard Worker    def get_device_type() -> str:
123*da0073e9SAndroid Build Coastguard Worker        """
124*da0073e9SAndroid Build Coastguard Worker        Get the current default device type for checkpointing.
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        Returns:
127*da0073e9SAndroid Build Coastguard Worker            str: The current default device type.
128*da0073e9SAndroid Build Coastguard Worker        """
129*da0073e9SAndroid Build Coastguard Worker        return DefaultDeviceType._default_device_type
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Workerdef _infer_device_type(*args):
133*da0073e9SAndroid Build Coastguard Worker    device_types = []
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def add_device_types(arg):
136*da0073e9SAndroid Build Coastguard Worker        nonlocal device_types
137*da0073e9SAndroid Build Coastguard Worker        if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
138*da0073e9SAndroid Build Coastguard Worker            device_types.append(arg.device.type)
139*da0073e9SAndroid Build Coastguard Worker    tree_map(add_device_types, args)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    device_types_set = set(device_types)
142*da0073e9SAndroid Build Coastguard Worker    if len(device_types_set) > 1:
143*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
144*da0073e9SAndroid Build Coastguard Worker            "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
145*da0073e9SAndroid Build Coastguard Worker            "Device state will only be saved for devices of a single device type, and the remaining "
146*da0073e9SAndroid Build Coastguard Worker            "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
147*da0073e9SAndroid Build Coastguard Worker            "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
148*da0073e9SAndroid Build Coastguard Worker            "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
149*da0073e9SAndroid Build Coastguard Worker            f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}"
150*da0073e9SAndroid Build Coastguard Worker        )
151*da0073e9SAndroid Build Coastguard Worker    if len(device_types) == 0:
152*da0073e9SAndroid Build Coastguard Worker        return DefaultDeviceType.get_device_type()
153*da0073e9SAndroid Build Coastguard Worker    elif "cuda" in device_types_set:
154*da0073e9SAndroid Build Coastguard Worker        return "cuda"
155*da0073e9SAndroid Build Coastguard Worker    else:
156*da0073e9SAndroid Build Coastguard Worker        return device_types[0]
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker# We can't know if the run_fn will internally move some args to different devices,
160*da0073e9SAndroid Build Coastguard Worker# which would require logic to preserve rng states for those devices as well.
161*da0073e9SAndroid Build Coastguard Worker# We could paranoically stash and restore ALL the rng states for all visible devices,
162*da0073e9SAndroid Build Coastguard Worker# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
163*da0073e9SAndroid Build Coastguard Worker# the device of all Tensor args.
164*da0073e9SAndroid Build Coastguard Worker#
165*da0073e9SAndroid Build Coastguard Worker# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
166*da0073e9SAndroid Build Coastguard Workerdef get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
167*da0073e9SAndroid Build Coastguard Worker    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
168*da0073e9SAndroid Build Coastguard Worker    # the conditionals short-circuit.
169*da0073e9SAndroid Build Coastguard Worker    fwd_device_ids = []
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    def add_device_ids(arg):
172*da0073e9SAndroid Build Coastguard Worker        nonlocal fwd_device_ids
173*da0073e9SAndroid Build Coastguard Worker        if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
174*da0073e9SAndroid Build Coastguard Worker            fwd_device_ids.append(arg.get_device())
175*da0073e9SAndroid Build Coastguard Worker    tree_map(add_device_ids, args)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    fwd_device_states = []
178*da0073e9SAndroid Build Coastguard Worker    device_module = _get_device_module(_infer_device_type(*args))
179*da0073e9SAndroid Build Coastguard Worker    for device_id in fwd_device_ids:
180*da0073e9SAndroid Build Coastguard Worker        with device_module.device(device_id):
181*da0073e9SAndroid Build Coastguard Worker            fwd_device_states.append(device_module.get_rng_state())
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    return fwd_device_ids, fwd_device_states
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef set_device_states(devices, states, *, device_type=None) -> None:
187*da0073e9SAndroid Build Coastguard Worker    """Sets random number generator states for the specified devices.
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    Args:
190*da0073e9SAndroid Build Coastguard Worker        devices: Device ids to set states for.
191*da0073e9SAndroid Build Coastguard Worker        states: States to set.
192*da0073e9SAndroid Build Coastguard Worker        device_type: ``device_type`` of the devices to set states for. Default
193*da0073e9SAndroid Build Coastguard Worker            is the device returned by a call to ``DefaultDeviceType.get_device_type()``,
194*da0073e9SAndroid Build Coastguard Worker            which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``.
195*da0073e9SAndroid Build Coastguard Worker    """
196*da0073e9SAndroid Build Coastguard Worker    if device_type is None:
197*da0073e9SAndroid Build Coastguard Worker        device_type = DefaultDeviceType.get_device_type()
198*da0073e9SAndroid Build Coastguard Worker    if device_type == "meta":
199*da0073e9SAndroid Build Coastguard Worker        return
200*da0073e9SAndroid Build Coastguard Worker    device_module = _get_device_module(device_type)
201*da0073e9SAndroid Build Coastguard Worker    for device, state in zip(devices, states):
202*da0073e9SAndroid Build Coastguard Worker        with device_module.device(device):
203*da0073e9SAndroid Build Coastguard Worker            device_module.set_rng_state(state)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerdef _get_autocast_kwargs(device_type="cuda"):
207*da0073e9SAndroid Build Coastguard Worker    if torch.amp.is_autocast_available(device_type):
208*da0073e9SAndroid Build Coastguard Worker        device_autocast_kwargs = {
209*da0073e9SAndroid Build Coastguard Worker            "enabled": torch.is_autocast_enabled(device_type),
210*da0073e9SAndroid Build Coastguard Worker            "dtype": torch.get_autocast_dtype(device_type),
211*da0073e9SAndroid Build Coastguard Worker            "cache_enabled": torch.is_autocast_cache_enabled(),
212*da0073e9SAndroid Build Coastguard Worker        }
213*da0073e9SAndroid Build Coastguard Worker    else:
214*da0073e9SAndroid Build Coastguard Worker        device_autocast_kwargs = None
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    cpu_autocast_kwargs = {
217*da0073e9SAndroid Build Coastguard Worker        "enabled": torch.is_autocast_enabled('cpu'),
218*da0073e9SAndroid Build Coastguard Worker        "dtype": torch.get_autocast_dtype('cpu'),
219*da0073e9SAndroid Build Coastguard Worker        "cache_enabled": torch.is_autocast_cache_enabled(),
220*da0073e9SAndroid Build Coastguard Worker    }
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    return device_autocast_kwargs, cpu_autocast_kwargs
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Workerclass CheckpointFunction(torch.autograd.Function):
226*da0073e9SAndroid Build Coastguard Worker    @staticmethod
227*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, run_function, preserve_rng_state, *args):
228*da0073e9SAndroid Build Coastguard Worker        check_backward_validity(args)
229*da0073e9SAndroid Build Coastguard Worker        ctx.run_function = run_function
230*da0073e9SAndroid Build Coastguard Worker        ctx.preserve_rng_state = preserve_rng_state
231*da0073e9SAndroid Build Coastguard Worker        # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
232*da0073e9SAndroid Build Coastguard Worker        ctx.device_type = _infer_device_type(*args)
233*da0073e9SAndroid Build Coastguard Worker        ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
234*da0073e9SAndroid Build Coastguard Worker            ctx.device_type
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker        if preserve_rng_state:
237*da0073e9SAndroid Build Coastguard Worker            ctx.fwd_cpu_state = torch.get_rng_state()
238*da0073e9SAndroid Build Coastguard Worker            # Don't eagerly initialize the cuda context by accident.
239*da0073e9SAndroid Build Coastguard Worker            # (If the user intends that the context is initialized later, within their
240*da0073e9SAndroid Build Coastguard Worker            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
241*da0073e9SAndroid Build Coastguard Worker            # we have no way to anticipate this will happen before we run the function.)
242*da0073e9SAndroid Build Coastguard Worker            ctx.had_device_in_fwd = False
243*da0073e9SAndroid Build Coastguard Worker            device_module = _get_device_module(ctx.device_type)
244*da0073e9SAndroid Build Coastguard Worker            if getattr(device_module, "_initialized", False):
245*da0073e9SAndroid Build Coastguard Worker                ctx.had_device_in_fwd = True
246*da0073e9SAndroid Build Coastguard Worker                ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
249*da0073e9SAndroid Build Coastguard Worker        # to be filled out during the backward.
250*da0073e9SAndroid Build Coastguard Worker        ctx.inputs = []
251*da0073e9SAndroid Build Coastguard Worker        ctx.tensor_indices = []
252*da0073e9SAndroid Build Coastguard Worker        tensor_inputs = []
253*da0073e9SAndroid Build Coastguard Worker        for i, arg in enumerate(args):
254*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(arg):
255*da0073e9SAndroid Build Coastguard Worker                tensor_inputs.append(arg)
256*da0073e9SAndroid Build Coastguard Worker                ctx.tensor_indices.append(i)
257*da0073e9SAndroid Build Coastguard Worker                ctx.inputs.append(None)
258*da0073e9SAndroid Build Coastguard Worker            else:
259*da0073e9SAndroid Build Coastguard Worker                ctx.inputs.append(arg)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(*tensor_inputs)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
264*da0073e9SAndroid Build Coastguard Worker            outputs = run_function(*args)
265*da0073e9SAndroid Build Coastguard Worker        return outputs
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @staticmethod
268*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, *args):
269*da0073e9SAndroid Build Coastguard Worker        if not torch.autograd._is_checkpoint_valid():
270*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
271*da0073e9SAndroid Build Coastguard Worker                "When use_reentrant=True, torch.utils.checkpoint is incompatible"
272*da0073e9SAndroid Build Coastguard Worker                " with .grad() or passing an `inputs` parameter to .backward()."
273*da0073e9SAndroid Build Coastguard Worker                " To resolve this error, you can either set use_reentrant=False,"
274*da0073e9SAndroid Build Coastguard Worker                " or call .backward() without passing the `inputs` argument."
275*da0073e9SAndroid Build Coastguard Worker            )
276*da0073e9SAndroid Build Coastguard Worker        # Copy the list to avoid modifying original list.
277*da0073e9SAndroid Build Coastguard Worker        inputs = list(ctx.inputs)
278*da0073e9SAndroid Build Coastguard Worker        tensor_indices = ctx.tensor_indices
279*da0073e9SAndroid Build Coastguard Worker        tensors = ctx.saved_tensors
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        # Fill in inputs with appropriate saved tensors.
282*da0073e9SAndroid Build Coastguard Worker        for i, idx in enumerate(tensor_indices):
283*da0073e9SAndroid Build Coastguard Worker            inputs[idx] = tensors[i]
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        # Stash the surrounding rng state, and mimic the state that was
286*da0073e9SAndroid Build Coastguard Worker        # present at this time during forward.  Restore the surrounding state
287*da0073e9SAndroid Build Coastguard Worker        # when we're done.
288*da0073e9SAndroid Build Coastguard Worker        rng_devices = []
289*da0073e9SAndroid Build Coastguard Worker        if ctx.preserve_rng_state and ctx.had_device_in_fwd:
290*da0073e9SAndroid Build Coastguard Worker            rng_devices = ctx.fwd_devices
291*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(
292*da0073e9SAndroid Build Coastguard Worker            devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type
293*da0073e9SAndroid Build Coastguard Worker        ):
294*da0073e9SAndroid Build Coastguard Worker            if ctx.preserve_rng_state:
295*da0073e9SAndroid Build Coastguard Worker                torch.set_rng_state(ctx.fwd_cpu_state)
296*da0073e9SAndroid Build Coastguard Worker                if ctx.had_device_in_fwd:
297*da0073e9SAndroid Build Coastguard Worker                    set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type)
298*da0073e9SAndroid Build Coastguard Worker            detached_inputs = detach_variable(tuple(inputs))
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker            device_autocast_ctx = torch.amp.autocast(
301*da0073e9SAndroid Build Coastguard Worker                device_type=ctx.device_type, **ctx.device_autocast_kwargs
302*da0073e9SAndroid Build Coastguard Worker            ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext()
303*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
304*da0073e9SAndroid Build Coastguard Worker                outputs = ctx.run_function(*detached_inputs)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        if isinstance(outputs, torch.Tensor):
307*da0073e9SAndroid Build Coastguard Worker            outputs = (outputs,)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        # run backward() with only tensor that requires grad
310*da0073e9SAndroid Build Coastguard Worker        outputs_with_grad = []
311*da0073e9SAndroid Build Coastguard Worker        args_with_grad = []
312*da0073e9SAndroid Build Coastguard Worker        for i in range(len(outputs)):
313*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
314*da0073e9SAndroid Build Coastguard Worker                outputs_with_grad.append(outputs[i])
315*da0073e9SAndroid Build Coastguard Worker                args_with_grad.append(args[i])
316*da0073e9SAndroid Build Coastguard Worker        if len(outputs_with_grad) == 0:
317*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
318*da0073e9SAndroid Build Coastguard Worker                "none of output has requires_grad=True,"
319*da0073e9SAndroid Build Coastguard Worker                " this checkpoint() is not necessary"
320*da0073e9SAndroid Build Coastguard Worker            )
321*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(outputs_with_grad, args_with_grad)
322*da0073e9SAndroid Build Coastguard Worker        grads = tuple(
323*da0073e9SAndroid Build Coastguard Worker            inp.grad if isinstance(inp, torch.Tensor) else None
324*da0073e9SAndroid Build Coastguard Worker            for inp in detached_inputs
325*da0073e9SAndroid Build Coastguard Worker        )
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        return (None, None) + grads
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Workerdef noop_context_fn():
331*da0073e9SAndroid Build Coastguard Worker    return contextlib.nullcontext(), contextlib.nullcontext()
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker# TorchDynamo does not step inside utils.checkpoint function.  The flow
334*da0073e9SAndroid Build Coastguard Worker# looks likes this
335*da0073e9SAndroid Build Coastguard Worker#  1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
336*da0073e9SAndroid Build Coastguard Worker#     speculatively checking if the forward function is safe to trace.
337*da0073e9SAndroid Build Coastguard Worker#  2) If yes, then Dynamo-generated Fx graph has the wrapped higher
338*da0073e9SAndroid Build Coastguard Worker#     order op. As a result, TorchDynamo does not look inside utils.checkpoint.
339*da0073e9SAndroid Build Coastguard Worker#  3) If not, then TorchDynamo falls back to eager by performing a graph
340*da0073e9SAndroid Build Coastguard Worker#     break. And here, the following disable wrapper ensures that
341*da0073e9SAndroid Build Coastguard Worker#     TorchDynamo does not trigger again on the frames created by
342*da0073e9SAndroid Build Coastguard Worker#     utils.checkpoint innards.
343*da0073e9SAndroid Build Coastguard Worker@torch._disable_dynamo
344*da0073e9SAndroid Build Coastguard Workerdef checkpoint(
345*da0073e9SAndroid Build Coastguard Worker    function,
346*da0073e9SAndroid Build Coastguard Worker    *args,
347*da0073e9SAndroid Build Coastguard Worker    use_reentrant: Optional[bool] = None,
348*da0073e9SAndroid Build Coastguard Worker    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
349*da0073e9SAndroid Build Coastguard Worker    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
350*da0073e9SAndroid Build Coastguard Worker    debug: bool = False,
351*da0073e9SAndroid Build Coastguard Worker    **kwargs
352*da0073e9SAndroid Build Coastguard Worker):
353*da0073e9SAndroid Build Coastguard Worker    r"""Checkpoint a model or part of the model.
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    Activation checkpointing is a technique that trades compute for memory.
356*da0073e9SAndroid Build Coastguard Worker    Instead of keeping tensors needed for backward alive until they are used in
357*da0073e9SAndroid Build Coastguard Worker    gradient computation during backward, forward computation in checkpointed
358*da0073e9SAndroid Build Coastguard Worker    regions omits saving tensors for backward and recomputes them during the
359*da0073e9SAndroid Build Coastguard Worker    backward pass. Activation checkpointing can be applied to any part of a
360*da0073e9SAndroid Build Coastguard Worker    model.
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    There are currently two checkpointing implementations available, determined
363*da0073e9SAndroid Build Coastguard Worker    by the :attr:`use_reentrant` parameter. It is recommended that you use
364*da0073e9SAndroid Build Coastguard Worker    ``use_reentrant=False``. Please refer the note below for a discussion of
365*da0073e9SAndroid Build Coastguard Worker    their differences.
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    .. warning::
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        If the :attr:`function` invocation during the backward pass differs
370*da0073e9SAndroid Build Coastguard Worker        from the forward pass, e.g., due to a global variable, the checkpointed
371*da0073e9SAndroid Build Coastguard Worker        version may not be equivalent, potentially causing an
372*da0073e9SAndroid Build Coastguard Worker        error being raised or leading to silently incorrect gradients.
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    .. warning::
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        The ``use_reentrant`` parameter should be passed explicitly. In version
377*da0073e9SAndroid Build Coastguard Worker        2.4 we will raise an exception if ``use_reentrant`` is not passed.
378*da0073e9SAndroid Build Coastguard Worker        If you are using the ``use_reentrant=True`` variant, please refer to the
379*da0073e9SAndroid Build Coastguard Worker        note below for important considerations and potential limitations.
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    .. note::
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        The reentrant variant of checkpoint (``use_reentrant=True``) and
384*da0073e9SAndroid Build Coastguard Worker        the non-reentrant variant of checkpoint (``use_reentrant=False``)
385*da0073e9SAndroid Build Coastguard Worker        differ in the following ways:
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        * Non-reentrant checkpoint stops recomputation as soon as all needed
388*da0073e9SAndroid Build Coastguard Worker          intermediate activations have been recomputed. This feature is enabled
389*da0073e9SAndroid Build Coastguard Worker          by default, but can be disabled with :func:`set_checkpoint_early_stop`.
390*da0073e9SAndroid Build Coastguard Worker          Reentrant checkpoint always recomputes :attr:`function` in its
391*da0073e9SAndroid Build Coastguard Worker          entirety during the backward pass.
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        * The reentrant variant does not record the autograd graph during the
394*da0073e9SAndroid Build Coastguard Worker          forward pass, as it runs with the forward pass under
395*da0073e9SAndroid Build Coastguard Worker          :func:`torch.no_grad`. The non-reentrant version does record the
396*da0073e9SAndroid Build Coastguard Worker          autograd graph, allowing one to perform backward on the graph within
397*da0073e9SAndroid Build Coastguard Worker          checkpointed regions.
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        * The reentrant checkpoint only supports the
400*da0073e9SAndroid Build Coastguard Worker          :func:`torch.autograd.backward` API for the backward pass without its
401*da0073e9SAndroid Build Coastguard Worker          `inputs` argument, while the non-reentrant version supports all ways
402*da0073e9SAndroid Build Coastguard Worker          of performing the backward pass.
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        * At least one input and output must have ``requires_grad=True`` for the
405*da0073e9SAndroid Build Coastguard Worker          reentrant variant. If this condition is unmet, the checkpointed part
406*da0073e9SAndroid Build Coastguard Worker          of the model will not have gradients. The non-reentrant version does
407*da0073e9SAndroid Build Coastguard Worker          not have this requirement.
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        * The reentrant version does not consider tensors in nested structures
410*da0073e9SAndroid Build Coastguard Worker          (e.g., custom objects, lists, dicts, etc) as participating in
411*da0073e9SAndroid Build Coastguard Worker          autograd, while the non-reentrant version does.
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        * The reentrant checkpoint does not support checkpointed regions with
414*da0073e9SAndroid Build Coastguard Worker          detached tensors from the computational graph, whereas the
415*da0073e9SAndroid Build Coastguard Worker          non-reentrant version does. For the reentrant variant, if the
416*da0073e9SAndroid Build Coastguard Worker          checkpointed segment contains tensors detached using ``detach()`` or
417*da0073e9SAndroid Build Coastguard Worker          with :func:`torch.no_grad`, the backward pass will raise an error.
418*da0073e9SAndroid Build Coastguard Worker          This is because ``checkpoint`` makes all the outputs require gradients
419*da0073e9SAndroid Build Coastguard Worker          and this causes issues when a tensor is defined to have no gradient in
420*da0073e9SAndroid Build Coastguard Worker          the model. To avoid this, detach the tensors outside of the
421*da0073e9SAndroid Build Coastguard Worker          ``checkpoint`` function.
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    Args:
424*da0073e9SAndroid Build Coastguard Worker        function: describes what to run in the forward pass of the model or
425*da0073e9SAndroid Build Coastguard Worker            part of the model. It should also know how to handle the inputs
426*da0073e9SAndroid Build Coastguard Worker            passed as the tuple. For example, in LSTM, if user passes
427*da0073e9SAndroid Build Coastguard Worker            ``(activation, hidden)``, :attr:`function` should correctly use the
428*da0073e9SAndroid Build Coastguard Worker            first input as ``activation`` and the second input as ``hidden``
429*da0073e9SAndroid Build Coastguard Worker        preserve_rng_state(bool, optional):  Omit stashing and restoring
430*da0073e9SAndroid Build Coastguard Worker            the RNG state during each checkpoint. Note that under torch.compile,
431*da0073e9SAndroid Build Coastguard Worker            this flag doesn't take effect and we always preserve RNG state.
432*da0073e9SAndroid Build Coastguard Worker            Default: ``True``
433*da0073e9SAndroid Build Coastguard Worker        use_reentrant(bool):
434*da0073e9SAndroid Build Coastguard Worker            specify whether to use the activation checkpoint variant that
435*da0073e9SAndroid Build Coastguard Worker            requires reentrant autograd. This parameter should be passed
436*da0073e9SAndroid Build Coastguard Worker            explicitly. In version 2.5 we will raise an exception if
437*da0073e9SAndroid Build Coastguard Worker            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
438*da0073e9SAndroid Build Coastguard Worker            ``checkpoint`` will use an implementation that does not require
439*da0073e9SAndroid Build Coastguard Worker            reentrant autograd. This allows ``checkpoint`` to support additional
440*da0073e9SAndroid Build Coastguard Worker            functionality, such as working as expected with
441*da0073e9SAndroid Build Coastguard Worker            ``torch.autograd.grad`` and support for keyword arguments input into
442*da0073e9SAndroid Build Coastguard Worker            the checkpointed function.
443*da0073e9SAndroid Build Coastguard Worker        context_fn(Callable, optional): A callable returning a tuple of two
444*da0073e9SAndroid Build Coastguard Worker            context managers. The function and its recomputation will be run
445*da0073e9SAndroid Build Coastguard Worker            under the first and second context managers respectively.
446*da0073e9SAndroid Build Coastguard Worker            This argument is only supported if ``use_reentrant=False``.
447*da0073e9SAndroid Build Coastguard Worker        determinism_check(str, optional): A string specifying the determinism
448*da0073e9SAndroid Build Coastguard Worker            check to perform. By default it is set to ``"default"`` which
449*da0073e9SAndroid Build Coastguard Worker            compares the shapes, dtypes, and devices of the recomputed tensors
450*da0073e9SAndroid Build Coastguard Worker            against those the saved tensors. To turn off this check, specify
451*da0073e9SAndroid Build Coastguard Worker            ``"none"``. Currently these are the only two supported values.
452*da0073e9SAndroid Build Coastguard Worker            Please open an issue if you would like to see more determinism
453*da0073e9SAndroid Build Coastguard Worker            checks. This argument is only supported if ``use_reentrant=False``,
454*da0073e9SAndroid Build Coastguard Worker            if ``use_reentrant=True``, the determinism check is always disabled.
455*da0073e9SAndroid Build Coastguard Worker        debug(bool, optional): If ``True``, error messages will also include
456*da0073e9SAndroid Build Coastguard Worker            a trace of the operators ran during the original forward computation
457*da0073e9SAndroid Build Coastguard Worker            as well as the recomputation. This argument is only supported if
458*da0073e9SAndroid Build Coastguard Worker            ``use_reentrant=False``.
459*da0073e9SAndroid Build Coastguard Worker        args: tuple containing inputs to the :attr:`function`
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    Returns:
462*da0073e9SAndroid Build Coastguard Worker        Output of running :attr:`function` on :attr:`*args`
463*da0073e9SAndroid Build Coastguard Worker    """
464*da0073e9SAndroid Build Coastguard Worker    if use_reentrant is None:
465*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
466*da0073e9SAndroid Build Coastguard Worker            "torch.utils.checkpoint: the use_reentrant parameter should be "
467*da0073e9SAndroid Build Coastguard Worker            "passed explicitly. In version 2.5 we will raise an exception "
468*da0073e9SAndroid Build Coastguard Worker            "if use_reentrant is not passed. use_reentrant=False is "
469*da0073e9SAndroid Build Coastguard Worker            "recommended, but if you need to preserve the current default "
470*da0073e9SAndroid Build Coastguard Worker            "behavior, you can pass use_reentrant=True. Refer to docs for more "
471*da0073e9SAndroid Build Coastguard Worker            "details on the differences between the two variants.",
472*da0073e9SAndroid Build Coastguard Worker            stacklevel=2
473*da0073e9SAndroid Build Coastguard Worker        )
474*da0073e9SAndroid Build Coastguard Worker        use_reentrant = True
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    # Hack to mix *args with **kwargs in a python 2.7-compliant way
477*da0073e9SAndroid Build Coastguard Worker    preserve = kwargs.pop("preserve_rng_state", True)
478*da0073e9SAndroid Build Coastguard Worker    if kwargs and use_reentrant:
479*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
480*da0073e9SAndroid Build Coastguard Worker            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
481*da0073e9SAndroid Build Coastguard Worker        )
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    if use_reentrant:
484*da0073e9SAndroid Build Coastguard Worker        if context_fn is not noop_context_fn or debug is not False:
485*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
486*da0073e9SAndroid Build Coastguard Worker                "Passing `context_fn` or `debug` is only supported when "
487*da0073e9SAndroid Build Coastguard Worker                "use_reentrant=False."
488*da0073e9SAndroid Build Coastguard Worker            )
489*da0073e9SAndroid Build Coastguard Worker        return CheckpointFunction.apply(function, preserve, *args)
490*da0073e9SAndroid Build Coastguard Worker    else:
491*da0073e9SAndroid Build Coastguard Worker        gen = _checkpoint_without_reentrant_generator(
492*da0073e9SAndroid Build Coastguard Worker            function, preserve, context_fn, determinism_check, debug, *args, **kwargs
493*da0073e9SAndroid Build Coastguard Worker        )
494*da0073e9SAndroid Build Coastguard Worker        # Runs pre-forward logic
495*da0073e9SAndroid Build Coastguard Worker        next(gen)
496*da0073e9SAndroid Build Coastguard Worker        ret = function(*args, **kwargs)
497*da0073e9SAndroid Build Coastguard Worker        # Runs post-forward logic
498*da0073e9SAndroid Build Coastguard Worker        try:
499*da0073e9SAndroid Build Coastguard Worker            next(gen)
500*da0073e9SAndroid Build Coastguard Worker        except StopIteration:
501*da0073e9SAndroid Build Coastguard Worker            return ret
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Workerdef checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
505*da0073e9SAndroid Build Coastguard Worker    r"""Checkpoint a sequential model to save memory.
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    Sequential models execute a list of modules/functions in order
508*da0073e9SAndroid Build Coastguard Worker    (sequentially). Therefore, we can divide such a model in various segments
509*da0073e9SAndroid Build Coastguard Worker    and checkpoint each segment. All segments except the last will not store
510*da0073e9SAndroid Build Coastguard Worker    the intermediate activations. The inputs of each checkpointed segment will
511*da0073e9SAndroid Build Coastguard Worker    be saved for re-running the segment in the backward pass.
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker    .. warning::
514*da0073e9SAndroid Build Coastguard Worker        The ``use_reentrant`` parameter should be passed explicitly. In version
515*da0073e9SAndroid Build Coastguard Worker        2.4 we will raise an exception if ``use_reentrant`` is not passed.
516*da0073e9SAndroid Build Coastguard Worker        If you are using the ``use_reentrant=True` variant, please see
517*da0073e9SAndroid Build Coastguard Worker        :func:`~torch.utils.checkpoint.checkpoint` for
518*da0073e9SAndroid Build Coastguard Worker        the important considerations and limitations of this variant. It is
519*da0073e9SAndroid Build Coastguard Worker        recommended that you use ``use_reentrant=False``.
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    .. warning:
522*da0073e9SAndroid Build Coastguard Worker        Since PyTorch 1.4, it allows only one Tensor as the input and
523*da0073e9SAndroid Build Coastguard Worker        intermediate outputs, just like :class:`torch.nn.Sequential`.
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker    Args:
526*da0073e9SAndroid Build Coastguard Worker        functions: A :class:`torch.nn.Sequential` or the list of modules or
527*da0073e9SAndroid Build Coastguard Worker            functions (comprising the model) to run sequentially.
528*da0073e9SAndroid Build Coastguard Worker        segments: Number of chunks to create in the model
529*da0073e9SAndroid Build Coastguard Worker        input: A Tensor that is input to :attr:`functions`
530*da0073e9SAndroid Build Coastguard Worker        preserve_rng_state(bool, optional):  Omit stashing and restoring
531*da0073e9SAndroid Build Coastguard Worker            the RNG state during each checkpoint.
532*da0073e9SAndroid Build Coastguard Worker            Default: ``True``
533*da0073e9SAndroid Build Coastguard Worker        use_reentrant(bool):
534*da0073e9SAndroid Build Coastguard Worker            specify whether to use the activation checkpoint variant that
535*da0073e9SAndroid Build Coastguard Worker            requires reentrant autograd. This parameter should be passed
536*da0073e9SAndroid Build Coastguard Worker            explicitly. In version 2.5 we will raise an exception if
537*da0073e9SAndroid Build Coastguard Worker            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
538*da0073e9SAndroid Build Coastguard Worker            ``checkpoint`` will use an implementation that does not require
539*da0073e9SAndroid Build Coastguard Worker            reentrant autograd. This allows ``checkpoint`` to support additional
540*da0073e9SAndroid Build Coastguard Worker            functionality, such as working as expected with
541*da0073e9SAndroid Build Coastguard Worker            ``torch.autograd.grad`` and support for keyword arguments input into
542*da0073e9SAndroid Build Coastguard Worker            the checkpointed function.
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    Returns:
545*da0073e9SAndroid Build Coastguard Worker        Output of running :attr:`functions` sequentially on :attr:`*inputs`
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker    Example:
548*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("stub")
549*da0073e9SAndroid Build Coastguard Worker        >>> model = nn.Sequential(...)
550*da0073e9SAndroid Build Coastguard Worker        >>> input_var = checkpoint_sequential(model, chunks, input_var)
551*da0073e9SAndroid Build Coastguard Worker    """
552*da0073e9SAndroid Build Coastguard Worker    if use_reentrant is None:
553*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
554*da0073e9SAndroid Build Coastguard Worker            "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
555*da0073e9SAndroid Build Coastguard Worker            "parameter should be passed explicitly. "
556*da0073e9SAndroid Build Coastguard Worker            "In version 2.5 we will raise an exception if use_reentrant "
557*da0073e9SAndroid Build Coastguard Worker            "is not passed. use_reentrant=False is "
558*da0073e9SAndroid Build Coastguard Worker            "recommended, but if you need to preserve the current default "
559*da0073e9SAndroid Build Coastguard Worker            "behavior, you can pass use_reentrant=True. Refer to docs for more "
560*da0073e9SAndroid Build Coastguard Worker            "details on the differences between the two variants."
561*da0073e9SAndroid Build Coastguard Worker        )
562*da0073e9SAndroid Build Coastguard Worker        use_reentrant = True
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker    # Hack for keyword-only parameter in a python 2.7-compliant way
565*da0073e9SAndroid Build Coastguard Worker    preserve = kwargs.pop("preserve_rng_state", True)
566*da0073e9SAndroid Build Coastguard Worker    if kwargs:
567*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
568*da0073e9SAndroid Build Coastguard Worker            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
569*da0073e9SAndroid Build Coastguard Worker        )
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker    def run_function(start, end, functions):
572*da0073e9SAndroid Build Coastguard Worker        def forward(input):
573*da0073e9SAndroid Build Coastguard Worker            for j in range(start, end + 1):
574*da0073e9SAndroid Build Coastguard Worker                input = functions[j](input)
575*da0073e9SAndroid Build Coastguard Worker            return input
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        return forward
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    if isinstance(functions, torch.nn.Sequential):
580*da0073e9SAndroid Build Coastguard Worker        functions = list(functions.children())
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    segment_size = len(functions) // segments
583*da0073e9SAndroid Build Coastguard Worker    # the last chunk has to be non-volatile
584*da0073e9SAndroid Build Coastguard Worker    end = -1
585*da0073e9SAndroid Build Coastguard Worker    for start in range(0, segment_size * (segments - 1), segment_size):
586*da0073e9SAndroid Build Coastguard Worker        end = start + segment_size - 1
587*da0073e9SAndroid Build Coastguard Worker        input = checkpoint(
588*da0073e9SAndroid Build Coastguard Worker            run_function(start, end, functions),
589*da0073e9SAndroid Build Coastguard Worker            input,
590*da0073e9SAndroid Build Coastguard Worker            use_reentrant=use_reentrant,
591*da0073e9SAndroid Build Coastguard Worker            preserve_rng_state=preserve,
592*da0073e9SAndroid Build Coastguard Worker        )
593*da0073e9SAndroid Build Coastguard Worker    return run_function(end + 1, len(functions) - 1, functions)(input)
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Workerdef _internal_assert(cond):
597*da0073e9SAndroid Build Coastguard Worker    if not cond:
598*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(
599*da0073e9SAndroid Build Coastguard Worker            "Something went unexpectedly wrong in activation checkpoint. "
600*da0073e9SAndroid Build Coastguard Worker            "Please report this bug by filing an issue to PyTorch."
601*da0073e9SAndroid Build Coastguard Worker        )
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker# NOTE [ Nestable Checkpoint ]
605*da0073e9SAndroid Build Coastguard Worker#
606*da0073e9SAndroid Build Coastguard Worker# The semantics of nested checkpoint can be defined by two basic rules.
607*da0073e9SAndroid Build Coastguard Worker# Following the two rules leads to an important implication that is central
608*da0073e9SAndroid Build Coastguard Worker# to motivating the design.
609*da0073e9SAndroid Build Coastguard Worker#
610*da0073e9SAndroid Build Coastguard Worker# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
611*da0073e9SAndroid Build Coastguard Worker#         from any outer layers of checkpoint.
612*da0073e9SAndroid Build Coastguard Worker#
613*da0073e9SAndroid Build Coastguard Worker# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
614*da0073e9SAndroid Build Coastguard Worker#         parent checkpoint.
615*da0073e9SAndroid Build Coastguard Worker#
616*da0073e9SAndroid Build Coastguard Worker# Implication: To recompute any given saved tensor, we need to recompute all of
617*da0073e9SAndroid Build Coastguard Worker#              the checkpoints wrapping it.
618*da0073e9SAndroid Build Coastguard Worker#
619*da0073e9SAndroid Build Coastguard Worker# Why is this implied? To unpack a saved tensor X during backward we need to
620*da0073e9SAndroid Build Coastguard Worker# recompute the inner-most checkpoint (#1), and in order to recompute that
621*da0073e9SAndroid Build Coastguard Worker# checkpoint I need to have its inputs, which are managed by that checkpoint's
622*da0073e9SAndroid Build Coastguard Worker# parent (#2), which thus also needs to be recomputed first. Continue this line
623*da0073e9SAndroid Build Coastguard Worker# of reasoning and we realize that in order to unpack X, all checkpoints that
624*da0073e9SAndroid Build Coastguard Worker# were active at the time X was saved need to be recomputed. (unless we have
625*da0073e9SAndroid Build Coastguard Worker# already done so in that backward for some other saved tensor).
626*da0073e9SAndroid Build Coastguard Worker#
627*da0073e9SAndroid Build Coastguard Worker# In practice, we use a noop autograd Function to save inputs as saved tensors.
628*da0073e9SAndroid Build Coastguard Worker# During unpack calling ctx.saved_tensor triggers the parent checkpoint to
629*da0073e9SAndroid Build Coastguard Worker# recompute.
630*da0073e9SAndroid Build Coastguard Worker#
631*da0073e9SAndroid Build Coastguard Worker# Rule 3. We should start recomputation as if there are no checkpoints currently
632*da0073e9SAndroid Build Coastguard Worker#         active. Checkpoints encountered during recomputation are still
633*da0073e9SAndroid Build Coastguard Worker#         respected.
634*da0073e9SAndroid Build Coastguard Worker#
635*da0073e9SAndroid Build Coastguard Worker# When we start recomputation, we push the saved variable hook meant for
636*da0073e9SAndroid Build Coastguard Worker# recomputation on the stack. See examples in Rule 6 for more context.
637*da0073e9SAndroid Build Coastguard Worker#
638*da0073e9SAndroid Build Coastguard Worker#                                  * * * *
639*da0073e9SAndroid Build Coastguard Worker#
640*da0073e9SAndroid Build Coastguard Worker# Beyond the basic semantics specific to nested checkpoint, we impose several
641*da0073e9SAndroid Build Coastguard Worker# more constraints that may apply to checkpointing in general.
642*da0073e9SAndroid Build Coastguard Worker#
643*da0073e9SAndroid Build Coastguard Worker# Rule 4. Lifetime of recomputed tensors
644*da0073e9SAndroid Build Coastguard Worker#
645*da0073e9SAndroid Build Coastguard Worker#         Recomputed tensors are considered specific to particular invocations
646*da0073e9SAndroid Build Coastguard Worker#         of backward and are always cleared immediately as they are unpacked
647*da0073e9SAndroid Build Coastguard Worker#         Particularly, we require this to happen even if retain_graph=True.
648*da0073e9SAndroid Build Coastguard Worker#
649*da0073e9SAndroid Build Coastguard Worker# [ Implementation details of Rule 4 ]
650*da0073e9SAndroid Build Coastguard Worker#
651*da0073e9SAndroid Build Coastguard Worker# If we were okay with recomputed tensors staying alive after backward is run
652*da0073e9SAndroid Build Coastguard Worker# with retain_graph=True, we would store recomputed variables as the values of a
653*da0073e9SAndroid Build Coastguard Worker# WeakKeyDictionary and pack strong references to the keys, so that as we
654*da0073e9SAndroid Build Coastguard Worker# backward, those packed keys would be cleared as long as retain_graph=False.
655*da0073e9SAndroid Build Coastguard Worker# Clearing the packed key clears the corresponding entry in the WKD.
656*da0073e9SAndroid Build Coastguard Worker#
657*da0073e9SAndroid Build Coastguard Worker# If we wish recomputed variables to be immediately cleared as we unpack them in
658*da0073e9SAndroid Build Coastguard Worker# the retain_graph=True case, we cannot rely on the packed keys to be cleared by
659*da0073e9SAndroid Build Coastguard Worker# backward automatically. Instead of packing the strong reference to the key
660*da0073e9SAndroid Build Coastguard Worker# directly, we pack a container object, which we manually clear as we unpack.
661*da0073e9SAndroid Build Coastguard Worker#
662*da0073e9SAndroid Build Coastguard Worker# An important detail is that if a second backward happens, the second
663*da0073e9SAndroid Build Coastguard Worker# recomputation needs to reset the container with a newly created key.
664*da0073e9SAndroid Build Coastguard Worker#
665*da0073e9SAndroid Build Coastguard Worker# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
666*da0073e9SAndroid Build Coastguard Worker#         know we need.
667*da0073e9SAndroid Build Coastguard Worker#
668*da0073e9SAndroid Build Coastguard Worker# [ Implementation details of Rule 5 ]
669*da0073e9SAndroid Build Coastguard Worker#
670*da0073e9SAndroid Build Coastguard Worker# During recomputation, raise an exception if the number of recomputed tensors
671*da0073e9SAndroid Build Coastguard Worker# matches the number of tensors that we expected to recompute. We wrap the
672*da0073e9SAndroid Build Coastguard Worker# recomputation call with a try-catch to catch this specific exception. See
673*da0073e9SAndroid Build Coastguard Worker# Rule #6 below for some examples.
674*da0073e9SAndroid Build Coastguard Worker#
675*da0073e9SAndroid Build Coastguard Worker# Rule 6. We support doing backward inside checkpoint context
676*da0073e9SAndroid Build Coastguard Worker#
677*da0073e9SAndroid Build Coastguard Worker# [ retain_graph is True]
678*da0073e9SAndroid Build Coastguard Worker#
679*da0073e9SAndroid Build Coastguard Worker# def fn(x):
680*da0073e9SAndroid Build Coastguard Worker#   y = x.sin()
681*da0073e9SAndroid Build Coastguard Worker#   z = y.cos()
682*da0073e9SAndroid Build Coastguard Worker#   gx, = torch.autograd.grad(z, x, retains_grad=True)
683*da0073e9SAndroid Build Coastguard Worker#   return gx, z
684*da0073e9SAndroid Build Coastguard Worker#
685*da0073e9SAndroid Build Coastguard Worker# out = checkpoint(fn)(inp)
686*da0073e9SAndroid Build Coastguard Worker# out.backward()
687*da0073e9SAndroid Build Coastguard Worker#
688*da0073e9SAndroid Build Coastguard Worker# Because z is saved by cos while checkpoint is enabled, it would not be
689*da0073e9SAndroid Build Coastguard Worker# actually saved, and so the .grad() call inside must trigger a recomputation.
690*da0073e9SAndroid Build Coastguard Worker#
691*da0073e9SAndroid Build Coastguard Worker# During recomputation the "inner pack hook" has two responsibilities:
692*da0073e9SAndroid Build Coastguard Worker#
693*da0073e9SAndroid Build Coastguard Worker# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
694*da0073e9SAndroid Build Coastguard Worker# 2) Pack the actual tensor (detached) so that one may perform backward on the
695*da0073e9SAndroid Build Coastguard Worker#    recomputed graph. The tensors saved to this graph will live until the end
696*da0073e9SAndroid Build Coastguard Worker#    of recomputation, or die earlier if someone performs backward with
697*da0073e9SAndroid Build Coastguard Worker#    retain_graph=False.
698*da0073e9SAndroid Build Coastguard Worker#
699*da0073e9SAndroid Build Coastguard Worker# More generally performing backward on the recomputed graph occurs in the
700*da0073e9SAndroid Build Coastguard Worker# following cases:
701*da0073e9SAndroid Build Coastguard Worker# - If backward is performed inside forward,
702*da0073e9SAndroid Build Coastguard Worker#   - During the original forward IF early-stop is disabled
703*da0073e9SAndroid Build Coastguard Worker#   - During the original backward
704*da0073e9SAndroid Build Coastguard Worker# - If there are multiple .grad()/.backward() calls, we would perform backward
705*da0073e9SAndroid Build Coastguard Worker#   on the recomputed graph even if early-stop is enabled (see the example below)
706*da0073e9SAndroid Build Coastguard Worker#
707*da0073e9SAndroid Build Coastguard Worker# [ retain_graph is False ]
708*da0073e9SAndroid Build Coastguard Worker#
709*da0073e9SAndroid Build Coastguard Worker# The example below shows what happens if during recomputation we find that some
710*da0073e9SAndroid Build Coastguard Worker# of the tensors we are trying to recompute have already been cleared.
711*da0073e9SAndroid Build Coastguard Worker#
712*da0073e9SAndroid Build Coastguard Worker# Spoiler: we don't do anything special, we just skip over them!
713*da0073e9SAndroid Build Coastguard Worker#
714*da0073e9SAndroid Build Coastguard Worker# def fn(x):
715*da0073e9SAndroid Build Coastguard Worker#   y = x.sin()                           # (1)
716*da0073e9SAndroid Build Coastguard Worker#   z = y.cos()                           # (2)
717*da0073e9SAndroid Build Coastguard Worker#   gx, = torch.autograd.grad(z, x)       # (3)
718*da0073e9SAndroid Build Coastguard Worker#   return x.cos() * gx                   # (4)
719*da0073e9SAndroid Build Coastguard Worker#
720*da0073e9SAndroid Build Coastguard Worker# out = checkpoint(fn)(inp)
721*da0073e9SAndroid Build Coastguard Worker# out.backward()                          # (5)
722*da0073e9SAndroid Build Coastguard Worker#
723*da0073e9SAndroid Build Coastguard Worker# 1, 2. Don't save x and y since we are inside a checkpoint.
724*da0073e9SAndroid Build Coastguard Worker# 3. Trigger a recompute of fn since x and y weren't saved.
725*da0073e9SAndroid Build Coastguard Worker#    And depending on whether early stop is enabled, either stop at (2) or
726*da0073e9SAndroid Build Coastguard Worker#    continue running the function.
727*da0073e9SAndroid Build Coastguard Worker#    Because we are running backward with retain_graph=False, we clear x and y's
728*da0073e9SAndroid Build Coastguard Worker#    holders.
729*da0073e9SAndroid Build Coastguard Worker# 4. Don't save x since we are inside a checkpoint.
730*da0073e9SAndroid Build Coastguard Worker# 5. Calling backward triggers another recompute of fn. During recompute, we see
731*da0073e9SAndroid Build Coastguard Worker#    that x and y have already been cleared in the original graph as indicated
732*da0073e9SAndroid Build Coastguard Worker#    by holder=None. We skip over them. We still save x at (4) (since its holder
733*da0073e9SAndroid Build Coastguard Worker#    is still alive.)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker_enable_checkpoint_early_stop = True
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
739*da0073e9SAndroid Build Coastguard Workerdef set_checkpoint_early_stop(enable: bool):
740*da0073e9SAndroid Build Coastguard Worker    """Context manager that sets whether checkpoint should stop recomputation early.
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    By default, non-reentrant checkpoint stops recomputation as soon as it
743*da0073e9SAndroid Build Coastguard Worker    has computed all needed Tensors. This context manager can be used to disable
744*da0073e9SAndroid Build Coastguard Worker    that feature if it is problematic for your specific application.
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker    This context manager only needs to be active when forward is run. It does
747*da0073e9SAndroid Build Coastguard Worker    not need to be active during backward.
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker    Example::
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    >>> # xdoctest: +SKIP(failing)
752*da0073e9SAndroid Build Coastguard Worker    >>> message = "saved tensors default hooks are disabled"
753*da0073e9SAndroid Build Coastguard Worker    >>> with set_checkpoint_early_stop(False):
754*da0073e9SAndroid Build Coastguard Worker    ...     # Any checkpoint under this context manager will respect this
755*da0073e9SAndroid Build Coastguard Worker    ...     # context manager, even if its backward is performed outside.
756*da0073e9SAndroid Build Coastguard Worker    ...     out = checkpoint(fn, inputs)
757*da0073e9SAndroid Build Coastguard Worker    ...
758*da0073e9SAndroid Build Coastguard Worker    >>> out.backward()
759*da0073e9SAndroid Build Coastguard Worker    """
760*da0073e9SAndroid Build Coastguard Worker    global _enable_checkpoint_early_stop
761*da0073e9SAndroid Build Coastguard Worker    try:
762*da0073e9SAndroid Build Coastguard Worker        prev = _enable_checkpoint_early_stop
763*da0073e9SAndroid Build Coastguard Worker        _enable_checkpoint_early_stop = enable
764*da0073e9SAndroid Build Coastguard Worker        yield
765*da0073e9SAndroid Build Coastguard Worker    finally:
766*da0073e9SAndroid Build Coastguard Worker        _enable_checkpoint_early_stop = prev
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Workerclass _Handle:
770*da0073e9SAndroid Build Coastguard Worker    pass
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Workerclass _Holder:
774*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
775*da0073e9SAndroid Build Coastguard Worker        self.handles: Dict[int, Optional[_Handle]] = {}
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Workerclass _NoopSaveInputs(torch.autograd.Function):
779*da0073e9SAndroid Build Coastguard Worker    @staticmethod
780*da0073e9SAndroid Build Coastguard Worker    def forward(*args):
781*da0073e9SAndroid Build Coastguard Worker        return torch.empty((0,))
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    @staticmethod
784*da0073e9SAndroid Build Coastguard Worker    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
785*da0073e9SAndroid Build Coastguard Worker        # Only tensors can be saved with ctx.save_for_backward, everything else
786*da0073e9SAndroid Build Coastguard Worker        # is captured by get_args, which is saved directly on ctx
787*da0073e9SAndroid Build Coastguard Worker        tensor_indices, tensors = zip(
788*da0073e9SAndroid Build Coastguard Worker            *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
789*da0073e9SAndroid Build Coastguard Worker        )
790*da0073e9SAndroid Build Coastguard Worker        idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
791*da0073e9SAndroid Build Coastguard Worker        # args but with tensors replaced with None as placeholders
792*da0073e9SAndroid Build Coastguard Worker        args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker        def get_args(saved_tensors):
795*da0073e9SAndroid Build Coastguard Worker            # restore the placeholders with the original tensors grabbed from
796*da0073e9SAndroid Build Coastguard Worker            # ctx.saved_tensors (which may be saved on a parent checkpoint if
797*da0073e9SAndroid Build Coastguard Worker            # this checkpoint is nested, and that would trigger a recursive
798*da0073e9SAndroid Build Coastguard Worker            # unpack!)
799*da0073e9SAndroid Build Coastguard Worker            ret = [
800*da0073e9SAndroid Build Coastguard Worker                saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
801*da0073e9SAndroid Build Coastguard Worker                for i, o in enumerate(args)
802*da0073e9SAndroid Build Coastguard Worker            ]
803*da0073e9SAndroid Build Coastguard Worker            # grab the tail since we also saved the dummy to avoid having to explicitly
804*da0073e9SAndroid Build Coastguard Worker            # handle the case where there are no tensor inputs
805*da0073e9SAndroid Build Coastguard Worker            return ret[1:]
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        ctx.get_args = get_args
808*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(*tensors)
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    @staticmethod
811*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, *grad_outputs):
812*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("Did not expect to backward on this graph")
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Workerclass _CheckpointFrame:
816*da0073e9SAndroid Build Coastguard Worker    def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
817*da0073e9SAndroid Build Coastguard Worker        self.recompute_fn = recompute_fn
818*da0073e9SAndroid Build Coastguard Worker        self.input_saver = None
819*da0073e9SAndroid Build Coastguard Worker        self.weak_holders: List[ReferenceType] = []
820*da0073e9SAndroid Build Coastguard Worker        # We store this as a weakkeydictionary so that in the case of a partial
821*da0073e9SAndroid Build Coastguard Worker        # backward, the entries in the dict are cleared alongside the Holder
822*da0073e9SAndroid Build Coastguard Worker        # which will be removed when the SavedVariable is cleared.
823*da0073e9SAndroid Build Coastguard Worker        self.recomputed: DefaultDict[
824*da0073e9SAndroid Build Coastguard Worker            int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
825*da0073e9SAndroid Build Coastguard Worker        ] = defaultdict(weakref.WeakKeyDictionary)
826*da0073e9SAndroid Build Coastguard Worker        # We need both recomp_counter and recomputed since they can diverge
827*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
828*da0073e9SAndroid Build Coastguard Worker        self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
829*da0073e9SAndroid Build Coastguard Worker        self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        # See Rule 5
832*da0073e9SAndroid Build Coastguard Worker        self.early_stop = early_stop
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker        # Debugging
835*da0073e9SAndroid Build Coastguard Worker        self.metadata_fn = metadata_fn
836*da0073e9SAndroid Build Coastguard Worker        self.unpack_error_cb = unpack_error_cb
837*da0073e9SAndroid Build Coastguard Worker        self.x_metadatas = []
838*da0073e9SAndroid Build Coastguard Worker        self.forward_completed = False
839*da0073e9SAndroid Build Coastguard Worker        self.ignore_saved_mismatch = False
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    def check_recomputed_tensors_match(self, gid):
842*da0073e9SAndroid Build Coastguard Worker        if self.ignore_saved_mismatch:
843*da0073e9SAndroid Build Coastguard Worker            # TODO: we can probably make this check stricter by checking that
844*da0073e9SAndroid Build Coastguard Worker            #       the metadata of the first tensors still match.
845*da0073e9SAndroid Build Coastguard Worker            return
846*da0073e9SAndroid Build Coastguard Worker        # NOTE [ Error handling for checkpoint ]
847*da0073e9SAndroid Build Coastguard Worker        #
848*da0073e9SAndroid Build Coastguard Worker        # At a high level, we need to check that the tensors saved
849*da0073e9SAndroid Build Coastguard Worker        # during original forward matches tensors saved during recompute
850*da0073e9SAndroid Build Coastguard Worker        # This means handling 3 cases:
851*da0073e9SAndroid Build Coastguard Worker        #
852*da0073e9SAndroid Build Coastguard Worker        # 1. During recompute, more tensors were saved.
853*da0073e9SAndroid Build Coastguard Worker        #
854*da0073e9SAndroid Build Coastguard Worker        #    Usually this is hidden due to the StopRecomputationError
855*da0073e9SAndroid Build Coastguard Worker        #    but if early stop is not enabled, or we would have errored
856*da0073e9SAndroid Build Coastguard Worker        #    anyway because there aren't enough weak_holders. But we
857*da0073e9SAndroid Build Coastguard Worker        #    do want to have a nice error. See the _recomputation_hook
858*da0073e9SAndroid Build Coastguard Worker        #    for details.
859*da0073e9SAndroid Build Coastguard Worker        if not len(self.weak_holders) == self.recomp_counter[gid]:
860*da0073e9SAndroid Build Coastguard Worker            # 2. During recompute, fewer tensors were saved
861*da0073e9SAndroid Build Coastguard Worker            #
862*da0073e9SAndroid Build Coastguard Worker            # We know that everytime we save something do original forward
863*da0073e9SAndroid Build Coastguard Worker            # we append to weak_holder, and every time we save a tensor
864*da0073e9SAndroid Build Coastguard Worker            # during recompute we increment recompute_counter.
865*da0073e9SAndroid Build Coastguard Worker            raise CheckpointError(
866*da0073e9SAndroid Build Coastguard Worker                "torch.utils.checkpoint: A different number of tensors was saved "
867*da0073e9SAndroid Build Coastguard Worker                "during the original forward and recomputation.\n"
868*da0073e9SAndroid Build Coastguard Worker                f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
869*da0073e9SAndroid Build Coastguard Worker                f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
870*da0073e9SAndroid Build Coastguard Worker            )
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker        # 3. During recompute, the same tensors were saved, but they
873*da0073e9SAndroid Build Coastguard Worker        #    have different metadata
874*da0073e9SAndroid Build Coastguard Worker        nb_meta_different = []
875*da0073e9SAndroid Build Coastguard Worker        for idx, weak_holder in enumerate(self.weak_holders):
876*da0073e9SAndroid Build Coastguard Worker            holder = weak_holder()
877*da0073e9SAndroid Build Coastguard Worker            if holder is None:
878*da0073e9SAndroid Build Coastguard Worker                continue
879*da0073e9SAndroid Build Coastguard Worker            # We've seen all holders since we iterate over them in order
880*da0073e9SAndroid Build Coastguard Worker            # For every holder that is still alive now, it must've been
881*da0073e9SAndroid Build Coastguard Worker            # alive when we saw it during recompute, therefore, the
882*da0073e9SAndroid Build Coastguard Worker            # gid must be set.
883*da0073e9SAndroid Build Coastguard Worker            _internal_assert(gid in holder.handles)
884*da0073e9SAndroid Build Coastguard Worker            # We know this is the first unpack, so it couldn't have been set
885*da0073e9SAndroid Build Coastguard Worker            # to None yet.
886*da0073e9SAndroid Build Coastguard Worker            _internal_assert(holder.handles[gid] is not None)
887*da0073e9SAndroid Build Coastguard Worker            # We always set these together in the recomputation hook
888*da0073e9SAndroid Build Coastguard Worker            _internal_assert(holder.handles[gid] in self.recomputed[gid])
889*da0073e9SAndroid Build Coastguard Worker            # see pack hook, x_metadata is 1:1 with weak_holders.
890*da0073e9SAndroid Build Coastguard Worker            x_meta = self.x_metadatas[idx]
891*da0073e9SAndroid Build Coastguard Worker            recomputed_x = self.recomputed[gid][holder.handles[gid]]
892*da0073e9SAndroid Build Coastguard Worker            if x_meta != self.metadata_fn(recomputed_x):
893*da0073e9SAndroid Build Coastguard Worker                nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker        if len(nb_meta_different) > 0:
896*da0073e9SAndroid Build Coastguard Worker            mismatched_tensors = ""
897*da0073e9SAndroid Build Coastguard Worker            for idx, x_meta, recomputed_meta in nb_meta_different:
898*da0073e9SAndroid Build Coastguard Worker                mismatched_tensors += (
899*da0073e9SAndroid Build Coastguard Worker                    f"tensor at position {idx}:\n"
900*da0073e9SAndroid Build Coastguard Worker                    f"saved metadata: {x_meta}\n"
901*da0073e9SAndroid Build Coastguard Worker                    f"recomputed metadata: {recomputed_meta}\n"
902*da0073e9SAndroid Build Coastguard Worker                )
903*da0073e9SAndroid Build Coastguard Worker            raise CheckpointError(
904*da0073e9SAndroid Build Coastguard Worker                "torch.utils.checkpoint: Recomputed values for the following tensors "
905*da0073e9SAndroid Build Coastguard Worker                "have different metadata than during the forward pass.\n"
906*da0073e9SAndroid Build Coastguard Worker                f"{mismatched_tensors}"
907*da0073e9SAndroid Build Coastguard Worker            )
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker_checkpoint_error_template = """ \
911*da0073e9SAndroid Build Coastguard WorkerAn error happened while unpacking tensors; dumping logs of latest computation
912*da0073e9SAndroid Build Coastguard Workerbecause you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
913*da0073e9SAndroid Build Coastguard WorkerScroll all the way down for guidance on how to navigate these logs.
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
916*da0073e9SAndroid Build Coastguard Worker|        1. Stack traces of the operators that ran in the original forward     |
917*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker{forward_traces}
920*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
921*da0073e9SAndroid Build Coastguard Worker|        2. Stack traces of the operators that ran during recomputation        |
922*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker{recompute_traces}
925*da0073e9SAndroid Build Coastguard Worker+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
926*da0073e9SAndroid Build Coastguard Worker|       3. Log of operators in the original forward and recomputation          |
927*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+
928*da0073e9SAndroid Build Coastguard Worker(Scroll up to correlate stack traces with each operation listed below. This
929*da0073e9SAndroid Build Coastguard Worker helps identify their source in the code.)
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard WorkerIMPORTANT: Differences in "detach" calls between the original forward and the
932*da0073e9SAndroid Build Coastguard Worker           recomputation are expected. They are introduced by the checkpointing
933*da0073e9SAndroid Build Coastguard Worker           mechanism and can be ignored.
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard WorkerOperations executed during the original forward:
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker{forward_ops}
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard WorkerOperations executed during recomputation:
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker{recompute_ops}
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker+------------------------------------------------------------------------------+
944*da0073e9SAndroid Build Coastguard Worker ERROR: Detected non-determinism while running activation checkpointing
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker You are seeing this error because you passed `debug=True` to checkpoint and
947*da0073e9SAndroid Build Coastguard Worker tensors to be saved during the original forward and differ between those saved
948*da0073e9SAndroid Build Coastguard Worker during recomputation. This can happen if different operators were ran in the
949*da0073e9SAndroid Build Coastguard Worker original forward and in the recomputation.
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker To identify where the mismatch may be coming from, you can do the following:
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker 1) Compare the operators ran during original forward and recomputation to
954*da0073e9SAndroid Build Coastguard Worker    see where they differ. These operators are printed above in the order they
955*da0073e9SAndroid Build Coastguard Worker    were executed.
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker 2) Review the stack trace for each operator to locate its invocation source.
958*da0073e9SAndroid Build Coastguard Worker    Each operator's stack trace is printed in their execution order.
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker Note that the logs can be quite long. Here's how they are structured:
961*da0073e9SAndroid Build Coastguard Worker (Tip: you can Ctrl-f for these headers)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker 1. Stack traces of the operators that ran in the original forward
964*da0073e9SAndroid Build Coastguard Worker 2. Stack traces of the operators that ran during recomputation
965*da0073e9SAndroid Build Coastguard Worker 3. Log of operators in the original forward and recomputation
966*da0073e9SAndroid Build Coastguard Worker 4. Error message                                             <--- You are here
967*da0073e9SAndroid Build Coastguard Worker--------------------------------------------------------------------------------
968*da0073e9SAndroid Build Coastguard Worker"""
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Workerclass CheckpointError(RuntimeError):
971*da0073e9SAndroid Build Coastguard Worker    pass
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Workerdef _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
975*da0073e9SAndroid Build Coastguard Worker    # This function returns the context_fn and error_cb to be used by the
976*da0073e9SAndroid Build Coastguard Worker    # checkpointing mechanism. error_cb is invoked when an error is detected
977*da0073e9SAndroid Build Coastguard Worker    # during unpack.
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker    # record_context_cpp is not support on non-linux non-x86_64 platforms
980*da0073e9SAndroid Build Coastguard Worker    cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    class CaptureLogs:
983*da0073e9SAndroid Build Coastguard Worker        def __init__(self):
984*da0073e9SAndroid Build Coastguard Worker            self.logs = None
985*da0073e9SAndroid Build Coastguard Worker            self.tbs = None
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker        def get_context_manager(self):
988*da0073e9SAndroid Build Coastguard Worker            @contextlib.contextmanager
989*da0073e9SAndroid Build Coastguard Worker            def logging_mode():
990*da0073e9SAndroid Build Coastguard Worker                with LoggingTensorMode(), \
991*da0073e9SAndroid Build Coastguard Worker                     capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
992*da0073e9SAndroid Build Coastguard Worker                    self.logs, self.tbs = logs_and_tb
993*da0073e9SAndroid Build Coastguard Worker                    yield logs_and_tb
994*da0073e9SAndroid Build Coastguard Worker            return logging_mode()
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    capture_logs_fwd = CaptureLogs()
997*da0073e9SAndroid Build Coastguard Worker    capture_logs_recompute = CaptureLogs()
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker    def unpack_error_cb(e: CheckpointError):
1000*da0073e9SAndroid Build Coastguard Worker        def get_str_tb(label, capture_logs):
1001*da0073e9SAndroid Build Coastguard Worker            out = ""
1002*da0073e9SAndroid Build Coastguard Worker            total_len = len(capture_logs.logs)
1003*da0073e9SAndroid Build Coastguard Worker            for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
1004*da0073e9SAndroid Build Coastguard Worker                out += f"{log}   ({i + 1} of {total_len} in {label})\n\n"
1005*da0073e9SAndroid Build Coastguard Worker                found_torch_dispatch = False
1006*da0073e9SAndroid Build Coastguard Worker                for line in tb:
1007*da0073e9SAndroid Build Coastguard Worker                    # Start printing stack trace only after __torch_dispatch__ is found
1008*da0073e9SAndroid Build Coastguard Worker                    is_torch_dispatch = line['name'] == '__torch_dispatch__'
1009*da0073e9SAndroid Build Coastguard Worker                    if not found_torch_dispatch and not is_torch_dispatch:
1010*da0073e9SAndroid Build Coastguard Worker                        continue
1011*da0073e9SAndroid Build Coastguard Worker                    elif is_torch_dispatch:
1012*da0073e9SAndroid Build Coastguard Worker                        found_torch_dispatch = True
1013*da0073e9SAndroid Build Coastguard Worker                        continue
1014*da0073e9SAndroid Build Coastguard Worker                    out += f"{line['filename']}:{line['line']}:{line['name']}\n"
1015*da0073e9SAndroid Build Coastguard Worker                out += "\n\n"
1016*da0073e9SAndroid Build Coastguard Worker            return out
1017*da0073e9SAndroid Build Coastguard Worker        assert capture_logs_fwd.logs is not None
1018*da0073e9SAndroid Build Coastguard Worker        assert capture_logs_recompute.logs is not None
1019*da0073e9SAndroid Build Coastguard Worker        raise CheckpointError(
1020*da0073e9SAndroid Build Coastguard Worker            _checkpoint_error_template.format(
1021*da0073e9SAndroid Build Coastguard Worker                forward_traces=get_str_tb("original", capture_logs_fwd),
1022*da0073e9SAndroid Build Coastguard Worker                recompute_traces=get_str_tb("recompute", capture_logs_recompute),
1023*da0073e9SAndroid Build Coastguard Worker                forward_ops="\n".join(capture_logs_fwd.logs),
1024*da0073e9SAndroid Build Coastguard Worker                recompute_ops="\n".join(capture_logs_recompute.logs)
1025*da0073e9SAndroid Build Coastguard Worker            )
1026*da0073e9SAndroid Build Coastguard Worker        ) from e
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker    def context_fn():
1029*da0073e9SAndroid Build Coastguard Worker        return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    return context_fn, unpack_error_cb
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Workerdef _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
1034*da0073e9SAndroid Build Coastguard Worker    # These properties are fast to check, easy to understand
1035*da0073e9SAndroid Build Coastguard Worker    return {
1036*da0073e9SAndroid Build Coastguard Worker        "shape": x.shape,
1037*da0073e9SAndroid Build Coastguard Worker        "dtype": x.dtype,
1038*da0073e9SAndroid Build Coastguard Worker        "device": x.device
1039*da0073e9SAndroid Build Coastguard Worker    }
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
1042*da0073e9SAndroid Build Coastguard Worker    _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
1043*da0073e9SAndroid Build Coastguard Worker    "none": lambda _: None,
1044*da0073e9SAndroid Build Coastguard Worker}
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker# See Rule 5
1047*da0073e9SAndroid Build Coastguard Workerclass _StopRecomputationError(Exception):
1048*da0073e9SAndroid Build Coastguard Worker    pass
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Workerclass _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
1052*da0073e9SAndroid Build Coastguard Worker    def __init__(self, target_frame_ref: ReferenceType, gid: int):
1053*da0073e9SAndroid Build Coastguard Worker        def pack_hook(x):
1054*da0073e9SAndroid Build Coastguard Worker            x = x.detach() if x.requires_grad else x
1055*da0073e9SAndroid Build Coastguard Worker            target_frame = target_frame_ref()
1056*da0073e9SAndroid Build Coastguard Worker            assert target_frame is not None  # appease mypy
1057*da0073e9SAndroid Build Coastguard Worker            recomp_idx = target_frame.recomp_counter[gid]
1058*da0073e9SAndroid Build Coastguard Worker            target_frame.recomp_counter[gid] += 1
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker            if recomp_idx >= len(target_frame.weak_holders):
1061*da0073e9SAndroid Build Coastguard Worker                assert not target_frame.early_stop
1062*da0073e9SAndroid Build Coastguard Worker                if not target_frame.forward_completed:
1063*da0073e9SAndroid Build Coastguard Worker                    # We run into this case when early stop is not enabled and do
1064*da0073e9SAndroid Build Coastguard Worker                    # grad within checkpoint.
1065*da0073e9SAndroid Build Coastguard Worker                    # We need to set this flag, so we don't error out later when
1066*da0073e9SAndroid Build Coastguard Worker                    # we check if the number of tensors saved during forward and
1067*da0073e9SAndroid Build Coastguard Worker                    # recomputation match.
1068*da0073e9SAndroid Build Coastguard Worker                    target_frame.ignore_saved_mismatch = True
1069*da0073e9SAndroid Build Coastguard Worker                    return x
1070*da0073e9SAndroid Build Coastguard Worker                raise CheckpointError(
1071*da0073e9SAndroid Build Coastguard Worker                    "torch.utils.checkpoint: trying to save more tensors during "
1072*da0073e9SAndroid Build Coastguard Worker                    "recomputation than during the original forward pass."
1073*da0073e9SAndroid Build Coastguard Worker                )
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker            holder = target_frame.weak_holders[recomp_idx]()
1076*da0073e9SAndroid Build Coastguard Worker
1077*da0073e9SAndroid Build Coastguard Worker            # This holder may have been cleared because someone may have called
1078*da0073e9SAndroid Build Coastguard Worker            # backward within forward. If so, we don't need to save.
1079*da0073e9SAndroid Build Coastguard Worker            if holder is not None:
1080*da0073e9SAndroid Build Coastguard Worker                _internal_assert(holder.handles.get(gid, None) is None)
1081*da0073e9SAndroid Build Coastguard Worker                holder.handles[gid] = _Handle()
1082*da0073e9SAndroid Build Coastguard Worker                target_frame.recomputed[gid][holder.handles[gid]] = x
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker            if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
1085*da0073e9SAndroid Build Coastguard Worker                target_frame.weak_holders
1086*da0073e9SAndroid Build Coastguard Worker            ):
1087*da0073e9SAndroid Build Coastguard Worker                raise _StopRecomputationError
1088*da0073e9SAndroid Build Coastguard Worker            # See Rule 6: [ retain_graph is True ] above
1089*da0073e9SAndroid Build Coastguard Worker            return x
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(x):
1092*da0073e9SAndroid Build Coastguard Worker            # See Rule 6: [ retain_graph is True ] above for an example of when
1093*da0073e9SAndroid Build Coastguard Worker            # the graph created during recomputation could be backwarded.
1094*da0073e9SAndroid Build Coastguard Worker            return x
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        super().__init__(pack_hook, unpack_hook)
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Workerclass _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
1100*da0073e9SAndroid Build Coastguard Worker    def __init__(self, frame):
1101*da0073e9SAndroid Build Coastguard Worker        def pack_hook(x):
1102*da0073e9SAndroid Build Coastguard Worker            # See Rule 4 above
1103*da0073e9SAndroid Build Coastguard Worker            holder = _Holder()
1104*da0073e9SAndroid Build Coastguard Worker            frame.weak_holders.append(weakref.ref(holder))
1105*da0073e9SAndroid Build Coastguard Worker            # Save metadata to detect non-determinism
1106*da0073e9SAndroid Build Coastguard Worker            if frame.metadata_fn is not None:
1107*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
1108*da0073e9SAndroid Build Coastguard Worker                    frame.x_metadatas.append(frame.metadata_fn(x))
1109*da0073e9SAndroid Build Coastguard Worker            return holder
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(holder):
1112*da0073e9SAndroid Build Coastguard Worker            gid = torch._C._current_graph_task_id()
1113*da0073e9SAndroid Build Coastguard Worker            if gid == -1:
1114*da0073e9SAndroid Build Coastguard Worker                # generate a temporary id if we trigger unpack outside of a backward call
1115*da0073e9SAndroid Build Coastguard Worker                gid = int(uuid.uuid4())
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker            if not frame.is_recomputed[gid]:
1118*da0073e9SAndroid Build Coastguard Worker                ctx = frame.input_saver.grad_fn
1119*da0073e9SAndroid Build Coastguard Worker                args = ctx.get_args(ctx.saved_tensors)
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker                try:
1122*da0073e9SAndroid Build Coastguard Worker                    with _recomputation_hook(
1123*da0073e9SAndroid Build Coastguard Worker                        weakref.ref(frame), gid
1124*da0073e9SAndroid Build Coastguard Worker                    ), torch.autograd.enable_grad():
1125*da0073e9SAndroid Build Coastguard Worker                        frame.recompute_fn(*args)
1126*da0073e9SAndroid Build Coastguard Worker                except _StopRecomputationError:
1127*da0073e9SAndroid Build Coastguard Worker                    pass
1128*da0073e9SAndroid Build Coastguard Worker                frame.is_recomputed[gid] = True
1129*da0073e9SAndroid Build Coastguard Worker                frame.check_recomputed_tensors_match(gid)
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker            _internal_assert(gid in holder.handles)
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Worker            if holder.handles[gid] is None:
1134*da0073e9SAndroid Build Coastguard Worker                raise CheckpointError(
1135*da0073e9SAndroid Build Coastguard Worker                    "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
1136*da0073e9SAndroid Build Coastguard Worker                    "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
1137*da0073e9SAndroid Build Coastguard Worker                    "so only once. Otherwise please open an issue with details on your use case."
1138*da0073e9SAndroid Build Coastguard Worker                )
1139*da0073e9SAndroid Build Coastguard Worker            _internal_assert(holder.handles[gid] in frame.recomputed[gid])
1140*da0073e9SAndroid Build Coastguard Worker            ret = frame.recomputed[gid][holder.handles[gid]]
1141*da0073e9SAndroid Build Coastguard Worker            holder.handles[gid] = None
1142*da0073e9SAndroid Build Coastguard Worker            return ret
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        if frame.unpack_error_cb is not None:
1145*da0073e9SAndroid Build Coastguard Worker            def unpack_hook_with_error_cb(holder):
1146*da0073e9SAndroid Build Coastguard Worker                try:
1147*da0073e9SAndroid Build Coastguard Worker                    return unpack_hook(holder)
1148*da0073e9SAndroid Build Coastguard Worker                except CheckpointError as e:
1149*da0073e9SAndroid Build Coastguard Worker                    frame.unpack_error_cb(e)
1150*da0073e9SAndroid Build Coastguard Worker            super().__init__(pack_hook, unpack_hook_with_error_cb)
1151*da0073e9SAndroid Build Coastguard Worker        else:
1152*da0073e9SAndroid Build Coastguard Worker            super().__init__(pack_hook, unpack_hook)
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Workerdef _is_compiling(func, args, kwargs):
1156*da0073e9SAndroid Build Coastguard Worker    # Check if we are under AOTAutograd tracing
1157*da0073e9SAndroid Build Coastguard Worker    # There should probably be a better way to do this...
1158*da0073e9SAndroid Build Coastguard Worker    # TODO: unify _is_compiling across all compile stacks
1159*da0073e9SAndroid Build Coastguard Worker    for arg in args:
1160*da0073e9SAndroid Build Coastguard Worker        if isinstance(arg, torch.Tensor) and is_fun(arg):
1161*da0073e9SAndroid Build Coastguard Worker            return True
1162*da0073e9SAndroid Build Coastguard Worker    return False
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Workerclass _VersionWrapper:
1166*da0073e9SAndroid Build Coastguard Worker    # Check that cached tensors are not mutated.
1167*da0073e9SAndroid Build Coastguard Worker    def __init__(self, val):
1168*da0073e9SAndroid Build Coastguard Worker        self.val: Union[torch.Tensor, Any] = val
1169*da0073e9SAndroid Build Coastguard Worker        self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker    def get_val(self, allow_cache_entry_mutation):
1172*da0073e9SAndroid Build Coastguard Worker        if self.version is not None and not allow_cache_entry_mutation:
1173*da0073e9SAndroid Build Coastguard Worker            if self.val._version != self.version:
1174*da0073e9SAndroid Build Coastguard Worker                # Can we give user a stack trace of where the mutation happened?
1175*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1176*da0073e9SAndroid Build Coastguard Worker                    "Tensor cached during selective activation checkpoint has been mutated"
1177*da0073e9SAndroid Build Coastguard Worker                )
1178*da0073e9SAndroid Build Coastguard Worker        return self.val
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Workerdef _maybe_detach(x, any_ret_has_alias_info):
1182*da0073e9SAndroid Build Coastguard Worker    # We detach for two separate reasons:
1183*da0073e9SAndroid Build Coastguard Worker    # - For view ops, we need to ensure that when the tensor is returned from
1184*da0073e9SAndroid Build Coastguard Worker    #   CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
1185*da0073e9SAndroid Build Coastguard Worker    # - Avoid reference cycles
1186*da0073e9SAndroid Build Coastguard Worker    # For case 1, it is not enough to check whether x has differentiable dtype
1187*da0073e9SAndroid Build Coastguard Worker    # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
1188*da0073e9SAndroid Build Coastguard Worker    # when the tensor is a view.
1189*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
1190*da0073e9SAndroid Build Coastguard Worker        with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
1191*da0073e9SAndroid Build Coastguard Worker            # Ensure that view performed beneath autograd properly propagates
1192*da0073e9SAndroid Build Coastguard Worker            # version counter. TODO: Use reentrant_dispatch instead of
1193*da0073e9SAndroid Build Coastguard Worker            # manually manipulating dispatch keys. Using reentrant_dispatch
1194*da0073e9SAndroid Build Coastguard Worker            # would respect inference_mode, though that is not relevant for
1195*da0073e9SAndroid Build Coastguard Worker            # this case.
1196*da0073e9SAndroid Build Coastguard Worker            x = x.detach()
1197*da0073e9SAndroid Build Coastguard Worker    return x
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Workerclass SelectiveCheckpointContext:
1201*da0073e9SAndroid Build Coastguard Worker    """
1202*da0073e9SAndroid Build Coastguard Worker    Context passed to policy function during selective checkpointing.
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    This class is used to pass relevant metadata to the policy function during
1205*da0073e9SAndroid Build Coastguard Worker    selective checkpointing. The metadata includes whether the current invocation
1206*da0073e9SAndroid Build Coastguard Worker    of the policy function is during recomputation or not.
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker    Example:
1209*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP(stub)
1210*da0073e9SAndroid Build Coastguard Worker        >>>
1211*da0073e9SAndroid Build Coastguard Worker        >>> def policy_fn(ctx, op, *args, **kwargs):
1212*da0073e9SAndroid Build Coastguard Worker        >>>    print(ctx.is_recompute)
1213*da0073e9SAndroid Build Coastguard Worker        >>>
1214*da0073e9SAndroid Build Coastguard Worker        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
1215*da0073e9SAndroid Build Coastguard Worker        >>>
1216*da0073e9SAndroid Build Coastguard Worker        >>> out = torch.utils.checkpoint.checkpoint(
1217*da0073e9SAndroid Build Coastguard Worker        >>>     fn, x, y,
1218*da0073e9SAndroid Build Coastguard Worker        >>>     use_reentrant=False,
1219*da0073e9SAndroid Build Coastguard Worker        >>>     context_fn=context_fn,
1220*da0073e9SAndroid Build Coastguard Worker        >>> )
1221*da0073e9SAndroid Build Coastguard Worker    """
1222*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *, is_recompute):
1223*da0073e9SAndroid Build Coastguard Worker        self.is_recompute = is_recompute
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Workerclass CheckpointPolicy(enum.Enum):
1227*da0073e9SAndroid Build Coastguard Worker    """
1228*da0073e9SAndroid Build Coastguard Worker    Enum for specifying the policy for checkpointing during backpropagation.
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    The following policies are supported:
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward
1233*da0073e9SAndroid Build Coastguard Worker      pass and will not be recomputed during the backward pass
1234*da0073e9SAndroid Build Coastguard Worker    - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the
1235*da0073e9SAndroid Build Coastguard Worker      forward pass and will be recomputed during the backward pass
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker    Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden
1238*da0073e9SAndroid Build Coastguard Worker    by other subsystems like `torch.compile`.
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker    .. note::
1241*da0073e9SAndroid Build Coastguard Worker        A policy function that always returns ``PREFER_RECOMPUTE`` is
1242*da0073e9SAndroid Build Coastguard Worker        equivalent to vanilla checkpointing.
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker        A policy function that returns ``PREFER_SAVE`` every op is
1245*da0073e9SAndroid Build Coastguard Worker        NOT equivalent to not using checkpointing. Using such a policy would
1246*da0073e9SAndroid Build Coastguard Worker        save additional tensors not limited to ones that are actually needed for
1247*da0073e9SAndroid Build Coastguard Worker        gradient computation.
1248*da0073e9SAndroid Build Coastguard Worker    """
1249*da0073e9SAndroid Build Coastguard Worker    MUST_SAVE = 0
1250*da0073e9SAndroid Build Coastguard Worker    PREFER_SAVE = 1
1251*da0073e9SAndroid Build Coastguard Worker    MUST_RECOMPUTE = 2
1252*da0073e9SAndroid Build Coastguard Worker    PREFER_RECOMPUTE = 3
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Workerdef _policy_from_bool(b):
1256*da0073e9SAndroid Build Coastguard Worker    # For backward compatability
1257*da0073e9SAndroid Build Coastguard Worker    return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard WorkerSAC_IGNORED_OPS = {
1261*da0073e9SAndroid Build Coastguard Worker    # AC inserts different number of detach during forward and recompute.
1262*da0073e9SAndroid Build Coastguard Worker    torch.ops.aten.detach.default,
1263*da0073e9SAndroid Build Coastguard Worker    # AC's determinism check invokes additional metadata ops during forward.
1264*da0073e9SAndroid Build Coastguard Worker    # With subclasses involved, these metadata ops become dispatchable, this
1265*da0073e9SAndroid Build Coastguard Worker    # can result in incorrectness if these ops are selected cached.
1266*da0073e9SAndroid Build Coastguard Worker    torch.ops.prim.device.default,
1267*da0073e9SAndroid Build Coastguard Worker} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Workerclass _CachingTorchDispatchMode(TorchDispatchMode):
1271*da0073e9SAndroid Build Coastguard Worker    # Used together with _CachedTorchDispatchMode to implement SAC.
1272*da0073e9SAndroid Build Coastguard Worker    def __init__(self, policy_fn, storage):
1273*da0073e9SAndroid Build Coastguard Worker        self.policy_fn = policy_fn
1274*da0073e9SAndroid Build Coastguard Worker        self.storage = storage
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1277*da0073e9SAndroid Build Coastguard Worker        if func in SAC_IGNORED_OPS:
1278*da0073e9SAndroid Build Coastguard Worker            return func(*args, **kwargs)
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker        kwargs = {} if kwargs is None else kwargs
1281*da0073e9SAndroid Build Coastguard Worker        policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
1282*da0073e9SAndroid Build Coastguard Worker                                func, *args, **kwargs)
1283*da0073e9SAndroid Build Coastguard Worker        if isinstance(policy, bool):
1284*da0073e9SAndroid Build Coastguard Worker            policy = _policy_from_bool(policy)
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker        is_compiling = _is_compiling(func, args, kwargs)
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        if is_compiling:
1289*da0073e9SAndroid Build Coastguard Worker            # Overwrite each node's "recompute" tag to add in the user annotation.
1290*da0073e9SAndroid Build Coastguard Worker            fx_traceback.current_meta["recompute"] = policy
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker        out = func(*args, **kwargs)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
1297*da0073e9SAndroid Build Coastguard Worker            self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
1298*da0073e9SAndroid Build Coastguard Worker        return out
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Workerclass _CachedTorchDispatchMode(TorchDispatchMode):
1301*da0073e9SAndroid Build Coastguard Worker    # Used together with _CachedTorchDispatchMode to implement SAC.
1302*da0073e9SAndroid Build Coastguard Worker    def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
1303*da0073e9SAndroid Build Coastguard Worker        self.policy_fn = policy_fn
1304*da0073e9SAndroid Build Coastguard Worker        self.storage = storage
1305*da0073e9SAndroid Build Coastguard Worker        self.allow_cache_entry_mutation = allow_cache_entry_mutation
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1308*da0073e9SAndroid Build Coastguard Worker        if func in SAC_IGNORED_OPS:
1309*da0073e9SAndroid Build Coastguard Worker            return func(*args, **kwargs)
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker        kwargs = {} if kwargs is None else kwargs
1312*da0073e9SAndroid Build Coastguard Worker        policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
1313*da0073e9SAndroid Build Coastguard Worker                                func, *args, **kwargs)
1314*da0073e9SAndroid Build Coastguard Worker        if isinstance(policy, bool):
1315*da0073e9SAndroid Build Coastguard Worker            policy = _policy_from_bool(policy)
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        is_compiling = _is_compiling(func, args, kwargs)
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
1320*da0073e9SAndroid Build Coastguard Worker            storage = self.storage.get(func)
1321*da0073e9SAndroid Build Coastguard Worker            if storage is None:
1322*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
1323*da0073e9SAndroid Build Coastguard Worker            if len(storage) == 0:
1324*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1325*da0073e9SAndroid Build Coastguard Worker                    "Trying to backward an extra time. You are only allowed to backward once "
1326*da0073e9SAndroid Build Coastguard Worker                    "on any region computed under selective activation checkpoint."
1327*da0073e9SAndroid Build Coastguard Worker                )
1328*da0073e9SAndroid Build Coastguard Worker            out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
1329*da0073e9SAndroid Build Coastguard Worker        else:
1330*da0073e9SAndroid Build Coastguard Worker            out = func(*args, **kwargs)
1331*da0073e9SAndroid Build Coastguard Worker        return out
1332*da0073e9SAndroid Build Coastguard Worker
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Workerdef create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1335*da0073e9SAndroid Build Coastguard Worker    """
1336*da0073e9SAndroid Build Coastguard Worker    Helper to avoid recomputing certain ops during activation checkpointing.
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    Use this with `torch.utils.checkpoint.checkpoint` to control which
1339*da0073e9SAndroid Build Coastguard Worker    operations are recomputed during the backward pass.
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker    Args:
1342*da0073e9SAndroid Build Coastguard Worker        policy_fn_or_list (Callable or List):
1343*da0073e9SAndroid Build Coastguard Worker          - If a policy function is provided, it should accept a
1344*da0073e9SAndroid Build Coastguard Worker            :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
1345*da0073e9SAndroid Build Coastguard Worker            kwargs to the op, and return a :class:`CheckpointPolicy` enum value
1346*da0073e9SAndroid Build Coastguard Worker            indicating whether the execution of the op should be recomputed or not.
1347*da0073e9SAndroid Build Coastguard Worker          - If a list of operations is provided, it is equivalent to a policy
1348*da0073e9SAndroid Build Coastguard Worker            returning `CheckpointPolicy.MUST_SAVE` for the specified
1349*da0073e9SAndroid Build Coastguard Worker            operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
1350*da0073e9SAndroid Build Coastguard Worker            operations.
1351*da0073e9SAndroid Build Coastguard Worker        allow_cache_entry_mutation (bool, optional): By default, an error is
1352*da0073e9SAndroid Build Coastguard Worker            raised if any tensors cached by selective activation checkpoint are
1353*da0073e9SAndroid Build Coastguard Worker            mutated in order to ensure correctness. If set to `True`, this check
1354*da0073e9SAndroid Build Coastguard Worker            is disabled.
1355*da0073e9SAndroid Build Coastguard Worker    Returns:
1356*da0073e9SAndroid Build Coastguard Worker        A tuple of two context managers.
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    Example:
1359*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(LINUX)
1360*da0073e9SAndroid Build Coastguard Worker        >>> import functools
1361*da0073e9SAndroid Build Coastguard Worker        >>>
1362*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.rand(10, 10, requires_grad=True)
1363*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.rand(10, 10, requires_grad=True)
1364*da0073e9SAndroid Build Coastguard Worker        >>>
1365*da0073e9SAndroid Build Coastguard Worker        >>> ops_to_save = [
1366*da0073e9SAndroid Build Coastguard Worker        >>>    torch.ops.aten.mm.default,
1367*da0073e9SAndroid Build Coastguard Worker        >>> ]
1368*da0073e9SAndroid Build Coastguard Worker        >>>
1369*da0073e9SAndroid Build Coastguard Worker        >>> def policy_fn(ctx, op, *args, **kwargs):
1370*da0073e9SAndroid Build Coastguard Worker        >>>    if op in ops_to_save:
1371*da0073e9SAndroid Build Coastguard Worker        >>>        return CheckpointPolicy.MUST_SAVE
1372*da0073e9SAndroid Build Coastguard Worker        >>>    else:
1373*da0073e9SAndroid Build Coastguard Worker        >>>        return CheckpointPolicy.PREFER_RECOMPUTE
1374*da0073e9SAndroid Build Coastguard Worker        >>>
1375*da0073e9SAndroid Build Coastguard Worker        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
1376*da0073e9SAndroid Build Coastguard Worker        >>>
1377*da0073e9SAndroid Build Coastguard Worker        >>> # or equivalently
1378*da0073e9SAndroid Build Coastguard Worker        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
1379*da0073e9SAndroid Build Coastguard Worker        >>>
1380*da0073e9SAndroid Build Coastguard Worker        >>> def fn(x, y):
1381*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
1382*da0073e9SAndroid Build Coastguard Worker        >>>
1383*da0073e9SAndroid Build Coastguard Worker        >>> out = torch.utils.checkpoint.checkpoint(
1384*da0073e9SAndroid Build Coastguard Worker        >>>     fn, x, y,
1385*da0073e9SAndroid Build Coastguard Worker        >>>     use_reentrant=False,
1386*da0073e9SAndroid Build Coastguard Worker        >>>     context_fn=context_fn,
1387*da0073e9SAndroid Build Coastguard Worker        >>> )
1388*da0073e9SAndroid Build Coastguard Worker    """
1389*da0073e9SAndroid Build Coastguard Worker    # NB: If grad_mode is disabled, checkpoint would not run forward under
1390*da0073e9SAndroid Build Coastguard Worker    #     context_fn anyway, so proceed as usual.
1391*da0073e9SAndroid Build Coastguard Worker    if isinstance(policy_fn_or_list, list):
1392*da0073e9SAndroid Build Coastguard Worker        for op in policy_fn_or_list:
1393*da0073e9SAndroid Build Coastguard Worker            if not isinstance(op, torch._ops.OpOverload):
1394*da0073e9SAndroid Build Coastguard Worker                _extra_msg = (
1395*da0073e9SAndroid Build Coastguard Worker                    "Please update the OpOverloadPacket to a specific OpOverload."
1396*da0073e9SAndroid Build Coastguard Worker                    "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
1397*da0073e9SAndroid Build Coastguard Worker                ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
1398*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
1399*da0073e9SAndroid Build Coastguard Worker                    f"Expected op in `op_list` to be an OpOverload but got: {op} "
1400*da0073e9SAndroid Build Coastguard Worker                    f"of type {type(op)}. {_extra_msg}"
1401*da0073e9SAndroid Build Coastguard Worker                )
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
1404*da0073e9SAndroid Build Coastguard Worker            if op in policy_fn_or_list:
1405*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
1406*da0073e9SAndroid Build Coastguard Worker            else:
1407*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
1408*da0073e9SAndroid Build Coastguard Worker    elif callable(policy_fn_or_list):
1409*da0073e9SAndroid Build Coastguard Worker        policy_fn = policy_fn_or_list
1410*da0073e9SAndroid Build Coastguard Worker    else:
1411*da0073e9SAndroid Build Coastguard Worker        raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    storage: Dict[Any, List[Any]] = defaultdict(list)
1414*da0073e9SAndroid Build Coastguard Worker    return (
1415*da0073e9SAndroid Build Coastguard Worker        _CachingTorchDispatchMode(policy_fn, storage),
1416*da0073e9SAndroid Build Coastguard Worker        _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
1417*da0073e9SAndroid Build Coastguard Worker    )
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
1420*da0073e9SAndroid Build Coastguard Worker#     saving/restoring of global state is handled here.
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Workerdef _checkpoint_without_reentrant_generator(
1423*da0073e9SAndroid Build Coastguard Worker    fn,
1424*da0073e9SAndroid Build Coastguard Worker    preserve_rng_state=True,
1425*da0073e9SAndroid Build Coastguard Worker    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
1426*da0073e9SAndroid Build Coastguard Worker    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
1427*da0073e9SAndroid Build Coastguard Worker    debug: bool = False,
1428*da0073e9SAndroid Build Coastguard Worker    *args,
1429*da0073e9SAndroid Build Coastguard Worker    **kwargs
1430*da0073e9SAndroid Build Coastguard Worker):
1431*da0073e9SAndroid Build Coastguard Worker    """Checkpointing without reentrant autograd.
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker    Args:
1434*da0073e9SAndroid Build Coastguard Worker        function: describes what to run in the forward pass of the model or
1435*da0073e9SAndroid Build Coastguard Worker            part of the model. It should also know how to handle the inputs
1436*da0073e9SAndroid Build Coastguard Worker            passed as the tuple. For example, in LSTM, if user passes
1437*da0073e9SAndroid Build Coastguard Worker            ``(activation, hidden)``, :attr:`function` should correctly use the
1438*da0073e9SAndroid Build Coastguard Worker            first input as ``activation`` and the second input as ``hidden``
1439*da0073e9SAndroid Build Coastguard Worker        preserve_rng_state(bool, optional):  Omit stashing and restoring
1440*da0073e9SAndroid Build Coastguard Worker            the RNG state during each checkpoint.
1441*da0073e9SAndroid Build Coastguard Worker            Default: ``True``
1442*da0073e9SAndroid Build Coastguard Worker        context_fn(Callable, optional): A callable returning a tuple of two
1443*da0073e9SAndroid Build Coastguard Worker            context managers. The function and its recomputation will be run
1444*da0073e9SAndroid Build Coastguard Worker            under the first and second context managers respectively.
1445*da0073e9SAndroid Build Coastguard Worker        determinism_check(str, optional): A string specifying the determinism
1446*da0073e9SAndroid Build Coastguard Worker            check to perform. By default it is set to ``"default"`` which
1447*da0073e9SAndroid Build Coastguard Worker            compares the shapes, dtypes, and devices of the recomputed tensors
1448*da0073e9SAndroid Build Coastguard Worker            against those the saved tensors. To turn off this check, specify
1449*da0073e9SAndroid Build Coastguard Worker            ``"none"``. Currently these are the only two supported values.
1450*da0073e9SAndroid Build Coastguard Worker            Please open an issue if you would like to see more determinism
1451*da0073e9SAndroid Build Coastguard Worker            checks.
1452*da0073e9SAndroid Build Coastguard Worker        debug(bool, optional): If ``True``, error messages will also include
1453*da0073e9SAndroid Build Coastguard Worker            a trace of the operators ran during the original forward computation
1454*da0073e9SAndroid Build Coastguard Worker            as well as the recomputation.
1455*da0073e9SAndroid Build Coastguard Worker        *args: Arguments to pass in to the given ``function``.
1456*da0073e9SAndroid Build Coastguard Worker        **kwargs: Keyword arguments to pass into the given ``function``.
1457*da0073e9SAndroid Build Coastguard Worker    """
1458*da0073e9SAndroid Build Coastguard Worker    unpack_error_cb = None
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker    if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
1461*da0073e9SAndroid Build Coastguard Worker        if context_fn != noop_context_fn:
1462*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
1463*da0073e9SAndroid Build Coastguard Worker                "debug=True is incompatible with non-default context_fn"
1464*da0073e9SAndroid Build Coastguard Worker            )
1465*da0073e9SAndroid Build Coastguard Worker        context_fn, unpack_error_cb = _get_debug_context_and_cb()
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker    if determinism_check in _allowed_determinism_checks_to_fns:
1468*da0073e9SAndroid Build Coastguard Worker        metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
1469*da0073e9SAndroid Build Coastguard Worker    else:
1470*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
1471*da0073e9SAndroid Build Coastguard Worker            f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
1472*da0073e9SAndroid Build Coastguard Worker            f"but got {determinism_check}"
1473*da0073e9SAndroid Build Coastguard Worker        )
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker    device_type = _infer_device_type(*args)
1476*da0073e9SAndroid Build Coastguard Worker    device_module = _get_device_module(device_type)
1477*da0073e9SAndroid Build Coastguard Worker    forward_context, recompute_context = context_fn()
1478*da0073e9SAndroid Build Coastguard Worker    if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
1479*da0073e9SAndroid Build Coastguard Worker        assert (
1480*da0073e9SAndroid Build Coastguard Worker            isinstance(forward_context, TorchDispatchMode) and
1481*da0073e9SAndroid Build Coastguard Worker            isinstance(recompute_context, TorchDispatchMode)
1482*da0073e9SAndroid Build Coastguard Worker        ), \
1483*da0073e9SAndroid Build Coastguard Worker            "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
1484*da0073e9SAndroid Build Coastguard Worker            "must generate a tuple of two `TorchDispatchMode`s."
1485*da0073e9SAndroid Build Coastguard Worker    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
1486*da0073e9SAndroid Build Coastguard Worker    device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker    if preserve_rng_state:
1489*da0073e9SAndroid Build Coastguard Worker        fwd_cpu_state = torch.get_rng_state()
1490*da0073e9SAndroid Build Coastguard Worker        # Don't eagerly initialize the cuda context by accident.
1491*da0073e9SAndroid Build Coastguard Worker        # (If the user intends that the context is initialized later, within their
1492*da0073e9SAndroid Build Coastguard Worker        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
1493*da0073e9SAndroid Build Coastguard Worker        # we have no way to anticipate this will happen before we run the function.
1494*da0073e9SAndroid Build Coastguard Worker        # If they do so, we raise an error.)
1495*da0073e9SAndroid Build Coastguard Worker        had_device_in_fwd = False
1496*da0073e9SAndroid Build Coastguard Worker        if getattr(device_module, "_initialized", False):
1497*da0073e9SAndroid Build Coastguard Worker            had_device_in_fwd = True
1498*da0073e9SAndroid Build Coastguard Worker            fwd_devices, fwd_device_states = get_device_states(*args)
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    def recompute_fn(*inputs):
1501*da0073e9SAndroid Build Coastguard Worker        kwargs, *args = inputs
1502*da0073e9SAndroid Build Coastguard Worker        # This will be called later during recomputation. This wrapping enables
1503*da0073e9SAndroid Build Coastguard Worker        # the necessary global state to be captured.
1504*da0073e9SAndroid Build Coastguard Worker        rng_devices = []
1505*da0073e9SAndroid Build Coastguard Worker        if preserve_rng_state and had_device_in_fwd:
1506*da0073e9SAndroid Build Coastguard Worker            rng_devices = fwd_devices
1507*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(
1508*da0073e9SAndroid Build Coastguard Worker            devices=rng_devices, enabled=preserve_rng_state, device_type=device_type
1509*da0073e9SAndroid Build Coastguard Worker        ):
1510*da0073e9SAndroid Build Coastguard Worker            if preserve_rng_state:
1511*da0073e9SAndroid Build Coastguard Worker                torch.set_rng_state(fwd_cpu_state)
1512*da0073e9SAndroid Build Coastguard Worker                if had_device_in_fwd:
1513*da0073e9SAndroid Build Coastguard Worker                    set_device_states(fwd_devices, fwd_device_states, device_type=device_type)
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker            device_autocast_ctx = torch.amp.autocast(
1516*da0073e9SAndroid Build Coastguard Worker                device_type=device_type, **device_autocast_kwargs
1517*da0073e9SAndroid Build Coastguard Worker            ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext()
1518*da0073e9SAndroid Build Coastguard Worker            with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
1519*da0073e9SAndroid Build Coastguard Worker                fn(*args, **kwargs)
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker    new_frame = _CheckpointFrame(
1522*da0073e9SAndroid Build Coastguard Worker        recompute_fn,
1523*da0073e9SAndroid Build Coastguard Worker        _enable_checkpoint_early_stop,
1524*da0073e9SAndroid Build Coastguard Worker        unpack_error_cb,
1525*da0073e9SAndroid Build Coastguard Worker        metadata_fn
1526*da0073e9SAndroid Build Coastguard Worker    )
1527*da0073e9SAndroid Build Coastguard Worker    dummy = torch.empty((0,), requires_grad=True)
1528*da0073e9SAndroid Build Coastguard Worker    new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker    # When ambient grad_mode is False
1531*da0073e9SAndroid Build Coastguard Worker    if new_frame.input_saver.grad_fn is None:
1532*da0073e9SAndroid Build Coastguard Worker        yield
1533*da0073e9SAndroid Build Coastguard Worker        return
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker    with _checkpoint_hook(new_frame), forward_context:
1536*da0073e9SAndroid Build Coastguard Worker        yield
1537*da0073e9SAndroid Build Coastguard Worker    new_frame.forward_completed = True
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    if getattr(device_module, "_initialized", False) and \
1540*da0073e9SAndroid Build Coastguard Worker       preserve_rng_state and not had_device_in_fwd:  # type: ignore[possibly-undefined]
1541*da0073e9SAndroid Build Coastguard Worker        # Device was not initialized before running the forward, so we didn't
1542*da0073e9SAndroid Build Coastguard Worker        # stash the device state.
1543*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1544*da0073e9SAndroid Build Coastguard Worker            "PyTorch's device state was initialized in the forward pass "
1545*da0073e9SAndroid Build Coastguard Worker            "of a Checkpoint, which is not allowed. Please open an issue "
1546*da0073e9SAndroid Build Coastguard Worker            "if you need this feature."
1547*da0073e9SAndroid Build Coastguard Worker        )
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker    return
1550