1# Owner(s): ["oncall: distributed"] 2from copy import deepcopy 3 4import torch 5import torch.distributed.checkpoint as dist_cp 6import torch.nn as nn 7import torch.nn.functional as F 8from torch.distributed._tensor import init_device_mesh, Replicate 9from torch.distributed.checkpoint.default_planner import ( 10 DefaultLoadPlanner, 11 DefaultSavePlanner, 12) 13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 14from torch.distributed.fsdp.fully_sharded_data_parallel import ( 15 ShardingStrategy, 16 StateDictType, 17) 18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 19from torch.testing._internal.common_utils import ( 20 instantiate_parametrized_tests, 21 parametrize, 22 run_tests, 23) 24from torch.testing._internal.distributed._tensor.common_dtensor import ( 25 DTensorTestBase, 26 with_comms, 27) 28from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 29 30 31class SimpleModel(torch.nn.Module): 32 def __init__(self) -> None: 33 super().__init__() 34 self.net1 = nn.Linear(5, 8) 35 self.relu = nn.ReLU() 36 self.net2 = nn.Linear(8, 4) 37 self.net3 = nn.Linear(4, 12) 38 39 def forward(self, x): 40 x = F.relu(self.net1(x)) 41 x = F.relu(self.net2(x)) 42 x = F.relu(self.net3(x)) 43 return x 44 45 def get_input(self): 46 return torch.rand(4, 5, device="cuda") 47 48 49class SimpleModelUneven(torch.nn.Module): 50 def __init__(self) -> None: 51 super().__init__() 52 self.net1 = nn.Linear(5, 10) 53 self.relu = nn.ReLU() 54 self.net2 = nn.Linear(10, 15) 55 self.net3 = nn.Linear(15, 30) 56 self.net4 = nn.Linear(30, 5) 57 58 def forward(self, x): 59 x = F.relu(self.net1(x)) 60 x = F.relu(self.net2(x)) 61 x = F.relu(self.net3(x)) 62 x = F.relu(self.net4(x)) 63 return x 64 65 def get_input(self): 66 return torch.rand(4, 5, device="cuda") 67 68 69class TestHSDPCheckpoint(DTensorTestBase): 70 @property 71 def backend(self): 72 return "cpu:gloo,cuda:nccl" 73 74 @with_comms 75 @skip_if_lt_x_gpu(4) 76 @with_temp_dir 77 @parametrize("is_even_sharded_model", [True, False]) 78 def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: 79 CHECKPOINT_DIR = self.temp_dir 80 simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven 81 82 mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2)) 83 model = FSDP( 84 simple_model().cuda(), 85 sharding_strategy=ShardingStrategy.HYBRID_SHARD, 86 device_mesh=mesh_2d, 87 ) 88 optim = torch.optim.Adam(model.parameters(), lr=0.1) 89 90 FSDP.set_state_dict_type( 91 model, 92 StateDictType.SHARDED_STATE_DICT, 93 ) 94 state_dict = {"model": model.state_dict()} 95 state_dict_to_save = deepcopy(state_dict) 96 97 dist_cp.save( 98 state_dict=state_dict_to_save, 99 storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), 100 planner=DefaultSavePlanner(), 101 ) 102 103 # Update the parameters so current model state_dict now be different from state_dict_to_save. 104 model(model.get_input()).sum().backward() 105 optim.step() 106 107 # At this point, the current state dict is different from state_dict_to_save. 108 for (k1, v1), (k2, v2) in zip( 109 state_dict_to_save["model"].items(), model.state_dict().items() 110 ): 111 self.assertEqual(k1, k2) 112 self.assertEqual(v1.device_mesh, v2.device_mesh) 113 self.assertEqual(v1.placements, v2.placements) 114 self.assertNotEqual(v1.to_local(), v2.to_local()) 115 116 dist_cp.load( 117 state_dict=state_dict_to_save, 118 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 119 planner=DefaultLoadPlanner(), 120 ) 121 model.load_state_dict(state_dict_to_save["model"]) 122 123 state_dict_after_load = model.state_dict() 124 # After loading, the current model state dict should be the same as state_dict_to_save. 125 for (k1, v1), (k2, v2) in zip( 126 state_dict_to_save["model"].items(), model.state_dict().items() 127 ): 128 self.assertEqual(k1, k2) 129 self.assertEqual(v1.device_mesh, v2.device_mesh) 130 self.assertEqual(v1.placements, v2.placements) 131 self.assertEqual(v1.to_local(), v2.to_local()) 132 133 @with_comms 134 @skip_if_lt_x_gpu(4) 135 @with_temp_dir 136 @parametrize("is_even_sharded_model", [True, False]) 137 def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None: 138 CHECKPOINT_DIR = self.temp_dir 139 simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven 140 141 # save the hsdp model state_dict 142 mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2)) 143 hsdp_model = FSDP( 144 simple_model().cuda(), 145 sharding_strategy=ShardingStrategy.HYBRID_SHARD, 146 device_mesh=mesh_2d, 147 ) 148 FSDP.set_state_dict_type( 149 hsdp_model, 150 StateDictType.SHARDED_STATE_DICT, 151 ) 152 hsdp_state_dict = {"model": hsdp_model.state_dict()} 153 dist_cp.save_state_dict( 154 state_dict=hsdp_state_dict, 155 storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), 156 planner=DefaultSavePlanner(), 157 ) 158 159 # initialize a fsdp model to load checkpoint into 160 mesh_1d = init_device_mesh(self.device_type, (self.world_size,)) 161 fsdp_model = FSDP( 162 simple_model().cuda(), 163 device_mesh=mesh_1d, 164 ) 165 FSDP.set_state_dict_type( 166 fsdp_model, 167 StateDictType.SHARDED_STATE_DICT, 168 ) 169 fsdp_state_dict = {"model": fsdp_model.state_dict()} 170 171 # at this point, the hsdp model parameters are different from fsdp model parameters. 172 for (k1, v1), (k2, v2) in zip( 173 hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items() 174 ): 175 self.assertEqual(k1, k2) 176 self.assertNotEqual(v1.device_mesh, v2.device_mesh) 177 self.assertNotEqual(v1.placements, v2.placements) 178 v1_all_gather = v1.redistribute( 179 mesh_2d, placements=(Replicate(), Replicate()) 180 ) 181 v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),)) 182 self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local()) 183 184 # load the fsdp state_dict from storage 185 dist_cp.load_state_dict( 186 state_dict=fsdp_state_dict, 187 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 188 planner=DefaultLoadPlanner(), 189 ) 190 fsdp_model.load_state_dict(fsdp_state_dict["model"]) 191 192 state_dict_after_load = fsdp_model.state_dict() 193 # After loading, the current model state dict should be the same as hsdp_state_dict. 194 for (k1, v1), (k2, v2) in zip( 195 hsdp_state_dict["model"].items(), state_dict_after_load.items() 196 ): 197 self.assertEqual(k1, k2) 198 self.assertNotEqual(v1.device_mesh, v2.device_mesh) 199 self.assertNotEqual(v1.placements, v2.placements) 200 v1_all_gather = v1.redistribute( 201 mesh_2d, placements=(Replicate(), Replicate()) 202 ) 203 v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),)) 204 self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local()) 205 206 207instantiate_parametrized_tests(TestHSDPCheckpoint) 208if __name__ == "__main__": 209 run_tests() 210