1# Owner(s): ["oncall: distributed"] 2 3import functools 4import itertools 5import os 6import tempfile 7import unittest 8from enum import auto, Enum 9from typing import Callable, Union 10 11import torch 12import torch.nn as nn 13import torch.nn.functional as F 14from torch.distributed.fsdp._wrap_utils import _validate_frozen_params 15from torch.distributed.fsdp.fully_sharded_data_parallel import ( 16 BackwardPrefetch, 17 CPUOffload, 18 FullyShardedDataParallel as FSDP, 19 MixedPrecision, 20 ShardingStrategy, 21) 22from torch.distributed.fsdp.wrap import ( 23 _or_policy, 24 _Policy, 25 _wrap_module_cls_individually, 26 always_wrap_policy, 27 CustomPolicy, 28 enable_wrap, 29 ModuleWrapPolicy, 30 size_based_auto_wrap_policy, 31 transformer_auto_wrap_policy, 32 wrap, 33) 34from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer 35from torch.nn.modules.batchnorm import _BatchNorm 36from torch.testing._internal.common_cuda import TEST_MULTIGPU 37from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 38from torch.testing._internal.common_fsdp import ( 39 _maybe_cuda, 40 CUDAInitMode, 41 DummyProcessGroup, 42 FSDPInitMode, 43 FSDPTest, 44 TransformerWithSharedParams, 45) 46from torch.testing._internal.common_utils import ( 47 FILE_SCHEMA, 48 find_free_port, 49 instantiate_parametrized_tests, 50 parametrize, 51 run_tests, 52 TEST_CUDA, 53 TestCase, 54) 55 56 57class BatchNormNet(nn.Module): 58 def __init__(self) -> None: 59 super().__init__() 60 self.lin = nn.Linear(10, 10, bias=False) 61 self.bn1 = nn.BatchNorm1d(10) 62 self.bn2 = nn.BatchNorm2d(10) 63 self.bn3 = nn.BatchNorm3d(10) 64 self.sync_bn = nn.SyncBatchNorm(10) 65 66 67class LoraModel(nn.Module): 68 """This is a toy LoRA decoder model.""" 69 70 def __init__(self) -> None: 71 super().__init__() 72 self.embed_tokens = nn.Embedding(100, 32) 73 self.layers = nn.ModuleList([LoraDecoder() for _ in range(4)]) 74 self.norm = nn.LayerNorm(32) 75 self.embed_tokens.weight.requires_grad_(False) 76 self.norm.weight.requires_grad_(False) 77 self.norm.bias.requires_grad_(False) 78 79 80class LoraDecoder(nn.Module): 81 def __init__(self) -> None: 82 super().__init__() 83 self.attn = LoraAttention() 84 self.mlp = LoraMLP() 85 self.inp_layernorm = nn.LayerNorm(32) 86 self.post_attn_layernorm = nn.LayerNorm(32) 87 self.inp_layernorm.weight.requires_grad_(False) 88 self.inp_layernorm.bias.requires_grad_(False) 89 self.post_attn_layernorm.weight.requires_grad_(False) 90 self.post_attn_layernorm.bias.requires_grad_(False) 91 92 93class LoraAttention(nn.Module): 94 def __init__(self) -> None: 95 super().__init__() 96 self.q_proj = nn.Linear(32, 32, bias=False) 97 self.lora_A = nn.Linear(32, 8, bias=False) 98 self.lora_B = nn.Linear(8, 32, bias=False) 99 self.k_proj = nn.Linear(32, 32, bias=False) 100 self.v_proj = nn.Linear(32, 32, bias=False) 101 self.o_proj = nn.Linear(32, 32, bias=False) 102 self.q_proj.weight.requires_grad_(False) 103 self.k_proj.weight.requires_grad_(False) 104 self.v_proj.weight.requires_grad_(False) 105 self.o_proj.weight.requires_grad_(False) 106 107 108class LoraMLP(nn.Module): 109 def __init__(self) -> None: 110 super().__init__() 111 self.proj1 = nn.Linear(32, 128, bias=False) 112 self.proj2 = nn.Linear(128, 32, bias=False) 113 self.proj1.weight.requires_grad_(False) 114 self.proj2.weight.requires_grad_(False) 115 116 117class WrapMethod(Enum): 118 FSDP_CTOR = auto() 119 # FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss 120 # any use cases and fix them to work with FSDP_CTOR over time. 121 WRAP_API = auto() 122 123 124class TestFSDPWrap(FSDPTest): 125 """ 126 Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into 127 FSDP constructor. 128 """ 129 130 def setUp(self) -> None: 131 super().setUp() 132 133 class NestedSequentialModel: 134 @staticmethod 135 def get_model(cuda=True): 136 sequential = nn.Sequential( 137 nn.Linear(5, 5), 138 nn.Linear(5, 5), 139 nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)), 140 ) 141 if cuda: 142 sequential = sequential.cuda() 143 return sequential 144 145 @staticmethod 146 def verify_model_all_wrapped(cls, model): 147 cls.assertTrue(isinstance(model, FSDP)) 148 cls.assertTrue(isinstance(model.module[0], FSDP)) 149 cls.assertTrue(isinstance(model.module[1], FSDP)) 150 cls.assertTrue(isinstance(model.module[2], FSDP)) 151 cls.assertTrue(isinstance(model.module[2].module[0], FSDP)) 152 cls.assertTrue(isinstance(model.module[2].module[1], FSDP)) 153 154 @staticmethod 155 def verify_model(cls, model): 156 cls.assertTrue(isinstance(model, FSDP)) 157 cls.assertTrue(isinstance(model.module[0], nn.Linear)) 158 cls.assertTrue(isinstance(model.module[1], nn.Linear)) 159 cls.assertTrue(isinstance(model.module[2], FSDP)) 160 # following modules were not wrapped by the policy. 161 cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) 162 cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear)) 163 164 def _get_linear(self, fin, fout): 165 return nn.Linear(fin, fout, bias=False) 166 167 def _get_already_wrapped_fsdp( 168 self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False 169 ) -> FSDP: 170 fn_self = self 171 172 class MyModel(nn.Module): 173 def __init__(self, nested): 174 super().__init__() 175 # TODO: test the various init modes. 176 move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE 177 # if nested=True, the FSDP module will be nested one layer deep 178 # and we should pick that up. 179 if nested: 180 self.lin1 = nn.Sequential( 181 _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda), 182 FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)), 183 ) 184 else: 185 self.lin1 = FSDP( 186 _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda) 187 ) 188 self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) 189 self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) 190 191 def forward(self, input: torch.Tensor) -> torch.Tensor: 192 return self.lin3(self.lin2(self.lin1(input))) 193 194 model = MyModel(nested=nested) 195 return model 196 197 @skip_if_lt_x_gpu(2) 198 @parametrize("nested", [True, False]) 199 @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) 200 def test_error_already_wrapped(self, nested, cuda_init_mode): 201 """ 202 Test that an error is raised if we attempt to wrap when submodules are 203 already FSDP. 204 """ 205 wrapped_fsdp = self._get_already_wrapped_fsdp( 206 nested=nested, cuda_init_mode=cuda_init_mode 207 ) 208 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 209 wrapped_fsdp = wrapped_fsdp.cuda() 210 211 wrapped_module_name = "lin1.1" if nested else "lin1" 212 with self.assertRaisesRegex( 213 ValueError, 214 "FSDP auto wrapping requires modules to not already have FSDP " 215 f"applied but found {wrapped_module_name} in", 216 ): 217 FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy) 218 219 @skip_if_lt_x_gpu(2) 220 @parametrize("use_or_policy", [True, False]) 221 def test_wrap_batchnorm_individually(self, use_or_policy): 222 def never_wrap_policy(*args, **kwargs): 223 return False 224 225 wrap_batchnorm_individually = functools.partial( 226 _wrap_module_cls_individually, 227 module_classes=[ 228 _BatchNorm, 229 ], 230 ) 231 policy = ( 232 functools.partial( 233 _or_policy, policies=[never_wrap_policy, wrap_batchnorm_individually] 234 ) 235 if use_or_policy 236 else wrap_batchnorm_individually 237 ) 238 model = BatchNormNet() 239 fsdp = FSDP(model, auto_wrap_policy=policy) 240 # Batchnorms should be wrapped 241 for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]: 242 self.assertTrue(isinstance(layer, FSDP)) 243 244 self.assertFalse(isinstance(fsdp.lin, FSDP)) 245 246 @skip_if_lt_x_gpu(2) 247 def test_bn_always_wrapped_individually(self): 248 """ 249 Ensures that by using _or_policy with _wrap_module_cls_individually, even 250 if the other policy results in a module containing a BN unit being 251 wrapped, the contained BN unit will still be individually wrapped. 252 """ 253 254 class MyModule(nn.Module): 255 def __init__(self) -> None: 256 super().__init__() 257 self.bn_container = BatchNormNet() 258 259 def wrap_bn_container(module, recurse, *args, **kwargs): 260 if recurse: 261 return True 262 return isinstance(module, BatchNormNet) 263 264 wrap_batchnorm_individually = functools.partial( 265 _wrap_module_cls_individually, 266 module_classes=[ 267 _BatchNorm, 268 ], 269 ) 270 271 my_policy = functools.partial( 272 _or_policy, policies=[wrap_bn_container, wrap_batchnorm_individually] 273 ) 274 mod = MyModule() 275 fsdp = FSDP(mod, auto_wrap_policy=my_policy) 276 277 # Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN)))) 278 # and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner 279 # BN is not individually wrapped.) 280 281 for bn in [ 282 fsdp.bn_container.bn1, 283 fsdp.bn_container.bn2, 284 fsdp.bn_container.bn3, 285 fsdp.bn_container.sync_bn, 286 ]: 287 self.assertTrue(isinstance(bn, FSDP)) 288 289 # if we just wrapped BN container, individual batchnorms are not 290 # wrapped. 291 mod = MyModule() 292 fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container) 293 self.assertTrue(isinstance(mod.bn_container, FSDP)) 294 for bn in [ 295 fsdp.bn_container.bn1, 296 fsdp.bn_container.bn2, 297 fsdp.bn_container.bn3, 298 fsdp.bn_container.sync_bn, 299 ]: 300 self.assertFalse(isinstance(bn, FSDP)) 301 302 @skip_if_lt_x_gpu(2) 303 @parametrize( 304 "cpu_offload", 305 [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], 306 ) 307 @parametrize( 308 "backward_prefetch", 309 [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE], 310 ) 311 @parametrize("forward_prefetch", [False, True]) 312 @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) 313 def test_main_wrap_api( 314 self, 315 cpu_offload: CPUOffload, 316 backward_prefetch: BackwardPrefetch, 317 forward_prefetch: bool, 318 cuda_init_mode: CUDAInitMode, 319 ): 320 if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: 321 # they don't work together, expected 322 return 323 324 move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE 325 326 class Nested(nn.Module): 327 def __init__(self) -> None: 328 super().__init__() 329 self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) 330 331 def forward(self, input): 332 return self.nested_lin(input) 333 334 class MyModel(nn.Module): 335 def __init__(self) -> None: 336 super().__init__() 337 self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) 338 self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) 339 self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) 340 self.lin4 = Nested() 341 342 def forward(self, input): 343 return self.lin4(self.lin3(self.lin2(self.lin1(input)))) 344 345 model = MyModel() 346 wrapped_model = FSDP( 347 model, 348 auto_wrap_policy=functools.partial( 349 size_based_auto_wrap_policy, 350 min_num_params=0, # wrap all modules 351 ), 352 cpu_offload=cpu_offload, 353 backward_prefetch=backward_prefetch, 354 forward_prefetch=forward_prefetch, 355 ) 356 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 357 wrapped_model = wrapped_model.cuda() 358 359 modules_in_fsdp_graph_order = [ 360 wrapped_model.module.lin1, 361 wrapped_model.module.lin2, 362 wrapped_model.module.lin3, 363 wrapped_model.module.lin4.module.nested_lin, 364 wrapped_model.module.lin4, 365 wrapped_model, 366 ] 367 368 for module in modules_in_fsdp_graph_order: 369 self.assertTrue(isinstance(module, FSDP)) 370 self._check_cpu_offload(module, cpu_offload) 371 self._check_backward_prefetch(module, backward_prefetch) 372 self._check_forward_prefetch(module, forward_prefetch) 373 374 # Run model a few times for sanity check. 375 optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) 376 inp = torch.ones(1).cuda() 377 for _ in range(6): 378 optim.zero_grad() 379 loss = wrapped_model(inp).sum() 380 loss.backward() 381 optim.step() 382 383 384class TestAutoWrap(TestCase): 385 def setUp(self) -> None: 386 super().setUp() 387 # For all the tests here, we use a fake group 388 self.process_group = DummyProcessGroup(rank=0, size=1) 389 390 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 391 @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) 392 def test_wrap(self, wrap_method): 393 if wrap_method == WrapMethod.WRAP_API: 394 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): 395 layer = wrap(nn.Linear(5, 5)) 396 else: 397 assert wrap_method == WrapMethod.FSDP_CTOR 398 layer = FSDP( 399 nn.Linear(5, 5), 400 process_group=self.process_group, 401 auto_wrap_policy=functools.partial( 402 size_based_auto_wrap_policy, min_num_params=1 403 ), 404 ) 405 self.assertTrue(isinstance(layer, FSDP)) 406 self.assertEqual(layer.rank, self.process_group.rank()) 407 self.assertEqual(layer.world_size, self.process_group.size()) 408 409 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 410 def test_wrap_disabled_outside_context(self): 411 pg = self.process_group 412 413 class MyModel(nn.Module): 414 def __init__(self) -> None: 415 super().__init__() 416 self.lin = wrap(nn.Linear(5, 5), process_group=pg) 417 418 model = MyModel() 419 with enable_wrap(wrapper_cls=FSDP, process_group=pg): 420 model = wrap(model) 421 422 self.assertTrue(isinstance(model, FSDP)) 423 self.assertFalse(isinstance(model.lin, FSDP)) 424 self.assertTrue(isinstance(model.lin, nn.Linear)) 425 426 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 427 def test_wrap_override_defaults(self): 428 new_process_group = DummyProcessGroup(rank=0, size=2) 429 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): 430 layer = wrap(nn.Linear(5, 5), process_group=new_process_group) 431 self.assertTrue(isinstance(layer, FSDP)) 432 self.assertTrue(layer.process_group is new_process_group) 433 self.assertEqual(layer.rank, 0) 434 self.assertEqual(layer.world_size, 2) 435 436 @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA") 437 def test_always_wrap(self): 438 """ 439 Test to ensure that if `always_wrap_policy` is 440 passed into FSDP, all submodules are wrapped. 441 """ 442 seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True) 443 model = FSDP( 444 seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy 445 ) 446 TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model) 447 448 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 449 def test_transformer_auto_wrap_policy(self): 450 """Tests the ``transformer_auto_wrap_policy``.""" 451 auto_wrap_policy = functools.partial( 452 transformer_auto_wrap_policy, 453 transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, 454 ) 455 self._test_transformer_wrapping(auto_wrap_policy) 456 457 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 458 def test_module_wrap_policy(self): 459 """Tests the ``ModuleWrapPolicy``.""" 460 auto_wrap_policy = ModuleWrapPolicy( 461 {TransformerEncoderLayer, TransformerDecoderLayer} 462 ) 463 self._test_transformer_wrapping(auto_wrap_policy) 464 465 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 466 def test_module_wrap_policy_callable(self): 467 """Tests the ``ModuleWrapPolicy`` as a ``Callable``.""" 468 auto_wrap_policy = ModuleWrapPolicy( 469 {TransformerEncoderLayer, TransformerDecoderLayer} 470 ) 471 callable_policy = functools.partial(_or_policy, policies=[auto_wrap_policy]) 472 self._test_transformer_wrapping(callable_policy) 473 474 def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy]): 475 fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} 476 fsdp_model = TransformerWithSharedParams.init( 477 self.process_group, 478 FSDPInitMode.RECURSIVE, 479 CUDAInitMode.CUDA_BEFORE, 480 fsdp_kwargs, 481 ) 482 modules = list(fsdp_model.modules()) 483 encoder_layers = set(fsdp_model.module.transformer.encoder.layers) 484 decoder_layers = set(fsdp_model.module.transformer.decoder.layers) 485 for module in modules: 486 if ( 487 module is fsdp_model 488 or module in encoder_layers 489 or module in decoder_layers 490 ): 491 self.assertTrue(isinstance(module, FSDP)) 492 else: 493 self.assertFalse(isinstance(module, FSDP)) 494 495 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 496 def test_custom_policy(self): 497 """ 498 Tests ``CustomPolicy`` with both a lambda function that uses uniform 499 kwargs (so only returns ``False`` or ``True``) and a lambda function 500 that uses non-uniform kwargs (so returns a dict to override the root 501 kwargs). 502 """ 503 for use_uniform_kwargs in [False, True]: 504 self._test_custom_policy(use_uniform_kwargs) 505 506 def _test_custom_policy(self, use_uniform_kwargs: bool): 507 print(f"use_uniform_kwargs={use_uniform_kwargs}") 508 model = TransformerWithSharedParams.init( 509 self.process_group, 510 FSDPInitMode.NO_FSDP, 511 CUDAInitMode.CUDA_BEFORE, 512 {}, 513 ) 514 515 if use_uniform_kwargs: 516 517 def lambda_fn(module: nn.Module): 518 if module is model.bn: 519 return True 520 elif isinstance( 521 module, (TransformerEncoderLayer, TransformerDecoderLayer) 522 ): 523 return True 524 return False 525 526 else: 527 528 def lambda_fn(module: nn.Module): 529 if module is model.bn: 530 return {"sharding_strategy": ShardingStrategy.NO_SHARD} 531 elif isinstance(module, TransformerEncoderLayer): 532 return True 533 elif isinstance(module, TransformerDecoderLayer): 534 return { 535 "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, 536 "backward_prefetch": BackwardPrefetch.BACKWARD_POST, 537 } 538 return False 539 540 policy = CustomPolicy(lambda_fn) 541 # Use a size-2 dummy PG to avoid clamping the sharding strategy to 542 # `NO_SHARD` as for a size-1 PG 543 process_group = DummyProcessGroup(rank=0, size=2) 544 fp16_mp = MixedPrecision(param_dtype=torch.float16) 545 fp32_mp = MixedPrecision() 546 model = FSDP( 547 model, 548 process_group=process_group, 549 auto_wrap_policy=policy, 550 mixed_precision=fp16_mp, 551 ) 552 encoder_layers = set(model.module.transformer.encoder.layers) 553 decoder_layers = set(model.module.transformer.decoder.layers) 554 bn = model.module.bn 555 bn_strategy = ( 556 ShardingStrategy.FULL_SHARD 557 if use_uniform_kwargs 558 else ShardingStrategy.NO_SHARD 559 ) 560 bn_prefetch = BackwardPrefetch.BACKWARD_PRE 561 encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD 562 encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE 563 decoder_strategy = ( 564 ShardingStrategy.FULL_SHARD 565 if use_uniform_kwargs 566 else ShardingStrategy.SHARD_GRAD_OP 567 ) 568 decoder_prefetch = ( 569 BackwardPrefetch.BACKWARD_PRE 570 if use_uniform_kwargs 571 else BackwardPrefetch.BACKWARD_POST 572 ) 573 for module in model.modules(): 574 if module is bn: 575 self.assertTrue(isinstance(module, FSDP)) 576 self.assertEqual(module.sharding_strategy, bn_strategy) 577 self.assertEqual(module.backward_prefetch, bn_prefetch) 578 # We currently override batch norm modules to use fp32 579 self.assertEqual(module.mixed_precision, fp32_mp) 580 elif module in encoder_layers: 581 self.assertTrue(isinstance(module, FSDP)) 582 self.assertEqual(module.sharding_strategy, encoder_strategy) 583 self.assertEqual(module.backward_prefetch, encoder_prefetch) 584 self.assertEqual(module.mixed_precision, fp16_mp) 585 elif module in decoder_layers: 586 self.assertTrue(isinstance(module, FSDP)) 587 self.assertEqual(module.sharding_strategy, decoder_strategy) 588 self.assertEqual(module.backward_prefetch, decoder_prefetch) 589 self.assertEqual(module.mixed_precision, fp16_mp) 590 elif module is model: 591 self.assertTrue(isinstance(module, FSDP)) 592 self.assertEqual(module.sharding_strategy, root_strategy) 593 self.assertEqual(module.backward_prefetch, root_prefetch) 594 self.assertEqual(module.mixed_precision, fp16_mp) 595 else: 596 self.assertFalse(isinstance(module, FSDP)) 597 598 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 599 def test_auto_wrap_api(self): 600 """ 601 Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. 602 ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. 603 """ 604 sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) 605 my_auto_wrap_policy = functools.partial( 606 size_based_auto_wrap_policy, min_num_params=40 607 ) 608 model = FSDP( 609 sequential, 610 process_group=self.process_group, 611 auto_wrap_policy=my_auto_wrap_policy, 612 ) 613 614 TestFSDPWrap.NestedSequentialModel.verify_model(self, model) 615 616 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 617 def test_auto_wrap_preset_exclude_wrap(self): 618 """ 619 Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the 620 min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} 621 """ 622 sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) 623 my_auto_wrap_policy = functools.partial( 624 size_based_auto_wrap_policy, min_num_params=40 625 ) 626 627 model = FSDP( 628 sequential, 629 process_group=self.process_group, 630 auto_wrap_policy=my_auto_wrap_policy, 631 ) 632 633 self.assertTrue(isinstance(model, FSDP)) 634 self.assertTrue(isinstance(model[0], nn.Linear)) 635 self.assertTrue(isinstance(model[1], nn.Linear)) 636 637 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 638 def test_auto_wrap_preset_exclude_wrap_include_children(self): 639 """ 640 Test to ensure excluded modules are not wrapped, but children are if param size is greater than 641 min_num_params 642 """ 643 sequential = nn.ModuleList([nn.Linear(10, 10)]) 644 my_auto_wrap_policy = functools.partial( 645 size_based_auto_wrap_policy, min_num_params=40 646 ) 647 model = FSDP( 648 sequential, 649 process_group=self.process_group, 650 auto_wrap_policy=my_auto_wrap_policy, 651 ) 652 653 self.assertTrue(isinstance(model, FSDP)) 654 self.assertTrue(isinstance(model[0], FSDP)) 655 656 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 657 def test_auto_wrap_preset_force_leaf(self): 658 """ 659 Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The 660 size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped 661 """ 662 sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) 663 my_auto_wrap_policy = functools.partial( 664 size_based_auto_wrap_policy, min_num_params=40 665 ) 666 model = FSDP( 667 sequential, 668 process_group=self.process_group, 669 auto_wrap_policy=my_auto_wrap_policy, 670 ) 671 self.assertTrue(isinstance(model.module[0], FSDP)) 672 # Assert children of multihead attention are not wrapped 673 self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) 674 self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) 675 676 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 677 def test_auto_wrap_preset_force_leaf_custom(self): 678 """ 679 Test to ensure force-leaf modules are not wrapped. 680 """ 681 my_auto_wrap_policy = functools.partial( 682 size_based_auto_wrap_policy, 683 min_num_params=40, 684 force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.union( 685 {nn.Linear} 686 ), 687 ) 688 sequential = nn.Sequential( 689 nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]) 690 ) 691 model = FSDP( 692 sequential, 693 process_group=self.process_group, 694 auto_wrap_policy=my_auto_wrap_policy, 695 ) 696 # Model was wrapped in FSDP as no inner modules were wrapped. 697 self.assertTrue(isinstance(model, FSDP)) 698 self.assertTrue(isinstance(model.module[0], nn.Linear)) 699 self.assertTrue(isinstance(model.module[1], nn.ModuleList)) 700 701 @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA") 702 @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER]) 703 @parametrize( 704 "cpu_offload", 705 [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], 706 ) 707 @parametrize("use_device_id", [True, False]) 708 def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): 709 # CPU offload and CUDA after don't work together as expected. 710 if cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER: 711 return 712 713 device = torch.device("cuda") 714 torch.cuda.set_device(0) 715 device_id = ( 716 torch.device("cuda", torch.cuda.current_device()) if use_device_id else None 717 ) 718 719 # Random port in case the next test run quickly, same port would cause conflict. 720 os.environ["MASTER_ADDR"] = "localhost" 721 os.environ["MASTER_PORT"] = str(find_free_port()) 722 723 file_name = tempfile.NamedTemporaryFile(delete=False).name 724 torch.distributed.init_process_group( 725 backend="nccl", 726 init_method=f"{FILE_SCHEMA}_{file_name}", 727 rank=0, 728 world_size=1, 729 ) 730 731 # NOTE: We move model to CUDA after init with FSDP to simulate real use 732 # cases where full model cannot be loaded onto GPU, but their shards can. 733 cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER 734 try: 735 sequential = TestFSDPWrap.NestedSequentialModel.get_model( 736 cuda=(not cuda_after_init) 737 ) 738 my_auto_wrap_policy = functools.partial( 739 size_based_auto_wrap_policy, min_num_params=40 740 ) 741 model = FSDP( 742 sequential, 743 cpu_offload=cpu_offload, 744 auto_wrap_policy=my_auto_wrap_policy, 745 device_id=device_id, 746 ) 747 TestFSDPWrap.NestedSequentialModel.verify_model(self, model) 748 if cuda_after_init: 749 model = model.cuda() 750 input = torch.rand((1, 5), dtype=torch.float).to(device) 751 output = model(input) 752 loss = F.mse_loss(input, output) 753 loss.backward() 754 finally: 755 torch.distributed.destroy_process_group() 756 757 try: 758 os.remove(file_name) 759 except FileNotFoundError: 760 pass 761 762 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 763 @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) 764 def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod): 765 sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) 766 ignored_modules = [sequential[1], sequential[2][0]] 767 fsdp_kwargs = { 768 "process_group": self.process_group, 769 "auto_wrap_policy": always_wrap_policy, 770 "ignored_modules": ignored_modules, 771 } 772 if wrap_method == WrapMethod.FSDP_CTOR: 773 model = FSDP(sequential, **fsdp_kwargs) 774 elif wrap_method == WrapMethod.WRAP_API: 775 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): 776 model = wrap(sequential) 777 else: 778 assert 0, f"Unsupported wrap method: {wrap_method}" 779 # All non-ignored modules should be wrapped with FSDP 780 self.assertTrue(isinstance(model, FSDP)) 781 self.assertTrue(isinstance(model.module[0], FSDP)) 782 self.assertTrue(isinstance(model.module[1], nn.Linear)) 783 self.assertTrue(isinstance(model.module[2], FSDP)) 784 self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) 785 self.assertTrue(isinstance(model.module[2].module[1], FSDP)) 786 787 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 788 @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) 789 def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod): 790 sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) 791 ignored_modules = [sequential[1], sequential[2][0]] 792 my_auto_wrap_policy = functools.partial( 793 size_based_auto_wrap_policy, 794 min_num_params=40, 795 ) 796 fsdp_kwargs = { 797 "process_group": self.process_group, 798 "auto_wrap_policy": my_auto_wrap_policy, 799 "ignored_modules": ignored_modules, 800 } 801 if wrap_method == WrapMethod.FSDP_CTOR: 802 model = FSDP(sequential, **fsdp_kwargs) 803 elif wrap_method == WrapMethod.WRAP_API: 804 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): 805 model = wrap(sequential) 806 else: 807 assert 0, f"Unsupported wrap method: {wrap_method}" 808 # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping 809 # policy does not exceed the parameter threshold before the inner 810 # sequential (`sequential[2]`) anymore; hence, it flattens 811 # `sequential[0]` and `sequential[2][0]` into `model` and leaves 812 # `sequential[1]` and `sequential[2][1]` as-is since they are ignored 813 self.assertTrue(isinstance(model, FSDP)) 814 self.assertTrue(isinstance(model.module[0], nn.Linear)) 815 self.assertTrue(isinstance(model.module[1], nn.Linear)) 816 self.assertTrue(isinstance(model.module[2], nn.Sequential)) 817 self.assertTrue(isinstance(model.module[2][0], nn.Linear)) 818 self.assertTrue(isinstance(model.module[2][1], nn.Linear)) 819 820 @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") 821 def test_frozen_params(self): 822 """ 823 Tests that mixing frozen/non-frozen parameters in an FSDP instance 824 raises for ``use_orig_params=False`` and warns for ``True``. 825 """ 826 module_classes = (LoraAttention, LoraMLP, LoraDecoder) 827 module_wrap_policy = ModuleWrapPolicy(module_classes) 828 829 def lambda_fn_uniform(module: nn.Module): 830 return isinstance(module, module_classes) 831 832 def lambda_fn_nonuniform(module: nn.Module): 833 if isinstance(module, LoraAttention): 834 return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} 835 elif isinstance(module, module_classes): 836 return True 837 return False 838 839 lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform) 840 lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform) 841 842 for use_orig_params, policy in itertools.product( 843 [True, False], 844 [ 845 module_wrap_policy, 846 lambda_wrap_policy_uniform, 847 lambda_wrap_policy_nonuniform, 848 ], 849 ): 850 self._test_frozen_params(use_orig_params, policy) 851 852 def _test_frozen_params(self, use_orig_params: bool, policy: _Policy): 853 model = LoraModel().cuda() 854 msg = "layers.0.attn has both parameters with requires_grad=True and False. " 855 if use_orig_params: 856 msg += "We do not recommend wrapping such modules" 857 ctx = self.assertWarnsRegex(UserWarning, msg) 858 else: 859 msg += "FSDP does not support wrapping such modules when use_orig_params=False." 860 ctx = self.assertRaisesRegex(ValueError, msg) 861 with ctx: 862 FSDP( 863 model, 864 process_group=self.process_group, 865 auto_wrap_policy=policy, 866 use_orig_params=use_orig_params, 867 ) 868 869 870class TestWrapUtils(TestCase): 871 def test_validate_frozen_params(self): 872 """Tests the method ``_validate_frozen_params()``.""" 873 for use_orig_params in [True, False]: 874 self._test_validate_frozen_params(use_orig_params) 875 876 def _test_validate_frozen_params(self, use_orig_params: bool): 877 model = LoraModel() 878 # Wrap only LoRA modules 879 modules_to_wrap = { 880 module 881 for module_name, module in model.named_modules() 882 if "lora_A" in module_name or "lora_B" in module_name 883 } 884 _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params) 885 # Additionally wrap attention 886 for module in model.modules(): 887 if isinstance(module, LoraAttention): 888 modules_to_wrap.add(module) 889 _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params) 890 # Additionally wrap decoders 891 for module in model.modules(): 892 if isinstance(module, LoraDecoder): 893 modules_to_wrap.add(module) 894 _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params) 895 # Do not wrap the LoRA-A modules (meaning mixed frozen/non-frozen) 896 for module_name, module in model.named_modules(): 897 if "lora_A" in module_name: 898 modules_to_wrap.remove(module) 899 regex = "layers.0.attn has both parameters with requires_grad=True and False." 900 if use_orig_params: 901 # Wrapping the attention manages all parameters except those from 902 # the LoRA-B module, which is separately wrapped and all nonfrozen 903 lorab_numel = sum( 904 p.numel() for p in model.layers[0].attn.lora_B.parameters() 905 ) 906 attn_frozen_param_numel = sum( 907 p.numel() 908 for p in model.layers[0].attn.parameters() 909 if not p.requires_grad 910 ) 911 attn_nonfrozen_param_numel = ( 912 sum( 913 p.numel() 914 for p in model.layers[0].attn.parameters() 915 if p.requires_grad 916 ) 917 - lorab_numel 918 ) 919 attn_total_param_numel = ( 920 attn_frozen_param_numel + attn_nonfrozen_param_numel 921 ) 922 regex += ( 923 " We do not recommend wrapping such modules since the " 924 r"gradient memory usage will be higher than expected \(" 925 f"{attn_total_param_numel} numel instead of {attn_nonfrozen_param_numel} numel " 926 r"before sharding via reduce-scatter\). " 927 ) 928 else: 929 regex += " FSDP does not support wrapping such modules when use_orig_params=False. " 930 regex += "If possible, wrap the frozen parameters with FSDP separately.\n" 931 regex += ( 932 "The following parameters have requires_grad=True:\n" 933 r"\['layers.0.attn.lora_A.weight'\]\n" 934 "The following parameters have requires_grad=False:\n" 935 r"\['layers.0.attn.q_proj.weight', 'layers.0.attn.k_proj.weight', " 936 r"'layers.0.attn.v_proj.weight', 'layers.0.attn.o_proj.weight'\]" 937 ) 938 if use_orig_params: 939 ctx = self.assertWarnsRegex(UserWarning, regex) 940 else: 941 ctx = self.assertRaisesRegex(ValueError, regex) 942 with ctx: 943 _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params) 944 # Now ignore those LoRA-A modules' parameters 945 ignored_params = set() 946 for module_name, module in model.named_modules(): 947 if "lora_A" in module_name: 948 ignored_params.update(module.parameters()) 949 _validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params) 950 951 952instantiate_parametrized_tests(TestFSDPWrap) 953instantiate_parametrized_tests(TestAutoWrap) 954 955if __name__ == "__main__": 956 run_tests() 957