xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/resharding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Tuple
3
4from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
5
6
7__all__: List[str] = []
8
9
10def _check_shard_metadata_pair_overlap(
11    shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
12):
13    """Check if two shards overlap."""
14    # For each dim of each shard, check if one shard resides on the other
15    # end of second shard with respect to that dim. As an example for a 2D
16    # shard, we would check if one shard is above or on the left of the
17    # other shard.
18    ndims = len(shard1.offsets)
19    for i in range(ndims):
20        if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
21            return False
22        if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
23            return False
24
25    return True
26
27
28def _shards_get_overlap_region_wrt_saved_tensor(
29    saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
30) -> List[Tuple[int, int, int, int]]:
31    """
32    Return the overlapping region between saved_shard and current_shard.
33
34    There returned list has the same number of elements as the tensor's dimension.
35    For each element, we produce a tuple with the following contents:
36        (dimension, `saved_shard` offset, `current_shard` offset, length)
37
38    Offsets are relative to each shard.
39    """
40    narrows = []
41    for dim, (
42        saved_shard_offset,
43        current_shard_offset,
44        saved_shard_size,
45        current_shard_size,
46    ) in enumerate(
47        zip(
48            saved_shard.offsets,
49            current_shard.offsets,
50            saved_shard.sizes,
51            current_shard.sizes,
52        )
53    ):
54        min_range_end = min(
55            saved_shard_offset + saved_shard_size,
56            current_shard_offset + current_shard_size,
57        )
58
59        length = min_range_end - max(current_shard_offset, saved_shard_offset)
60
61        if saved_shard_offset > current_shard_offset:
62            offset_for_saved_tensor = 0
63            offset_for_current_tensor = saved_shard_offset - current_shard_offset
64        else:
65            offset_for_saved_tensor = current_shard_offset - saved_shard_offset
66            offset_for_current_tensor = 0
67
68        narrows.append(
69            (dim, offset_for_saved_tensor, offset_for_current_tensor, length)
70        )
71
72    return narrows
73