xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/test_compose.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["oncall: distributed"]
2 
3 import copy
4 import sys
5 from typing import Dict
6 
7 import torch
8 import torch.distributed as dist
9 import torch.nn as nn
10 from torch.distributed._composable import checkpoint, fully_shard, replicate
11 from torch.distributed._shard.sharded_tensor import ShardedTensor
12 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
13 from torch.distributed.fsdp.api import MixedPrecision, ShardingStrategy
14 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
15 from torch.testing._internal.common_dist_composable import (
16     CompositeModel,
17     CompositeParamModel,
18     UnitModule,
19 )
20 from torch.testing._internal.common_distributed import (
21     SaveForwardInputsModel,
22     skip_if_lt_x_gpu,
23 )
24 from torch.testing._internal.common_fsdp import FSDPTest
25 from torch.testing._internal.common_utils import (
26     instantiate_parametrized_tests,
27     run_tests,
28     TEST_WITH_DEV_DBG_ASAN,
29 )
30 
31 
32 if not dist.is_available():
33     print("Distributed not available, skipping tests", file=sys.stderr)
34     sys.exit(0)
35 
36 
37 if TEST_WITH_DEV_DBG_ASAN:
38     print(
39         "Skip dev-asan as torch + multiprocessing spawn have known issues",
40         file=sys.stderr,
41     )
42     sys.exit(0)
43 
44 
45 class TestFSDPCheckpoint(FSDPTest):
46     @property
47     def world_size(self) -> int:
48         return 2
49 
50     # TODO: Define `use_same_inputs_across_ranks` for now for BC since some
51     # test model configs do not have a simple base model to compare against. In
52     # those cases, we use the same inputs across ranks so that the averaged
53     # gradient equals the local gradient to check for parity. This means that
54     # the gradient reduction is unchecked.
55     def _test_parity(
56         self,
57         base_model: nn.Module,
58         test_model: nn.Module,
59         inp_size: torch.Size,
60         inp_device: torch.device,
61         grad_to_none: bool,
62         use_same_inputs_across_ranks: bool,
63     ):
64         LR = 0.01
65         base_optim = torch.optim.Adam(base_model.parameters(), lr=LR)
66         test_optim = torch.optim.Adam(test_model.parameters(), lr=LR)
67 
68         for _ in range(5):
69             if use_same_inputs_across_ranks:
70                 torch.manual_seed(0)
71             x = torch.randn(inp_size, device=inp_device)
72             test_loss = test_model(x).sum()
73             base_loss = base_model(x).sum()
74 
75             self.assertEqual(test_loss, base_loss)
76 
77             test_loss.backward()
78             test_optim.step()
79             test_optim.zero_grad(set_to_none=grad_to_none)
80 
81             base_loss.backward()
82             base_optim.step()
83             base_optim.zero_grad(set_to_none=grad_to_none)
84 
85     @skip_if_lt_x_gpu(2)
86     def test_wrap_same_submodule(self):
87         model = UnitModule(device=torch.device("cuda"))
88 
89         base_model = copy.deepcopy(model)
90 
91         test_model = copy.deepcopy(model)
92         # compose checkpoint and fully_shard
93         test_model.seq = checkpoint(test_model.seq)
94         test_model.seq = fully_shard(
95             test_model.seq,
96             policy=ModuleWrapPolicy({nn.Linear}),
97         )
98 
99         self.run_subtests(
100             {
101                 "base_model": [base_model],
102                 "test_model": [test_model],
103                 "inp_size": [torch.Size((2, 100))],
104                 "inp_device": [torch.device("cuda")],
105                 "grad_to_none": [True, False],
106                 "use_same_inputs_across_ranks": [True],
107             },
108             self._test_parity,
109         )
110 
111     def _test_checkpoint_fsdp_submodules(self):
112         model = CompositeModel(device=torch.device("cuda"))
113 
114         base_model = copy.deepcopy(model)
115 
116         test_model = copy.deepcopy(model)
117         test_model.u1 = fully_shard(test_model.u1, policy=None)
118         test_model.u2 = fully_shard(test_model.u2)
119 
120         test_model.u1.seq = checkpoint(test_model.u1.seq)
121         test_model.u2.seq = checkpoint(test_model.u2.seq)
122 
123         self.run_subtests(
124             {
125                 "base_model": [base_model],
126                 "test_model": [test_model],
127                 "inp_size": [torch.Size((2, 100))],
128                 "inp_device": [torch.device("cuda")],
129                 "grad_to_none": [True, False],
130                 "use_same_inputs_across_ranks": [True],
131             },
132             self._test_parity,
133         )
134 
135     @skip_if_lt_x_gpu(2)
136     def test_checkpoint_fsdp_submodules_non_reentrant(self):
137         self._test_checkpoint_fsdp_submodules()
138 
139     @skip_if_lt_x_gpu(2)
140     def test_checkpoint_fully_shard_cast_forward_inputs(self):
141         self.run_subtests(
142             {
143                 "checkpoint_strict_submodule": [False, True],
144             },
145             self._test_checkpoint_fully_shard_cast_forward_inputs,
146         )
147 
148     def _test_checkpoint_fully_shard_cast_forward_inputs(
149         self, checkpoint_strict_submodule: bool
150     ):
151         forward_inputs: Dict[nn.Module, torch.Tensor] = {}
152         fp16_mp = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
153         fp32_mp = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
154 
155         model = SaveForwardInputsModel(
156             forward_inputs=forward_inputs, cast_forward_inputs=False
157         ).cuda()
158         x = torch.zeros(2, 100, device="cuda")
159 
160         fully_shard(model.c2, mixed_precision=fp16_mp)
161         if checkpoint_strict_submodule:
162             checkpoint(model.c2.l)
163         else:
164             checkpoint(model.c2)
165         fully_shard(model, mixed_precision=fp32_mp)
166 
167         loss = model(x).sum()
168         loss.backward()
169 
170         self.assertEqual(forward_inputs[model].dtype, torch.float32)
171         self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
172         # Notably, check that the recomputed forward preserves the right dtype
173         self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
174 
175     @skip_if_lt_x_gpu(2)
176     def test_fully_shard_replicate_correct_replicate_params(self):
177         model = CompositeParamModel(device=torch.device("cuda"))
178         # Shard Linears within UnitModule
179         fully_shard(model.u1, policy=ModuleWrapPolicy({nn.Linear}))
180         fully_shard(model.u2, policy=ModuleWrapPolicy({nn.Linear}))
181         # replicate the rest
182         replicate(model)
183         # Run fwd + bwd to initialize DDP
184         inp = torch.randn(2, 100, device="cuda")
185         model(inp).sum().backward()
186         # Ensure replicate param names are as expected, i.e.
187         # immediate parameters of model and parameters of model's non-UnitModule
188         # submodules are replicated
189         param_names = replicate.state(model)._param_names
190         replicated_modules = [
191             (name, mod)
192             for (name, mod) in model.named_children()
193             if mod not in [model.u1, model.u2]
194         ]
195         replicated_param_names = [
196             f"{module_name}.{n}"
197             for module_name, mod in replicated_modules
198             for n, _ in mod.named_parameters()
199         ]
200         replicated_param_names.extend(
201             [n for n, _ in model.named_parameters(recurse=False)]
202         )
203         self.assertEqual(set(param_names), set(replicated_param_names))
204 
205     @skip_if_lt_x_gpu(2)
206     def test_checkpoint_fsdp_submodules_with_param(self):
207         model = CompositeParamModel(device=torch.device("cuda"))
208 
209         base_model = copy.deepcopy(model)
210 
211         test_model = copy.deepcopy(model)
212         test_model.u1.seq = checkpoint(test_model.u1.seq)
213         test_model.u2.seq = checkpoint(test_model.u2.seq)
214         test_model = fully_shard(test_model)
215 
216         self.run_subtests(
217             {
218                 "base_model": [base_model],
219                 "test_model": [test_model],
220                 "inp_size": [torch.Size((2, 100))],
221                 "inp_device": [torch.device("cuda")],
222                 "grad_to_none": [True, False],
223                 "use_same_inputs_across_ranks": [True],
224             },
225             self._test_parity,
226         )
227 
228     @skip_if_lt_x_gpu(2)
229     def test_checkpoint_fsdp_submodules_with_param_no_shard(self):
230         model = CompositeParamModel(device=torch.device("cuda"))
231 
232         base_model = copy.deepcopy(model)
233 
234         test_model = copy.deepcopy(model)
235         test_model.u1.seq = checkpoint(test_model.u1.seq)
236         test_model.u2.seq = checkpoint(test_model.u2.seq)
237         test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD)
238 
239         self.run_subtests(
240             {
241                 "base_model": [base_model],
242                 "test_model": [test_model],
243                 "inp_size": [torch.Size((2, 100))],
244                 "inp_device": [torch.device("cuda")],
245                 "grad_to_none": [True, False],
246                 "use_same_inputs_across_ranks": [True],
247             },
248             self._test_parity,
249         )
250 
251     @skip_if_lt_x_gpu(2)
252     def test_composable_fsdp_replicate(self):
253         # Verify how the APIs can be composed, e.g. if both `fully_shard` and
254         # `replicate` are applied on the same module, it should raise exception.
255         model = CompositeModel(device=torch.device("cuda"))
256         fully_shard(model.l1)
257         with self.assertRaisesRegex(RuntimeError, "Cannot apply .*replicate"):
258             replicate(model.l1)
259         replicate(model.l2)  # should not raise
260 
261     @skip_if_lt_x_gpu(2)
262     def test_fully_shard_replicate_composability(self):
263         """
264         Tests composing ``fully_shard`` and ``replicate``. To save unit test
265         time, we run the different configs in subtests.
266         """
267         self.run_subtests(
268             {
269                 "config": [
270                     "1fm,1r",
271                     "1r,1fm",
272                     "1r,1fa",
273                     "1r1fm,1fm",
274                     "1r1fa,1fm",
275                     "1fm1fm,1r1r,1fm",
276                 ]
277             },
278             self._test_replicate_in_fully_shard,
279         )
280 
281     def _test_replicate_in_fully_shard(self, config: str):
282         """
283         To interpret the config, each comma delineates a level in the module
284         tree ordered bottom-up; 'r' means ``replicate``; 'f' means
285         ``fully_shard``; 'a' means auto wrap; and 'm' means manual wrap.
286         """
287         # Set the seed to ensure that all ranks initialize the same model
288         torch.manual_seed(0)
289         if config == "1fm,1r":
290             base_model = CompositeModel(device=torch.device("cuda"))
291             test_model = copy.deepcopy(base_model)
292             fully_shard(test_model.l1)
293             replicate(test_model)
294         elif config == "1r,1fm":
295             base_model = CompositeParamModel(torch.device("cuda"))
296             test_model = copy.deepcopy(base_model)
297             replicate(test_model.u1)
298             fully_shard(test_model)
299         elif config == "1r,1fa":
300             base_model = CompositeParamModel(torch.device("cuda"))
301             test_model = copy.deepcopy(base_model)
302             replicate(test_model.u1)
303             fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
304         elif config == "1r1fm,1fm":
305             base_model = CompositeParamModel(torch.device("cuda"))
306             test_model = copy.deepcopy(base_model)
307             replicate(test_model.u1)
308             fully_shard(test_model.u2)
309             fully_shard(test_model)
310         elif config == "1r1fa,1fm":
311             base_model = CompositeParamModel(torch.device("cuda"))
312             test_model = copy.deepcopy(base_model)
313             replicate(test_model.u1)
314             fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))
315             fully_shard(test_model)
316         elif config == "1fm1fm,1r1r,1fm":
317             base_model = CompositeParamModel(torch.device("cuda"))
318             test_model = copy.deepcopy(base_model)
319             fully_shard(test_model.u1.seq)
320             fully_shard(test_model.u2.seq)
321             replicate(test_model.u1)
322             replicate(test_model.u2)
323             fully_shard(test_model)
324         else:
325             raise ValueError(f"Unknown config: {config}")
326         # Apply data parallelism to the base model for parity since we apply
327         # data parallelism to the test model
328         replicate(base_model)
329 
330         # Set the seed to ensure that ranks get different input data
331         torch.manual_seed(self.rank + 1)
332         self._test_parity(
333             base_model,
334             test_model,
335             torch.Size((2, 100)),
336             torch.device("cuda"),
337             True,
338             False,
339         )
340 
341     @skip_if_lt_x_gpu(2)
342     def test_state_dict_fsdp_submodules(self):
343         model = CompositeModel(device=torch.device("cuda"))
344 
345         full_shard_args = {"strategy": ShardingStrategy.FULL_SHARD}
346         no_shard_args = {"strategy": ShardingStrategy.NO_SHARD}
347 
348         model.u1 = fully_shard(model.u1, **full_shard_args)
349         model.u2 = fully_shard(model.u2, **no_shard_args)
350 
351         FSDP.set_state_dict_type(
352             model,
353             StateDictType.SHARDED_STATE_DICT,
354         )
355 
356         state_dict = model.state_dict()
357         for fqn, tensor in state_dict.items():
358             if "u1" in fqn:
359                 self.assertIsInstance(tensor, ShardedTensor)
360             elif "u2" in fqn:
361                 self.assertIsInstance(tensor, torch.Tensor)
362         # Ensure that get_state_dict_type can still correctly get the settings.
363         _ = FSDP.get_state_dict_type(model)
364 
365 
366 instantiate_parametrized_tests(TestFSDPCheckpoint)
367 
368 
369 if __name__ == "__main__":
370     run_tests()
371