xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_wrap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import functools
4import itertools
5import os
6import tempfile
7import unittest
8from enum import auto, Enum
9from typing import Callable, Union
10
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14from torch.distributed.fsdp._wrap_utils import _validate_frozen_params
15from torch.distributed.fsdp.fully_sharded_data_parallel import (
16    BackwardPrefetch,
17    CPUOffload,
18    FullyShardedDataParallel as FSDP,
19    MixedPrecision,
20    ShardingStrategy,
21)
22from torch.distributed.fsdp.wrap import (
23    _or_policy,
24    _Policy,
25    _wrap_module_cls_individually,
26    always_wrap_policy,
27    CustomPolicy,
28    enable_wrap,
29    ModuleWrapPolicy,
30    size_based_auto_wrap_policy,
31    transformer_auto_wrap_policy,
32    wrap,
33)
34from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
35from torch.nn.modules.batchnorm import _BatchNorm
36from torch.testing._internal.common_cuda import TEST_MULTIGPU
37from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
38from torch.testing._internal.common_fsdp import (
39    _maybe_cuda,
40    CUDAInitMode,
41    DummyProcessGroup,
42    FSDPInitMode,
43    FSDPTest,
44    TransformerWithSharedParams,
45)
46from torch.testing._internal.common_utils import (
47    FILE_SCHEMA,
48    find_free_port,
49    instantiate_parametrized_tests,
50    parametrize,
51    run_tests,
52    TEST_CUDA,
53    TestCase,
54)
55
56
57class BatchNormNet(nn.Module):
58    def __init__(self) -> None:
59        super().__init__()
60        self.lin = nn.Linear(10, 10, bias=False)
61        self.bn1 = nn.BatchNorm1d(10)
62        self.bn2 = nn.BatchNorm2d(10)
63        self.bn3 = nn.BatchNorm3d(10)
64        self.sync_bn = nn.SyncBatchNorm(10)
65
66
67class LoraModel(nn.Module):
68    """This is a toy LoRA decoder model."""
69
70    def __init__(self) -> None:
71        super().__init__()
72        self.embed_tokens = nn.Embedding(100, 32)
73        self.layers = nn.ModuleList([LoraDecoder() for _ in range(4)])
74        self.norm = nn.LayerNorm(32)
75        self.embed_tokens.weight.requires_grad_(False)
76        self.norm.weight.requires_grad_(False)
77        self.norm.bias.requires_grad_(False)
78
79
80class LoraDecoder(nn.Module):
81    def __init__(self) -> None:
82        super().__init__()
83        self.attn = LoraAttention()
84        self.mlp = LoraMLP()
85        self.inp_layernorm = nn.LayerNorm(32)
86        self.post_attn_layernorm = nn.LayerNorm(32)
87        self.inp_layernorm.weight.requires_grad_(False)
88        self.inp_layernorm.bias.requires_grad_(False)
89        self.post_attn_layernorm.weight.requires_grad_(False)
90        self.post_attn_layernorm.bias.requires_grad_(False)
91
92
93class LoraAttention(nn.Module):
94    def __init__(self) -> None:
95        super().__init__()
96        self.q_proj = nn.Linear(32, 32, bias=False)
97        self.lora_A = nn.Linear(32, 8, bias=False)
98        self.lora_B = nn.Linear(8, 32, bias=False)
99        self.k_proj = nn.Linear(32, 32, bias=False)
100        self.v_proj = nn.Linear(32, 32, bias=False)
101        self.o_proj = nn.Linear(32, 32, bias=False)
102        self.q_proj.weight.requires_grad_(False)
103        self.k_proj.weight.requires_grad_(False)
104        self.v_proj.weight.requires_grad_(False)
105        self.o_proj.weight.requires_grad_(False)
106
107
108class LoraMLP(nn.Module):
109    def __init__(self) -> None:
110        super().__init__()
111        self.proj1 = nn.Linear(32, 128, bias=False)
112        self.proj2 = nn.Linear(128, 32, bias=False)
113        self.proj1.weight.requires_grad_(False)
114        self.proj2.weight.requires_grad_(False)
115
116
117class WrapMethod(Enum):
118    FSDP_CTOR = auto()
119    # FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss
120    # any use cases and fix them to work with FSDP_CTOR over time.
121    WRAP_API = auto()
122
123
124class TestFSDPWrap(FSDPTest):
125    """
126    Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into
127    FSDP constructor.
128    """
129
130    def setUp(self) -> None:
131        super().setUp()
132
133    class NestedSequentialModel:
134        @staticmethod
135        def get_model(cuda=True):
136            sequential = nn.Sequential(
137                nn.Linear(5, 5),
138                nn.Linear(5, 5),
139                nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),
140            )
141            if cuda:
142                sequential = sequential.cuda()
143            return sequential
144
145        @staticmethod
146        def verify_model_all_wrapped(cls, model):
147            cls.assertTrue(isinstance(model, FSDP))
148            cls.assertTrue(isinstance(model.module[0], FSDP))
149            cls.assertTrue(isinstance(model.module[1], FSDP))
150            cls.assertTrue(isinstance(model.module[2], FSDP))
151            cls.assertTrue(isinstance(model.module[2].module[0], FSDP))
152            cls.assertTrue(isinstance(model.module[2].module[1], FSDP))
153
154        @staticmethod
155        def verify_model(cls, model):
156            cls.assertTrue(isinstance(model, FSDP))
157            cls.assertTrue(isinstance(model.module[0], nn.Linear))
158            cls.assertTrue(isinstance(model.module[1], nn.Linear))
159            cls.assertTrue(isinstance(model.module[2], FSDP))
160            # following modules were not wrapped by the policy.
161            cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
162            cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear))
163
164    def _get_linear(self, fin, fout):
165        return nn.Linear(fin, fout, bias=False)
166
167    def _get_already_wrapped_fsdp(
168        self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False
169    ) -> FSDP:
170        fn_self = self
171
172        class MyModel(nn.Module):
173            def __init__(self, nested):
174                super().__init__()
175                # TODO: test the various init modes.
176                move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
177                # if nested=True, the FSDP module will be nested one layer deep
178                # and we should pick that up.
179                if nested:
180                    self.lin1 = nn.Sequential(
181                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),
182                        FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)),
183                    )
184                else:
185                    self.lin1 = FSDP(
186                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)
187                    )
188                self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
189                self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
190
191            def forward(self, input: torch.Tensor) -> torch.Tensor:
192                return self.lin3(self.lin2(self.lin1(input)))
193
194        model = MyModel(nested=nested)
195        return model
196
197    @skip_if_lt_x_gpu(2)
198    @parametrize("nested", [True, False])
199    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])
200    def test_error_already_wrapped(self, nested, cuda_init_mode):
201        """
202        Test that an error is raised if we attempt to wrap when submodules are
203        already FSDP.
204        """
205        wrapped_fsdp = self._get_already_wrapped_fsdp(
206            nested=nested, cuda_init_mode=cuda_init_mode
207        )
208        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
209            wrapped_fsdp = wrapped_fsdp.cuda()
210
211        wrapped_module_name = "lin1.1" if nested else "lin1"
212        with self.assertRaisesRegex(
213            ValueError,
214            "FSDP auto wrapping requires modules to not already have FSDP "
215            f"applied but found {wrapped_module_name} in",
216        ):
217            FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)
218
219    @skip_if_lt_x_gpu(2)
220    @parametrize("use_or_policy", [True, False])
221    def test_wrap_batchnorm_individually(self, use_or_policy):
222        def never_wrap_policy(*args, **kwargs):
223            return False
224
225        wrap_batchnorm_individually = functools.partial(
226            _wrap_module_cls_individually,
227            module_classes=[
228                _BatchNorm,
229            ],
230        )
231        policy = (
232            functools.partial(
233                _or_policy, policies=[never_wrap_policy, wrap_batchnorm_individually]
234            )
235            if use_or_policy
236            else wrap_batchnorm_individually
237        )
238        model = BatchNormNet()
239        fsdp = FSDP(model, auto_wrap_policy=policy)
240        # Batchnorms should be wrapped
241        for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]:
242            self.assertTrue(isinstance(layer, FSDP))
243
244        self.assertFalse(isinstance(fsdp.lin, FSDP))
245
246    @skip_if_lt_x_gpu(2)
247    def test_bn_always_wrapped_individually(self):
248        """
249        Ensures that by using _or_policy with _wrap_module_cls_individually, even
250        if the other policy results in a module containing a BN unit being
251        wrapped, the contained BN unit will still be individually wrapped.
252        """
253
254        class MyModule(nn.Module):
255            def __init__(self) -> None:
256                super().__init__()
257                self.bn_container = BatchNormNet()
258
259        def wrap_bn_container(module, recurse, *args, **kwargs):
260            if recurse:
261                return True
262            return isinstance(module, BatchNormNet)
263
264        wrap_batchnorm_individually = functools.partial(
265            _wrap_module_cls_individually,
266            module_classes=[
267                _BatchNorm,
268            ],
269        )
270
271        my_policy = functools.partial(
272            _or_policy, policies=[wrap_bn_container, wrap_batchnorm_individually]
273        )
274        mod = MyModule()
275        fsdp = FSDP(mod, auto_wrap_policy=my_policy)
276
277        # Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN))))
278        # and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner
279        # BN is not individually wrapped.)
280
281        for bn in [
282            fsdp.bn_container.bn1,
283            fsdp.bn_container.bn2,
284            fsdp.bn_container.bn3,
285            fsdp.bn_container.sync_bn,
286        ]:
287            self.assertTrue(isinstance(bn, FSDP))
288
289        # if we just wrapped BN container, individual batchnorms are not
290        # wrapped.
291        mod = MyModule()
292        fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container)
293        self.assertTrue(isinstance(mod.bn_container, FSDP))
294        for bn in [
295            fsdp.bn_container.bn1,
296            fsdp.bn_container.bn2,
297            fsdp.bn_container.bn3,
298            fsdp.bn_container.sync_bn,
299        ]:
300            self.assertFalse(isinstance(bn, FSDP))
301
302    @skip_if_lt_x_gpu(2)
303    @parametrize(
304        "cpu_offload",
305        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
306    )
307    @parametrize(
308        "backward_prefetch",
309        [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE],
310    )
311    @parametrize("forward_prefetch", [False, True])
312    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])
313    def test_main_wrap_api(
314        self,
315        cpu_offload: CPUOffload,
316        backward_prefetch: BackwardPrefetch,
317        forward_prefetch: bool,
318        cuda_init_mode: CUDAInitMode,
319    ):
320        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
321            # they don't work together, expected
322            return
323
324        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
325
326        class Nested(nn.Module):
327            def __init__(self) -> None:
328                super().__init__()
329                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
330
331            def forward(self, input):
332                return self.nested_lin(input)
333
334        class MyModel(nn.Module):
335            def __init__(self) -> None:
336                super().__init__()
337                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
338                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
339                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
340                self.lin4 = Nested()
341
342            def forward(self, input):
343                return self.lin4(self.lin3(self.lin2(self.lin1(input))))
344
345        model = MyModel()
346        wrapped_model = FSDP(
347            model,
348            auto_wrap_policy=functools.partial(
349                size_based_auto_wrap_policy,
350                min_num_params=0,  # wrap all modules
351            ),
352            cpu_offload=cpu_offload,
353            backward_prefetch=backward_prefetch,
354            forward_prefetch=forward_prefetch,
355        )
356        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
357            wrapped_model = wrapped_model.cuda()
358
359        modules_in_fsdp_graph_order = [
360            wrapped_model.module.lin1,
361            wrapped_model.module.lin2,
362            wrapped_model.module.lin3,
363            wrapped_model.module.lin4.module.nested_lin,
364            wrapped_model.module.lin4,
365            wrapped_model,
366        ]
367
368        for module in modules_in_fsdp_graph_order:
369            self.assertTrue(isinstance(module, FSDP))
370            self._check_cpu_offload(module, cpu_offload)
371            self._check_backward_prefetch(module, backward_prefetch)
372            self._check_forward_prefetch(module, forward_prefetch)
373
374        # Run model a few times for sanity check.
375        optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
376        inp = torch.ones(1).cuda()
377        for _ in range(6):
378            optim.zero_grad()
379            loss = wrapped_model(inp).sum()
380            loss.backward()
381            optim.step()
382
383
384class TestAutoWrap(TestCase):
385    def setUp(self) -> None:
386        super().setUp()
387        # For all the tests here, we use a fake group
388        self.process_group = DummyProcessGroup(rank=0, size=1)
389
390    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
391    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
392    def test_wrap(self, wrap_method):
393        if wrap_method == WrapMethod.WRAP_API:
394            with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
395                layer = wrap(nn.Linear(5, 5))
396        else:
397            assert wrap_method == WrapMethod.FSDP_CTOR
398            layer = FSDP(
399                nn.Linear(5, 5),
400                process_group=self.process_group,
401                auto_wrap_policy=functools.partial(
402                    size_based_auto_wrap_policy, min_num_params=1
403                ),
404            )
405        self.assertTrue(isinstance(layer, FSDP))
406        self.assertEqual(layer.rank, self.process_group.rank())
407        self.assertEqual(layer.world_size, self.process_group.size())
408
409    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
410    def test_wrap_disabled_outside_context(self):
411        pg = self.process_group
412
413        class MyModel(nn.Module):
414            def __init__(self) -> None:
415                super().__init__()
416                self.lin = wrap(nn.Linear(5, 5), process_group=pg)
417
418        model = MyModel()
419        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
420            model = wrap(model)
421
422        self.assertTrue(isinstance(model, FSDP))
423        self.assertFalse(isinstance(model.lin, FSDP))
424        self.assertTrue(isinstance(model.lin, nn.Linear))
425
426    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
427    def test_wrap_override_defaults(self):
428        new_process_group = DummyProcessGroup(rank=0, size=2)
429        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
430            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
431        self.assertTrue(isinstance(layer, FSDP))
432        self.assertTrue(layer.process_group is new_process_group)
433        self.assertEqual(layer.rank, 0)
434        self.assertEqual(layer.world_size, 2)
435
436    @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
437    def test_always_wrap(self):
438        """
439        Test to ensure that if `always_wrap_policy` is
440        passed into FSDP, all submodules are wrapped.
441        """
442        seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
443        model = FSDP(
444            seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy
445        )
446        TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model)
447
448    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
449    def test_transformer_auto_wrap_policy(self):
450        """Tests the ``transformer_auto_wrap_policy``."""
451        auto_wrap_policy = functools.partial(
452            transformer_auto_wrap_policy,
453            transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
454        )
455        self._test_transformer_wrapping(auto_wrap_policy)
456
457    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
458    def test_module_wrap_policy(self):
459        """Tests the ``ModuleWrapPolicy``."""
460        auto_wrap_policy = ModuleWrapPolicy(
461            {TransformerEncoderLayer, TransformerDecoderLayer}
462        )
463        self._test_transformer_wrapping(auto_wrap_policy)
464
465    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
466    def test_module_wrap_policy_callable(self):
467        """Tests the ``ModuleWrapPolicy`` as a ``Callable``."""
468        auto_wrap_policy = ModuleWrapPolicy(
469            {TransformerEncoderLayer, TransformerDecoderLayer}
470        )
471        callable_policy = functools.partial(_or_policy, policies=[auto_wrap_policy])
472        self._test_transformer_wrapping(callable_policy)
473
474    def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy]):
475        fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
476        fsdp_model = TransformerWithSharedParams.init(
477            self.process_group,
478            FSDPInitMode.RECURSIVE,
479            CUDAInitMode.CUDA_BEFORE,
480            fsdp_kwargs,
481        )
482        modules = list(fsdp_model.modules())
483        encoder_layers = set(fsdp_model.module.transformer.encoder.layers)
484        decoder_layers = set(fsdp_model.module.transformer.decoder.layers)
485        for module in modules:
486            if (
487                module is fsdp_model
488                or module in encoder_layers
489                or module in decoder_layers
490            ):
491                self.assertTrue(isinstance(module, FSDP))
492            else:
493                self.assertFalse(isinstance(module, FSDP))
494
495    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
496    def test_custom_policy(self):
497        """
498        Tests ``CustomPolicy`` with both a lambda function that uses uniform
499        kwargs (so only returns ``False`` or ``True``) and a lambda function
500        that uses non-uniform kwargs (so returns a dict to override the root
501        kwargs).
502        """
503        for use_uniform_kwargs in [False, True]:
504            self._test_custom_policy(use_uniform_kwargs)
505
506    def _test_custom_policy(self, use_uniform_kwargs: bool):
507        print(f"use_uniform_kwargs={use_uniform_kwargs}")
508        model = TransformerWithSharedParams.init(
509            self.process_group,
510            FSDPInitMode.NO_FSDP,
511            CUDAInitMode.CUDA_BEFORE,
512            {},
513        )
514
515        if use_uniform_kwargs:
516
517            def lambda_fn(module: nn.Module):
518                if module is model.bn:
519                    return True
520                elif isinstance(
521                    module, (TransformerEncoderLayer, TransformerDecoderLayer)
522                ):
523                    return True
524                return False
525
526        else:
527
528            def lambda_fn(module: nn.Module):
529                if module is model.bn:
530                    return {"sharding_strategy": ShardingStrategy.NO_SHARD}
531                elif isinstance(module, TransformerEncoderLayer):
532                    return True
533                elif isinstance(module, TransformerDecoderLayer):
534                    return {
535                        "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
536                        "backward_prefetch": BackwardPrefetch.BACKWARD_POST,
537                    }
538                return False
539
540        policy = CustomPolicy(lambda_fn)
541        # Use a size-2 dummy PG to avoid clamping the sharding strategy to
542        # `NO_SHARD` as for a size-1 PG
543        process_group = DummyProcessGroup(rank=0, size=2)
544        fp16_mp = MixedPrecision(param_dtype=torch.float16)
545        fp32_mp = MixedPrecision()
546        model = FSDP(
547            model,
548            process_group=process_group,
549            auto_wrap_policy=policy,
550            mixed_precision=fp16_mp,
551        )
552        encoder_layers = set(model.module.transformer.encoder.layers)
553        decoder_layers = set(model.module.transformer.decoder.layers)
554        bn = model.module.bn
555        bn_strategy = (
556            ShardingStrategy.FULL_SHARD
557            if use_uniform_kwargs
558            else ShardingStrategy.NO_SHARD
559        )
560        bn_prefetch = BackwardPrefetch.BACKWARD_PRE
561        encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
562        encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
563        decoder_strategy = (
564            ShardingStrategy.FULL_SHARD
565            if use_uniform_kwargs
566            else ShardingStrategy.SHARD_GRAD_OP
567        )
568        decoder_prefetch = (
569            BackwardPrefetch.BACKWARD_PRE
570            if use_uniform_kwargs
571            else BackwardPrefetch.BACKWARD_POST
572        )
573        for module in model.modules():
574            if module is bn:
575                self.assertTrue(isinstance(module, FSDP))
576                self.assertEqual(module.sharding_strategy, bn_strategy)
577                self.assertEqual(module.backward_prefetch, bn_prefetch)
578                # We currently override batch norm modules to use fp32
579                self.assertEqual(module.mixed_precision, fp32_mp)
580            elif module in encoder_layers:
581                self.assertTrue(isinstance(module, FSDP))
582                self.assertEqual(module.sharding_strategy, encoder_strategy)
583                self.assertEqual(module.backward_prefetch, encoder_prefetch)
584                self.assertEqual(module.mixed_precision, fp16_mp)
585            elif module in decoder_layers:
586                self.assertTrue(isinstance(module, FSDP))
587                self.assertEqual(module.sharding_strategy, decoder_strategy)
588                self.assertEqual(module.backward_prefetch, decoder_prefetch)
589                self.assertEqual(module.mixed_precision, fp16_mp)
590            elif module is model:
591                self.assertTrue(isinstance(module, FSDP))
592                self.assertEqual(module.sharding_strategy, root_strategy)
593                self.assertEqual(module.backward_prefetch, root_prefetch)
594                self.assertEqual(module.mixed_precision, fp16_mp)
595            else:
596                self.assertFalse(isinstance(module, FSDP))
597
598    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
599    def test_auto_wrap_api(self):
600        """
601        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
602        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
603        """
604        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
605        my_auto_wrap_policy = functools.partial(
606            size_based_auto_wrap_policy, min_num_params=40
607        )
608        model = FSDP(
609            sequential,
610            process_group=self.process_group,
611            auto_wrap_policy=my_auto_wrap_policy,
612        )
613
614        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
615
616    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
617    def test_auto_wrap_preset_exclude_wrap(self):
618        """
619        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
620        min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
621        """
622        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
623        my_auto_wrap_policy = functools.partial(
624            size_based_auto_wrap_policy, min_num_params=40
625        )
626
627        model = FSDP(
628            sequential,
629            process_group=self.process_group,
630            auto_wrap_policy=my_auto_wrap_policy,
631        )
632
633        self.assertTrue(isinstance(model, FSDP))
634        self.assertTrue(isinstance(model[0], nn.Linear))
635        self.assertTrue(isinstance(model[1], nn.Linear))
636
637    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
638    def test_auto_wrap_preset_exclude_wrap_include_children(self):
639        """
640        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
641        min_num_params
642        """
643        sequential = nn.ModuleList([nn.Linear(10, 10)])
644        my_auto_wrap_policy = functools.partial(
645            size_based_auto_wrap_policy, min_num_params=40
646        )
647        model = FSDP(
648            sequential,
649            process_group=self.process_group,
650            auto_wrap_policy=my_auto_wrap_policy,
651        )
652
653        self.assertTrue(isinstance(model, FSDP))
654        self.assertTrue(isinstance(model[0], FSDP))
655
656    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
657    def test_auto_wrap_preset_force_leaf(self):
658        """
659        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
660        size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
661        """
662        sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
663        my_auto_wrap_policy = functools.partial(
664            size_based_auto_wrap_policy, min_num_params=40
665        )
666        model = FSDP(
667            sequential,
668            process_group=self.process_group,
669            auto_wrap_policy=my_auto_wrap_policy,
670        )
671        self.assertTrue(isinstance(model.module[0], FSDP))
672        # Assert children of multihead attention are not wrapped
673        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
674        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
675
676    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
677    def test_auto_wrap_preset_force_leaf_custom(self):
678        """
679        Test to ensure force-leaf modules are not wrapped.
680        """
681        my_auto_wrap_policy = functools.partial(
682            size_based_auto_wrap_policy,
683            min_num_params=40,
684            force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.union(
685                {nn.Linear}
686            ),
687        )
688        sequential = nn.Sequential(
689            nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])
690        )
691        model = FSDP(
692            sequential,
693            process_group=self.process_group,
694            auto_wrap_policy=my_auto_wrap_policy,
695        )
696        # Model was wrapped in FSDP as no inner modules were wrapped.
697        self.assertTrue(isinstance(model, FSDP))
698        self.assertTrue(isinstance(model.module[0], nn.Linear))
699        self.assertTrue(isinstance(model.module[1], nn.ModuleList))
700
701    @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
702    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER])
703    @parametrize(
704        "cpu_offload",
705        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
706    )
707    @parametrize("use_device_id", [True, False])
708    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id):
709        # CPU offload and CUDA after don't work together as expected.
710        if cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER:
711            return
712
713        device = torch.device("cuda")
714        torch.cuda.set_device(0)
715        device_id = (
716            torch.device("cuda", torch.cuda.current_device()) if use_device_id else None
717        )
718
719        # Random port in case the next test run quickly, same port would cause conflict.
720        os.environ["MASTER_ADDR"] = "localhost"
721        os.environ["MASTER_PORT"] = str(find_free_port())
722
723        file_name = tempfile.NamedTemporaryFile(delete=False).name
724        torch.distributed.init_process_group(
725            backend="nccl",
726            init_method=f"{FILE_SCHEMA}_{file_name}",
727            rank=0,
728            world_size=1,
729        )
730
731        # NOTE: We move model to CUDA after init with FSDP to simulate real use
732        # cases where full model cannot be loaded onto GPU, but their shards can.
733        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
734        try:
735            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
736                cuda=(not cuda_after_init)
737            )
738            my_auto_wrap_policy = functools.partial(
739                size_based_auto_wrap_policy, min_num_params=40
740            )
741            model = FSDP(
742                sequential,
743                cpu_offload=cpu_offload,
744                auto_wrap_policy=my_auto_wrap_policy,
745                device_id=device_id,
746            )
747            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
748            if cuda_after_init:
749                model = model.cuda()
750            input = torch.rand((1, 5), dtype=torch.float).to(device)
751            output = model(input)
752            loss = F.mse_loss(input, output)
753            loss.backward()
754        finally:
755            torch.distributed.destroy_process_group()
756
757        try:
758            os.remove(file_name)
759        except FileNotFoundError:
760            pass
761
762    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
763    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
764    def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
765        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
766        ignored_modules = [sequential[1], sequential[2][0]]
767        fsdp_kwargs = {
768            "process_group": self.process_group,
769            "auto_wrap_policy": always_wrap_policy,
770            "ignored_modules": ignored_modules,
771        }
772        if wrap_method == WrapMethod.FSDP_CTOR:
773            model = FSDP(sequential, **fsdp_kwargs)
774        elif wrap_method == WrapMethod.WRAP_API:
775            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
776                model = wrap(sequential)
777        else:
778            assert 0, f"Unsupported wrap method: {wrap_method}"
779        # All non-ignored modules should be wrapped with FSDP
780        self.assertTrue(isinstance(model, FSDP))
781        self.assertTrue(isinstance(model.module[0], FSDP))
782        self.assertTrue(isinstance(model.module[1], nn.Linear))
783        self.assertTrue(isinstance(model.module[2], FSDP))
784        self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
785        self.assertTrue(isinstance(model.module[2].module[1], FSDP))
786
787    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
788    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
789    def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
790        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
791        ignored_modules = [sequential[1], sequential[2][0]]
792        my_auto_wrap_policy = functools.partial(
793            size_based_auto_wrap_policy,
794            min_num_params=40,
795        )
796        fsdp_kwargs = {
797            "process_group": self.process_group,
798            "auto_wrap_policy": my_auto_wrap_policy,
799            "ignored_modules": ignored_modules,
800        }
801        if wrap_method == WrapMethod.FSDP_CTOR:
802            model = FSDP(sequential, **fsdp_kwargs)
803        elif wrap_method == WrapMethod.WRAP_API:
804            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
805                model = wrap(sequential)
806        else:
807            assert 0, f"Unsupported wrap method: {wrap_method}"
808        # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping
809        # policy does not exceed the parameter threshold before the inner
810        # sequential (`sequential[2]`) anymore; hence, it flattens
811        # `sequential[0]` and `sequential[2][0]` into `model` and leaves
812        # `sequential[1]` and `sequential[2][1]` as-is since they are ignored
813        self.assertTrue(isinstance(model, FSDP))
814        self.assertTrue(isinstance(model.module[0], nn.Linear))
815        self.assertTrue(isinstance(model.module[1], nn.Linear))
816        self.assertTrue(isinstance(model.module[2], nn.Sequential))
817        self.assertTrue(isinstance(model.module[2][0], nn.Linear))
818        self.assertTrue(isinstance(model.module[2][1], nn.Linear))
819
820    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
821    def test_frozen_params(self):
822        """
823        Tests that mixing frozen/non-frozen parameters in an FSDP instance
824        raises for ``use_orig_params=False`` and warns for ``True``.
825        """
826        module_classes = (LoraAttention, LoraMLP, LoraDecoder)
827        module_wrap_policy = ModuleWrapPolicy(module_classes)
828
829        def lambda_fn_uniform(module: nn.Module):
830            return isinstance(module, module_classes)
831
832        def lambda_fn_nonuniform(module: nn.Module):
833            if isinstance(module, LoraAttention):
834                return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
835            elif isinstance(module, module_classes):
836                return True
837            return False
838
839        lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform)
840        lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform)
841
842        for use_orig_params, policy in itertools.product(
843            [True, False],
844            [
845                module_wrap_policy,
846                lambda_wrap_policy_uniform,
847                lambda_wrap_policy_nonuniform,
848            ],
849        ):
850            self._test_frozen_params(use_orig_params, policy)
851
852    def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
853        model = LoraModel().cuda()
854        msg = "layers.0.attn has both parameters with requires_grad=True and False. "
855        if use_orig_params:
856            msg += "We do not recommend wrapping such modules"
857            ctx = self.assertWarnsRegex(UserWarning, msg)
858        else:
859            msg += "FSDP does not support wrapping such modules when use_orig_params=False."
860            ctx = self.assertRaisesRegex(ValueError, msg)
861        with ctx:
862            FSDP(
863                model,
864                process_group=self.process_group,
865                auto_wrap_policy=policy,
866                use_orig_params=use_orig_params,
867            )
868
869
870class TestWrapUtils(TestCase):
871    def test_validate_frozen_params(self):
872        """Tests the method ``_validate_frozen_params()``."""
873        for use_orig_params in [True, False]:
874            self._test_validate_frozen_params(use_orig_params)
875
876    def _test_validate_frozen_params(self, use_orig_params: bool):
877        model = LoraModel()
878        # Wrap only LoRA modules
879        modules_to_wrap = {
880            module
881            for module_name, module in model.named_modules()
882            if "lora_A" in module_name or "lora_B" in module_name
883        }
884        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
885        # Additionally wrap attention
886        for module in model.modules():
887            if isinstance(module, LoraAttention):
888                modules_to_wrap.add(module)
889        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
890        # Additionally wrap decoders
891        for module in model.modules():
892            if isinstance(module, LoraDecoder):
893                modules_to_wrap.add(module)
894        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
895        # Do not wrap the LoRA-A modules (meaning mixed frozen/non-frozen)
896        for module_name, module in model.named_modules():
897            if "lora_A" in module_name:
898                modules_to_wrap.remove(module)
899        regex = "layers.0.attn has both parameters with requires_grad=True and False."
900        if use_orig_params:
901            # Wrapping the attention manages all parameters except those from
902            # the LoRA-B module, which is separately wrapped and all nonfrozen
903            lorab_numel = sum(
904                p.numel() for p in model.layers[0].attn.lora_B.parameters()
905            )
906            attn_frozen_param_numel = sum(
907                p.numel()
908                for p in model.layers[0].attn.parameters()
909                if not p.requires_grad
910            )
911            attn_nonfrozen_param_numel = (
912                sum(
913                    p.numel()
914                    for p in model.layers[0].attn.parameters()
915                    if p.requires_grad
916                )
917                - lorab_numel
918            )
919            attn_total_param_numel = (
920                attn_frozen_param_numel + attn_nonfrozen_param_numel
921            )
922            regex += (
923                " We do not recommend wrapping such modules since the "
924                r"gradient memory usage will be higher than expected \("
925                f"{attn_total_param_numel} numel instead of {attn_nonfrozen_param_numel} numel "
926                r"before sharding via reduce-scatter\). "
927            )
928        else:
929            regex += " FSDP does not support wrapping such modules when use_orig_params=False. "
930        regex += "If possible, wrap the frozen parameters with FSDP separately.\n"
931        regex += (
932            "The following parameters have requires_grad=True:\n"
933            r"\['layers.0.attn.lora_A.weight'\]\n"
934            "The following parameters have requires_grad=False:\n"
935            r"\['layers.0.attn.q_proj.weight', 'layers.0.attn.k_proj.weight', "
936            r"'layers.0.attn.v_proj.weight', 'layers.0.attn.o_proj.weight'\]"
937        )
938        if use_orig_params:
939            ctx = self.assertWarnsRegex(UserWarning, regex)
940        else:
941            ctx = self.assertRaisesRegex(ValueError, regex)
942        with ctx:
943            _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
944        # Now ignore those LoRA-A modules' parameters
945        ignored_params = set()
946        for module_name, module in model.named_modules():
947            if "lora_A" in module_name:
948                ignored_params.update(module.parameters())
949        _validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
950
951
952instantiate_parametrized_tests(TestFSDPWrap)
953instantiate_parametrized_tests(TestAutoWrap)
954
955if __name__ == "__main__":
956    run_tests()
957