1"""Tests for proactor_events.py"""
2
3import io
4import socket
5import unittest
6import sys
7from unittest import mock
8
9import asyncio
10from asyncio.proactor_events import BaseProactorEventLoop
11from asyncio.proactor_events import _ProactorSocketTransport
12from asyncio.proactor_events import _ProactorWritePipeTransport
13from asyncio.proactor_events import _ProactorDuplexPipeTransport
14from asyncio.proactor_events import _ProactorDatagramTransport
15from test.support import os_helper
16from test.support import socket_helper
17from test.test_asyncio import utils as test_utils
18
19
20def tearDownModule():
21    asyncio.set_event_loop_policy(None)
22
23
24def close_transport(transport):
25    # Don't call transport.close() because the event loop and the IOCP proactor
26    # are mocked
27    if transport._sock is None:
28        return
29    transport._sock.close()
30    transport._sock = None
31
32
33class ProactorSocketTransportTests(test_utils.TestCase):
34
35    def setUp(self):
36        super().setUp()
37        self.loop = self.new_test_loop()
38        self.addCleanup(self.loop.close)
39        self.proactor = mock.Mock()
40        self.loop._proactor = self.proactor
41        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
42        self.sock = mock.Mock(socket.socket)
43        self.buffer_size = 65536
44
45    def socket_transport(self, waiter=None):
46        transport = _ProactorSocketTransport(self.loop, self.sock,
47                                             self.protocol, waiter=waiter)
48        self.addCleanup(close_transport, transport)
49        return transport
50
51    def test_ctor(self):
52        fut = self.loop.create_future()
53        tr = self.socket_transport(waiter=fut)
54        test_utils.run_briefly(self.loop)
55        self.assertIsNone(fut.result())
56        self.protocol.connection_made(tr)
57        self.proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size))
58
59    def test_loop_reading(self):
60        tr = self.socket_transport()
61        tr._loop_reading()
62        self.loop._proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size))
63        self.assertFalse(self.protocol.data_received.called)
64        self.assertFalse(self.protocol.eof_received.called)
65
66    def test_loop_reading_data(self):
67        buf = b'data'
68        res = self.loop.create_future()
69        res.set_result(len(buf))
70
71        tr = self.socket_transport()
72        tr._read_fut = res
73        tr._data[:len(buf)] = buf
74        tr._loop_reading(res)
75        called_buf = bytearray(self.buffer_size)
76        called_buf[:len(buf)] = buf
77        self.loop._proactor.recv_into.assert_called_with(self.sock, called_buf)
78        self.protocol.data_received.assert_called_with(bytearray(buf))
79
80    @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode")
81    def test_loop_reading_no_data(self):
82        res = self.loop.create_future()
83        res.set_result(0)
84
85        tr = self.socket_transport()
86        self.assertRaises(AssertionError, tr._loop_reading, res)
87
88        tr.close = mock.Mock()
89        tr._read_fut = res
90        tr._loop_reading(res)
91        self.assertFalse(self.loop._proactor.recv_into.called)
92        self.assertTrue(self.protocol.eof_received.called)
93        self.assertTrue(tr.close.called)
94
95    def test_loop_reading_aborted(self):
96        err = self.loop._proactor.recv_into.side_effect = ConnectionAbortedError()
97
98        tr = self.socket_transport()
99        tr._fatal_error = mock.Mock()
100        tr._loop_reading()
101        tr._fatal_error.assert_called_with(
102                            err,
103                            'Fatal read error on pipe transport')
104
105    def test_loop_reading_aborted_closing(self):
106        self.loop._proactor.recv_into.side_effect = ConnectionAbortedError()
107
108        tr = self.socket_transport()
109        tr._closing = True
110        tr._fatal_error = mock.Mock()
111        tr._loop_reading()
112        self.assertFalse(tr._fatal_error.called)
113
114    def test_loop_reading_aborted_is_fatal(self):
115        self.loop._proactor.recv_into.side_effect = ConnectionAbortedError()
116        tr = self.socket_transport()
117        tr._closing = False
118        tr._fatal_error = mock.Mock()
119        tr._loop_reading()
120        self.assertTrue(tr._fatal_error.called)
121
122    def test_loop_reading_conn_reset_lost(self):
123        err = self.loop._proactor.recv_into.side_effect = ConnectionResetError()
124
125        tr = self.socket_transport()
126        tr._closing = False
127        tr._fatal_error = mock.Mock()
128        tr._force_close = mock.Mock()
129        tr._loop_reading()
130        self.assertFalse(tr._fatal_error.called)
131        tr._force_close.assert_called_with(err)
132
133    def test_loop_reading_exception(self):
134        err = self.loop._proactor.recv_into.side_effect = (OSError())
135
136        tr = self.socket_transport()
137        tr._fatal_error = mock.Mock()
138        tr._loop_reading()
139        tr._fatal_error.assert_called_with(
140                            err,
141                            'Fatal read error on pipe transport')
142
143    def test_write(self):
144        tr = self.socket_transport()
145        tr._loop_writing = mock.Mock()
146        tr.write(b'data')
147        self.assertEqual(tr._buffer, None)
148        tr._loop_writing.assert_called_with(data=b'data')
149
150    def test_write_no_data(self):
151        tr = self.socket_transport()
152        tr.write(b'')
153        self.assertFalse(tr._buffer)
154
155    def test_write_more(self):
156        tr = self.socket_transport()
157        tr._write_fut = mock.Mock()
158        tr._loop_writing = mock.Mock()
159        tr.write(b'data')
160        self.assertEqual(tr._buffer, b'data')
161        self.assertFalse(tr._loop_writing.called)
162
163    def test_loop_writing(self):
164        tr = self.socket_transport()
165        tr._buffer = bytearray(b'data')
166        tr._loop_writing()
167        self.loop._proactor.send.assert_called_with(self.sock, b'data')
168        self.loop._proactor.send.return_value.add_done_callback.\
169            assert_called_with(tr._loop_writing)
170
171    @mock.patch('asyncio.proactor_events.logger')
172    def test_loop_writing_err(self, m_log):
173        err = self.loop._proactor.send.side_effect = OSError()
174        tr = self.socket_transport()
175        tr._fatal_error = mock.Mock()
176        tr._buffer = [b'da', b'ta']
177        tr._loop_writing()
178        tr._fatal_error.assert_called_with(
179                            err,
180                            'Fatal write error on pipe transport')
181        tr._conn_lost = 1
182
183        tr.write(b'data')
184        tr.write(b'data')
185        tr.write(b'data')
186        tr.write(b'data')
187        tr.write(b'data')
188        self.assertEqual(tr._buffer, None)
189        m_log.warning.assert_called_with('socket.send() raised exception.')
190
191    def test_loop_writing_stop(self):
192        fut = self.loop.create_future()
193        fut.set_result(b'data')
194
195        tr = self.socket_transport()
196        tr._write_fut = fut
197        tr._loop_writing(fut)
198        self.assertIsNone(tr._write_fut)
199
200    def test_loop_writing_closing(self):
201        fut = self.loop.create_future()
202        fut.set_result(1)
203
204        tr = self.socket_transport()
205        tr._write_fut = fut
206        tr.close()
207        tr._loop_writing(fut)
208        self.assertIsNone(tr._write_fut)
209        test_utils.run_briefly(self.loop)
210        self.protocol.connection_lost.assert_called_with(None)
211
212    def test_abort(self):
213        tr = self.socket_transport()
214        tr._force_close = mock.Mock()
215        tr.abort()
216        tr._force_close.assert_called_with(None)
217
218    def test_close(self):
219        tr = self.socket_transport()
220        tr.close()
221        test_utils.run_briefly(self.loop)
222        self.protocol.connection_lost.assert_called_with(None)
223        self.assertTrue(tr.is_closing())
224        self.assertEqual(tr._conn_lost, 1)
225
226        self.protocol.connection_lost.reset_mock()
227        tr.close()
228        test_utils.run_briefly(self.loop)
229        self.assertFalse(self.protocol.connection_lost.called)
230
231    def test_close_write_fut(self):
232        tr = self.socket_transport()
233        tr._write_fut = mock.Mock()
234        tr.close()
235        test_utils.run_briefly(self.loop)
236        self.assertFalse(self.protocol.connection_lost.called)
237
238    def test_close_buffer(self):
239        tr = self.socket_transport()
240        tr._buffer = [b'data']
241        tr.close()
242        test_utils.run_briefly(self.loop)
243        self.assertFalse(self.protocol.connection_lost.called)
244
245    def test_close_invalid_sockobj(self):
246        tr = self.socket_transport()
247        self.sock.fileno.return_value = -1
248        tr.close()
249        test_utils.run_briefly(self.loop)
250        self.protocol.connection_lost.assert_called_with(None)
251        self.assertFalse(self.sock.shutdown.called)
252
253    @mock.patch('asyncio.base_events.logger')
254    def test_fatal_error(self, m_logging):
255        tr = self.socket_transport()
256        tr._force_close = mock.Mock()
257        tr._fatal_error(None)
258        self.assertTrue(tr._force_close.called)
259        self.assertTrue(m_logging.error.called)
260
261    def test_force_close(self):
262        tr = self.socket_transport()
263        tr._buffer = [b'data']
264        read_fut = tr._read_fut = mock.Mock()
265        write_fut = tr._write_fut = mock.Mock()
266        tr._force_close(None)
267
268        read_fut.cancel.assert_called_with()
269        write_fut.cancel.assert_called_with()
270        test_utils.run_briefly(self.loop)
271        self.protocol.connection_lost.assert_called_with(None)
272        self.assertEqual(None, tr._buffer)
273        self.assertEqual(tr._conn_lost, 1)
274
275    def test_loop_writing_force_close(self):
276        exc_handler = mock.Mock()
277        self.loop.set_exception_handler(exc_handler)
278        fut = self.loop.create_future()
279        fut.set_result(1)
280        self.proactor.send.return_value = fut
281
282        tr = self.socket_transport()
283        tr.write(b'data')
284        tr._force_close(None)
285        test_utils.run_briefly(self.loop)
286        exc_handler.assert_not_called()
287
288    def test_force_close_idempotent(self):
289        tr = self.socket_transport()
290        tr._closing = True
291        tr._force_close(None)
292        test_utils.run_briefly(self.loop)
293        # See https://github.com/python/cpython/issues/89237
294        # `protocol.connection_lost` should be called even if
295        # the transport was closed forcefully otherwise
296        # the resources held by protocol will never be freed
297        # and waiters will never be notified leading to hang.
298        self.assertTrue(self.protocol.connection_lost.called)
299
300    def test_force_close_protocol_connection_lost_once(self):
301        tr = self.socket_transport()
302        self.assertFalse(self.protocol.connection_lost.called)
303        tr._closing = True
304        # Calling _force_close twice should not call
305        # protocol.connection_lost twice
306        tr._force_close(None)
307        tr._force_close(None)
308        test_utils.run_briefly(self.loop)
309        self.assertEqual(1, self.protocol.connection_lost.call_count)
310
311    def test_close_protocol_connection_lost_once(self):
312        tr = self.socket_transport()
313        self.assertFalse(self.protocol.connection_lost.called)
314        # Calling close twice should not call
315        # protocol.connection_lost twice
316        tr.close()
317        tr.close()
318        test_utils.run_briefly(self.loop)
319        self.assertEqual(1, self.protocol.connection_lost.call_count)
320
321    def test_fatal_error_2(self):
322        tr = self.socket_transport()
323        tr._buffer = [b'data']
324        tr._force_close(None)
325
326        test_utils.run_briefly(self.loop)
327        self.protocol.connection_lost.assert_called_with(None)
328        self.assertEqual(None, tr._buffer)
329
330    def test_call_connection_lost(self):
331        tr = self.socket_transport()
332        tr._call_connection_lost(None)
333        self.assertTrue(self.protocol.connection_lost.called)
334        self.assertTrue(self.sock.close.called)
335
336    def test_write_eof(self):
337        tr = self.socket_transport()
338        self.assertTrue(tr.can_write_eof())
339        tr.write_eof()
340        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
341        tr.write_eof()
342        self.assertEqual(self.sock.shutdown.call_count, 1)
343        tr.close()
344
345    def test_write_eof_buffer(self):
346        tr = self.socket_transport()
347        f = self.loop.create_future()
348        tr._loop._proactor.send.return_value = f
349        tr.write(b'data')
350        tr.write_eof()
351        self.assertTrue(tr._eof_written)
352        self.assertFalse(self.sock.shutdown.called)
353        tr._loop._proactor.send.assert_called_with(self.sock, b'data')
354        f.set_result(4)
355        self.loop._run_once()
356        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
357        tr.close()
358
359    def test_write_eof_write_pipe(self):
360        tr = _ProactorWritePipeTransport(
361            self.loop, self.sock, self.protocol)
362        self.assertTrue(tr.can_write_eof())
363        tr.write_eof()
364        self.assertTrue(tr.is_closing())
365        self.loop._run_once()
366        self.assertTrue(self.sock.close.called)
367        tr.close()
368
369    def test_write_eof_buffer_write_pipe(self):
370        tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
371        f = self.loop.create_future()
372        tr._loop._proactor.send.return_value = f
373        tr.write(b'data')
374        tr.write_eof()
375        self.assertTrue(tr.is_closing())
376        self.assertFalse(self.sock.shutdown.called)
377        tr._loop._proactor.send.assert_called_with(self.sock, b'data')
378        f.set_result(4)
379        self.loop._run_once()
380        self.loop._run_once()
381        self.assertTrue(self.sock.close.called)
382        tr.close()
383
384    def test_write_eof_duplex_pipe(self):
385        tr = _ProactorDuplexPipeTransport(
386            self.loop, self.sock, self.protocol)
387        self.assertFalse(tr.can_write_eof())
388        with self.assertRaises(NotImplementedError):
389            tr.write_eof()
390        close_transport(tr)
391
392    def test_pause_resume_reading(self):
393        tr = self.socket_transport()
394        index = 0
395        msgs = [b'data1', b'data2', b'data3', b'data4', b'data5', b'']
396        reversed_msgs = list(reversed(msgs))
397
398        def recv_into(sock, data):
399            f = self.loop.create_future()
400            msg = reversed_msgs.pop()
401
402            result = f.result
403            def monkey():
404                data[:len(msg)] = msg
405                return result()
406            f.result = monkey
407
408            f.set_result(len(msg))
409            return f
410
411        self.loop._proactor.recv_into.side_effect = recv_into
412        self.loop._run_once()
413        self.assertFalse(tr._paused)
414        self.assertTrue(tr.is_reading())
415
416        for msg in msgs[:2]:
417            self.loop._run_once()
418            self.protocol.data_received.assert_called_with(bytearray(msg))
419
420        tr.pause_reading()
421        tr.pause_reading()
422        self.assertTrue(tr._paused)
423        self.assertFalse(tr.is_reading())
424        for i in range(10):
425            self.loop._run_once()
426        self.protocol.data_received.assert_called_with(bytearray(msgs[1]))
427
428        tr.resume_reading()
429        tr.resume_reading()
430        self.assertFalse(tr._paused)
431        self.assertTrue(tr.is_reading())
432
433        for msg in msgs[2:4]:
434            self.loop._run_once()
435            self.protocol.data_received.assert_called_with(bytearray(msg))
436
437        tr.pause_reading()
438        tr.resume_reading()
439        self.loop.call_exception_handler = mock.Mock()
440        self.loop._run_once()
441        self.loop.call_exception_handler.assert_not_called()
442        self.protocol.data_received.assert_called_with(bytearray(msgs[4]))
443        tr.close()
444
445        self.assertFalse(tr.is_reading())
446
447    def test_pause_reading_connection_made(self):
448        tr = self.socket_transport()
449        self.protocol.connection_made.side_effect = lambda _: tr.pause_reading()
450        test_utils.run_briefly(self.loop)
451        self.assertFalse(tr.is_reading())
452        self.loop.assert_no_reader(7)
453
454        tr.resume_reading()
455        self.assertTrue(tr.is_reading())
456
457        tr.close()
458        self.assertFalse(tr.is_reading())
459
460
461    def pause_writing_transport(self, high):
462        tr = self.socket_transport()
463        tr.set_write_buffer_limits(high=high)
464
465        self.assertEqual(tr.get_write_buffer_size(), 0)
466        self.assertFalse(self.protocol.pause_writing.called)
467        self.assertFalse(self.protocol.resume_writing.called)
468        return tr
469
470    def test_pause_resume_writing(self):
471        tr = self.pause_writing_transport(high=4)
472
473        # write a large chunk, must pause writing
474        fut = self.loop.create_future()
475        self.loop._proactor.send.return_value = fut
476        tr.write(b'large data')
477        self.loop._run_once()
478        self.assertTrue(self.protocol.pause_writing.called)
479
480        # flush the buffer
481        fut.set_result(None)
482        self.loop._run_once()
483        self.assertEqual(tr.get_write_buffer_size(), 0)
484        self.assertTrue(self.protocol.resume_writing.called)
485
486    def test_pause_writing_2write(self):
487        tr = self.pause_writing_transport(high=4)
488
489        # first short write, the buffer is not full (3 <= 4)
490        fut1 = self.loop.create_future()
491        self.loop._proactor.send.return_value = fut1
492        tr.write(b'123')
493        self.loop._run_once()
494        self.assertEqual(tr.get_write_buffer_size(), 3)
495        self.assertFalse(self.protocol.pause_writing.called)
496
497        # fill the buffer, must pause writing (6 > 4)
498        tr.write(b'abc')
499        self.loop._run_once()
500        self.assertEqual(tr.get_write_buffer_size(), 6)
501        self.assertTrue(self.protocol.pause_writing.called)
502
503    def test_pause_writing_3write(self):
504        tr = self.pause_writing_transport(high=4)
505
506        # first short write, the buffer is not full (1 <= 4)
507        fut = self.loop.create_future()
508        self.loop._proactor.send.return_value = fut
509        tr.write(b'1')
510        self.loop._run_once()
511        self.assertEqual(tr.get_write_buffer_size(), 1)
512        self.assertFalse(self.protocol.pause_writing.called)
513
514        # second short write, the buffer is not full (3 <= 4)
515        tr.write(b'23')
516        self.loop._run_once()
517        self.assertEqual(tr.get_write_buffer_size(), 3)
518        self.assertFalse(self.protocol.pause_writing.called)
519
520        # fill the buffer, must pause writing (6 > 4)
521        tr.write(b'abc')
522        self.loop._run_once()
523        self.assertEqual(tr.get_write_buffer_size(), 6)
524        self.assertTrue(self.protocol.pause_writing.called)
525
526    def test_dont_pause_writing(self):
527        tr = self.pause_writing_transport(high=4)
528
529        # write a large chunk which completes immediately,
530        # it should not pause writing
531        fut = self.loop.create_future()
532        fut.set_result(None)
533        self.loop._proactor.send.return_value = fut
534        tr.write(b'very large data')
535        self.loop._run_once()
536        self.assertEqual(tr.get_write_buffer_size(), 0)
537        self.assertFalse(self.protocol.pause_writing.called)
538
539
540class ProactorDatagramTransportTests(test_utils.TestCase):
541
542    def setUp(self):
543        super().setUp()
544        self.loop = self.new_test_loop()
545        self.proactor = mock.Mock()
546        self.loop._proactor = self.proactor
547        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
548        self.sock = mock.Mock(spec_set=socket.socket)
549        self.sock.fileno.return_value = 7
550
551    def datagram_transport(self, address=None):
552        self.sock.getpeername.side_effect = None if address else OSError
553        transport = _ProactorDatagramTransport(self.loop, self.sock,
554                                               self.protocol,
555                                               address=address)
556        self.addCleanup(close_transport, transport)
557        return transport
558
559    def test_sendto(self):
560        data = b'data'
561        transport = self.datagram_transport()
562        transport.sendto(data, ('0.0.0.0', 1234))
563        self.assertTrue(self.proactor.sendto.called)
564        self.proactor.sendto.assert_called_with(
565            self.sock, data, addr=('0.0.0.0', 1234))
566
567    def test_sendto_bytearray(self):
568        data = bytearray(b'data')
569        transport = self.datagram_transport()
570        transport.sendto(data, ('0.0.0.0', 1234))
571        self.assertTrue(self.proactor.sendto.called)
572        self.proactor.sendto.assert_called_with(
573            self.sock, b'data', addr=('0.0.0.0', 1234))
574
575    def test_sendto_memoryview(self):
576        data = memoryview(b'data')
577        transport = self.datagram_transport()
578        transport.sendto(data, ('0.0.0.0', 1234))
579        self.assertTrue(self.proactor.sendto.called)
580        self.proactor.sendto.assert_called_with(
581            self.sock, b'data', addr=('0.0.0.0', 1234))
582
583    def test_sendto_no_data(self):
584        transport = self.datagram_transport()
585        transport._buffer.append((b'data', ('0.0.0.0', 12345)))
586        transport.sendto(b'', ())
587        self.assertFalse(self.sock.sendto.called)
588        self.assertEqual(
589            [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
590
591    def test_sendto_buffer(self):
592        transport = self.datagram_transport()
593        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
594        transport._write_fut = object()
595        transport.sendto(b'data2', ('0.0.0.0', 12345))
596        self.assertFalse(self.proactor.sendto.called)
597        self.assertEqual(
598            [(b'data1', ('0.0.0.0', 12345)),
599             (b'data2', ('0.0.0.0', 12345))],
600            list(transport._buffer))
601
602    def test_sendto_buffer_bytearray(self):
603        data2 = bytearray(b'data2')
604        transport = self.datagram_transport()
605        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
606        transport._write_fut = object()
607        transport.sendto(data2, ('0.0.0.0', 12345))
608        self.assertFalse(self.proactor.sendto.called)
609        self.assertEqual(
610            [(b'data1', ('0.0.0.0', 12345)),
611             (b'data2', ('0.0.0.0', 12345))],
612            list(transport._buffer))
613        self.assertIsInstance(transport._buffer[1][0], bytes)
614
615    def test_sendto_buffer_memoryview(self):
616        data2 = memoryview(b'data2')
617        transport = self.datagram_transport()
618        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
619        transport._write_fut = object()
620        transport.sendto(data2, ('0.0.0.0', 12345))
621        self.assertFalse(self.proactor.sendto.called)
622        self.assertEqual(
623            [(b'data1', ('0.0.0.0', 12345)),
624             (b'data2', ('0.0.0.0', 12345))],
625            list(transport._buffer))
626        self.assertIsInstance(transport._buffer[1][0], bytes)
627
628    @mock.patch('asyncio.proactor_events.logger')
629    def test_sendto_exception(self, m_log):
630        data = b'data'
631        err = self.proactor.sendto.side_effect = RuntimeError()
632
633        transport = self.datagram_transport()
634        transport._fatal_error = mock.Mock()
635        transport.sendto(data, ())
636
637        self.assertTrue(transport._fatal_error.called)
638        transport._fatal_error.assert_called_with(
639                                   err,
640                                   'Fatal write error on datagram transport')
641        transport._conn_lost = 1
642
643        transport._address = ('123',)
644        transport.sendto(data)
645        transport.sendto(data)
646        transport.sendto(data)
647        transport.sendto(data)
648        transport.sendto(data)
649        m_log.warning.assert_called_with('socket.sendto() raised exception.')
650
651    def test_sendto_error_received(self):
652        data = b'data'
653
654        self.sock.sendto.side_effect = ConnectionRefusedError
655
656        transport = self.datagram_transport()
657        transport._fatal_error = mock.Mock()
658        transport.sendto(data, ())
659
660        self.assertEqual(transport._conn_lost, 0)
661        self.assertFalse(transport._fatal_error.called)
662
663    def test_sendto_error_received_connected(self):
664        data = b'data'
665
666        self.proactor.send.side_effect = ConnectionRefusedError
667
668        transport = self.datagram_transport(address=('0.0.0.0', 1))
669        transport._fatal_error = mock.Mock()
670        transport.sendto(data)
671
672        self.assertFalse(transport._fatal_error.called)
673        self.assertTrue(self.protocol.error_received.called)
674
675    def test_sendto_str(self):
676        transport = self.datagram_transport()
677        self.assertRaises(TypeError, transport.sendto, 'str', ())
678
679    def test_sendto_connected_addr(self):
680        transport = self.datagram_transport(address=('0.0.0.0', 1))
681        self.assertRaises(
682            ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
683
684    def test_sendto_closing(self):
685        transport = self.datagram_transport(address=(1,))
686        transport.close()
687        self.assertEqual(transport._conn_lost, 1)
688        transport.sendto(b'data', (1,))
689        self.assertEqual(transport._conn_lost, 2)
690
691    def test__loop_writing_closing(self):
692        transport = self.datagram_transport()
693        transport._closing = True
694        transport._loop_writing()
695        self.assertIsNone(transport._write_fut)
696        test_utils.run_briefly(self.loop)
697        self.sock.close.assert_called_with()
698        self.protocol.connection_lost.assert_called_with(None)
699
700    def test__loop_writing_exception(self):
701        err = self.proactor.sendto.side_effect = RuntimeError()
702
703        transport = self.datagram_transport()
704        transport._fatal_error = mock.Mock()
705        transport._buffer.append((b'data', ()))
706        transport._loop_writing()
707
708        transport._fatal_error.assert_called_with(
709                                   err,
710                                   'Fatal write error on datagram transport')
711
712    def test__loop_writing_error_received(self):
713        self.proactor.sendto.side_effect = ConnectionRefusedError
714
715        transport = self.datagram_transport()
716        transport._fatal_error = mock.Mock()
717        transport._buffer.append((b'data', ()))
718        transport._loop_writing()
719
720        self.assertFalse(transport._fatal_error.called)
721
722    def test__loop_writing_error_received_connection(self):
723        self.proactor.send.side_effect = ConnectionRefusedError
724
725        transport = self.datagram_transport(address=('0.0.0.0', 1))
726        transport._fatal_error = mock.Mock()
727        transport._buffer.append((b'data', ()))
728        transport._loop_writing()
729
730        self.assertFalse(transport._fatal_error.called)
731        self.assertTrue(self.protocol.error_received.called)
732
733    @mock.patch('asyncio.base_events.logger.error')
734    def test_fatal_error_connected(self, m_exc):
735        transport = self.datagram_transport(address=('0.0.0.0', 1))
736        err = ConnectionRefusedError()
737        transport._fatal_error(err)
738        self.assertFalse(self.protocol.error_received.called)
739        m_exc.assert_not_called()
740
741
742class BaseProactorEventLoopTests(test_utils.TestCase):
743
744    def setUp(self):
745        super().setUp()
746
747        self.sock = test_utils.mock_nonblocking_socket()
748        self.proactor = mock.Mock()
749
750        self.ssock, self.csock = mock.Mock(), mock.Mock()
751
752        with mock.patch('asyncio.proactor_events.socket.socketpair',
753                        return_value=(self.ssock, self.csock)):
754            with mock.patch('signal.set_wakeup_fd'):
755                self.loop = BaseProactorEventLoop(self.proactor)
756        self.set_event_loop(self.loop)
757
758    @mock.patch('asyncio.proactor_events.socket.socketpair')
759    def test_ctor(self, socketpair):
760        ssock, csock = socketpair.return_value = (
761            mock.Mock(), mock.Mock())
762        with mock.patch('signal.set_wakeup_fd'):
763            loop = BaseProactorEventLoop(self.proactor)
764        self.assertIs(loop._ssock, ssock)
765        self.assertIs(loop._csock, csock)
766        self.assertEqual(loop._internal_fds, 1)
767        loop.close()
768
769    def test_close_self_pipe(self):
770        self.loop._close_self_pipe()
771        self.assertEqual(self.loop._internal_fds, 0)
772        self.assertTrue(self.ssock.close.called)
773        self.assertTrue(self.csock.close.called)
774        self.assertIsNone(self.loop._ssock)
775        self.assertIsNone(self.loop._csock)
776
777        # Don't call close(): _close_self_pipe() cannot be called twice
778        self.loop._closed = True
779
780    def test_close(self):
781        self.loop._close_self_pipe = mock.Mock()
782        self.loop.close()
783        self.assertTrue(self.loop._close_self_pipe.called)
784        self.assertTrue(self.proactor.close.called)
785        self.assertIsNone(self.loop._proactor)
786
787        self.loop._close_self_pipe.reset_mock()
788        self.loop.close()
789        self.assertFalse(self.loop._close_self_pipe.called)
790
791    def test_make_socket_transport(self):
792        tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
793        self.assertIsInstance(tr, _ProactorSocketTransport)
794        close_transport(tr)
795
796    def test_loop_self_reading(self):
797        self.loop._loop_self_reading()
798        self.proactor.recv.assert_called_with(self.ssock, 4096)
799        self.proactor.recv.return_value.add_done_callback.assert_called_with(
800            self.loop._loop_self_reading)
801
802    def test_loop_self_reading_fut(self):
803        fut = mock.Mock()
804        self.loop._self_reading_future = fut
805        self.loop._loop_self_reading(fut)
806        self.assertTrue(fut.result.called)
807        self.proactor.recv.assert_called_with(self.ssock, 4096)
808        self.proactor.recv.return_value.add_done_callback.assert_called_with(
809            self.loop._loop_self_reading)
810
811    def test_loop_self_reading_exception(self):
812        self.loop.call_exception_handler = mock.Mock()
813        self.proactor.recv.side_effect = OSError()
814        self.loop._loop_self_reading()
815        self.assertTrue(self.loop.call_exception_handler.called)
816
817    def test_write_to_self(self):
818        self.loop._write_to_self()
819        self.csock.send.assert_called_with(b'\0')
820
821    def test_process_events(self):
822        self.loop._process_events([])
823
824    @mock.patch('asyncio.base_events.logger')
825    def test_create_server(self, m_log):
826        pf = mock.Mock()
827        call_soon = self.loop.call_soon = mock.Mock()
828
829        self.loop._start_serving(pf, self.sock)
830        self.assertTrue(call_soon.called)
831
832        # callback
833        loop = call_soon.call_args[0][0]
834        loop()
835        self.proactor.accept.assert_called_with(self.sock)
836
837        # conn
838        fut = mock.Mock()
839        fut.result.return_value = (mock.Mock(), mock.Mock())
840
841        make_tr = self.loop._make_socket_transport = mock.Mock()
842        loop(fut)
843        self.assertTrue(fut.result.called)
844        self.assertTrue(make_tr.called)
845
846        # exception
847        fut.result.side_effect = OSError()
848        loop(fut)
849        self.assertTrue(self.sock.close.called)
850        self.assertTrue(m_log.error.called)
851
852    def test_create_server_cancel(self):
853        pf = mock.Mock()
854        call_soon = self.loop.call_soon = mock.Mock()
855
856        self.loop._start_serving(pf, self.sock)
857        loop = call_soon.call_args[0][0]
858
859        # cancelled
860        fut = self.loop.create_future()
861        fut.cancel()
862        loop(fut)
863        self.assertTrue(self.sock.close.called)
864
865    def test_stop_serving(self):
866        sock1 = mock.Mock()
867        future1 = mock.Mock()
868        sock2 = mock.Mock()
869        future2 = mock.Mock()
870        self.loop._accept_futures = {
871            sock1.fileno(): future1,
872            sock2.fileno(): future2
873        }
874
875        self.loop._stop_serving(sock1)
876        self.assertTrue(sock1.close.called)
877        self.assertTrue(future1.cancel.called)
878        self.proactor._stop_serving.assert_called_with(sock1)
879        self.assertFalse(sock2.close.called)
880        self.assertFalse(future2.cancel.called)
881
882    def datagram_transport(self):
883        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
884        return self.loop._make_datagram_transport(self.sock, self.protocol)
885
886    def test_make_datagram_transport(self):
887        tr = self.datagram_transport()
888        self.assertIsInstance(tr, _ProactorDatagramTransport)
889        self.assertIsInstance(tr, asyncio.DatagramTransport)
890        close_transport(tr)
891
892    def test_datagram_loop_writing(self):
893        tr = self.datagram_transport()
894        tr._buffer.appendleft((b'data', ('127.0.0.1', 12068)))
895        tr._loop_writing()
896        self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068))
897        self.loop._proactor.sendto.return_value.add_done_callback.\
898            assert_called_with(tr._loop_writing)
899
900        close_transport(tr)
901
902    def test_datagram_loop_reading(self):
903        tr = self.datagram_transport()
904        tr._loop_reading()
905        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
906        self.assertFalse(self.protocol.datagram_received.called)
907        self.assertFalse(self.protocol.error_received.called)
908        close_transport(tr)
909
910    def test_datagram_loop_reading_data(self):
911        res = self.loop.create_future()
912        res.set_result((b'data', ('127.0.0.1', 12068)))
913
914        tr = self.datagram_transport()
915        tr._read_fut = res
916        tr._loop_reading(res)
917        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
918        self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068))
919        close_transport(tr)
920
921    @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode")
922    def test_datagram_loop_reading_no_data(self):
923        res = self.loop.create_future()
924        res.set_result((b'', ('127.0.0.1', 12068)))
925
926        tr = self.datagram_transport()
927        self.assertRaises(AssertionError, tr._loop_reading, res)
928
929        tr.close = mock.Mock()
930        tr._read_fut = res
931        tr._loop_reading(res)
932        self.assertTrue(self.loop._proactor.recvfrom.called)
933        self.assertFalse(self.protocol.error_received.called)
934        self.assertFalse(tr.close.called)
935        close_transport(tr)
936
937    def test_datagram_loop_reading_aborted(self):
938        err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError()
939
940        tr = self.datagram_transport()
941        tr._fatal_error = mock.Mock()
942        tr._protocol.error_received = mock.Mock()
943        tr._loop_reading()
944        tr._protocol.error_received.assert_called_with(err)
945        close_transport(tr)
946
947    def test_datagram_loop_writing_aborted(self):
948        err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError()
949
950        tr = self.datagram_transport()
951        tr._fatal_error = mock.Mock()
952        tr._protocol.error_received = mock.Mock()
953        tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068)))
954        tr._loop_writing()
955        tr._protocol.error_received.assert_called_with(err)
956        close_transport(tr)
957
958
959@unittest.skipIf(sys.platform != 'win32',
960                 'Proactor is supported on Windows only')
961class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase):
962    DATA = b"12345abcde" * 16 * 1024  # 160 KiB
963
964    class MyProto(asyncio.Protocol):
965
966        def __init__(self, loop):
967            self.started = False
968            self.closed = False
969            self.data = bytearray()
970            self.fut = loop.create_future()
971            self.transport = None
972
973        def connection_made(self, transport):
974            self.started = True
975            self.transport = transport
976
977        def data_received(self, data):
978            self.data.extend(data)
979
980        def connection_lost(self, exc):
981            self.closed = True
982            self.fut.set_result(None)
983
984        async def wait_closed(self):
985            await self.fut
986
987    @classmethod
988    def setUpClass(cls):
989        with open(os_helper.TESTFN, 'wb') as fp:
990            fp.write(cls.DATA)
991        super().setUpClass()
992
993    @classmethod
994    def tearDownClass(cls):
995        os_helper.unlink(os_helper.TESTFN)
996        super().tearDownClass()
997
998    def setUp(self):
999        self.loop = asyncio.ProactorEventLoop()
1000        self.set_event_loop(self.loop)
1001        self.addCleanup(self.loop.close)
1002        self.file = open(os_helper.TESTFN, 'rb')
1003        self.addCleanup(self.file.close)
1004        super().setUp()
1005
1006    def make_socket(self, cleanup=True):
1007        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1008        sock.setblocking(False)
1009        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
1010        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
1011        if cleanup:
1012            self.addCleanup(sock.close)
1013        return sock
1014
1015    def run_loop(self, coro):
1016        return self.loop.run_until_complete(coro)
1017
1018    def prepare(self):
1019        sock = self.make_socket()
1020        proto = self.MyProto(self.loop)
1021        port = socket_helper.find_unused_port()
1022        srv_sock = self.make_socket(cleanup=False)
1023        srv_sock.bind(('127.0.0.1', port))
1024        server = self.run_loop(self.loop.create_server(
1025            lambda: proto, sock=srv_sock))
1026        self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname()))
1027
1028        def cleanup():
1029            if proto.transport is not None:
1030                # can be None if the task was cancelled before
1031                # connection_made callback
1032                proto.transport.close()
1033                self.run_loop(proto.wait_closed())
1034
1035            server.close()
1036            self.run_loop(server.wait_closed())
1037
1038        self.addCleanup(cleanup)
1039
1040        return sock, proto
1041
1042    def test_sock_sendfile_not_a_file(self):
1043        sock, proto = self.prepare()
1044        f = object()
1045        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
1046                                    "not a regular file"):
1047            self.run_loop(self.loop._sock_sendfile_native(sock, f,
1048                                                          0, None))
1049        self.assertEqual(self.file.tell(), 0)
1050
1051    def test_sock_sendfile_iobuffer(self):
1052        sock, proto = self.prepare()
1053        f = io.BytesIO()
1054        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
1055                                    "not a regular file"):
1056            self.run_loop(self.loop._sock_sendfile_native(sock, f,
1057                                                          0, None))
1058        self.assertEqual(self.file.tell(), 0)
1059
1060    def test_sock_sendfile_not_regular_file(self):
1061        sock, proto = self.prepare()
1062        f = mock.Mock()
1063        f.fileno.return_value = -1
1064        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
1065                                    "not a regular file"):
1066            self.run_loop(self.loop._sock_sendfile_native(sock, f,
1067                                                          0, None))
1068        self.assertEqual(self.file.tell(), 0)
1069
1070
1071if __name__ == '__main__':
1072    unittest.main()
1073