1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from test.support import os_helper
10from test.support import socket_helper
11from time import sleep
12import unittest
13import unittest.mock
14import tempfile
15from time import monotonic as time
16try:
17    import resource
18except ImportError:
19    resource = None
20
21
22if support.is_emscripten or support.is_wasi:
23    raise unittest.SkipTest("Cannot create socketpair on Emscripten/WASI.")
24
25
26if hasattr(socket, 'socketpair'):
27    socketpair = socket.socketpair
28else:
29    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
30        with socket.socket(family, type, proto) as l:
31            l.bind((socket_helper.HOST, 0))
32            l.listen()
33            c = socket.socket(family, type, proto)
34            try:
35                c.connect(l.getsockname())
36                caddr = c.getsockname()
37                while True:
38                    a, addr = l.accept()
39                    # check that we've got the correct client
40                    if addr == caddr:
41                        return c, a
42                    a.close()
43            except OSError:
44                c.close()
45                raise
46
47
48def find_ready_matching(ready, flag):
49    match = []
50    for key, events in ready:
51        if events & flag:
52            match.append(key.fileobj)
53    return match
54
55
56class BaseSelectorTestCase:
57
58    def make_socketpair(self):
59        rd, wr = socketpair()
60        self.addCleanup(rd.close)
61        self.addCleanup(wr.close)
62        return rd, wr
63
64    def test_register(self):
65        s = self.SELECTOR()
66        self.addCleanup(s.close)
67
68        rd, wr = self.make_socketpair()
69
70        key = s.register(rd, selectors.EVENT_READ, "data")
71        self.assertIsInstance(key, selectors.SelectorKey)
72        self.assertEqual(key.fileobj, rd)
73        self.assertEqual(key.fd, rd.fileno())
74        self.assertEqual(key.events, selectors.EVENT_READ)
75        self.assertEqual(key.data, "data")
76
77        # register an unknown event
78        self.assertRaises(ValueError, s.register, 0, 999999)
79
80        # register an invalid FD
81        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
82
83        # register twice
84        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
85
86        # register the same FD, but with a different object
87        self.assertRaises(KeyError, s.register, rd.fileno(),
88                          selectors.EVENT_READ)
89
90    def test_unregister(self):
91        s = self.SELECTOR()
92        self.addCleanup(s.close)
93
94        rd, wr = self.make_socketpair()
95
96        s.register(rd, selectors.EVENT_READ)
97        s.unregister(rd)
98
99        # unregister an unknown file obj
100        self.assertRaises(KeyError, s.unregister, 999999)
101
102        # unregister twice
103        self.assertRaises(KeyError, s.unregister, rd)
104
105    def test_unregister_after_fd_close(self):
106        s = self.SELECTOR()
107        self.addCleanup(s.close)
108        rd, wr = self.make_socketpair()
109        r, w = rd.fileno(), wr.fileno()
110        s.register(r, selectors.EVENT_READ)
111        s.register(w, selectors.EVENT_WRITE)
112        rd.close()
113        wr.close()
114        s.unregister(r)
115        s.unregister(w)
116
117    @unittest.skipUnless(os.name == 'posix', "requires posix")
118    def test_unregister_after_fd_close_and_reuse(self):
119        s = self.SELECTOR()
120        self.addCleanup(s.close)
121        rd, wr = self.make_socketpair()
122        r, w = rd.fileno(), wr.fileno()
123        s.register(r, selectors.EVENT_READ)
124        s.register(w, selectors.EVENT_WRITE)
125        rd2, wr2 = self.make_socketpair()
126        rd.close()
127        wr.close()
128        os.dup2(rd2.fileno(), r)
129        os.dup2(wr2.fileno(), w)
130        self.addCleanup(os.close, r)
131        self.addCleanup(os.close, w)
132        s.unregister(r)
133        s.unregister(w)
134
135    def test_unregister_after_socket_close(self):
136        s = self.SELECTOR()
137        self.addCleanup(s.close)
138        rd, wr = self.make_socketpair()
139        s.register(rd, selectors.EVENT_READ)
140        s.register(wr, selectors.EVENT_WRITE)
141        rd.close()
142        wr.close()
143        s.unregister(rd)
144        s.unregister(wr)
145
146    def test_modify(self):
147        s = self.SELECTOR()
148        self.addCleanup(s.close)
149
150        rd, wr = self.make_socketpair()
151
152        key = s.register(rd, selectors.EVENT_READ)
153
154        # modify events
155        key2 = s.modify(rd, selectors.EVENT_WRITE)
156        self.assertNotEqual(key.events, key2.events)
157        self.assertEqual(key2, s.get_key(rd))
158
159        s.unregister(rd)
160
161        # modify data
162        d1 = object()
163        d2 = object()
164
165        key = s.register(rd, selectors.EVENT_READ, d1)
166        key2 = s.modify(rd, selectors.EVENT_READ, d2)
167        self.assertEqual(key.events, key2.events)
168        self.assertNotEqual(key.data, key2.data)
169        self.assertEqual(key2, s.get_key(rd))
170        self.assertEqual(key2.data, d2)
171
172        # modify unknown file obj
173        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
174
175        # modify use a shortcut
176        d3 = object()
177        s.register = unittest.mock.Mock()
178        s.unregister = unittest.mock.Mock()
179
180        s.modify(rd, selectors.EVENT_READ, d3)
181        self.assertFalse(s.register.called)
182        self.assertFalse(s.unregister.called)
183
184    def test_modify_unregister(self):
185        # Make sure the fd is unregister()ed in case of error on
186        # modify(): http://bugs.python.org/issue30014
187        if self.SELECTOR.__name__ == 'EpollSelector':
188            patch = unittest.mock.patch(
189                'selectors.EpollSelector._selector_cls')
190        elif self.SELECTOR.__name__ == 'PollSelector':
191            patch = unittest.mock.patch(
192                'selectors.PollSelector._selector_cls')
193        elif self.SELECTOR.__name__ == 'DevpollSelector':
194            patch = unittest.mock.patch(
195                'selectors.DevpollSelector._selector_cls')
196        else:
197            raise self.skipTest("")
198
199        with patch as m:
200            m.return_value.modify = unittest.mock.Mock(
201                side_effect=ZeroDivisionError)
202            s = self.SELECTOR()
203            self.addCleanup(s.close)
204            rd, wr = self.make_socketpair()
205            s.register(rd, selectors.EVENT_READ)
206            self.assertEqual(len(s._map), 1)
207            with self.assertRaises(ZeroDivisionError):
208                s.modify(rd, selectors.EVENT_WRITE)
209            self.assertEqual(len(s._map), 0)
210
211    def test_close(self):
212        s = self.SELECTOR()
213        self.addCleanup(s.close)
214
215        mapping = s.get_map()
216        rd, wr = self.make_socketpair()
217
218        s.register(rd, selectors.EVENT_READ)
219        s.register(wr, selectors.EVENT_WRITE)
220
221        s.close()
222        self.assertRaises(RuntimeError, s.get_key, rd)
223        self.assertRaises(RuntimeError, s.get_key, wr)
224        self.assertRaises(KeyError, mapping.__getitem__, rd)
225        self.assertRaises(KeyError, mapping.__getitem__, wr)
226
227    def test_get_key(self):
228        s = self.SELECTOR()
229        self.addCleanup(s.close)
230
231        rd, wr = self.make_socketpair()
232
233        key = s.register(rd, selectors.EVENT_READ, "data")
234        self.assertEqual(key, s.get_key(rd))
235
236        # unknown file obj
237        self.assertRaises(KeyError, s.get_key, 999999)
238
239    def test_get_map(self):
240        s = self.SELECTOR()
241        self.addCleanup(s.close)
242
243        rd, wr = self.make_socketpair()
244
245        keys = s.get_map()
246        self.assertFalse(keys)
247        self.assertEqual(len(keys), 0)
248        self.assertEqual(list(keys), [])
249        key = s.register(rd, selectors.EVENT_READ, "data")
250        self.assertIn(rd, keys)
251        self.assertEqual(key, keys[rd])
252        self.assertEqual(len(keys), 1)
253        self.assertEqual(list(keys), [rd.fileno()])
254        self.assertEqual(list(keys.values()), [key])
255
256        # unknown file obj
257        with self.assertRaises(KeyError):
258            keys[999999]
259
260        # Read-only mapping
261        with self.assertRaises(TypeError):
262            del keys[rd]
263
264    def test_select(self):
265        s = self.SELECTOR()
266        self.addCleanup(s.close)
267
268        rd, wr = self.make_socketpair()
269
270        s.register(rd, selectors.EVENT_READ)
271        wr_key = s.register(wr, selectors.EVENT_WRITE)
272
273        result = s.select()
274        for key, events in result:
275            self.assertTrue(isinstance(key, selectors.SelectorKey))
276            self.assertTrue(events)
277            self.assertFalse(events & ~(selectors.EVENT_READ |
278                                        selectors.EVENT_WRITE))
279
280        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
281
282    def test_context_manager(self):
283        s = self.SELECTOR()
284        self.addCleanup(s.close)
285
286        rd, wr = self.make_socketpair()
287
288        with s as sel:
289            sel.register(rd, selectors.EVENT_READ)
290            sel.register(wr, selectors.EVENT_WRITE)
291
292        self.assertRaises(RuntimeError, s.get_key, rd)
293        self.assertRaises(RuntimeError, s.get_key, wr)
294
295    def test_fileno(self):
296        s = self.SELECTOR()
297        self.addCleanup(s.close)
298
299        if hasattr(s, 'fileno'):
300            fd = s.fileno()
301            self.assertTrue(isinstance(fd, int))
302            self.assertGreaterEqual(fd, 0)
303
304    def test_selector(self):
305        s = self.SELECTOR()
306        self.addCleanup(s.close)
307
308        NUM_SOCKETS = 12
309        MSG = b" This is a test."
310        MSG_LEN = len(MSG)
311        readers = []
312        writers = []
313        r2w = {}
314        w2r = {}
315
316        for i in range(NUM_SOCKETS):
317            rd, wr = self.make_socketpair()
318            s.register(rd, selectors.EVENT_READ)
319            s.register(wr, selectors.EVENT_WRITE)
320            readers.append(rd)
321            writers.append(wr)
322            r2w[rd] = wr
323            w2r[wr] = rd
324
325        bufs = []
326
327        while writers:
328            ready = s.select()
329            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
330            if not ready_writers:
331                self.fail("no sockets ready for writing")
332            wr = random.choice(ready_writers)
333            wr.send(MSG)
334
335            for i in range(10):
336                ready = s.select()
337                ready_readers = find_ready_matching(ready,
338                                                    selectors.EVENT_READ)
339                if ready_readers:
340                    break
341                # there might be a delay between the write to the write end and
342                # the read end is reported ready
343                sleep(0.1)
344            else:
345                self.fail("no sockets ready for reading")
346            self.assertEqual([w2r[wr]], ready_readers)
347            rd = ready_readers[0]
348            buf = rd.recv(MSG_LEN)
349            self.assertEqual(len(buf), MSG_LEN)
350            bufs.append(buf)
351            s.unregister(r2w[rd])
352            s.unregister(rd)
353            writers.remove(r2w[rd])
354
355        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
356
357    @unittest.skipIf(sys.platform == 'win32',
358                     'select.select() cannot be used with empty fd sets')
359    def test_empty_select(self):
360        # Issue #23009: Make sure EpollSelector.select() works when no FD is
361        # registered.
362        s = self.SELECTOR()
363        self.addCleanup(s.close)
364        self.assertEqual(s.select(timeout=0), [])
365
366    def test_timeout(self):
367        s = self.SELECTOR()
368        self.addCleanup(s.close)
369
370        rd, wr = self.make_socketpair()
371
372        s.register(wr, selectors.EVENT_WRITE)
373        t = time()
374        self.assertEqual(1, len(s.select(0)))
375        self.assertEqual(1, len(s.select(-1)))
376        self.assertLess(time() - t, 0.5)
377
378        s.unregister(wr)
379        s.register(rd, selectors.EVENT_READ)
380        t = time()
381        self.assertFalse(s.select(0))
382        self.assertFalse(s.select(-1))
383        self.assertLess(time() - t, 0.5)
384
385        t0 = time()
386        self.assertFalse(s.select(1))
387        t1 = time()
388        dt = t1 - t0
389        # Tolerate 2.0 seconds for very slow buildbots
390        self.assertTrue(0.8 <= dt <= 2.0, dt)
391
392    @unittest.skipUnless(hasattr(signal, "alarm"),
393                         "signal.alarm() required for this test")
394    def test_select_interrupt_exc(self):
395        s = self.SELECTOR()
396        self.addCleanup(s.close)
397
398        rd, wr = self.make_socketpair()
399
400        class InterruptSelect(Exception):
401            pass
402
403        def handler(*args):
404            raise InterruptSelect
405
406        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
407        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
408
409        try:
410            signal.alarm(1)
411
412            s.register(rd, selectors.EVENT_READ)
413            t = time()
414            # select() is interrupted by a signal which raises an exception
415            with self.assertRaises(InterruptSelect):
416                s.select(30)
417            # select() was interrupted before the timeout of 30 seconds
418            self.assertLess(time() - t, 5.0)
419        finally:
420            signal.alarm(0)
421
422    @unittest.skipUnless(hasattr(signal, "alarm"),
423                         "signal.alarm() required for this test")
424    def test_select_interrupt_noraise(self):
425        s = self.SELECTOR()
426        self.addCleanup(s.close)
427
428        rd, wr = self.make_socketpair()
429
430        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
431        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
432
433        try:
434            signal.alarm(1)
435
436            s.register(rd, selectors.EVENT_READ)
437            t = time()
438            # select() is interrupted by a signal, but the signal handler doesn't
439            # raise an exception, so select() should by retries with a recomputed
440            # timeout
441            self.assertFalse(s.select(1.5))
442            self.assertGreaterEqual(time() - t, 1.0)
443        finally:
444            signal.alarm(0)
445
446
447class ScalableSelectorMixIn:
448
449    # see issue #18963 for why it's skipped on older OS X versions
450    @support.requires_mac_ver(10, 5)
451    @unittest.skipUnless(resource, "Test needs resource module")
452    def test_above_fd_setsize(self):
453        # A scalable implementation should have no problem with more than
454        # FD_SETSIZE file descriptors. Since we don't know the value, we just
455        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
456        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
457        try:
458            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
459            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
460                            (soft, hard))
461            NUM_FDS = min(hard, 2**16)
462        except (OSError, ValueError):
463            NUM_FDS = soft
464
465        # guard for already allocated FDs (stdin, stdout...)
466        NUM_FDS -= 32
467
468        s = self.SELECTOR()
469        self.addCleanup(s.close)
470
471        for i in range(NUM_FDS // 2):
472            try:
473                rd, wr = self.make_socketpair()
474            except OSError:
475                # too many FDs, skip - note that we should only catch EMFILE
476                # here, but apparently *BSD and Solaris can fail upon connect()
477                # or bind() with EADDRNOTAVAIL, so let's be safe
478                self.skipTest("FD limit reached")
479
480            try:
481                s.register(rd, selectors.EVENT_READ)
482                s.register(wr, selectors.EVENT_WRITE)
483            except OSError as e:
484                if e.errno == errno.ENOSPC:
485                    # this can be raised by epoll if we go over
486                    # fs.epoll.max_user_watches sysctl
487                    self.skipTest("FD limit reached")
488                raise
489
490        try:
491            fds = s.select()
492        except OSError as e:
493            if e.errno == errno.EINVAL and sys.platform == 'darwin':
494                # unexplainable errors on macOS don't need to fail the test
495                self.skipTest("Invalid argument error calling poll()")
496            raise
497        self.assertEqual(NUM_FDS // 2, len(fds))
498
499
500class DefaultSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
501
502    SELECTOR = selectors.DefaultSelector
503
504
505class SelectSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
506
507    SELECTOR = selectors.SelectSelector
508
509
510@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
511                     "Test needs selectors.PollSelector")
512class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
513                           unittest.TestCase):
514
515    SELECTOR = getattr(selectors, 'PollSelector', None)
516
517
518@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
519                     "Test needs selectors.EpollSelector")
520class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
521                            unittest.TestCase):
522
523    SELECTOR = getattr(selectors, 'EpollSelector', None)
524
525    def test_register_file(self):
526        # epoll(7) returns EPERM when given a file to watch
527        s = self.SELECTOR()
528        with tempfile.NamedTemporaryFile() as f:
529            with self.assertRaises(IOError):
530                s.register(f, selectors.EVENT_READ)
531            # the SelectorKey has been removed
532            with self.assertRaises(KeyError):
533                s.get_key(f)
534
535
536@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
537                     "Test needs selectors.KqueueSelector)")
538class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
539                             unittest.TestCase):
540
541    SELECTOR = getattr(selectors, 'KqueueSelector', None)
542
543    def test_register_bad_fd(self):
544        # a file descriptor that's been closed should raise an OSError
545        # with EBADF
546        s = self.SELECTOR()
547        bad_f = os_helper.make_bad_fd()
548        with self.assertRaises(OSError) as cm:
549            s.register(bad_f, selectors.EVENT_READ)
550        self.assertEqual(cm.exception.errno, errno.EBADF)
551        # the SelectorKey has been removed
552        with self.assertRaises(KeyError):
553            s.get_key(bad_f)
554
555    def test_empty_select_timeout(self):
556        # Issues #23009, #29255: Make sure timeout is applied when no fds
557        # are registered.
558        s = self.SELECTOR()
559        self.addCleanup(s.close)
560
561        t0 = time()
562        self.assertEqual(s.select(1), [])
563        t1 = time()
564        dt = t1 - t0
565        # Tolerate 2.0 seconds for very slow buildbots
566        self.assertTrue(0.8 <= dt <= 2.0, dt)
567
568
569@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
570                     "Test needs selectors.DevpollSelector")
571class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
572                              unittest.TestCase):
573
574    SELECTOR = getattr(selectors, 'DevpollSelector', None)
575
576
577def tearDownModule():
578    support.reap_children()
579
580
581if __name__ == "__main__":
582    unittest.main()
583