xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/test_hsdp_checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2from copy import deepcopy
3
4import torch
5import torch.distributed.checkpoint as dist_cp
6import torch.nn as nn
7import torch.nn.functional as F
8from torch.distributed._tensor import init_device_mesh, Replicate
9from torch.distributed.checkpoint.default_planner import (
10    DefaultLoadPlanner,
11    DefaultSavePlanner,
12)
13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14from torch.distributed.fsdp.fully_sharded_data_parallel import (
15    ShardingStrategy,
16    StateDictType,
17)
18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19from torch.testing._internal.common_utils import (
20    instantiate_parametrized_tests,
21    parametrize,
22    run_tests,
23)
24from torch.testing._internal.distributed._tensor.common_dtensor import (
25    DTensorTestBase,
26    with_comms,
27)
28from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
29
30
31class SimpleModel(torch.nn.Module):
32    def __init__(self) -> None:
33        super().__init__()
34        self.net1 = nn.Linear(5, 8)
35        self.relu = nn.ReLU()
36        self.net2 = nn.Linear(8, 4)
37        self.net3 = nn.Linear(4, 12)
38
39    def forward(self, x):
40        x = F.relu(self.net1(x))
41        x = F.relu(self.net2(x))
42        x = F.relu(self.net3(x))
43        return x
44
45    def get_input(self):
46        return torch.rand(4, 5, device="cuda")
47
48
49class SimpleModelUneven(torch.nn.Module):
50    def __init__(self) -> None:
51        super().__init__()
52        self.net1 = nn.Linear(5, 10)
53        self.relu = nn.ReLU()
54        self.net2 = nn.Linear(10, 15)
55        self.net3 = nn.Linear(15, 30)
56        self.net4 = nn.Linear(30, 5)
57
58    def forward(self, x):
59        x = F.relu(self.net1(x))
60        x = F.relu(self.net2(x))
61        x = F.relu(self.net3(x))
62        x = F.relu(self.net4(x))
63        return x
64
65    def get_input(self):
66        return torch.rand(4, 5, device="cuda")
67
68
69class TestHSDPCheckpoint(DTensorTestBase):
70    @property
71    def backend(self):
72        return "cpu:gloo,cuda:nccl"
73
74    @with_comms
75    @skip_if_lt_x_gpu(4)
76    @with_temp_dir
77    @parametrize("is_even_sharded_model", [True, False])
78    def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
79        CHECKPOINT_DIR = self.temp_dir
80        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
81
82        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
83        model = FSDP(
84            simple_model().cuda(),
85            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
86            device_mesh=mesh_2d,
87        )
88        optim = torch.optim.Adam(model.parameters(), lr=0.1)
89
90        FSDP.set_state_dict_type(
91            model,
92            StateDictType.SHARDED_STATE_DICT,
93        )
94        state_dict = {"model": model.state_dict()}
95        state_dict_to_save = deepcopy(state_dict)
96
97        dist_cp.save(
98            state_dict=state_dict_to_save,
99            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
100            planner=DefaultSavePlanner(),
101        )
102
103        # Update the parameters so current model state_dict now be different from state_dict_to_save.
104        model(model.get_input()).sum().backward()
105        optim.step()
106
107        # At this point, the current state dict is different from state_dict_to_save.
108        for (k1, v1), (k2, v2) in zip(
109            state_dict_to_save["model"].items(), model.state_dict().items()
110        ):
111            self.assertEqual(k1, k2)
112            self.assertEqual(v1.device_mesh, v2.device_mesh)
113            self.assertEqual(v1.placements, v2.placements)
114            self.assertNotEqual(v1.to_local(), v2.to_local())
115
116        dist_cp.load(
117            state_dict=state_dict_to_save,
118            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
119            planner=DefaultLoadPlanner(),
120        )
121        model.load_state_dict(state_dict_to_save["model"])
122
123        state_dict_after_load = model.state_dict()
124        # After loading, the current model state dict should be the same as state_dict_to_save.
125        for (k1, v1), (k2, v2) in zip(
126            state_dict_to_save["model"].items(), model.state_dict().items()
127        ):
128            self.assertEqual(k1, k2)
129            self.assertEqual(v1.device_mesh, v2.device_mesh)
130            self.assertEqual(v1.placements, v2.placements)
131            self.assertEqual(v1.to_local(), v2.to_local())
132
133    @with_comms
134    @skip_if_lt_x_gpu(4)
135    @with_temp_dir
136    @parametrize("is_even_sharded_model", [True, False])
137    def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
138        CHECKPOINT_DIR = self.temp_dir
139        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
140
141        # save the hsdp model state_dict
142        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
143        hsdp_model = FSDP(
144            simple_model().cuda(),
145            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
146            device_mesh=mesh_2d,
147        )
148        FSDP.set_state_dict_type(
149            hsdp_model,
150            StateDictType.SHARDED_STATE_DICT,
151        )
152        hsdp_state_dict = {"model": hsdp_model.state_dict()}
153        dist_cp.save_state_dict(
154            state_dict=hsdp_state_dict,
155            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
156            planner=DefaultSavePlanner(),
157        )
158
159        # initialize a fsdp model to load checkpoint into
160        mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
161        fsdp_model = FSDP(
162            simple_model().cuda(),
163            device_mesh=mesh_1d,
164        )
165        FSDP.set_state_dict_type(
166            fsdp_model,
167            StateDictType.SHARDED_STATE_DICT,
168        )
169        fsdp_state_dict = {"model": fsdp_model.state_dict()}
170
171        # at this point, the hsdp model parameters are different from fsdp model parameters.
172        for (k1, v1), (k2, v2) in zip(
173            hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items()
174        ):
175            self.assertEqual(k1, k2)
176            self.assertNotEqual(v1.device_mesh, v2.device_mesh)
177            self.assertNotEqual(v1.placements, v2.placements)
178            v1_all_gather = v1.redistribute(
179                mesh_2d, placements=(Replicate(), Replicate())
180            )
181            v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
182            self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
183
184        # load the fsdp state_dict from storage
185        dist_cp.load_state_dict(
186            state_dict=fsdp_state_dict,
187            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
188            planner=DefaultLoadPlanner(),
189        )
190        fsdp_model.load_state_dict(fsdp_state_dict["model"])
191
192        state_dict_after_load = fsdp_model.state_dict()
193        # After loading, the current model state dict should be the same as hsdp_state_dict.
194        for (k1, v1), (k2, v2) in zip(
195            hsdp_state_dict["model"].items(), state_dict_after_load.items()
196        ):
197            self.assertEqual(k1, k2)
198            self.assertNotEqual(v1.device_mesh, v2.device_mesh)
199            self.assertNotEqual(v1.placements, v2.placements)
200            v1_all_gather = v1.redistribute(
201                mesh_2d, placements=(Replicate(), Replicate())
202            )
203            v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
204            self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
205
206
207instantiate_parametrized_tests(TestHSDPCheckpoint)
208if __name__ == "__main__":
209    run_tests()
210