1# mypy: allow-untyped-defs 2from typing import List, Tuple 3 4from torch.distributed.checkpoint.metadata import ChunkStorageMetadata 5 6 7__all__: List[str] = [] 8 9 10def _check_shard_metadata_pair_overlap( 11 shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata 12): 13 """Check if two shards overlap.""" 14 # For each dim of each shard, check if one shard resides on the other 15 # end of second shard with respect to that dim. As an example for a 2D 16 # shard, we would check if one shard is above or on the left of the 17 # other shard. 18 ndims = len(shard1.offsets) 19 for i in range(ndims): 20 if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: 21 return False 22 if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: 23 return False 24 25 return True 26 27 28def _shards_get_overlap_region_wrt_saved_tensor( 29 saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata 30) -> List[Tuple[int, int, int, int]]: 31 """ 32 Return the overlapping region between saved_shard and current_shard. 33 34 There returned list has the same number of elements as the tensor's dimension. 35 For each element, we produce a tuple with the following contents: 36 (dimension, `saved_shard` offset, `current_shard` offset, length) 37 38 Offsets are relative to each shard. 39 """ 40 narrows = [] 41 for dim, ( 42 saved_shard_offset, 43 current_shard_offset, 44 saved_shard_size, 45 current_shard_size, 46 ) in enumerate( 47 zip( 48 saved_shard.offsets, 49 current_shard.offsets, 50 saved_shard.sizes, 51 current_shard.sizes, 52 ) 53 ): 54 min_range_end = min( 55 saved_shard_offset + saved_shard_size, 56 current_shard_offset + current_shard_size, 57 ) 58 59 length = min_range_end - max(current_shard_offset, saved_shard_offset) 60 61 if saved_shard_offset > current_shard_offset: 62 offset_for_saved_tensor = 0 63 offset_for_current_tensor = saved_shard_offset - current_shard_offset 64 else: 65 offset_for_saved_tensor = current_shard_offset - saved_shard_offset 66 offset_for_current_tensor = 0 67 68 narrows.append( 69 (dim, offset_for_saved_tensor, offset_for_current_tensor, length) 70 ) 71 72 return narrows 73