xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2
3import dataclasses
4from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
5
6import torch
7import torch.distributed as dist
8from torch._utils import _get_device_module
9from torch.distributed._shard.sharded_tensor.api import ShardedTensor
10from torch.distributed._shard.sharded_tensor.metadata import (
11    TensorProperties as ShardTensorProperties,
12)
13from torch.distributed._shard.sharded_tensor.shard import Shard
14from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
15from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
16from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
17from torch.distributed.checkpoint.metadata import (
18    BytesStorageMetadata,
19    ChunkStorageMetadata,
20    Metadata,
21    MetadataIndex,
22    STATE_DICT_TYPE,
23    TensorProperties,
24    TensorStorageMetadata,
25)
26from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
27from torch.distributed.checkpoint.planner_helpers import (
28    _create_read_items,
29    create_read_items_for_chunk_list,
30)
31from torch.distributed.checkpoint.state_dict_loader import load_state_dict
32from torch.distributed.checkpoint.storage import StorageReader
33from torch.distributed.checkpoint.utils import (
34    _element_wise_add,
35    _element_wise_sub,
36    _normalize_device_info,
37)
38from torch.distributed.distributed_c10d import _get_default_group
39from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
40from torch.distributed.remote_device import _remote_device
41from torch.distributed.tensor import DTensor
42
43
44STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
45
46
47# TODO: Update docstrings for optimizer.py
48__all__ = [
49    "load_sharded_optimizer_state_dict",
50]
51
52
53def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
54    if device_type == "cpu":
55        return "cpu"
56    device_module = _get_device_module(device_type)
57    if device_module.is_available():
58        return _normalize_device_info(
59            device_type, global_rank % device_module.device_count()
60        )
61    return "cpu"
62
63
64def _create_colwise_spec(
65    pg: Optional[dist.ProcessGroup] = None,
66) -> ChunkShardingSpec:
67    pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
68    if pg is None:
69        placements = [
70            f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
71            for idx in range(dist.get_world_size())
72        ]
73    else:
74        placements = [
75            f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
76            for idx in range(pg.size())
77        ]
78    return ChunkShardingSpec(
79        dim=0,
80        placements=cast(List[Union[_remote_device, str]], placements),
81    )
82
83
84def _is_nested_tensor(val: torch.Tensor) -> bool:
85    if type(val) is ShardedTensor:
86        if len(val.local_shards()) == 0:
87            return False
88        if type(val.local_shards()[0].tensor) is ShardedTensor:
89            return True
90        if type(val.local_shards()[0].tensor) is DTensor:
91            raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
92    elif type(val) is DTensor and (
93        type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
94    ):
95        raise ValueError("Cannot handle nested DTensor")
96    return False
97
98
99def _alloc_tensor(
100    props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
101) -> torch.Tensor:
102    if device_type == "cpu":
103        device = cast(torch.device, _get_device_module(device_type).current_device())
104    else:
105        device = torch.device(
106            device_type, _get_device_module(device_type).current_device()
107        )
108
109    return torch.empty(
110        size=size,
111        dtype=props.dtype,
112        layout=props.layout,
113        requires_grad=props.requires_grad,
114        pin_memory=props.pin_memory,
115        device=device,
116    )
117
118
119def _get_state_dict_2d_layout(
120    state_dict: STATE_DICT_TYPE,
121) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
122    """
123    Load the right TP slice of the optimizer state.
124
125    This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
126    We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
127    This is pretty fragile and it might be easier for FSDP to compute this info for us.
128    Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
129    (offset, size) for the current rank TP slice.
130    N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
131    """
132    specs: STATE_DICT_2D_LAYOUT = {}
133    dp_pg: Optional[dist.ProcessGroup] = None
134    for key, value in state_dict.items():
135        specs[key] = (None, value.size())
136        if _is_nested_tensor(value):
137            assert (
138                len(value.local_shards()) == 1
139            ), "Cannot handle ST with multiple shards"
140            assert isinstance(
141                value, ShardedTensor
142            ), "Can only handle nested ShardedTensor"
143            shard = value.local_shards()[0]
144            specs[key] = (
145                shard.metadata.shard_offsets,
146                shard.metadata.shard_sizes,
147            )
148            dp_pg = shard.tensor._process_group  # type: ignore[attr-defined]
149
150    return (
151        specs,
152        dp_pg,
153    )
154
155
156class _ReaderWithOffset(DefaultLoadPlanner):
157    translation: Dict[MetadataIndex, MetadataIndex]
158    state_dict: STATE_DICT_TYPE
159    metadata: Metadata
160
161    def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
162        super().__init__()
163        self.fqn_to_offset = fqn_to_offset
164        self.metadata = Metadata({})
165        self.state_dict = {}
166        self.translation = {}
167
168    def create_local_plan(self) -> LoadPlan:
169        requests = []
170        self.translation = {}
171        for fqn, obj in self.state_dict.items():
172            md = self.metadata.state_dict_metadata[fqn]
173            if not isinstance(obj, ShardedTensor):
174                requests += _create_read_items(fqn, md, obj)
175                continue
176
177            if fqn not in self.fqn_to_offset:
178                requests += _create_read_items(fqn, md, obj)
179                continue
180
181            offset = self.fqn_to_offset[fqn]
182
183            assert len(obj.local_shards()) == 1
184            original_shard = obj.local_shards()[0]
185            local_chunks = [
186                ChunkStorageMetadata(
187                    offsets=torch.Size(
188                        _element_wise_add(original_shard.metadata.shard_offsets, offset)
189                    ),
190                    sizes=torch.Size(original_shard.metadata.shard_sizes),
191                )
192            ]
193
194            reqs = create_read_items_for_chunk_list(
195                fqn, cast(TensorStorageMetadata, md), local_chunks
196            )
197            # TODO: The ReadItems will have a displaced MetadataIndex, fix it.
198            # TODO: we should change _create_sharded_read_items to have more ergonomic API
199            for ri in reqs:
200                assert ri.dest_index.offset is not None
201                original_offset = _element_wise_sub(ri.dest_index.offset, offset)
202                original_index = dataclasses.replace(
203                    ri.dest_index, offset=torch.Size(original_offset)
204                )
205                self.translation[ri.dest_index] = original_index
206
207            requests += reqs
208        return LoadPlan(requests)
209
210    def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
211        return super().lookup_tensor(self.translation.get(index, index))
212
213
214def load_sharded_optimizer_state_dict(
215    model_state_dict: STATE_DICT_TYPE,
216    optimizer_key: str,
217    storage_reader: StorageReader,
218    planner: Optional[LoadPlanner] = None,
219) -> STATE_DICT_TYPE:
220    """
221    Load a state_dict in conjunction with FSDP sharded optimizer state.
222
223    This is the current recommended way to checkpoint FSDP.
224    >>> # xdoctest: +SKIP
225    >>> import torch.distributed.checkpoint as dist_cp
226    >>> # Save
227    >>> model: torch.nn.Model
228    >>> optim_params = model.parameters()
229    >>> optim = torch.optim.SGD(optim_params, lr=0.01)
230    >>> # Save
231    >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
232    >>>     state_dict = {
233    >>>         "optimizer": FSDP.optim_state_dict(model, optim),
234    >>>         "model": model.state_dict()
235    >>>     }
236    >>>     dist_cp.save_state_dict(
237    >>>         state_dict=optim_state,
238    >>>         storage_writer=dist_cp.FileSystemWriter("checkpoint"),
239    >>>         planner=dist_cp.DefaultSavePlanner(),
240    >>>     )
241    >>>
242    >>> # Load
243    >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
244    >>>     model_state_dict = model_tp.state_dict()
245    >>>     checkpoint = {
246    >>>         "model": model_state_dict
247    >>>     }
248    >>>     dist_cp.load_state_dict(
249    >>>         state_dict=checkpoint,
250    >>>         storage_reader=dist_cp.FileSystemReader(checkpoint_file),
251    >>>         planner=dist_cp.DefaultLoadPlanner(),
252    >>>     )
253    >>>     model.load_state_dict(checkpoint["model_state"])
254    >>>
255    >>>     optim_state = dist_cp.load_sharded_optimizer_state_dict(
256    >>>         model_state_dict,
257    >>>         optimizer_key="optimizer",
258    >>>         storage_reader=dist_cp.FileSystemReader("checkpoint"),
259    >>>     )
260    >>>
261    >>>     flattened_osd = FSDP.optim_state_dict_to_load(
262    >>>        model, optim, optim_state["optimizer"]
263    >>>     )
264    >>>
265    >>>     optim.load_state_dict(flattened_osd)
266    """
267    metadata = storage_reader.read_metadata()
268
269    layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
270    dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
271    device_module = _get_device_module(dp_pg_device_type)
272
273    if dp_pg is None:
274        placements = []
275        for i in range(dist.get_world_size()):
276            device_info = _normalize_device_info(
277                dp_pg_device_type, i % device_module.device_count()
278            )
279            placements.append(f"rank:{i}/{device_info}")
280        sharding_spec = ChunkShardingSpec(dim=0, placements=placements)  # type: ignore[arg-type]
281    else:
282        sharding_spec = _create_colwise_spec(dp_pg)
283
284    # Create a state_dict for optimizer state
285    state_dict: STATE_DICT_TYPE = {}
286
287    fqn_to_offset: Dict[str, Sequence[int]] = {}
288    for key, value in metadata.state_dict_metadata.items():
289        key_path = metadata.planner_data[key]
290        if key_path[0] != optimizer_key:
291            continue
292
293        if isinstance(value, BytesStorageMetadata):
294            state_dict[key] = "<bytes_io>"
295            continue
296
297        # value: TensorStorageMetadata
298        if value.size.numel() == 1:
299            state_dict[key] = _alloc_tensor(
300                value.properties, value.size, dp_pg_device_type
301            )
302        elif dp_pg is None:
303            state_dict[key] = _create_chunk_sharded_tensor(
304                _alloc_tensor(value.properties, value.size, dp_pg_device_type),
305                rank=dist.get_rank(),
306                world_size=dist.get_world_size(),
307                num_devices_per_node=device_module.device_count(),
308                pg=_get_default_group(),
309            )
310        else:
311            spec_key = key_path[2]
312            alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
313
314            properties = ShardTensorProperties(
315                dtype=value.properties.dtype,
316                layout=value.properties.layout,
317                requires_grad=value.properties.requires_grad,
318                memory_format=value.properties.memory_format,
319                pin_memory=value.properties.pin_memory,
320            )
321
322            st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties)
323            local_shards = []
324            current_rank = dist.get_rank(dp_pg)
325            for shard_md in st_md.shards_metadata:
326                if cast(_remote_device, shard_md.placement).rank() != current_rank:
327                    continue
328                local_shards.append(
329                    Shard(
330                        tensor=_alloc_tensor(
331                            value.properties, shard_md.shard_sizes, dp_pg_device_type
332                        ),
333                        metadata=shard_md,
334                    )
335                )
336
337            st = ShardedTensor._init_from_local_shards_and_global_metadata(
338                local_shards, st_md, process_group=dp_pg
339            )
340
341            if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
342                fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
343
344            state_dict[key] = st
345
346    # Whether we unflatten before or after doesn't matter
347    load_state_dict(
348        state_dict=state_dict,
349        storage_reader=storage_reader,
350        # FIXME the type of planner is wrong in load_state_dict
351        planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
352    )
353
354    state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
355
356    return state_dict
357