xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/test_checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from typing import cast, List, Optional, Union
6
7import torch
8import torch.distributed as dist
9import torch.futures
10import torch.nn
11from torch.distributed._shard import sharded_tensor
12from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
13from torch.distributed._shard.sharding_spec import ChunkShardingSpec
14from torch.distributed.checkpoint import (
15    CheckpointException,
16    load_state_dict,
17    save_state_dict,
18    StorageReader,
19    StorageWriter,
20)
21from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
22from torch.distributed.checkpoint.metadata import (
23    BytesStorageMetadata,
24    Metadata,
25    TensorStorageMetadata,
26)
27from torch.distributed.checkpoint.planner import (
28    LoadPlan,
29    LoadPlanner,
30    SavePlan,
31    SavePlanner,
32)
33from torch.distributed.checkpoint.storage import WriteResult
34from torch.futures import Future
35from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
36from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
37from torch.testing._internal.distributed._shard.sharded_tensor import (
38    ShardedTensorTestBase,
39    with_comms,
40)
41
42
43if TEST_WITH_DEV_DBG_ASAN:
44    print(
45        "Skip dev-asan as torch + multiprocessing spawn have known issues",
46        file=sys.stderr,
47    )
48    sys.exit(0)
49
50
51class TestModule(torch.nn.Module):
52    def __init__(self) -> None:
53        super().__init__()
54        self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4)
55        self.regular = torch.nn.Parameter(torch.ones(4, 4))
56        self.extra_sharded: Optional[ShardedTensor] = None
57        self.extra_param: Optional[torch.nn.Parameter] = None
58        self._register_state_dict_hook(state_dict_hook)
59
60    def spec(self) -> ChunkShardingSpec:
61        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
62        return ChunkShardingSpec(
63            dim=0,
64            placements=[
65                "rank:0/cuda:0",
66                "rank:1/cuda:1",
67            ],
68        )
69
70
71class TestDistributedCheckpointing(ShardedTensorTestBase):
72    @property
73    def world_size(self) -> int:
74        return 2
75
76    @with_comms(init_rpc=False)
77    @skip_if_lt_x_gpu(2)
78    @requires_nccl()
79    def test_tensor_metadata_with_missing_rank_spec(self) -> None:
80        spec = ChunkShardingSpec(
81            dim=0,
82            placements=[
83                "rank:1/cuda:1",
84            ],
85        )
86
87        st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
88        mapping = {}
89
90        md = _create_default_local_metadata({"st": st})
91
92        st_md = md.state_dict_metadata["st"]
93        self.assertEqual(1, len(st_md.chunks))
94
95    @with_comms(init_rpc=False)
96    @skip_if_lt_x_gpu(2)
97    @requires_nccl()
98    def test_default_metadata(self) -> None:
99        device = f"cuda:{dist.get_rank()}"
100        spec = ChunkShardingSpec(
101            dim=0,
102            placements=[
103                "rank:0/cuda:0",
104                "rank:1/cuda:1",
105            ],
106        )
107
108        state_dict = {
109            "sharded": sharded_tensor.rand(
110                spec,
111                (
112                    10,
113                    10,
114                ),
115            ),
116            "replicated": torch.rand(4, device=device),
117            "bytes": [1, 2, 3, 4],
118        }
119
120        metadata = _create_default_local_metadata(state_dict)
121        self.assertTrue("bytes" in metadata.state_dict_metadata)
122        self.assertIsInstance(
123            metadata.state_dict_metadata["bytes"], BytesStorageMetadata
124        )
125
126        self.assertTrue("replicated" in metadata.state_dict_metadata)
127        self.assertIsInstance(
128            metadata.state_dict_metadata["replicated"], TensorStorageMetadata
129        )
130        md = metadata.state_dict_metadata["replicated"]
131        self.assertEqual(md.size, state_dict["replicated"].size())
132        self.assertEqual(md.properties.dtype, torch.float32)
133        self.assertEqual(1, len(md.chunks))
134
135        self.assertTrue("sharded" in metadata.state_dict_metadata)
136        self.assertIsInstance(
137            metadata.state_dict_metadata["sharded"], TensorStorageMetadata
138        )
139        md = metadata.state_dict_metadata["sharded"]
140        self.assertEqual(md.properties.dtype, torch.float32)
141        self.assertEqual(md.size, state_dict["sharded"].size())
142        self.assertEqual(2, len(md.chunks))
143
144
145class TestStorageBase:
146    def __init__(self, fail_conf):
147        self.fail_conf = fail_conf
148        self.rank = 0 if not dist.is_initialized() else dist.get_rank()
149
150    def _get_ranks(self, name):
151        return self.fail_conf[name] if name in self.fail_conf else None
152
153    def _fail_rank(self, name):
154        ranks = self._get_ranks(name)
155        if ranks is not None and self.rank in ranks:
156            raise ValueError(f"rank fail {self.rank} for {name}")
157
158    def _fail_rank_async(self, name, result=None):
159        ranks = self._get_ranks(name)
160        fut = Future()
161        if ranks is not None and self.rank in ranks:
162            fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
163        else:
164            fut.set_result(result)
165        return fut
166
167
168class FaultyStorageWriter(TestStorageBase, StorageWriter):
169    def __init__(self, fail_conf):
170        super().__init__(fail_conf)
171
172    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
173        return
174
175    def set_up_storage_writer(self, is_coordinator: bool) -> None:
176        self._fail_rank("fail_set_up_storage_writer")
177
178    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
179        self._fail_rank("fail_prepare_local_plan")
180        return plan
181
182    def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
183        self._fail_rank("fail_prepare_global_plan")
184        return plans
185
186    def write_data(
187        self, plan: SavePlan, planner: SavePlanner
188    ) -> Future[List[WriteResult]]:
189        self._fail_rank("fail_write_data")
190        return self._fail_rank_async("fail_write_data_async", [])
191
192    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
193        self._fail_rank("fail_finish")
194
195    @classmethod
196    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
197        return True
198
199
200class FaultyStorageReader(TestStorageBase, StorageReader):
201    def __init__(self, metadata, fail_conf):
202        super().__init__(fail_conf)
203        self.metadata = metadata
204
205    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
206        return
207
208    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
209        self._fail_rank("fail_set_up_storage_reader")
210
211    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
212        self._fail_rank("fail_prepare_local_plan")
213        return plan
214
215    def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
216        self._fail_rank("fail_prepare_global_plan")
217        return plans
218
219    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
220        self._fail_rank("fail_read_data")
221        return self._fail_rank_async("fail_read_data_async")
222
223    def read_metadata(self) -> Metadata:
224        self._fail_rank("fail_read_metadata")
225        return self.metadata
226
227    @classmethod
228    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
229        return True
230
231
232class TestDistributedFailure(ShardedTensorTestBase):
233    def get_spec(self):
234        return ChunkShardingSpec(
235            dim=0,
236            placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
237        )
238
239    @with_comms(init_rpc=False)
240    @skip_if_lt_x_gpu(2)
241    @requires_nccl()
242    def test_dummy_writer_works(self) -> None:
243        state_dict = {
244            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
245            "replicated": torch.rand(10, 10),
246            "bytes": [1, 2, 3, 4],
247        }
248
249        save_state_dict(state_dict, FaultyStorageWriter({}))
250
251    @with_comms(init_rpc=False)
252    @skip_if_lt_x_gpu(2)
253    @requires_nccl()
254    def test_dummy_reader_works(self) -> None:
255        state_dict = {
256            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
257            "replicated": torch.rand(10, 10),
258            "bytes": [1, 2, 3, 4],
259        }
260        metadata = _create_default_local_metadata(state_dict)
261
262        load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
263
264    def _test_dist_failure(self, callback, kwargs):
265        bad_ranks = next(iter(kwargs.values())) if len(kwargs) > 0 else []
266
267        # Empty bad_ranks means it must work
268        if len(bad_ranks) == 0:
269            callback()
270        else:
271            with self.assertRaises(CheckpointException) as cm:
272                callback()
273            e = cast(CheckpointException, cm.exception)
274            for rank, wrapped_ex in e.failures.items():
275                ex = wrapped_ex[0]
276                self.assertTrue(rank in bad_ranks, msg=f"{rank} did not fail")
277                if not kwargs.get("ignore_exception_type", False):
278                    self.assertEqual(ValueError, type(ex), str(ex))
279
280            failed_ranks = e.failures.keys()
281            for rank in bad_ranks:
282                self.assertTrue(
283                    rank in failed_ranks,
284                    msg=f"{rank} was supposed to fail was fine",
285                )
286
287    def _test_save(self, state_dict, coordinator=0, **kwargs):
288        no_dist = not dist.is_initialized()
289
290        def _save():
291            save_state_dict(
292                state_dict,
293                storage_writer=FaultyStorageWriter(kwargs),
294                coordinator_rank=coordinator,
295                no_dist=no_dist,
296            )
297
298        self._test_dist_failure(_save, kwargs)
299
300    def _test_load(self, state_dict, coordinator=0, **kwargs):
301        no_dist = not dist.is_initialized()
302
303        def _load():
304            metadata = _create_default_local_metadata(state_dict)
305            load_state_dict(
306                state_dict,
307                storage_reader=FaultyStorageReader(metadata, kwargs),
308                coordinator_rank=coordinator,
309                no_dist=no_dist,
310            )
311
312        self._test_dist_failure(_load, kwargs)
313
314    @with_comms(init_rpc=False)
315    @skip_if_lt_x_gpu(4)
316    @requires_nccl()
317    def test_save_error_handling(self) -> None:
318        state_dict = {
319            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
320            "replicated": torch.rand(10, 10),
321            "bytes": [1, 2, 3, 4],
322        }
323
324        self._test_save(state_dict, fail_set_up_storage_writer=[0])
325        self._test_save(state_dict, fail_finish=[0])
326        self._test_save(state_dict, fail_prepare_global_plan=[0])
327
328        self._test_save(state_dict, fail_prepare_local_plan=[0])
329        self._test_save(state_dict, fail_write_data=[2])
330        self._test_save(state_dict, fail_write_data_async=[3])
331
332        self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
333        self._test_save(state_dict, coordinator=1, fail_finish=[1])
334
335    def test_save_error_handling_no_dist(self) -> None:
336        state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
337
338        self.assertFalse(dist.is_initialized())
339
340        self._test_save(state_dict, fail_set_up_storage_writer=[0])
341        self._test_save(state_dict, fail_finish=[0])
342        self._test_save(state_dict, fail_prepare_global_plan=[0])
343
344        self._test_save(state_dict, fail_prepare_local_plan=[0])
345        self._test_save(state_dict, fail_write_data=[0])
346        self._test_save(state_dict, fail_write_data_async=[0])
347
348    @with_comms(init_rpc=False)
349    @skip_if_lt_x_gpu(4)
350    @requires_nccl()
351    def test_load_error_handling(self) -> None:
352        state_dict = {
353            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
354            "replicated": torch.rand(10, 10),
355            "bytes": [1, 2, 3, 4],
356        }
357
358        self._test_load(state_dict)
359        self._test_load(state_dict, fail_set_up_storage_reader=[0])
360        self._test_load(state_dict, fail_prepare_global_plan=[0])
361        self._test_load(state_dict, fail_read_metadata=[0])
362        self._test_load(state_dict, fail_prepare_local_plan=[1])
363        self._test_load(state_dict, fail_read_data=[3])
364        self._test_load(state_dict, fail_read_data_async=[1])
365
366        self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
367        self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
368        self._test_load(state_dict, coordinator=2, fail_read_data=[0])
369        self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
370        self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
371
372    def test_load_error_handling_no_dist(self) -> None:
373        state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
374        self._test_load(state_dict)
375        self._test_load(state_dict, fail_set_up_storage_reader=[0])
376        self._test_load(state_dict, fail_read_metadata=[0])
377        self._test_load(state_dict, fail_prepare_local_plan=[0])
378        self._test_load(state_dict, fail_prepare_global_plan=[0])
379        self._test_load(state_dict, fail_read_data=[0])
380        self._test_load(state_dict, fail_read_data_async=[0])
381
382
383if __name__ == "__main__":
384    run_tests()
385