xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/_checkpointer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from concurrent.futures import Future
2from typing import Any, Dict, List, Optional
3
4import torch.distributed as dist
5import torch.distributed.checkpoint.state_dict_loader as loader
6import torch.distributed.checkpoint.state_dict_saver as saver
7from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
8from torch.distributed.checkpoint.storage import (
9    LoadPlanner,
10    SavePlanner,
11    StorageReader,
12    StorageWriter,
13)
14
15
16__all__: List[str] = []
17
18
19class _Checkpointer:
20    """This base class specefies a high level API for saving and loading
21    distributed `state_dict` 's. It provides an abstraction over the low-level APIs
22    provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
23    :py:meth: `torch.distributed.state_dict_saver.save` and
24    :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
25    readers and writers.
26
27    .. warning::
28        This feature is experimental and subject to removal/change.
29
30    """
31
32    def __init__(
33        self,
34        storage_writer: StorageWriter,
35        storage_reader: StorageReader,
36        *,
37        process_group: Optional[dist.ProcessGroup] = None,
38        coordinator_rank: int = 0,
39        no_dist: bool = False,
40        load_planner: Optional[LoadPlanner] = None,
41        save_planner: Optional[SavePlanner] = None,
42    ):
43        """Initializes the Checkpointer instance.
44
45        Args:
46            storage_writer: Instance of StorageWrite use to perform writes.
47            storage_reader: StorageReader used to load data from.
48            process_group: ProcessGroup to be used for cross-rank synchronization.
49            coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
50            no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
51            loader_planner: Instance of LoadPlanner to use when loading.
52            save_planner: Instance of SavePlanner to use when saving.
53        """
54        self.storage_writer = storage_writer
55        self.storage_reader = storage_reader
56        self.process_group = process_group
57        self.coordinator_rank = coordinator_rank
58        self.no_dist = no_dist
59        self.load_planner = load_planner
60        self.save_planner = save_planner
61
62    def save(
63        self,
64        state_dict: STATE_DICT_TYPE,
65    ) -> Metadata:
66        """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
67        return saver.save(
68            state_dict,
69            self.storage_writer,
70            process_group=self.process_group,
71            coordinator_rank=self.coordinator_rank,
72            no_dist=self.no_dist,
73            planner=self.save_planner,
74        )
75
76    def async_save(
77        self,
78        state_dict: STATE_DICT_TYPE,
79    ) -> Future:
80        """
81        Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
82
83        Returns:
84            Future: A future holding the resultant Metadata object from `save`.
85        """
86        return saver.async_save(
87            state_dict,
88            storage_writer=self.storage_writer,
89            process_group=self.process_group,
90            planner=self.save_planner,
91        )
92
93    def load(self, state_dict: Dict[str, Any]) -> None:
94        """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
95        loader.load(
96            state_dict,
97            storage_reader=self.storage_reader,
98            process_group=self.process_group,
99            planner=self.load_planner,
100        )
101