1# Copyright (c) Meta Platforms, Inc. and affiliates 2 3import dataclasses 4from typing import cast, Dict, List, Optional, Sequence, Tuple, Union 5 6import torch 7import torch.distributed as dist 8from torch._utils import _get_device_module 9from torch.distributed._shard.sharded_tensor.api import ShardedTensor 10from torch.distributed._shard.sharded_tensor.metadata import ( 11 TensorProperties as ShardTensorProperties, 12) 13from torch.distributed._shard.sharded_tensor.shard import Shard 14from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec 15from torch.distributed.checkpoint._nested_dict import unflatten_state_dict 16from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner 17from torch.distributed.checkpoint.metadata import ( 18 BytesStorageMetadata, 19 ChunkStorageMetadata, 20 Metadata, 21 MetadataIndex, 22 STATE_DICT_TYPE, 23 TensorProperties, 24 TensorStorageMetadata, 25) 26from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner 27from torch.distributed.checkpoint.planner_helpers import ( 28 _create_read_items, 29 create_read_items_for_chunk_list, 30) 31from torch.distributed.checkpoint.state_dict_loader import load_state_dict 32from torch.distributed.checkpoint.storage import StorageReader 33from torch.distributed.checkpoint.utils import ( 34 _element_wise_add, 35 _element_wise_sub, 36 _normalize_device_info, 37) 38from torch.distributed.distributed_c10d import _get_default_group 39from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor 40from torch.distributed.remote_device import _remote_device 41from torch.distributed.tensor import DTensor 42 43 44STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] 45 46 47# TODO: Update docstrings for optimizer.py 48__all__ = [ 49 "load_sharded_optimizer_state_dict", 50] 51 52 53def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: 54 if device_type == "cpu": 55 return "cpu" 56 device_module = _get_device_module(device_type) 57 if device_module.is_available(): 58 return _normalize_device_info( 59 device_type, global_rank % device_module.device_count() 60 ) 61 return "cpu" 62 63 64def _create_colwise_spec( 65 pg: Optional[dist.ProcessGroup] = None, 66) -> ChunkShardingSpec: 67 pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type 68 if pg is None: 69 placements = [ 70 f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" 71 for idx in range(dist.get_world_size()) 72 ] 73 else: 74 placements = [ 75 f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" 76 for idx in range(pg.size()) 77 ] 78 return ChunkShardingSpec( 79 dim=0, 80 placements=cast(List[Union[_remote_device, str]], placements), 81 ) 82 83 84def _is_nested_tensor(val: torch.Tensor) -> bool: 85 if type(val) is ShardedTensor: 86 if len(val.local_shards()) == 0: 87 return False 88 if type(val.local_shards()[0].tensor) is ShardedTensor: 89 return True 90 if type(val.local_shards()[0].tensor) is DTensor: 91 raise ValueError("Cannot handle DTensor nested insided ShardedTensor") 92 elif type(val) is DTensor and ( 93 type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor 94 ): 95 raise ValueError("Cannot handle nested DTensor") 96 return False 97 98 99def _alloc_tensor( 100 props: TensorProperties, size: Sequence[int], device_type: str = "cuda" 101) -> torch.Tensor: 102 if device_type == "cpu": 103 device = cast(torch.device, _get_device_module(device_type).current_device()) 104 else: 105 device = torch.device( 106 device_type, _get_device_module(device_type).current_device() 107 ) 108 109 return torch.empty( 110 size=size, 111 dtype=props.dtype, 112 layout=props.layout, 113 requires_grad=props.requires_grad, 114 pin_memory=props.pin_memory, 115 device=device, 116 ) 117 118 119def _get_state_dict_2d_layout( 120 state_dict: STATE_DICT_TYPE, 121) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: 122 """ 123 Load the right TP slice of the optimizer state. 124 125 This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. 126 We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. 127 This is pretty fragile and it might be easier for FSDP to compute this info for us. 128 Returns a dictionary where keys are the same of the state_dict and the value is a tuple of 129 (offset, size) for the current rank TP slice. 130 N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. 131 """ 132 specs: STATE_DICT_2D_LAYOUT = {} 133 dp_pg: Optional[dist.ProcessGroup] = None 134 for key, value in state_dict.items(): 135 specs[key] = (None, value.size()) 136 if _is_nested_tensor(value): 137 assert ( 138 len(value.local_shards()) == 1 139 ), "Cannot handle ST with multiple shards" 140 assert isinstance( 141 value, ShardedTensor 142 ), "Can only handle nested ShardedTensor" 143 shard = value.local_shards()[0] 144 specs[key] = ( 145 shard.metadata.shard_offsets, 146 shard.metadata.shard_sizes, 147 ) 148 dp_pg = shard.tensor._process_group # type: ignore[attr-defined] 149 150 return ( 151 specs, 152 dp_pg, 153 ) 154 155 156class _ReaderWithOffset(DefaultLoadPlanner): 157 translation: Dict[MetadataIndex, MetadataIndex] 158 state_dict: STATE_DICT_TYPE 159 metadata: Metadata 160 161 def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None: 162 super().__init__() 163 self.fqn_to_offset = fqn_to_offset 164 self.metadata = Metadata({}) 165 self.state_dict = {} 166 self.translation = {} 167 168 def create_local_plan(self) -> LoadPlan: 169 requests = [] 170 self.translation = {} 171 for fqn, obj in self.state_dict.items(): 172 md = self.metadata.state_dict_metadata[fqn] 173 if not isinstance(obj, ShardedTensor): 174 requests += _create_read_items(fqn, md, obj) 175 continue 176 177 if fqn not in self.fqn_to_offset: 178 requests += _create_read_items(fqn, md, obj) 179 continue 180 181 offset = self.fqn_to_offset[fqn] 182 183 assert len(obj.local_shards()) == 1 184 original_shard = obj.local_shards()[0] 185 local_chunks = [ 186 ChunkStorageMetadata( 187 offsets=torch.Size( 188 _element_wise_add(original_shard.metadata.shard_offsets, offset) 189 ), 190 sizes=torch.Size(original_shard.metadata.shard_sizes), 191 ) 192 ] 193 194 reqs = create_read_items_for_chunk_list( 195 fqn, cast(TensorStorageMetadata, md), local_chunks 196 ) 197 # TODO: The ReadItems will have a displaced MetadataIndex, fix it. 198 # TODO: we should change _create_sharded_read_items to have more ergonomic API 199 for ri in reqs: 200 assert ri.dest_index.offset is not None 201 original_offset = _element_wise_sub(ri.dest_index.offset, offset) 202 original_index = dataclasses.replace( 203 ri.dest_index, offset=torch.Size(original_offset) 204 ) 205 self.translation[ri.dest_index] = original_index 206 207 requests += reqs 208 return LoadPlan(requests) 209 210 def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: 211 return super().lookup_tensor(self.translation.get(index, index)) 212 213 214def load_sharded_optimizer_state_dict( 215 model_state_dict: STATE_DICT_TYPE, 216 optimizer_key: str, 217 storage_reader: StorageReader, 218 planner: Optional[LoadPlanner] = None, 219) -> STATE_DICT_TYPE: 220 """ 221 Load a state_dict in conjunction with FSDP sharded optimizer state. 222 223 This is the current recommended way to checkpoint FSDP. 224 >>> # xdoctest: +SKIP 225 >>> import torch.distributed.checkpoint as dist_cp 226 >>> # Save 227 >>> model: torch.nn.Model 228 >>> optim_params = model.parameters() 229 >>> optim = torch.optim.SGD(optim_params, lr=0.01) 230 >>> # Save 231 >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 232 >>> state_dict = { 233 >>> "optimizer": FSDP.optim_state_dict(model, optim), 234 >>> "model": model.state_dict() 235 >>> } 236 >>> dist_cp.save_state_dict( 237 >>> state_dict=optim_state, 238 >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), 239 >>> planner=dist_cp.DefaultSavePlanner(), 240 >>> ) 241 >>> 242 >>> # Load 243 >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): 244 >>> model_state_dict = model_tp.state_dict() 245 >>> checkpoint = { 246 >>> "model": model_state_dict 247 >>> } 248 >>> dist_cp.load_state_dict( 249 >>> state_dict=checkpoint, 250 >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), 251 >>> planner=dist_cp.DefaultLoadPlanner(), 252 >>> ) 253 >>> model.load_state_dict(checkpoint["model_state"]) 254 >>> 255 >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( 256 >>> model_state_dict, 257 >>> optimizer_key="optimizer", 258 >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), 259 >>> ) 260 >>> 261 >>> flattened_osd = FSDP.optim_state_dict_to_load( 262 >>> model, optim, optim_state["optimizer"] 263 >>> ) 264 >>> 265 >>> optim.load_state_dict(flattened_osd) 266 """ 267 metadata = storage_reader.read_metadata() 268 269 layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) 270 dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type 271 device_module = _get_device_module(dp_pg_device_type) 272 273 if dp_pg is None: 274 placements = [] 275 for i in range(dist.get_world_size()): 276 device_info = _normalize_device_info( 277 dp_pg_device_type, i % device_module.device_count() 278 ) 279 placements.append(f"rank:{i}/{device_info}") 280 sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] 281 else: 282 sharding_spec = _create_colwise_spec(dp_pg) 283 284 # Create a state_dict for optimizer state 285 state_dict: STATE_DICT_TYPE = {} 286 287 fqn_to_offset: Dict[str, Sequence[int]] = {} 288 for key, value in metadata.state_dict_metadata.items(): 289 key_path = metadata.planner_data[key] 290 if key_path[0] != optimizer_key: 291 continue 292 293 if isinstance(value, BytesStorageMetadata): 294 state_dict[key] = "<bytes_io>" 295 continue 296 297 # value: TensorStorageMetadata 298 if value.size.numel() == 1: 299 state_dict[key] = _alloc_tensor( 300 value.properties, value.size, dp_pg_device_type 301 ) 302 elif dp_pg is None: 303 state_dict[key] = _create_chunk_sharded_tensor( 304 _alloc_tensor(value.properties, value.size, dp_pg_device_type), 305 rank=dist.get_rank(), 306 world_size=dist.get_world_size(), 307 num_devices_per_node=device_module.device_count(), 308 pg=_get_default_group(), 309 ) 310 else: 311 spec_key = key_path[2] 312 alloc_size = layout_specs.get(spec_key, (None, value.size))[1] 313 314 properties = ShardTensorProperties( 315 dtype=value.properties.dtype, 316 layout=value.properties.layout, 317 requires_grad=value.properties.requires_grad, 318 memory_format=value.properties.memory_format, 319 pin_memory=value.properties.pin_memory, 320 ) 321 322 st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) 323 local_shards = [] 324 current_rank = dist.get_rank(dp_pg) 325 for shard_md in st_md.shards_metadata: 326 if cast(_remote_device, shard_md.placement).rank() != current_rank: 327 continue 328 local_shards.append( 329 Shard( 330 tensor=_alloc_tensor( 331 value.properties, shard_md.shard_sizes, dp_pg_device_type 332 ), 333 metadata=shard_md, 334 ) 335 ) 336 337 st = ShardedTensor._init_from_local_shards_and_global_metadata( 338 local_shards, st_md, process_group=dp_pg 339 ) 340 341 if spec_key in layout_specs and layout_specs[spec_key][0] is not None: 342 fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) 343 344 state_dict[key] = st 345 346 # Whether we unflatten before or after doesn't matter 347 load_state_dict( 348 state_dict=state_dict, 349 storage_reader=storage_reader, 350 # FIXME the type of planner is wrong in load_state_dict 351 planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, 352 ) 353 354 state_dict = unflatten_state_dict(state_dict, metadata.planner_data) 355 356 return state_dict 357