xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/test_checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import unittest
4from collections import deque, OrderedDict
5from contextlib import ContextDecorator, contextmanager, nullcontext
6from copy import deepcopy
7from functools import partial
8from typing import Tuple
9
10import torch
11import torch.nn as nn
12from torch.distributed._composable import checkpoint
13from torch.testing._internal.common_cuda import TEST_CUDA
14from torch.testing._internal.common_utils import run_tests, TestCase
15from torch.utils.checkpoint import CheckpointError
16
17
18class MemoryDelta(ContextDecorator):
19    def __init__(self, device: torch.device):
20        self.device: torch.device = device
21        self.active_memory_enter: int = 0
22        self.active_memory_exit: int = 0
23
24    def __enter__(self):
25        self.active_memory_enter = (
26            torch.cuda.memory_stats()["active_bytes.all.current"]
27            if self.device.type == "cuda"
28            else 0
29        )
30        return self
31
32    def __exit__(self, *exc):
33        self.active_memory_exit = (
34            torch.cuda.memory_stats()["active_bytes.all.current"]
35            if self.device.type == "cuda"
36            else 0
37        )
38
39    def delta(self) -> int:
40        return self.active_memory_exit - self.active_memory_enter
41
42
43class ToyModel(nn.Module):
44    def __init__(self) -> None:
45        super().__init__()
46        self.l1 = nn.Linear(100, 100)
47        self.seq = nn.Sequential(
48            nn.ReLU(),
49            nn.Linear(100, 100),
50            nn.ReLU(),
51        )
52
53    def forward(self, x):
54        return self.seq(self.l1(x))
55
56
57class RandomModel(nn.Module):
58    def __init__(self) -> None:
59        super().__init__()
60        self.p = nn.Parameter(torch.randn(100, 100))
61
62    def forward(self, x):
63        y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
64        return torch.matmul(x, y)
65
66
67class MultiOutputModel(nn.Module):
68    def __init__(self, device: torch.device):
69        super().__init__()
70        self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
71        self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
72
73    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
74        z = x @ self.w1
75        z = nn.functional.relu(z)
76        z = z @ self.w2
77        return z.sin(), z.cos()
78
79
80class MultiInputModel(nn.Module):
81    def __init__(self, device: torch.device):
82        super().__init__()
83        self.w = nn.Parameter(torch.randn((100, 100), device=device))
84
85    def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
86        assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
87        x, y = xs
88        z = x + y
89        z = z @ self.w
90        return nn.functional.relu(z)
91
92
93class TestCheckpoint(TestCase):
94    def _get_graph_size(self, out: torch.Tensor) -> int:
95        q = deque([out.grad_fn])
96        num_functions = 0
97        while len(q):
98            fn = q.pop()
99            num_functions += 1
100            for next_fn, _ in fn.next_functions:
101                if next_fn:
102                    q.append(next_fn)
103
104        return num_functions
105
106    def _test_tensor_only(
107        self,
108        net: nn.Module,
109        x: torch.Tensor,
110    ) -> None:
111        x1 = x.clone()
112        x2 = x.clone()
113        x1.requires_grad = True
114        x2.requires_grad = True
115
116        net1 = net
117        net2 = deepcopy(net)
118
119        # no checkpoint
120        with MemoryDelta(x.device) as mem1:
121            loss1 = net1(x1).sum()
122        graph_size1 = self._get_graph_size(loss1)
123        loss1.backward()
124
125        # with checkpoint
126        checkpoint(net2.seq)
127        with MemoryDelta(x.device) as mem2:
128            loss2 = net2(x2).sum()
129        loss2.backward()
130
131        if x.is_cuda:
132            self.assertTrue(mem2.delta() < mem1.delta())
133
134        for p1, p2 in zip(net1.parameters(), net2.parameters()):
135            self.assertEqual(p1.grad, p2.grad)
136
137    def test_tensor_only_cpu(self):
138        x = torch.randn(20, 100)
139        net = ToyModel()
140        self._test_tensor_only(net, x)
141
142    @unittest.skipIf(not TEST_CUDA, "no cuda")
143    def test_tensor_only_gpu(self):
144        x = torch.randn(20, 100, device="cuda:0")
145        net = ToyModel().to("cuda:0")
146        self._test_tensor_only(net, x)
147
148    def test_random_cpu(self):
149        x1 = torch.randn(20, 100, requires_grad=True)
150        x2 = x1.clone()
151
152        net1 = RandomModel()
153        net2 = deepcopy(net1)
154
155        cpu_rng_state = torch.get_rng_state()
156        net1(x1).sum().backward()
157        torch.set_rng_state(cpu_rng_state)
158        checkpoint(net2)(x2).sum().backward()
159
160        for p1, p2 in zip(net1.parameters(), net2.parameters()):
161            self.assertEqual(p1.grad, p2.grad)
162
163    def test_multi_args(self):
164        """
165        Tests checkpoint for modules with multiple output args and hence
166        multiple backward function input args.
167        """
168        device = torch.device("cpu")
169        net1 = nn.Sequential(
170            MultiOutputModel(device),
171            MultiInputModel(device),
172            MultiOutputModel(device),
173            MultiInputModel(device),
174        )
175        net2 = deepcopy(net1)
176        checkpoint(net2[0])
177        checkpoint(net2[2])
178        x1 = torch.randn(20, 100, requires_grad=True)
179        x2 = x1.clone()
180        net1(x1).sum().backward()
181        net2(x2).sum().backward()
182        for p1, p2 in zip(net1.parameters(), net2.parameters()):
183            self.assertEqual(p1.grad, p2.grad)
184
185    def test_clears_state_on_error_in_forward(self):
186        class MyModel(torch.nn.Module):
187            def __init__(self, raise_in_recomp):
188                super().__init__()
189                self.fwd_count = 0
190                self.raise_in_recomp = raise_in_recomp
191                self.a = torch.nn.Linear(2, 2)
192
193            def forward(self, x):
194                if self.raise_in_recomp and self.fwd_count == 1:
195                    raise RuntimeError("foo")
196                else:
197                    if not self.raise_in_recomp:
198                        # raise in the first forward
199                        raise RuntimeError("foo")
200                    self.fwd_count += 1
201                    return self.a(x)
202
203        m = MyModel(raise_in_recomp=True)
204        m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
205        checkpoint(m_seq.m)
206        inp = torch.randn(1, 2)
207        out = m_seq(inp).sum()
208        # Should raise in forward recomputation
209        with self.assertRaisesRegex(RuntimeError, "foo"):
210            out.backward()
211
212        # Check that _ac_generator is cleared out
213        self.assertEqual(None, checkpoint.state(m)._ac_generator)
214
215        m = MyModel(raise_in_recomp=False)
216        checkpoint(m)
217        inp = torch.randn(1, 2)
218        # Should raise in first forward
219        with self.assertRaises(RuntimeError):
220            m(inp)
221
222        self.assertEqual(None, checkpoint.state(m)._ac_generator)
223
224    def test_checkpoint_kwargs(self):
225        class MyModel(torch.nn.Module):
226            def __init__(self, raise_exp: bool, change_shape_in_recomp: bool):
227                super().__init__()
228                self.fwd_count = 0
229                self.raise_exp = raise_exp
230                self.change_shape_in_recomp = change_shape_in_recomp
231                self.a = torch.nn.Linear(2, 2)
232
233            def forward(self, x):
234                if self.raise_exp and self.fwd_count == 0:
235                    raise RuntimeError("foo")
236                if self.raise_exp and self.fwd_count == 1:
237                    raise RuntimeError("bar")
238                if self.change_shape_in_recomp and self.fwd_count == 1:
239                    x.relu_()
240                random_tensor = torch.randn(1, 2)
241                x = self.a(x + random_tensor)
242                self.fwd_count += 1
243                return x
244
245        m = MyModel(True, False)
246        m0, m1, m2, m3 = (deepcopy(m) for _ in range(4))
247
248        # composable checkpoint does not support use_reentrant=True
249        with self.assertRaisesRegex(
250            NotImplementedError,
251            "use_reentrant=True is not supported in composable checkpoint. "
252            "Please use torch.utils.checkpoint.checkpoint instead.",
253        ):
254            checkpoint(m, use_reentrant=True)
255
256        # check giving an unsupported kwarg
257        with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"):
258            checkpoint(m0, foo="bar")
259
260        handled_fwd_exp = False
261        handled_recomp_exp = False
262
263        @contextmanager
264        def fwd_ctx(mod: MyModel):
265            try:
266                mod.raise_exp = False
267                yield
268            finally:
269                nonlocal handled_fwd_exp
270                handled_fwd_exp = True
271                mod.raise_exp = True
272
273        @contextmanager
274        def recomp_ctx(mod: MyModel):
275            try:
276                mod.raise_exp = False
277                yield
278            finally:
279                nonlocal handled_recomp_exp
280                handled_recomp_exp = True
281                mod.raise_exp = True
282
283        # Test different context functions
284        x = torch.randn(1, 2, requires_grad=True)
285        checkpoint(
286            m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)())
287        )
288        m1(x.clone()).sum().backward()
289        self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True))
290
291        checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)()))
292        with self.assertRaisesRegex(RuntimeError, "foo"):
293            m2(x.clone())
294
295        handled_fwd_exp = False  # Reset flag
296        checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext()))
297        with self.assertRaisesRegex(RuntimeError, "bar"):
298            m3(x.clone()).sum().backward()
299        self.assertEqual(handled_fwd_exp, True)
300
301        # Test determinism check failure
302        m4 = MyModel(False, True)
303        m5 = deepcopy(m4)
304        # Determinism check should not throw an error,
305        # but autograd should throw a RuntimeError
306        checkpoint(m4, determinism_check="none")
307        with self.assertRaises(RuntimeError):
308            m4(x.clone()).sum().backward()
309
310        # Determinism check should throw a CheckpointError
311        checkpoint(m5, determinism_check="default")
312        with self.assertRaises(CheckpointError):
313            m5(x.clone()).sum().backward()
314
315        # Test preserving random state
316        m6 = MyModel(False, False)
317        m7, m8 = (deepcopy(m6) for _ in range(2))
318        checkpoint(m7, preserve_rng_state=False)
319        checkpoint(m8, preserve_rng_state=True)
320
321        for mi in (m6, m7, m8):
322            torch.manual_seed(42)
323            loss = mi(x.clone()).sum()
324            torch.manual_seed(41)
325            loss.backward()
326        # check that m6 and m7 have at least one different grad
327        self.assertNotEqual(
328            (p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters())
329        )
330        # check that m6 and m8 have identical grads
331        for p1, p2 in zip(m6.parameters(), m8.parameters()):
332            self.assertEqual(p1.grad, p2.grad)
333
334
335if __name__ == "__main__":
336    run_tests()
337