1from concurrent.futures import Future 2from typing import Any, Dict, List, Optional 3 4import torch.distributed as dist 5import torch.distributed.checkpoint.state_dict_loader as loader 6import torch.distributed.checkpoint.state_dict_saver as saver 7from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE 8from torch.distributed.checkpoint.storage import ( 9 LoadPlanner, 10 SavePlanner, 11 StorageReader, 12 StorageWriter, 13) 14 15 16__all__: List[str] = [] 17 18 19class _Checkpointer: 20 """This base class specefies a high level API for saving and loading 21 distributed `state_dict` 's. It provides an abstraction over the low-level APIs 22 provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling 23 :py:meth: `torch.distributed.state_dict_saver.save` and 24 :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage 25 readers and writers. 26 27 .. warning:: 28 This feature is experimental and subject to removal/change. 29 30 """ 31 32 def __init__( 33 self, 34 storage_writer: StorageWriter, 35 storage_reader: StorageReader, 36 *, 37 process_group: Optional[dist.ProcessGroup] = None, 38 coordinator_rank: int = 0, 39 no_dist: bool = False, 40 load_planner: Optional[LoadPlanner] = None, 41 save_planner: Optional[SavePlanner] = None, 42 ): 43 """Initializes the Checkpointer instance. 44 45 Args: 46 storage_writer: Instance of StorageWrite use to perform writes. 47 storage_reader: StorageReader used to load data from. 48 process_group: ProcessGroup to be used for cross-rank synchronization. 49 coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. 50 no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) 51 loader_planner: Instance of LoadPlanner to use when loading. 52 save_planner: Instance of SavePlanner to use when saving. 53 """ 54 self.storage_writer = storage_writer 55 self.storage_reader = storage_reader 56 self.process_group = process_group 57 self.coordinator_rank = coordinator_rank 58 self.no_dist = no_dist 59 self.load_planner = load_planner 60 self.save_planner = save_planner 61 62 def save( 63 self, 64 state_dict: STATE_DICT_TYPE, 65 ) -> Metadata: 66 """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" 67 return saver.save( 68 state_dict, 69 self.storage_writer, 70 process_group=self.process_group, 71 coordinator_rank=self.coordinator_rank, 72 no_dist=self.no_dist, 73 planner=self.save_planner, 74 ) 75 76 def async_save( 77 self, 78 state_dict: STATE_DICT_TYPE, 79 ) -> Future: 80 """ 81 Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. 82 83 Returns: 84 Future: A future holding the resultant Metadata object from `save`. 85 """ 86 return saver.async_save( 87 state_dict, 88 storage_writer=self.storage_writer, 89 process_group=self.process_group, 90 planner=self.save_planner, 91 ) 92 93 def load(self, state_dict: Dict[str, Any]) -> None: 94 """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" 95 loader.load( 96 state_dict, 97 storage_reader=self.storage_reader, 98 process_group=self.process_group, 99 planner=self.load_planner, 100 ) 101