1import unittest 2from weakref import WeakSet 3import copy 4import string 5from collections import UserString as ustr 6from collections.abc import Set, MutableSet 7import gc 8import contextlib 9from test import support 10 11 12class Foo: 13 pass 14 15class RefCycle: 16 def __init__(self): 17 self.cycle = self 18 19class WeakSetSubclass(WeakSet): 20 pass 21 22class WeakSetWithSlots(WeakSet): 23 __slots__ = ('x', 'y') 24 25 26class TestWeakSet(unittest.TestCase): 27 28 def setUp(self): 29 # need to keep references to them 30 self.items = [ustr(c) for c in ('a', 'b', 'c')] 31 self.items2 = [ustr(c) for c in ('x', 'y', 'z')] 32 self.ab_items = [ustr(c) for c in 'ab'] 33 self.abcde_items = [ustr(c) for c in 'abcde'] 34 self.def_items = [ustr(c) for c in 'def'] 35 self.ab_weakset = WeakSet(self.ab_items) 36 self.abcde_weakset = WeakSet(self.abcde_items) 37 self.def_weakset = WeakSet(self.def_items) 38 self.letters = [ustr(c) for c in string.ascii_letters] 39 self.s = WeakSet(self.items) 40 self.d = dict.fromkeys(self.items) 41 self.obj = ustr('F') 42 self.fs = WeakSet([self.obj]) 43 44 def test_methods(self): 45 weaksetmethods = dir(WeakSet) 46 for method in dir(set): 47 if method == 'test_c_api' or method.startswith('_'): 48 continue 49 self.assertIn(method, weaksetmethods, 50 "WeakSet missing method " + method) 51 52 def test_new_or_init(self): 53 self.assertRaises(TypeError, WeakSet, [], 2) 54 55 def test_len(self): 56 self.assertEqual(len(self.s), len(self.d)) 57 self.assertEqual(len(self.fs), 1) 58 del self.obj 59 support.gc_collect() # For PyPy or other GCs. 60 self.assertEqual(len(self.fs), 0) 61 62 def test_contains(self): 63 for c in self.letters: 64 self.assertEqual(c in self.s, c in self.d) 65 # 1 is not weakref'able, but that TypeError is caught by __contains__ 66 self.assertNotIn(1, self.s) 67 self.assertIn(self.obj, self.fs) 68 del self.obj 69 support.gc_collect() # For PyPy or other GCs. 70 self.assertNotIn(ustr('F'), self.fs) 71 72 def test_union(self): 73 u = self.s.union(self.items2) 74 for c in self.letters: 75 self.assertEqual(c in u, c in self.d or c in self.items2) 76 self.assertEqual(self.s, WeakSet(self.items)) 77 self.assertEqual(type(u), WeakSet) 78 self.assertRaises(TypeError, self.s.union, [[]]) 79 for C in set, frozenset, dict.fromkeys, list, tuple: 80 x = WeakSet(self.items + self.items2) 81 c = C(self.items2) 82 self.assertEqual(self.s.union(c), x) 83 del c 84 self.assertEqual(len(u), len(self.items) + len(self.items2)) 85 self.items2.pop() 86 gc.collect() 87 self.assertEqual(len(u), len(self.items) + len(self.items2)) 88 89 def test_or(self): 90 i = self.s.union(self.items2) 91 self.assertEqual(self.s | set(self.items2), i) 92 self.assertEqual(self.s | frozenset(self.items2), i) 93 94 def test_intersection(self): 95 s = WeakSet(self.letters) 96 i = s.intersection(self.items2) 97 for c in self.letters: 98 self.assertEqual(c in i, c in self.items2 and c in self.letters) 99 self.assertEqual(s, WeakSet(self.letters)) 100 self.assertEqual(type(i), WeakSet) 101 for C in set, frozenset, dict.fromkeys, list, tuple: 102 x = WeakSet([]) 103 self.assertEqual(i.intersection(C(self.items)), x) 104 self.assertEqual(len(i), len(self.items2)) 105 self.items2.pop() 106 gc.collect() 107 self.assertEqual(len(i), len(self.items2)) 108 109 def test_isdisjoint(self): 110 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 111 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 112 113 def test_and(self): 114 i = self.s.intersection(self.items2) 115 self.assertEqual(self.s & set(self.items2), i) 116 self.assertEqual(self.s & frozenset(self.items2), i) 117 118 def test_difference(self): 119 i = self.s.difference(self.items2) 120 for c in self.letters: 121 self.assertEqual(c in i, c in self.d and c not in self.items2) 122 self.assertEqual(self.s, WeakSet(self.items)) 123 self.assertEqual(type(i), WeakSet) 124 self.assertRaises(TypeError, self.s.difference, [[]]) 125 126 def test_sub(self): 127 i = self.s.difference(self.items2) 128 self.assertEqual(self.s - set(self.items2), i) 129 self.assertEqual(self.s - frozenset(self.items2), i) 130 131 def test_symmetric_difference(self): 132 i = self.s.symmetric_difference(self.items2) 133 for c in self.letters: 134 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 135 self.assertEqual(self.s, WeakSet(self.items)) 136 self.assertEqual(type(i), WeakSet) 137 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 138 self.assertEqual(len(i), len(self.items) + len(self.items2)) 139 self.items2.pop() 140 gc.collect() 141 self.assertEqual(len(i), len(self.items) + len(self.items2)) 142 143 def test_xor(self): 144 i = self.s.symmetric_difference(self.items2) 145 self.assertEqual(self.s ^ set(self.items2), i) 146 self.assertEqual(self.s ^ frozenset(self.items2), i) 147 148 def test_sub_and_super(self): 149 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 150 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 151 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 152 self.assertFalse(self.abcde_weakset <= self.def_weakset) 153 self.assertFalse(self.abcde_weakset >= self.def_weakset) 154 self.assertTrue(set('a').issubset('abc')) 155 self.assertTrue(set('abc').issuperset('a')) 156 self.assertFalse(set('a').issubset('cbs')) 157 self.assertFalse(set('cbs').issuperset('a')) 158 159 def test_lt(self): 160 self.assertTrue(self.ab_weakset < self.abcde_weakset) 161 self.assertFalse(self.abcde_weakset < self.def_weakset) 162 self.assertFalse(self.ab_weakset < self.ab_weakset) 163 self.assertFalse(WeakSet() < WeakSet()) 164 165 def test_gt(self): 166 self.assertTrue(self.abcde_weakset > self.ab_weakset) 167 self.assertFalse(self.abcde_weakset > self.def_weakset) 168 self.assertFalse(self.ab_weakset > self.ab_weakset) 169 self.assertFalse(WeakSet() > WeakSet()) 170 171 def test_gc(self): 172 # Create a nest of cycles to exercise overall ref count check 173 s = WeakSet(Foo() for i in range(1000)) 174 for elem in s: 175 elem.cycle = s 176 elem.sub = elem 177 elem.set = WeakSet([elem]) 178 179 def test_subclass_with_custom_hash(self): 180 # Bug #1257731 181 class H(WeakSet): 182 def __hash__(self): 183 return int(id(self) & 0x7fffffff) 184 s=H() 185 f=set() 186 f.add(s) 187 self.assertIn(s, f) 188 f.remove(s) 189 f.add(s) 190 f.discard(s) 191 192 def test_init(self): 193 s = WeakSet() 194 s.__init__(self.items) 195 self.assertEqual(s, self.s) 196 s.__init__(self.items2) 197 self.assertEqual(s, WeakSet(self.items2)) 198 self.assertRaises(TypeError, s.__init__, s, 2); 199 self.assertRaises(TypeError, s.__init__, 1); 200 201 def test_constructor_identity(self): 202 s = WeakSet(self.items) 203 t = WeakSet(s) 204 self.assertNotEqual(id(s), id(t)) 205 206 def test_hash(self): 207 self.assertRaises(TypeError, hash, self.s) 208 209 def test_clear(self): 210 self.s.clear() 211 self.assertEqual(self.s, WeakSet([])) 212 self.assertEqual(len(self.s), 0) 213 214 def test_copy(self): 215 dup = self.s.copy() 216 self.assertEqual(self.s, dup) 217 self.assertNotEqual(id(self.s), id(dup)) 218 219 def test_add(self): 220 x = ustr('Q') 221 self.s.add(x) 222 self.assertIn(x, self.s) 223 dup = self.s.copy() 224 self.s.add(x) 225 self.assertEqual(self.s, dup) 226 self.assertRaises(TypeError, self.s.add, []) 227 self.fs.add(Foo()) 228 support.gc_collect() # For PyPy or other GCs. 229 self.assertTrue(len(self.fs) == 1) 230 self.fs.add(self.obj) 231 self.assertTrue(len(self.fs) == 1) 232 233 def test_remove(self): 234 x = ustr('a') 235 self.s.remove(x) 236 self.assertNotIn(x, self.s) 237 self.assertRaises(KeyError, self.s.remove, x) 238 self.assertRaises(TypeError, self.s.remove, []) 239 240 def test_discard(self): 241 a, q = ustr('a'), ustr('Q') 242 self.s.discard(a) 243 self.assertNotIn(a, self.s) 244 self.s.discard(q) 245 self.assertRaises(TypeError, self.s.discard, []) 246 247 def test_pop(self): 248 for i in range(len(self.s)): 249 elem = self.s.pop() 250 self.assertNotIn(elem, self.s) 251 self.assertRaises(KeyError, self.s.pop) 252 253 def test_update(self): 254 retval = self.s.update(self.items2) 255 self.assertEqual(retval, None) 256 for c in (self.items + self.items2): 257 self.assertIn(c, self.s) 258 self.assertRaises(TypeError, self.s.update, [[]]) 259 260 def test_update_set(self): 261 self.s.update(set(self.items2)) 262 for c in (self.items + self.items2): 263 self.assertIn(c, self.s) 264 265 def test_ior(self): 266 self.s |= set(self.items2) 267 for c in (self.items + self.items2): 268 self.assertIn(c, self.s) 269 270 def test_intersection_update(self): 271 retval = self.s.intersection_update(self.items2) 272 self.assertEqual(retval, None) 273 for c in (self.items + self.items2): 274 if c in self.items2 and c in self.items: 275 self.assertIn(c, self.s) 276 else: 277 self.assertNotIn(c, self.s) 278 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 279 280 def test_iand(self): 281 self.s &= set(self.items2) 282 for c in (self.items + self.items2): 283 if c in self.items2 and c in self.items: 284 self.assertIn(c, self.s) 285 else: 286 self.assertNotIn(c, self.s) 287 288 def test_difference_update(self): 289 retval = self.s.difference_update(self.items2) 290 self.assertEqual(retval, None) 291 for c in (self.items + self.items2): 292 if c in self.items and c not in self.items2: 293 self.assertIn(c, self.s) 294 else: 295 self.assertNotIn(c, self.s) 296 self.assertRaises(TypeError, self.s.difference_update, [[]]) 297 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 298 299 def test_isub(self): 300 self.s -= set(self.items2) 301 for c in (self.items + self.items2): 302 if c in self.items and c not in self.items2: 303 self.assertIn(c, self.s) 304 else: 305 self.assertNotIn(c, self.s) 306 307 def test_symmetric_difference_update(self): 308 retval = self.s.symmetric_difference_update(self.items2) 309 self.assertEqual(retval, None) 310 for c in (self.items + self.items2): 311 if (c in self.items) ^ (c in self.items2): 312 self.assertIn(c, self.s) 313 else: 314 self.assertNotIn(c, self.s) 315 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 316 317 def test_ixor(self): 318 self.s ^= set(self.items2) 319 for c in (self.items + self.items2): 320 if (c in self.items) ^ (c in self.items2): 321 self.assertIn(c, self.s) 322 else: 323 self.assertNotIn(c, self.s) 324 325 def test_inplace_on_self(self): 326 t = self.s.copy() 327 t |= t 328 self.assertEqual(t, self.s) 329 t &= t 330 self.assertEqual(t, self.s) 331 t -= t 332 self.assertEqual(t, WeakSet()) 333 t = self.s.copy() 334 t ^= t 335 self.assertEqual(t, WeakSet()) 336 337 def test_eq(self): 338 # issue 5964 339 self.assertTrue(self.s == self.s) 340 self.assertTrue(self.s == WeakSet(self.items)) 341 self.assertFalse(self.s == set(self.items)) 342 self.assertFalse(self.s == list(self.items)) 343 self.assertFalse(self.s == tuple(self.items)) 344 self.assertFalse(self.s == WeakSet([Foo])) 345 self.assertFalse(self.s == 1) 346 347 def test_ne(self): 348 self.assertTrue(self.s != set(self.items)) 349 s1 = WeakSet() 350 s2 = WeakSet() 351 self.assertFalse(s1 != s2) 352 353 def test_weak_destroy_while_iterating(self): 354 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 355 # Create new items to be sure no-one else holds a reference 356 items = [ustr(c) for c in ('a', 'b', 'c')] 357 s = WeakSet(items) 358 it = iter(s) 359 next(it) # Trigger internal iteration 360 # Destroy an item 361 del items[-1] 362 gc.collect() # just in case 363 # We have removed either the first consumed items, or another one 364 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 365 del it 366 # The removal has been committed 367 self.assertEqual(len(s), len(items)) 368 369 def test_weak_destroy_and_mutate_while_iterating(self): 370 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 371 items = [ustr(c) for c in string.ascii_letters] 372 s = WeakSet(items) 373 @contextlib.contextmanager 374 def testcontext(): 375 try: 376 it = iter(s) 377 # Start iterator 378 yielded = ustr(str(next(it))) 379 # Schedule an item for removal and recreate it 380 u = ustr(str(items.pop())) 381 if yielded == u: 382 # The iterator still has a reference to the removed item, 383 # advance it (issue #20006). 384 next(it) 385 gc.collect() # just in case 386 yield u 387 finally: 388 it = None # should commit all removals 389 390 with testcontext() as u: 391 self.assertNotIn(u, s) 392 with testcontext() as u: 393 self.assertRaises(KeyError, s.remove, u) 394 self.assertNotIn(u, s) 395 with testcontext() as u: 396 s.add(u) 397 self.assertIn(u, s) 398 t = s.copy() 399 with testcontext() as u: 400 s.update(t) 401 self.assertEqual(len(s), len(t)) 402 with testcontext() as u: 403 s.clear() 404 self.assertEqual(len(s), 0) 405 406 def test_len_cycles(self): 407 N = 20 408 items = [RefCycle() for i in range(N)] 409 s = WeakSet(items) 410 del items 411 it = iter(s) 412 try: 413 next(it) 414 except StopIteration: 415 pass 416 gc.collect() 417 n1 = len(s) 418 del it 419 gc.collect() 420 gc.collect() # For PyPy or other GCs. 421 n2 = len(s) 422 # one item may be kept alive inside the iterator 423 self.assertIn(n1, (0, 1)) 424 self.assertEqual(n2, 0) 425 426 def test_len_race(self): 427 # Extended sanity checks for len() in the face of cyclic collection 428 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 429 for th in range(1, 100): 430 N = 20 431 gc.collect(0) 432 gc.set_threshold(th, th, th) 433 items = [RefCycle() for i in range(N)] 434 s = WeakSet(items) 435 del items 436 # All items will be collected at next garbage collection pass 437 it = iter(s) 438 try: 439 next(it) 440 except StopIteration: 441 pass 442 n1 = len(s) 443 del it 444 n2 = len(s) 445 self.assertGreaterEqual(n1, 0) 446 self.assertLessEqual(n1, N) 447 self.assertGreaterEqual(n2, 0) 448 self.assertLessEqual(n2, n1) 449 450 def test_repr(self): 451 assert repr(self.s) == repr(self.s.data) 452 453 def test_abc(self): 454 self.assertIsInstance(self.s, Set) 455 self.assertIsInstance(self.s, MutableSet) 456 457 def test_copying(self): 458 for cls in WeakSet, WeakSetWithSlots: 459 s = cls(self.items) 460 s.x = ['x'] 461 s.z = ['z'] 462 463 dup = copy.copy(s) 464 self.assertIsInstance(dup, cls) 465 self.assertEqual(dup, s) 466 self.assertIsNot(dup, s) 467 self.assertIs(dup.x, s.x) 468 self.assertIs(dup.z, s.z) 469 self.assertFalse(hasattr(dup, 'y')) 470 471 dup = copy.deepcopy(s) 472 self.assertIsInstance(dup, cls) 473 self.assertEqual(dup, s) 474 self.assertIsNot(dup, s) 475 self.assertEqual(dup.x, s.x) 476 self.assertIsNot(dup.x, s.x) 477 self.assertEqual(dup.z, s.z) 478 self.assertIsNot(dup.z, s.z) 479 self.assertFalse(hasattr(dup, 'y')) 480 481 482if __name__ == "__main__": 483 unittest.main() 484