1import unittest
2from weakref import WeakSet
3import copy
4import string
5from collections import UserString as ustr
6from collections.abc import Set, MutableSet
7import gc
8import contextlib
9from test import support
10
11
12class Foo:
13    pass
14
15class RefCycle:
16    def __init__(self):
17        self.cycle = self
18
19class WeakSetSubclass(WeakSet):
20    pass
21
22class WeakSetWithSlots(WeakSet):
23    __slots__ = ('x', 'y')
24
25
26class TestWeakSet(unittest.TestCase):
27
28    def setUp(self):
29        # need to keep references to them
30        self.items = [ustr(c) for c in ('a', 'b', 'c')]
31        self.items2 = [ustr(c) for c in ('x', 'y', 'z')]
32        self.ab_items = [ustr(c) for c in 'ab']
33        self.abcde_items = [ustr(c) for c in 'abcde']
34        self.def_items = [ustr(c) for c in 'def']
35        self.ab_weakset = WeakSet(self.ab_items)
36        self.abcde_weakset = WeakSet(self.abcde_items)
37        self.def_weakset = WeakSet(self.def_items)
38        self.letters = [ustr(c) for c in string.ascii_letters]
39        self.s = WeakSet(self.items)
40        self.d = dict.fromkeys(self.items)
41        self.obj = ustr('F')
42        self.fs = WeakSet([self.obj])
43
44    def test_methods(self):
45        weaksetmethods = dir(WeakSet)
46        for method in dir(set):
47            if method == 'test_c_api' or method.startswith('_'):
48                continue
49            self.assertIn(method, weaksetmethods,
50                         "WeakSet missing method " + method)
51
52    def test_new_or_init(self):
53        self.assertRaises(TypeError, WeakSet, [], 2)
54
55    def test_len(self):
56        self.assertEqual(len(self.s), len(self.d))
57        self.assertEqual(len(self.fs), 1)
58        del self.obj
59        support.gc_collect()  # For PyPy or other GCs.
60        self.assertEqual(len(self.fs), 0)
61
62    def test_contains(self):
63        for c in self.letters:
64            self.assertEqual(c in self.s, c in self.d)
65        # 1 is not weakref'able, but that TypeError is caught by __contains__
66        self.assertNotIn(1, self.s)
67        self.assertIn(self.obj, self.fs)
68        del self.obj
69        support.gc_collect()  # For PyPy or other GCs.
70        self.assertNotIn(ustr('F'), self.fs)
71
72    def test_union(self):
73        u = self.s.union(self.items2)
74        for c in self.letters:
75            self.assertEqual(c in u, c in self.d or c in self.items2)
76        self.assertEqual(self.s, WeakSet(self.items))
77        self.assertEqual(type(u), WeakSet)
78        self.assertRaises(TypeError, self.s.union, [[]])
79        for C in set, frozenset, dict.fromkeys, list, tuple:
80            x = WeakSet(self.items + self.items2)
81            c = C(self.items2)
82            self.assertEqual(self.s.union(c), x)
83            del c
84        self.assertEqual(len(u), len(self.items) + len(self.items2))
85        self.items2.pop()
86        gc.collect()
87        self.assertEqual(len(u), len(self.items) + len(self.items2))
88
89    def test_or(self):
90        i = self.s.union(self.items2)
91        self.assertEqual(self.s | set(self.items2), i)
92        self.assertEqual(self.s | frozenset(self.items2), i)
93
94    def test_intersection(self):
95        s = WeakSet(self.letters)
96        i = s.intersection(self.items2)
97        for c in self.letters:
98            self.assertEqual(c in i, c in self.items2 and c in self.letters)
99        self.assertEqual(s, WeakSet(self.letters))
100        self.assertEqual(type(i), WeakSet)
101        for C in set, frozenset, dict.fromkeys, list, tuple:
102            x = WeakSet([])
103            self.assertEqual(i.intersection(C(self.items)), x)
104        self.assertEqual(len(i), len(self.items2))
105        self.items2.pop()
106        gc.collect()
107        self.assertEqual(len(i), len(self.items2))
108
109    def test_isdisjoint(self):
110        self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
111        self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
112
113    def test_and(self):
114        i = self.s.intersection(self.items2)
115        self.assertEqual(self.s & set(self.items2), i)
116        self.assertEqual(self.s & frozenset(self.items2), i)
117
118    def test_difference(self):
119        i = self.s.difference(self.items2)
120        for c in self.letters:
121            self.assertEqual(c in i, c in self.d and c not in self.items2)
122        self.assertEqual(self.s, WeakSet(self.items))
123        self.assertEqual(type(i), WeakSet)
124        self.assertRaises(TypeError, self.s.difference, [[]])
125
126    def test_sub(self):
127        i = self.s.difference(self.items2)
128        self.assertEqual(self.s - set(self.items2), i)
129        self.assertEqual(self.s - frozenset(self.items2), i)
130
131    def test_symmetric_difference(self):
132        i = self.s.symmetric_difference(self.items2)
133        for c in self.letters:
134            self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
135        self.assertEqual(self.s, WeakSet(self.items))
136        self.assertEqual(type(i), WeakSet)
137        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
138        self.assertEqual(len(i), len(self.items) + len(self.items2))
139        self.items2.pop()
140        gc.collect()
141        self.assertEqual(len(i), len(self.items) + len(self.items2))
142
143    def test_xor(self):
144        i = self.s.symmetric_difference(self.items2)
145        self.assertEqual(self.s ^ set(self.items2), i)
146        self.assertEqual(self.s ^ frozenset(self.items2), i)
147
148    def test_sub_and_super(self):
149        self.assertTrue(self.ab_weakset <= self.abcde_weakset)
150        self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
151        self.assertTrue(self.abcde_weakset >= self.ab_weakset)
152        self.assertFalse(self.abcde_weakset <= self.def_weakset)
153        self.assertFalse(self.abcde_weakset >= self.def_weakset)
154        self.assertTrue(set('a').issubset('abc'))
155        self.assertTrue(set('abc').issuperset('a'))
156        self.assertFalse(set('a').issubset('cbs'))
157        self.assertFalse(set('cbs').issuperset('a'))
158
159    def test_lt(self):
160        self.assertTrue(self.ab_weakset < self.abcde_weakset)
161        self.assertFalse(self.abcde_weakset < self.def_weakset)
162        self.assertFalse(self.ab_weakset < self.ab_weakset)
163        self.assertFalse(WeakSet() < WeakSet())
164
165    def test_gt(self):
166        self.assertTrue(self.abcde_weakset > self.ab_weakset)
167        self.assertFalse(self.abcde_weakset > self.def_weakset)
168        self.assertFalse(self.ab_weakset > self.ab_weakset)
169        self.assertFalse(WeakSet() > WeakSet())
170
171    def test_gc(self):
172        # Create a nest of cycles to exercise overall ref count check
173        s = WeakSet(Foo() for i in range(1000))
174        for elem in s:
175            elem.cycle = s
176            elem.sub = elem
177            elem.set = WeakSet([elem])
178
179    def test_subclass_with_custom_hash(self):
180        # Bug #1257731
181        class H(WeakSet):
182            def __hash__(self):
183                return int(id(self) & 0x7fffffff)
184        s=H()
185        f=set()
186        f.add(s)
187        self.assertIn(s, f)
188        f.remove(s)
189        f.add(s)
190        f.discard(s)
191
192    def test_init(self):
193        s = WeakSet()
194        s.__init__(self.items)
195        self.assertEqual(s, self.s)
196        s.__init__(self.items2)
197        self.assertEqual(s, WeakSet(self.items2))
198        self.assertRaises(TypeError, s.__init__, s, 2);
199        self.assertRaises(TypeError, s.__init__, 1);
200
201    def test_constructor_identity(self):
202        s = WeakSet(self.items)
203        t = WeakSet(s)
204        self.assertNotEqual(id(s), id(t))
205
206    def test_hash(self):
207        self.assertRaises(TypeError, hash, self.s)
208
209    def test_clear(self):
210        self.s.clear()
211        self.assertEqual(self.s, WeakSet([]))
212        self.assertEqual(len(self.s), 0)
213
214    def test_copy(self):
215        dup = self.s.copy()
216        self.assertEqual(self.s, dup)
217        self.assertNotEqual(id(self.s), id(dup))
218
219    def test_add(self):
220        x = ustr('Q')
221        self.s.add(x)
222        self.assertIn(x, self.s)
223        dup = self.s.copy()
224        self.s.add(x)
225        self.assertEqual(self.s, dup)
226        self.assertRaises(TypeError, self.s.add, [])
227        self.fs.add(Foo())
228        support.gc_collect()  # For PyPy or other GCs.
229        self.assertTrue(len(self.fs) == 1)
230        self.fs.add(self.obj)
231        self.assertTrue(len(self.fs) == 1)
232
233    def test_remove(self):
234        x = ustr('a')
235        self.s.remove(x)
236        self.assertNotIn(x, self.s)
237        self.assertRaises(KeyError, self.s.remove, x)
238        self.assertRaises(TypeError, self.s.remove, [])
239
240    def test_discard(self):
241        a, q = ustr('a'), ustr('Q')
242        self.s.discard(a)
243        self.assertNotIn(a, self.s)
244        self.s.discard(q)
245        self.assertRaises(TypeError, self.s.discard, [])
246
247    def test_pop(self):
248        for i in range(len(self.s)):
249            elem = self.s.pop()
250            self.assertNotIn(elem, self.s)
251        self.assertRaises(KeyError, self.s.pop)
252
253    def test_update(self):
254        retval = self.s.update(self.items2)
255        self.assertEqual(retval, None)
256        for c in (self.items + self.items2):
257            self.assertIn(c, self.s)
258        self.assertRaises(TypeError, self.s.update, [[]])
259
260    def test_update_set(self):
261        self.s.update(set(self.items2))
262        for c in (self.items + self.items2):
263            self.assertIn(c, self.s)
264
265    def test_ior(self):
266        self.s |= set(self.items2)
267        for c in (self.items + self.items2):
268            self.assertIn(c, self.s)
269
270    def test_intersection_update(self):
271        retval = self.s.intersection_update(self.items2)
272        self.assertEqual(retval, None)
273        for c in (self.items + self.items2):
274            if c in self.items2 and c in self.items:
275                self.assertIn(c, self.s)
276            else:
277                self.assertNotIn(c, self.s)
278        self.assertRaises(TypeError, self.s.intersection_update, [[]])
279
280    def test_iand(self):
281        self.s &= set(self.items2)
282        for c in (self.items + self.items2):
283            if c in self.items2 and c in self.items:
284                self.assertIn(c, self.s)
285            else:
286                self.assertNotIn(c, self.s)
287
288    def test_difference_update(self):
289        retval = self.s.difference_update(self.items2)
290        self.assertEqual(retval, None)
291        for c in (self.items + self.items2):
292            if c in self.items and c not in self.items2:
293                self.assertIn(c, self.s)
294            else:
295                self.assertNotIn(c, self.s)
296        self.assertRaises(TypeError, self.s.difference_update, [[]])
297        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
298
299    def test_isub(self):
300        self.s -= set(self.items2)
301        for c in (self.items + self.items2):
302            if c in self.items and c not in self.items2:
303                self.assertIn(c, self.s)
304            else:
305                self.assertNotIn(c, self.s)
306
307    def test_symmetric_difference_update(self):
308        retval = self.s.symmetric_difference_update(self.items2)
309        self.assertEqual(retval, None)
310        for c in (self.items + self.items2):
311            if (c in self.items) ^ (c in self.items2):
312                self.assertIn(c, self.s)
313            else:
314                self.assertNotIn(c, self.s)
315        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
316
317    def test_ixor(self):
318        self.s ^= set(self.items2)
319        for c in (self.items + self.items2):
320            if (c in self.items) ^ (c in self.items2):
321                self.assertIn(c, self.s)
322            else:
323                self.assertNotIn(c, self.s)
324
325    def test_inplace_on_self(self):
326        t = self.s.copy()
327        t |= t
328        self.assertEqual(t, self.s)
329        t &= t
330        self.assertEqual(t, self.s)
331        t -= t
332        self.assertEqual(t, WeakSet())
333        t = self.s.copy()
334        t ^= t
335        self.assertEqual(t, WeakSet())
336
337    def test_eq(self):
338        # issue 5964
339        self.assertTrue(self.s == self.s)
340        self.assertTrue(self.s == WeakSet(self.items))
341        self.assertFalse(self.s == set(self.items))
342        self.assertFalse(self.s == list(self.items))
343        self.assertFalse(self.s == tuple(self.items))
344        self.assertFalse(self.s == WeakSet([Foo]))
345        self.assertFalse(self.s == 1)
346
347    def test_ne(self):
348        self.assertTrue(self.s != set(self.items))
349        s1 = WeakSet()
350        s2 = WeakSet()
351        self.assertFalse(s1 != s2)
352
353    def test_weak_destroy_while_iterating(self):
354        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
355        # Create new items to be sure no-one else holds a reference
356        items = [ustr(c) for c in ('a', 'b', 'c')]
357        s = WeakSet(items)
358        it = iter(s)
359        next(it)             # Trigger internal iteration
360        # Destroy an item
361        del items[-1]
362        gc.collect()    # just in case
363        # We have removed either the first consumed items, or another one
364        self.assertIn(len(list(it)), [len(items), len(items) - 1])
365        del it
366        # The removal has been committed
367        self.assertEqual(len(s), len(items))
368
369    def test_weak_destroy_and_mutate_while_iterating(self):
370        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
371        items = [ustr(c) for c in string.ascii_letters]
372        s = WeakSet(items)
373        @contextlib.contextmanager
374        def testcontext():
375            try:
376                it = iter(s)
377                # Start iterator
378                yielded = ustr(str(next(it)))
379                # Schedule an item for removal and recreate it
380                u = ustr(str(items.pop()))
381                if yielded == u:
382                    # The iterator still has a reference to the removed item,
383                    # advance it (issue #20006).
384                    next(it)
385                gc.collect()      # just in case
386                yield u
387            finally:
388                it = None           # should commit all removals
389
390        with testcontext() as u:
391            self.assertNotIn(u, s)
392        with testcontext() as u:
393            self.assertRaises(KeyError, s.remove, u)
394        self.assertNotIn(u, s)
395        with testcontext() as u:
396            s.add(u)
397        self.assertIn(u, s)
398        t = s.copy()
399        with testcontext() as u:
400            s.update(t)
401        self.assertEqual(len(s), len(t))
402        with testcontext() as u:
403            s.clear()
404        self.assertEqual(len(s), 0)
405
406    def test_len_cycles(self):
407        N = 20
408        items = [RefCycle() for i in range(N)]
409        s = WeakSet(items)
410        del items
411        it = iter(s)
412        try:
413            next(it)
414        except StopIteration:
415            pass
416        gc.collect()
417        n1 = len(s)
418        del it
419        gc.collect()
420        gc.collect()  # For PyPy or other GCs.
421        n2 = len(s)
422        # one item may be kept alive inside the iterator
423        self.assertIn(n1, (0, 1))
424        self.assertEqual(n2, 0)
425
426    def test_len_race(self):
427        # Extended sanity checks for len() in the face of cyclic collection
428        self.addCleanup(gc.set_threshold, *gc.get_threshold())
429        for th in range(1, 100):
430            N = 20
431            gc.collect(0)
432            gc.set_threshold(th, th, th)
433            items = [RefCycle() for i in range(N)]
434            s = WeakSet(items)
435            del items
436            # All items will be collected at next garbage collection pass
437            it = iter(s)
438            try:
439                next(it)
440            except StopIteration:
441                pass
442            n1 = len(s)
443            del it
444            n2 = len(s)
445            self.assertGreaterEqual(n1, 0)
446            self.assertLessEqual(n1, N)
447            self.assertGreaterEqual(n2, 0)
448            self.assertLessEqual(n2, n1)
449
450    def test_repr(self):
451        assert repr(self.s) == repr(self.s.data)
452
453    def test_abc(self):
454        self.assertIsInstance(self.s, Set)
455        self.assertIsInstance(self.s, MutableSet)
456
457    def test_copying(self):
458        for cls in WeakSet, WeakSetWithSlots:
459            s = cls(self.items)
460            s.x = ['x']
461            s.z = ['z']
462
463            dup = copy.copy(s)
464            self.assertIsInstance(dup, cls)
465            self.assertEqual(dup, s)
466            self.assertIsNot(dup, s)
467            self.assertIs(dup.x, s.x)
468            self.assertIs(dup.z, s.z)
469            self.assertFalse(hasattr(dup, 'y'))
470
471            dup = copy.deepcopy(s)
472            self.assertIsInstance(dup, cls)
473            self.assertEqual(dup, s)
474            self.assertIsNot(dup, s)
475            self.assertEqual(dup.x, s.x)
476            self.assertIsNot(dup.x, s.x)
477            self.assertEqual(dup.z, s.z)
478            self.assertIsNot(dup.z, s.z)
479            self.assertFalse(hasattr(dup, 'y'))
480
481
482if __name__ == "__main__":
483    unittest.main()
484