1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch.distributed._shard.sharded_tensor import ( 7 Shard, 8 ShardedTensor, 9 ShardedTensorMetadata, 10 ShardMetadata, 11) 12from torch.distributed._shard.sharded_tensor.metadata import TensorProperties 13from torch.distributed.c10d_logger import _c10d_logger 14from torch.distributed.checkpoint.logger import _dcp_logger 15from torch.distributed.checkpoint.metadata import MetadataIndex 16from torch.distributed.checkpoint.utils import find_state_dict_object 17from torch.testing._internal.common_utils import ( 18 run_tests, 19 TEST_WITH_DEV_DBG_ASAN, 20 TestCase, 21) 22from torch.testing._internal.distributed.distributed_utils import with_fake_comms 23 24 25if TEST_WITH_DEV_DBG_ASAN: 26 print( 27 "Skip dev-asan as torch + multiprocessing spawn have known issues", 28 file=sys.stderr, 29 ) 30 sys.exit(0) 31 32 33def create_sharded_tensor(rank, world_size, shards_per_rank): 34 shards_metadata = [] 35 local_shards = [] 36 for idx in range(0, world_size * shards_per_rank): 37 shard_rank = idx // shards_per_rank 38 shard_md = ShardMetadata( 39 shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu" 40 ) 41 shards_metadata.append(shard_md) 42 if shard_rank == rank: 43 shard = Shard.from_tensor_and_offsets( 44 torch.rand(*shard_md.shard_sizes), 45 shard_offsets=shard_md.shard_offsets, 46 rank=rank, 47 ) 48 local_shards.append(shard) 49 50 sharded_tensor_md = ShardedTensorMetadata( 51 shards_metadata=shards_metadata, 52 size=torch.Size([8 * len(shards_metadata)]), 53 tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)), 54 ) 55 56 return ShardedTensor._init_from_local_shards_and_global_metadata( 57 local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md 58 ) 59 60 61class TestMedatadaIndex(TestCase): 62 def test_init_convert_offset(self): 63 a = MetadataIndex("foo", [1, 2]) 64 b = MetadataIndex("foo", torch.Size([1, 2])) 65 self.assertEqual(a, b) 66 67 def test_index_hint_ignored_on_equals(self): 68 a = MetadataIndex("foo") 69 b = MetadataIndex("foo", index=99) 70 self.assertEqual(a, b) 71 72 def test_index_hint_ignored_on_hash(self): 73 a = MetadataIndex("foo") 74 b = MetadataIndex("foo", index=99) 75 self.assertEqual(hash(a), hash(b)) 76 77 def test_flat_data(self): 78 state_dict = { 79 "a": torch.rand(10), 80 "b": [1, 2, 3], 81 } 82 83 a = find_state_dict_object(state_dict, MetadataIndex("a")) 84 self.assertEqual(a, state_dict["a"]) 85 a = find_state_dict_object(state_dict, MetadataIndex("a", [0])) 86 self.assertEqual(a, state_dict["a"]) 87 a = find_state_dict_object(state_dict, MetadataIndex("a", index=99)) 88 self.assertEqual(a, state_dict["a"]) 89 90 b = find_state_dict_object(state_dict, MetadataIndex("b")) 91 self.assertEqual(b, state_dict["b"]) 92 b = find_state_dict_object(state_dict, MetadataIndex("b", index=1)) 93 self.assertEqual(b, state_dict["b"]) 94 95 with self.assertRaisesRegex(ValueError, "FQN"): 96 find_state_dict_object(state_dict, MetadataIndex("c")) 97 with self.assertRaisesRegex(ValueError, "ShardedTensor"): 98 find_state_dict_object(state_dict, MetadataIndex("b", [1])) 99 100 @with_fake_comms(rank=0, world_size=2) 101 def test_sharded_tensor_lookup(self): 102 st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3) 103 state_dict = {"st": st} 104 105 obj = find_state_dict_object(state_dict, MetadataIndex("st", [8])) 106 self.assertEqual(obj, st.local_shards()[1].tensor) 107 108 # good hint 109 obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1)) 110 self.assertEqual(obj, st.local_shards()[1].tensor) 111 112 # bad hint 113 obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2)) 114 self.assertEqual(obj, st.local_shards()[1].tensor) 115 116 # broken hint 117 obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99)) 118 self.assertEqual(obj, st.local_shards()[1].tensor) 119 120 with self.assertRaisesRegex(ValueError, "no offset was provided"): 121 find_state_dict_object(state_dict, MetadataIndex("st")) 122 123 with self.assertRaisesRegex(ValueError, "Could not find shard"): 124 find_state_dict_object(state_dict, MetadataIndex("st", [1])) 125 126 def test_dcp_logger(self): 127 self.assertTrue(_c10d_logger is not _dcp_logger) 128 self.assertEqual(1, len(_c10d_logger.handlers)) 129 130 131if __name__ == "__main__": 132 run_tests() 133