1# Owner(s): ["oncall: distributed"] 2 3import unittest 4from collections import deque, OrderedDict 5from contextlib import ContextDecorator, contextmanager, nullcontext 6from copy import deepcopy 7from functools import partial 8from typing import Tuple 9 10import torch 11import torch.nn as nn 12from torch.distributed._composable import checkpoint 13from torch.testing._internal.common_cuda import TEST_CUDA 14from torch.testing._internal.common_utils import run_tests, TestCase 15from torch.utils.checkpoint import CheckpointError 16 17 18class MemoryDelta(ContextDecorator): 19 def __init__(self, device: torch.device): 20 self.device: torch.device = device 21 self.active_memory_enter: int = 0 22 self.active_memory_exit: int = 0 23 24 def __enter__(self): 25 self.active_memory_enter = ( 26 torch.cuda.memory_stats()["active_bytes.all.current"] 27 if self.device.type == "cuda" 28 else 0 29 ) 30 return self 31 32 def __exit__(self, *exc): 33 self.active_memory_exit = ( 34 torch.cuda.memory_stats()["active_bytes.all.current"] 35 if self.device.type == "cuda" 36 else 0 37 ) 38 39 def delta(self) -> int: 40 return self.active_memory_exit - self.active_memory_enter 41 42 43class ToyModel(nn.Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.l1 = nn.Linear(100, 100) 47 self.seq = nn.Sequential( 48 nn.ReLU(), 49 nn.Linear(100, 100), 50 nn.ReLU(), 51 ) 52 53 def forward(self, x): 54 return self.seq(self.l1(x)) 55 56 57class RandomModel(nn.Module): 58 def __init__(self) -> None: 59 super().__init__() 60 self.p = nn.Parameter(torch.randn(100, 100)) 61 62 def forward(self, x): 63 y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device)) 64 return torch.matmul(x, y) 65 66 67class MultiOutputModel(nn.Module): 68 def __init__(self, device: torch.device): 69 super().__init__() 70 self.w1 = nn.Parameter(torch.randn((100, 100), device=device)) 71 self.w2 = nn.Parameter(torch.randn((100, 100), device=device)) 72 73 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 74 z = x @ self.w1 75 z = nn.functional.relu(z) 76 z = z @ self.w2 77 return z.sin(), z.cos() 78 79 80class MultiInputModel(nn.Module): 81 def __init__(self, device: torch.device): 82 super().__init__() 83 self.w = nn.Parameter(torch.randn((100, 100), device=device)) 84 85 def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: 86 assert len(xs) == 2, f"Expects 2 args but got {len(xs)}" 87 x, y = xs 88 z = x + y 89 z = z @ self.w 90 return nn.functional.relu(z) 91 92 93class TestCheckpoint(TestCase): 94 def _get_graph_size(self, out: torch.Tensor) -> int: 95 q = deque([out.grad_fn]) 96 num_functions = 0 97 while len(q): 98 fn = q.pop() 99 num_functions += 1 100 for next_fn, _ in fn.next_functions: 101 if next_fn: 102 q.append(next_fn) 103 104 return num_functions 105 106 def _test_tensor_only( 107 self, 108 net: nn.Module, 109 x: torch.Tensor, 110 ) -> None: 111 x1 = x.clone() 112 x2 = x.clone() 113 x1.requires_grad = True 114 x2.requires_grad = True 115 116 net1 = net 117 net2 = deepcopy(net) 118 119 # no checkpoint 120 with MemoryDelta(x.device) as mem1: 121 loss1 = net1(x1).sum() 122 graph_size1 = self._get_graph_size(loss1) 123 loss1.backward() 124 125 # with checkpoint 126 checkpoint(net2.seq) 127 with MemoryDelta(x.device) as mem2: 128 loss2 = net2(x2).sum() 129 loss2.backward() 130 131 if x.is_cuda: 132 self.assertTrue(mem2.delta() < mem1.delta()) 133 134 for p1, p2 in zip(net1.parameters(), net2.parameters()): 135 self.assertEqual(p1.grad, p2.grad) 136 137 def test_tensor_only_cpu(self): 138 x = torch.randn(20, 100) 139 net = ToyModel() 140 self._test_tensor_only(net, x) 141 142 @unittest.skipIf(not TEST_CUDA, "no cuda") 143 def test_tensor_only_gpu(self): 144 x = torch.randn(20, 100, device="cuda:0") 145 net = ToyModel().to("cuda:0") 146 self._test_tensor_only(net, x) 147 148 def test_random_cpu(self): 149 x1 = torch.randn(20, 100, requires_grad=True) 150 x2 = x1.clone() 151 152 net1 = RandomModel() 153 net2 = deepcopy(net1) 154 155 cpu_rng_state = torch.get_rng_state() 156 net1(x1).sum().backward() 157 torch.set_rng_state(cpu_rng_state) 158 checkpoint(net2)(x2).sum().backward() 159 160 for p1, p2 in zip(net1.parameters(), net2.parameters()): 161 self.assertEqual(p1.grad, p2.grad) 162 163 def test_multi_args(self): 164 """ 165 Tests checkpoint for modules with multiple output args and hence 166 multiple backward function input args. 167 """ 168 device = torch.device("cpu") 169 net1 = nn.Sequential( 170 MultiOutputModel(device), 171 MultiInputModel(device), 172 MultiOutputModel(device), 173 MultiInputModel(device), 174 ) 175 net2 = deepcopy(net1) 176 checkpoint(net2[0]) 177 checkpoint(net2[2]) 178 x1 = torch.randn(20, 100, requires_grad=True) 179 x2 = x1.clone() 180 net1(x1).sum().backward() 181 net2(x2).sum().backward() 182 for p1, p2 in zip(net1.parameters(), net2.parameters()): 183 self.assertEqual(p1.grad, p2.grad) 184 185 def test_clears_state_on_error_in_forward(self): 186 class MyModel(torch.nn.Module): 187 def __init__(self, raise_in_recomp): 188 super().__init__() 189 self.fwd_count = 0 190 self.raise_in_recomp = raise_in_recomp 191 self.a = torch.nn.Linear(2, 2) 192 193 def forward(self, x): 194 if self.raise_in_recomp and self.fwd_count == 1: 195 raise RuntimeError("foo") 196 else: 197 if not self.raise_in_recomp: 198 # raise in the first forward 199 raise RuntimeError("foo") 200 self.fwd_count += 1 201 return self.a(x) 202 203 m = MyModel(raise_in_recomp=True) 204 m_seq = torch.nn.Sequential(OrderedDict({"m": m})) 205 checkpoint(m_seq.m) 206 inp = torch.randn(1, 2) 207 out = m_seq(inp).sum() 208 # Should raise in forward recomputation 209 with self.assertRaisesRegex(RuntimeError, "foo"): 210 out.backward() 211 212 # Check that _ac_generator is cleared out 213 self.assertEqual(None, checkpoint.state(m)._ac_generator) 214 215 m = MyModel(raise_in_recomp=False) 216 checkpoint(m) 217 inp = torch.randn(1, 2) 218 # Should raise in first forward 219 with self.assertRaises(RuntimeError): 220 m(inp) 221 222 self.assertEqual(None, checkpoint.state(m)._ac_generator) 223 224 def test_checkpoint_kwargs(self): 225 class MyModel(torch.nn.Module): 226 def __init__(self, raise_exp: bool, change_shape_in_recomp: bool): 227 super().__init__() 228 self.fwd_count = 0 229 self.raise_exp = raise_exp 230 self.change_shape_in_recomp = change_shape_in_recomp 231 self.a = torch.nn.Linear(2, 2) 232 233 def forward(self, x): 234 if self.raise_exp and self.fwd_count == 0: 235 raise RuntimeError("foo") 236 if self.raise_exp and self.fwd_count == 1: 237 raise RuntimeError("bar") 238 if self.change_shape_in_recomp and self.fwd_count == 1: 239 x.relu_() 240 random_tensor = torch.randn(1, 2) 241 x = self.a(x + random_tensor) 242 self.fwd_count += 1 243 return x 244 245 m = MyModel(True, False) 246 m0, m1, m2, m3 = (deepcopy(m) for _ in range(4)) 247 248 # composable checkpoint does not support use_reentrant=True 249 with self.assertRaisesRegex( 250 NotImplementedError, 251 "use_reentrant=True is not supported in composable checkpoint. " 252 "Please use torch.utils.checkpoint.checkpoint instead.", 253 ): 254 checkpoint(m, use_reentrant=True) 255 256 # check giving an unsupported kwarg 257 with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"): 258 checkpoint(m0, foo="bar") 259 260 handled_fwd_exp = False 261 handled_recomp_exp = False 262 263 @contextmanager 264 def fwd_ctx(mod: MyModel): 265 try: 266 mod.raise_exp = False 267 yield 268 finally: 269 nonlocal handled_fwd_exp 270 handled_fwd_exp = True 271 mod.raise_exp = True 272 273 @contextmanager 274 def recomp_ctx(mod: MyModel): 275 try: 276 mod.raise_exp = False 277 yield 278 finally: 279 nonlocal handled_recomp_exp 280 handled_recomp_exp = True 281 mod.raise_exp = True 282 283 # Test different context functions 284 x = torch.randn(1, 2, requires_grad=True) 285 checkpoint( 286 m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)()) 287 ) 288 m1(x.clone()).sum().backward() 289 self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True)) 290 291 checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)())) 292 with self.assertRaisesRegex(RuntimeError, "foo"): 293 m2(x.clone()) 294 295 handled_fwd_exp = False # Reset flag 296 checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext())) 297 with self.assertRaisesRegex(RuntimeError, "bar"): 298 m3(x.clone()).sum().backward() 299 self.assertEqual(handled_fwd_exp, True) 300 301 # Test determinism check failure 302 m4 = MyModel(False, True) 303 m5 = deepcopy(m4) 304 # Determinism check should not throw an error, 305 # but autograd should throw a RuntimeError 306 checkpoint(m4, determinism_check="none") 307 with self.assertRaises(RuntimeError): 308 m4(x.clone()).sum().backward() 309 310 # Determinism check should throw a CheckpointError 311 checkpoint(m5, determinism_check="default") 312 with self.assertRaises(CheckpointError): 313 m5(x.clone()).sum().backward() 314 315 # Test preserving random state 316 m6 = MyModel(False, False) 317 m7, m8 = (deepcopy(m6) for _ in range(2)) 318 checkpoint(m7, preserve_rng_state=False) 319 checkpoint(m8, preserve_rng_state=True) 320 321 for mi in (m6, m7, m8): 322 torch.manual_seed(42) 323 loss = mi(x.clone()).sum() 324 torch.manual_seed(41) 325 loss.backward() 326 # check that m6 and m7 have at least one different grad 327 self.assertNotEqual( 328 (p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters()) 329 ) 330 # check that m6 and m8 have identical grads 331 for p1, p2 in zip(m6.parameters(), m8.parameters()): 332 self.assertEqual(p1.grad, p2.grad) 333 334 335if __name__ == "__main__": 336 run_tests() 337