1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5from typing import cast, List, Optional, Union 6 7import torch 8import torch.distributed as dist 9import torch.futures 10import torch.nn 11from torch.distributed._shard import sharded_tensor 12from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook 13from torch.distributed._shard.sharding_spec import ChunkShardingSpec 14from torch.distributed.checkpoint import ( 15 CheckpointException, 16 load_state_dict, 17 save_state_dict, 18 StorageReader, 19 StorageWriter, 20) 21from torch.distributed.checkpoint.default_planner import _create_default_local_metadata 22from torch.distributed.checkpoint.metadata import ( 23 BytesStorageMetadata, 24 Metadata, 25 TensorStorageMetadata, 26) 27from torch.distributed.checkpoint.planner import ( 28 LoadPlan, 29 LoadPlanner, 30 SavePlan, 31 SavePlanner, 32) 33from torch.distributed.checkpoint.storage import WriteResult 34from torch.futures import Future 35from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu 36from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 37from torch.testing._internal.distributed._shard.sharded_tensor import ( 38 ShardedTensorTestBase, 39 with_comms, 40) 41 42 43if TEST_WITH_DEV_DBG_ASAN: 44 print( 45 "Skip dev-asan as torch + multiprocessing spawn have known issues", 46 file=sys.stderr, 47 ) 48 sys.exit(0) 49 50 51class TestModule(torch.nn.Module): 52 def __init__(self) -> None: 53 super().__init__() 54 self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4) 55 self.regular = torch.nn.Parameter(torch.ones(4, 4)) 56 self.extra_sharded: Optional[ShardedTensor] = None 57 self.extra_param: Optional[torch.nn.Parameter] = None 58 self._register_state_dict_hook(state_dict_hook) 59 60 def spec(self) -> ChunkShardingSpec: 61 # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. 62 return ChunkShardingSpec( 63 dim=0, 64 placements=[ 65 "rank:0/cuda:0", 66 "rank:1/cuda:1", 67 ], 68 ) 69 70 71class TestDistributedCheckpointing(ShardedTensorTestBase): 72 @property 73 def world_size(self) -> int: 74 return 2 75 76 @with_comms(init_rpc=False) 77 @skip_if_lt_x_gpu(2) 78 @requires_nccl() 79 def test_tensor_metadata_with_missing_rank_spec(self) -> None: 80 spec = ChunkShardingSpec( 81 dim=0, 82 placements=[ 83 "rank:1/cuda:1", 84 ], 85 ) 86 87 st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64) 88 mapping = {} 89 90 md = _create_default_local_metadata({"st": st}) 91 92 st_md = md.state_dict_metadata["st"] 93 self.assertEqual(1, len(st_md.chunks)) 94 95 @with_comms(init_rpc=False) 96 @skip_if_lt_x_gpu(2) 97 @requires_nccl() 98 def test_default_metadata(self) -> None: 99 device = f"cuda:{dist.get_rank()}" 100 spec = ChunkShardingSpec( 101 dim=0, 102 placements=[ 103 "rank:0/cuda:0", 104 "rank:1/cuda:1", 105 ], 106 ) 107 108 state_dict = { 109 "sharded": sharded_tensor.rand( 110 spec, 111 ( 112 10, 113 10, 114 ), 115 ), 116 "replicated": torch.rand(4, device=device), 117 "bytes": [1, 2, 3, 4], 118 } 119 120 metadata = _create_default_local_metadata(state_dict) 121 self.assertTrue("bytes" in metadata.state_dict_metadata) 122 self.assertIsInstance( 123 metadata.state_dict_metadata["bytes"], BytesStorageMetadata 124 ) 125 126 self.assertTrue("replicated" in metadata.state_dict_metadata) 127 self.assertIsInstance( 128 metadata.state_dict_metadata["replicated"], TensorStorageMetadata 129 ) 130 md = metadata.state_dict_metadata["replicated"] 131 self.assertEqual(md.size, state_dict["replicated"].size()) 132 self.assertEqual(md.properties.dtype, torch.float32) 133 self.assertEqual(1, len(md.chunks)) 134 135 self.assertTrue("sharded" in metadata.state_dict_metadata) 136 self.assertIsInstance( 137 metadata.state_dict_metadata["sharded"], TensorStorageMetadata 138 ) 139 md = metadata.state_dict_metadata["sharded"] 140 self.assertEqual(md.properties.dtype, torch.float32) 141 self.assertEqual(md.size, state_dict["sharded"].size()) 142 self.assertEqual(2, len(md.chunks)) 143 144 145class TestStorageBase: 146 def __init__(self, fail_conf): 147 self.fail_conf = fail_conf 148 self.rank = 0 if not dist.is_initialized() else dist.get_rank() 149 150 def _get_ranks(self, name): 151 return self.fail_conf[name] if name in self.fail_conf else None 152 153 def _fail_rank(self, name): 154 ranks = self._get_ranks(name) 155 if ranks is not None and self.rank in ranks: 156 raise ValueError(f"rank fail {self.rank} for {name}") 157 158 def _fail_rank_async(self, name, result=None): 159 ranks = self._get_ranks(name) 160 fut = Future() 161 if ranks is not None and self.rank in ranks: 162 fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}")) 163 else: 164 fut.set_result(result) 165 return fut 166 167 168class FaultyStorageWriter(TestStorageBase, StorageWriter): 169 def __init__(self, fail_conf): 170 super().__init__(fail_conf) 171 172 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 173 return 174 175 def set_up_storage_writer(self, is_coordinator: bool) -> None: 176 self._fail_rank("fail_set_up_storage_writer") 177 178 def prepare_local_plan(self, plan: SavePlan) -> SavePlan: 179 self._fail_rank("fail_prepare_local_plan") 180 return plan 181 182 def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: 183 self._fail_rank("fail_prepare_global_plan") 184 return plans 185 186 def write_data( 187 self, plan: SavePlan, planner: SavePlanner 188 ) -> Future[List[WriteResult]]: 189 self._fail_rank("fail_write_data") 190 return self._fail_rank_async("fail_write_data_async", []) 191 192 def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: 193 self._fail_rank("fail_finish") 194 195 @classmethod 196 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 197 return True 198 199 200class FaultyStorageReader(TestStorageBase, StorageReader): 201 def __init__(self, metadata, fail_conf): 202 super().__init__(fail_conf) 203 self.metadata = metadata 204 205 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 206 return 207 208 def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: 209 self._fail_rank("fail_set_up_storage_reader") 210 211 def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: 212 self._fail_rank("fail_prepare_local_plan") 213 return plan 214 215 def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: 216 self._fail_rank("fail_prepare_global_plan") 217 return plans 218 219 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 220 self._fail_rank("fail_read_data") 221 return self._fail_rank_async("fail_read_data_async") 222 223 def read_metadata(self) -> Metadata: 224 self._fail_rank("fail_read_metadata") 225 return self.metadata 226 227 @classmethod 228 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 229 return True 230 231 232class TestDistributedFailure(ShardedTensorTestBase): 233 def get_spec(self): 234 return ChunkShardingSpec( 235 dim=0, 236 placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())], 237 ) 238 239 @with_comms(init_rpc=False) 240 @skip_if_lt_x_gpu(2) 241 @requires_nccl() 242 def test_dummy_writer_works(self) -> None: 243 state_dict = { 244 "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), 245 "replicated": torch.rand(10, 10), 246 "bytes": [1, 2, 3, 4], 247 } 248 249 save_state_dict(state_dict, FaultyStorageWriter({})) 250 251 @with_comms(init_rpc=False) 252 @skip_if_lt_x_gpu(2) 253 @requires_nccl() 254 def test_dummy_reader_works(self) -> None: 255 state_dict = { 256 "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), 257 "replicated": torch.rand(10, 10), 258 "bytes": [1, 2, 3, 4], 259 } 260 metadata = _create_default_local_metadata(state_dict) 261 262 load_state_dict(state_dict, FaultyStorageReader(metadata, {})) 263 264 def _test_dist_failure(self, callback, kwargs): 265 bad_ranks = next(iter(kwargs.values())) if len(kwargs) > 0 else [] 266 267 # Empty bad_ranks means it must work 268 if len(bad_ranks) == 0: 269 callback() 270 else: 271 with self.assertRaises(CheckpointException) as cm: 272 callback() 273 e = cast(CheckpointException, cm.exception) 274 for rank, wrapped_ex in e.failures.items(): 275 ex = wrapped_ex[0] 276 self.assertTrue(rank in bad_ranks, msg=f"{rank} did not fail") 277 if not kwargs.get("ignore_exception_type", False): 278 self.assertEqual(ValueError, type(ex), str(ex)) 279 280 failed_ranks = e.failures.keys() 281 for rank in bad_ranks: 282 self.assertTrue( 283 rank in failed_ranks, 284 msg=f"{rank} was supposed to fail was fine", 285 ) 286 287 def _test_save(self, state_dict, coordinator=0, **kwargs): 288 no_dist = not dist.is_initialized() 289 290 def _save(): 291 save_state_dict( 292 state_dict, 293 storage_writer=FaultyStorageWriter(kwargs), 294 coordinator_rank=coordinator, 295 no_dist=no_dist, 296 ) 297 298 self._test_dist_failure(_save, kwargs) 299 300 def _test_load(self, state_dict, coordinator=0, **kwargs): 301 no_dist = not dist.is_initialized() 302 303 def _load(): 304 metadata = _create_default_local_metadata(state_dict) 305 load_state_dict( 306 state_dict, 307 storage_reader=FaultyStorageReader(metadata, kwargs), 308 coordinator_rank=coordinator, 309 no_dist=no_dist, 310 ) 311 312 self._test_dist_failure(_load, kwargs) 313 314 @with_comms(init_rpc=False) 315 @skip_if_lt_x_gpu(4) 316 @requires_nccl() 317 def test_save_error_handling(self) -> None: 318 state_dict = { 319 "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), 320 "replicated": torch.rand(10, 10), 321 "bytes": [1, 2, 3, 4], 322 } 323 324 self._test_save(state_dict, fail_set_up_storage_writer=[0]) 325 self._test_save(state_dict, fail_finish=[0]) 326 self._test_save(state_dict, fail_prepare_global_plan=[0]) 327 328 self._test_save(state_dict, fail_prepare_local_plan=[0]) 329 self._test_save(state_dict, fail_write_data=[2]) 330 self._test_save(state_dict, fail_write_data_async=[3]) 331 332 self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1]) 333 self._test_save(state_dict, coordinator=1, fail_finish=[1]) 334 335 def test_save_error_handling_no_dist(self) -> None: 336 state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} 337 338 self.assertFalse(dist.is_initialized()) 339 340 self._test_save(state_dict, fail_set_up_storage_writer=[0]) 341 self._test_save(state_dict, fail_finish=[0]) 342 self._test_save(state_dict, fail_prepare_global_plan=[0]) 343 344 self._test_save(state_dict, fail_prepare_local_plan=[0]) 345 self._test_save(state_dict, fail_write_data=[0]) 346 self._test_save(state_dict, fail_write_data_async=[0]) 347 348 @with_comms(init_rpc=False) 349 @skip_if_lt_x_gpu(4) 350 @requires_nccl() 351 def test_load_error_handling(self) -> None: 352 state_dict = { 353 "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), 354 "replicated": torch.rand(10, 10), 355 "bytes": [1, 2, 3, 4], 356 } 357 358 self._test_load(state_dict) 359 self._test_load(state_dict, fail_set_up_storage_reader=[0]) 360 self._test_load(state_dict, fail_prepare_global_plan=[0]) 361 self._test_load(state_dict, fail_read_metadata=[0]) 362 self._test_load(state_dict, fail_prepare_local_plan=[1]) 363 self._test_load(state_dict, fail_read_data=[3]) 364 self._test_load(state_dict, fail_read_data_async=[1]) 365 366 self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0]) 367 self._test_load(state_dict, coordinator=1, fail_read_metadata=[3]) 368 self._test_load(state_dict, coordinator=2, fail_read_data=[0]) 369 self._test_load(state_dict, coordinator=3, fail_read_data_async=[2]) 370 self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1]) 371 372 def test_load_error_handling_no_dist(self) -> None: 373 state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} 374 self._test_load(state_dict) 375 self._test_load(state_dict, fail_set_up_storage_reader=[0]) 376 self._test_load(state_dict, fail_read_metadata=[0]) 377 self._test_load(state_dict, fail_prepare_local_plan=[0]) 378 self._test_load(state_dict, fail_prepare_global_plan=[0]) 379 self._test_load(state_dict, fail_read_data=[0]) 380 self._test_load(state_dict, fail_read_data_async=[0]) 381 382 383if __name__ == "__main__": 384 run_tests() 385