# Owner(s): ["module: meta tensors"] import copy import gc import random import threading import unittest import torch from torch.testing._internal.common_utils import ( find_library_location, IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, run_tests, TestCase, ) from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary def C(): return torch.randn(1) # These tests are ported from cpython/Lib/test/test_weakref.py, # but adapted to use tensor rather than object class WeakTest(TestCase): COUNT = 10 def test_make_weak_keyed_dict_from_dict(self): o = torch.randn(2) dict = WeakIdKeyDictionary({o: 364}) self.assertEqual(dict[o], 364) def test_make_weak_keyed_dict_from_weak_keyed_dict(self): o = torch.randn(3) dict = WeakIdKeyDictionary({o: 364}) dict2 = WeakIdKeyDictionary(dict) self.assertEqual(dict[o], 364) def check_popitem(self, klass, key1, value1, key2, value2): weakdict = klass() weakdict[key1] = value1 weakdict[key2] = value2 self.assertEqual(len(weakdict), 2) k, v = weakdict.popitem() self.assertEqual(len(weakdict), 1) if k is key1: self.assertIs(v, value1) else: self.assertIs(v, value2) k, v = weakdict.popitem() self.assertEqual(len(weakdict), 0) if k is key1: self.assertIs(v, value1) else: self.assertIs(v, value2) def test_weak_keyed_dict_popitem(self): self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2") def check_setdefault(self, klass, key, value1, value2): self.assertIsNot( value1, value2, "invalid test -- value parameters must be distinct objects", ) weakdict = klass() o = weakdict.setdefault(key, value1) self.assertIs(o, value1) self.assertIn(key, weakdict) self.assertIs(weakdict.get(key), value1) self.assertIs(weakdict[key], value1) o = weakdict.setdefault(key, value2) self.assertIs(o, value1) self.assertIn(key, weakdict) self.assertIs(weakdict.get(key), value1) self.assertIs(weakdict[key], value1) def test_weak_keyed_dict_setdefault(self): self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2") def check_update(self, klass, dict): # # This exercises d.update(), len(d), d.keys(), k in d, # d.get(), d[]. # weakdict = klass() weakdict.update(dict) self.assertEqual(len(weakdict), len(dict)) for k in weakdict.keys(): self.assertIn(k, dict, "mysterious new key appeared in weak dict") v = dict.get(k) self.assertIs(v, weakdict[k]) self.assertIs(v, weakdict.get(k)) for k in dict.keys(): self.assertIn(k, weakdict, "original key disappeared in weak dict") v = dict[k] self.assertIs(v, weakdict[k]) self.assertIs(v, weakdict.get(k)) def test_weak_keyed_dict_update(self): self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3}) def test_weak_keyed_delitem(self): d = WeakIdKeyDictionary() o1 = torch.randn(1) o2 = torch.randn(2) d[o1] = "something" d[o2] = "something" self.assertEqual(len(d), 2) del d[o1] self.assertEqual(len(d), 1) self.assertEqual(list(d.keys()), [o2]) def test_weak_keyed_union_operators(self): try: {} | {} except TypeError: self.skipTest("dict union not supported in this Python") o1 = C() o2 = C() o3 = C() wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2}) wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4}) wkd3 = wkd1.copy() d1 = {o2: "5", o3: "6"} pairs = [(o2, 7), (o3, 8)] tmp1 = wkd1 | wkd2 # Between two WeakKeyDictionaries self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2)) self.assertIs(type(tmp1), WeakIdKeyDictionary) wkd1 |= wkd2 self.assertEqual(wkd1, tmp1) tmp2 = wkd2 | d1 # Between WeakKeyDictionary and mapping self.assertEqual(dict(tmp2), dict(wkd2) | d1) self.assertIs(type(tmp2), WeakIdKeyDictionary) wkd2 |= d1 self.assertEqual(wkd2, tmp2) tmp3 = wkd3.copy() # Between WeakKeyDictionary and iterable key, value tmp3 |= pairs self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs)) self.assertIs(type(tmp3), WeakIdKeyDictionary) tmp4 = d1 | wkd3 # Testing .__ror__ self.assertEqual(dict(tmp4), d1 | dict(wkd3)) self.assertIs(type(tmp4), WeakIdKeyDictionary) del o1 self.assertNotIn(4, tmp1.values()) self.assertNotIn(4, tmp2.values()) self.assertNotIn(1, tmp3.values()) self.assertNotIn(1, tmp4.values()) def test_weak_keyed_bad_delitem(self): d = WeakIdKeyDictionary() o = torch.randn(1) # An attempt to delete an object that isn't there should raise # KeyError. It didn't before 2.3. self.assertRaises(KeyError, d.__delitem__, o) self.assertRaises(KeyError, d.__getitem__, o) # If a key isn't of a weakly referencable type, __getitem__ and # __setitem__ raise TypeError. __delitem__ should too. self.assertRaises(TypeError, d.__delitem__, 13) self.assertRaises(TypeError, d.__getitem__, 13) self.assertRaises(TypeError, d.__setitem__, 13, 13) def test_make_weak_keyed_dict_repr(self): dict = WeakIdKeyDictionary() self.assertRegex(repr(dict), "") def check_threaded_weak_dict_copy(self, type_, deepcopy): # `deepcopy` should be either True or False. exc = [] # Cannot give these slots as weakrefs weren't supported # on these objects until later versions of Python class DummyKey: # noqa: B903 def __init__(self, ctr): self.ctr = ctr class DummyValue: # noqa: B903 def __init__(self, ctr): self.ctr = ctr def dict_copy(d, exc): try: if deepcopy is True: _ = copy.deepcopy(d) else: _ = d.copy() except Exception as ex: exc.append(ex) def pop_and_collect(lst): gc_ctr = 0 while lst: i = random.randint(0, len(lst) - 1) gc_ctr += 1 lst.pop(i) if gc_ctr % 10000 == 0: gc.collect() # just in case d = type_() keys = [] values = [] # Initialize d with many entries for i in range(70000): k, v = DummyKey(i), DummyValue(i) keys.append(k) values.append(v) d[k] = v del k del v t_copy = threading.Thread(target=dict_copy, args=(d, exc)) t_collect = threading.Thread(target=pop_and_collect, args=(keys,)) t_copy.start() t_collect.start() t_copy.join() t_collect.join() # Test exceptions if exc: raise exc[0] def test_threaded_weak_key_dict_copy(self): # Issue #35615: Weakref keys or values getting GC'ed during dict # copying should not result in a crash. self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False) def test_threaded_weak_key_dict_deepcopy(self): # Issue #35615: Weakref keys or values getting GC'ed during dict # copying should not result in a crash. self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True) # Adapted from cpython/Lib/test/mapping_tests.py class WeakKeyDictionaryTestCase(TestCase): __ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3} type2test = WeakIdKeyDictionary def _reference(self): return self.__ref.copy() def _empty_mapping(self): """Return an empty mapping object""" return self.type2test() def _full_mapping(self, data): """Return a mapping object with the value contained in data dictionary""" x = self._empty_mapping() for key, value in data.items(): x[key] = value return x def __init__(self, *args, **kw): unittest.TestCase.__init__(self, *args, **kw) self.reference = self._reference().copy() # A (key, value) pair not in the mapping key, value = self.reference.popitem() self.other = {key: value} # A (key, value) pair in the mapping key, value = self.reference.popitem() self.inmapping = {key: value} self.reference[key] = value def test_read(self): # Test for read only operations on mapping p = self._empty_mapping() p1 = dict(p) # workaround for singleton objects d = self._full_mapping(self.reference) if d is p: p = p1 # Indexing for key, value in self.reference.items(): self.assertEqual(d[key], value) knownkey = next(iter(self.other.keys())) self.assertRaises(KeyError, lambda: d[knownkey]) # len self.assertEqual(len(p), 0) self.assertEqual(len(d), len(self.reference)) # __contains__ for k in self.reference: self.assertIn(k, d) for k in self.other: self.assertNotIn(k, d) # cmp self.assertTrue( p == p ) # NB: don't use assertEqual, that doesn't actually use == self.assertTrue(d == d) self.assertTrue(p != d) self.assertTrue(d != p) # bool if p: self.fail("Empty mapping must compare to False") if not d: self.fail("Full mapping must compare to True") # keys(), items(), iterkeys() ... def check_iterandlist(iter, lst, ref): self.assertTrue(hasattr(iter, "__next__")) self.assertTrue(hasattr(iter, "__iter__")) x = list(iter) self.assertTrue(set(x) == set(lst) == set(ref)) check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) # get key, value = next(iter(d.items())) knownkey, knownvalue = next(iter(self.other.items())) self.assertEqual(d.get(key, knownvalue), value) self.assertEqual(d.get(knownkey, knownvalue), knownvalue) self.assertNotIn(knownkey, d) def test_write(self): # Test for write operations on mapping p = self._empty_mapping() # Indexing for key, value in self.reference.items(): p[key] = value self.assertEqual(p[key], value) for key in self.reference.keys(): del p[key] self.assertRaises(KeyError, lambda: p[key]) p = self._empty_mapping() # update p.update(self.reference) self.assertEqual(dict(p), self.reference) items = list(p.items()) p = self._empty_mapping() p.update(items) self.assertEqual(dict(p), self.reference) d = self._full_mapping(self.reference) # setdefault key, value = next(iter(d.items())) knownkey, knownvalue = next(iter(self.other.items())) self.assertEqual(d.setdefault(key, knownvalue), value) self.assertEqual(d[key], value) self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) self.assertEqual(d[knownkey], knownvalue) # pop self.assertEqual(d.pop(knownkey), knownvalue) self.assertNotIn(knownkey, d) self.assertRaises(KeyError, d.pop, knownkey) default = 909 d[knownkey] = knownvalue self.assertEqual(d.pop(knownkey, default), knownvalue) self.assertNotIn(knownkey, d) self.assertEqual(d.pop(knownkey, default), default) # popitem key, value = d.popitem() self.assertNotIn(key, d) self.assertEqual(value, self.reference[key]) p = self._empty_mapping() self.assertRaises(KeyError, p.popitem) def test_constructor(self): self.assertEqual(self._empty_mapping(), self._empty_mapping()) def test_bool(self): self.assertTrue(not self._empty_mapping()) self.assertTrue(self.reference) self.assertTrue(bool(self._empty_mapping()) is False) self.assertTrue(bool(self.reference) is True) def test_keys(self): d = self._empty_mapping() self.assertEqual(list(d.keys()), []) d = self.reference self.assertIn(next(iter(self.inmapping.keys())), d.keys()) self.assertNotIn(next(iter(self.other.keys())), d.keys()) self.assertRaises(TypeError, d.keys, None) def test_values(self): d = self._empty_mapping() self.assertEqual(list(d.values()), []) self.assertRaises(TypeError, d.values, None) def test_items(self): d = self._empty_mapping() self.assertEqual(list(d.items()), []) self.assertRaises(TypeError, d.items, None) def test_len(self): d = self._empty_mapping() self.assertEqual(len(d), 0) def test_getitem(self): d = self.reference self.assertEqual( d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) ) self.assertRaises(TypeError, d.__getitem__) def test_update(self): # mapping argument d = self._empty_mapping() d.update(self.other) self.assertEqual(list(d.items()), list(self.other.items())) # No argument d = self._empty_mapping() d.update() self.assertEqual(d, self._empty_mapping()) # item sequence d = self._empty_mapping() d.update(self.other.items()) self.assertEqual(list(d.items()), list(self.other.items())) # Iterator d = self._empty_mapping() d.update(self.other.items()) self.assertEqual(list(d.items()), list(self.other.items())) # FIXME: Doesn't work with UserDict # self.assertRaises((TypeError, AttributeError), d.update, None) self.assertRaises((TypeError, AttributeError), d.update, 42) outerself = self class SimpleUserDict: def __init__(self) -> None: self.d = outerself.reference def keys(self): return self.d.keys() def __getitem__(self, i): return self.d[i] d.clear() d.update(SimpleUserDict()) i1 = sorted((id(k), v) for k, v in d.items()) i2 = sorted((id(k), v) for k, v in self.reference.items()) self.assertEqual(i1, i2) class Exc(Exception): pass d = self._empty_mapping() class FailingUserDict: def keys(self): raise Exc self.assertRaises(Exc, d.update, FailingUserDict()) d.clear() class FailingUserDict: def keys(self): class BogonIter: def __init__(self) -> None: self.i = 1 def __iter__(self): return self def __next__(self): if self.i: self.i = 0 return "a" raise Exc return BogonIter() def __getitem__(self, key): return key self.assertRaises(Exc, d.update, FailingUserDict()) class FailingUserDict: def keys(self): class BogonIter: def __init__(self) -> None: self.i = ord("a") def __iter__(self): return self def __next__(self): if self.i <= ord("z"): rtn = chr(self.i) self.i += 1 return rtn raise StopIteration return BogonIter() def __getitem__(self, key): raise Exc self.assertRaises(Exc, d.update, FailingUserDict()) d = self._empty_mapping() class badseq: def __iter__(self): return self def __next__(self): raise Exc self.assertRaises(Exc, d.update, badseq()) self.assertRaises(ValueError, d.update, [(1, 2, 3)]) # no test_fromkeys or test_copy as both os.environ and selves don't support it def test_get(self): d = self._empty_mapping() self.assertTrue(d.get(next(iter(self.other.keys()))) is None) self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) d = self.reference self.assertTrue(d.get(next(iter(self.other.keys()))) is None) self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) self.assertEqual( d.get(next(iter(self.inmapping.keys()))), next(iter(self.inmapping.values())), ) self.assertEqual( d.get(next(iter(self.inmapping.keys())), 3), next(iter(self.inmapping.values())), ) self.assertRaises(TypeError, d.get) self.assertRaises(TypeError, d.get, None, None, None) def test_setdefault(self): d = self._empty_mapping() self.assertRaises(TypeError, d.setdefault) def test_popitem(self): d = self._empty_mapping() self.assertRaises(KeyError, d.popitem) self.assertRaises(TypeError, d.popitem, 42) def test_pop(self): d = self._empty_mapping() k, v = next(iter(self.inmapping.items())) d[k] = v self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) self.assertEqual(d.pop(k), v) self.assertEqual(len(d), 0) self.assertRaises(KeyError, d.pop, k) # Adapted from cpython/Lib/test/mapping_tests.py class WeakKeyDictionaryScriptObjectTestCase(TestCase): def _reference(self): self.__ref = { torch.classes._TorchScriptTesting._Foo(1, 2): 1, torch.classes._TorchScriptTesting._Foo(2, 3): 2, torch.classes._TorchScriptTesting._Foo(3, 4): 3, } return self.__ref.copy() def _empty_mapping(self): """Return an empty mapping object""" return WeakIdKeyDictionary(ref_type=_WeakHashRef) def _full_mapping(self, data): """Return a mapping object with the value contained in data dictionary""" x = self._empty_mapping() for key, value in data.items(): x[key] = value return x def setUp(self): if IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") def __init__(self, *args, **kw): unittest.TestCase.__init__(self, *args, **kw) if IS_SANDCASTLE or IS_FBCODE: torch.ops.load_library( "//caffe2/test/cpp/jit:test_custom_class_registrations" ) elif IS_MACOS: # don't load the library, just skip the tests in setUp return else: lib_file_path = find_library_location("libtorchbind_test.so") if IS_WINDOWS: lib_file_path = find_library_location("torchbind_test.dll") torch.ops.load_library(str(lib_file_path)) self.reference = self._reference().copy() # A (key, value) pair not in the mapping key, value = self.reference.popitem() self.other = {key: value} # A (key, value) pair in the mapping key, value = self.reference.popitem() self.inmapping = {key: value} self.reference[key] = value def test_read(self): # Test for read only operations on mapping p = self._empty_mapping() p1 = dict(p) # workaround for singleton objects d = self._full_mapping(self.reference) if d is p: p = p1 # Indexing for key, value in self.reference.items(): self.assertEqual(d[key], value) knownkey = next(iter(self.other.keys())) self.assertRaises(KeyError, lambda: d[knownkey]) # len self.assertEqual(len(p), 0) self.assertEqual(len(d), len(self.reference)) # __contains__ for k in self.reference: self.assertIn(k, d) for k in self.other: self.assertNotIn(k, d) # cmp self.assertTrue( p == p ) # NB: don't use assertEqual, that doesn't actually use == self.assertTrue(d == d) self.assertTrue(p != d) self.assertTrue(d != p) # bool if p: self.fail("Empty mapping must compare to False") if not d: self.fail("Full mapping must compare to True") # keys(), items(), iterkeys() ... def check_iterandlist(iter, lst, ref): self.assertTrue(hasattr(iter, "__next__")) self.assertTrue(hasattr(iter, "__iter__")) x = list(iter) self.assertTrue(set(x) == set(lst) == set(ref)) check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) # get key, value = next(iter(d.items())) knownkey, knownvalue = next(iter(self.other.items())) self.assertEqual(d.get(key, knownvalue), value) self.assertEqual(d.get(knownkey, knownvalue), knownvalue) self.assertNotIn(knownkey, d) def test_write(self): # Test for write operations on mapping p = self._empty_mapping() # Indexing for key, value in self.reference.items(): p[key] = value self.assertEqual(p[key], value) for key in self.reference.keys(): del p[key] self.assertRaises(KeyError, lambda: p[key]) p = self._empty_mapping() # update p.update(self.reference) self.assertEqual(dict(p), self.reference) items = list(p.items()) p = self._empty_mapping() p.update(items) self.assertEqual(dict(p), self.reference) d = self._full_mapping(self.reference) # setdefault key, value = next(iter(d.items())) knownkey, knownvalue = next(iter(self.other.items())) self.assertEqual(d.setdefault(key, knownvalue), value) self.assertEqual(d[key], value) self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) self.assertEqual(d[knownkey], knownvalue) # pop self.assertEqual(d.pop(knownkey), knownvalue) self.assertNotIn(knownkey, d) self.assertRaises(KeyError, d.pop, knownkey) default = 909 d[knownkey] = knownvalue self.assertEqual(d.pop(knownkey, default), knownvalue) self.assertNotIn(knownkey, d) self.assertEqual(d.pop(knownkey, default), default) # popitem key, value = d.popitem() self.assertNotIn(key, d) self.assertEqual(value, self.reference[key]) p = self._empty_mapping() self.assertRaises(KeyError, p.popitem) def test_constructor(self): self.assertEqual(self._empty_mapping(), self._empty_mapping()) def test_bool(self): self.assertTrue(not self._empty_mapping()) self.assertTrue(self.reference) self.assertTrue(bool(self._empty_mapping()) is False) self.assertTrue(bool(self.reference) is True) def test_keys(self): d = self._empty_mapping() self.assertEqual(list(d.keys()), []) d = self.reference self.assertIn(next(iter(self.inmapping.keys())), d.keys()) self.assertNotIn(next(iter(self.other.keys())), d.keys()) self.assertRaises(TypeError, d.keys, None) def test_values(self): d = self._empty_mapping() self.assertEqual(list(d.values()), []) self.assertRaises(TypeError, d.values, None) def test_items(self): d = self._empty_mapping() self.assertEqual(list(d.items()), []) self.assertRaises(TypeError, d.items, None) def test_len(self): d = self._empty_mapping() self.assertEqual(len(d), 0) def test_getitem(self): d = self.reference self.assertEqual( d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) ) self.assertRaises(TypeError, d.__getitem__) def test_update(self): # mapping argument d = self._empty_mapping() d.update(self.other) self.assertEqual(list(d.items()), list(self.other.items())) # No argument d = self._empty_mapping() d.update() self.assertEqual(d, self._empty_mapping()) # item sequence d = self._empty_mapping() d.update(self.other.items()) self.assertEqual(list(d.items()), list(self.other.items())) # Iterator d = self._empty_mapping() d.update(self.other.items()) self.assertEqual(list(d.items()), list(self.other.items())) # FIXME: Doesn't work with UserDict # self.assertRaises((TypeError, AttributeError), d.update, None) self.assertRaises((TypeError, AttributeError), d.update, 42) outerself = self class SimpleUserDict: def __init__(self) -> None: self.d = outerself.reference def keys(self): return self.d.keys() def __getitem__(self, i): return self.d[i] d.clear() d.update(SimpleUserDict()) i1 = sorted((id(k), v) for k, v in d.items()) i2 = sorted((id(k), v) for k, v in self.reference.items()) self.assertEqual(i1, i2) class Exc(Exception): pass d = self._empty_mapping() class FailingUserDict: def keys(self): raise Exc self.assertRaises(Exc, d.update, FailingUserDict()) d.clear() class FailingUserDict: def keys(self): class BogonIter: def __init__(self) -> None: self.i = 1 def __iter__(self): return self def __next__(self): if self.i: self.i = 0 return "a" raise Exc return BogonIter() def __getitem__(self, key): return key self.assertRaises(Exc, d.update, FailingUserDict()) class FailingUserDict: def keys(self): class BogonIter: def __init__(self) -> None: self.i = ord("a") def __iter__(self): return self def __next__(self): if self.i <= ord("z"): rtn = chr(self.i) self.i += 1 return rtn raise StopIteration return BogonIter() def __getitem__(self, key): raise Exc self.assertRaises(Exc, d.update, FailingUserDict()) d = self._empty_mapping() class badseq: def __iter__(self): return self def __next__(self): raise Exc self.assertRaises(Exc, d.update, badseq()) self.assertRaises(ValueError, d.update, [(1, 2, 3)]) # no test_fromkeys or test_copy as both os.environ and selves don't support it def test_get(self): d = self._empty_mapping() self.assertTrue(d.get(next(iter(self.other.keys()))) is None) self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) d = self.reference self.assertTrue(d.get(next(iter(self.other.keys()))) is None) self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) self.assertEqual( d.get(next(iter(self.inmapping.keys()))), next(iter(self.inmapping.values())), ) self.assertEqual( d.get(next(iter(self.inmapping.keys())), 3), next(iter(self.inmapping.values())), ) self.assertRaises(TypeError, d.get) self.assertRaises(TypeError, d.get, None, None, None) def test_setdefault(self): d = self._empty_mapping() self.assertRaises(TypeError, d.setdefault) def test_popitem(self): d = self._empty_mapping() self.assertRaises(KeyError, d.popitem) self.assertRaises(TypeError, d.popitem, 42) def test_pop(self): d = self._empty_mapping() k, v = next(iter(self.inmapping.items())) d[k] = v self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) self.assertEqual(d.pop(k), v) self.assertEqual(len(d), 0) self.assertRaises(KeyError, d.pop, k) if __name__ == "__main__": run_tests()