1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import time
16
17import flask
18import pytest
19from pytest_localserver.http import WSGIServer
20from six.moves import http_client
21
22from google.auth import exceptions
23from tests.transport import compliance
24
25
26class RequestResponseTests(object):
27    @pytest.fixture(scope="module")
28    def server(self):
29        """Provides a test HTTP server.
30
31        The test server is automatically created before
32        a test and destroyed at the end. The server is serving a test
33        application that can be used to verify requests.
34        """
35        app = flask.Flask(__name__)
36        app.debug = True
37
38        # pylint: disable=unused-variable
39        # (pylint thinks the flask routes are unusued.)
40        @app.route("/basic")
41        def index():
42            header_value = flask.request.headers.get("x-test-header", "value")
43            headers = {"X-Test-Header": header_value}
44            return "Basic Content", http_client.OK, headers
45
46        @app.route("/server_error")
47        def server_error():
48            return "Error", http_client.INTERNAL_SERVER_ERROR
49
50        @app.route("/wait")
51        def wait():
52            time.sleep(3)
53            return "Waited"
54
55        # pylint: enable=unused-variable
56
57        server = WSGIServer(application=app.wsgi_app)
58        server.start()
59        yield server
60        server.stop()
61
62    @pytest.mark.asyncio
63    async def test_request_basic(self, server):
64        request = self.make_request()
65        response = await request(url=server.url + "/basic", method="GET")
66        assert response.status == http_client.OK
67        assert response.headers["x-test-header"] == "value"
68
69        # Use 13 as this is the length of the data written into the stream.
70
71        data = await response.data.read(13)
72        assert data == b"Basic Content"
73
74    @pytest.mark.asyncio
75    async def test_request_basic_with_http(self, server):
76        request = self.make_with_parameter_request()
77        response = await request(url=server.url + "/basic", method="GET")
78        assert response.status == http_client.OK
79        assert response.headers["x-test-header"] == "value"
80
81        # Use 13 as this is the length of the data written into the stream.
82
83        data = await response.data.read(13)
84        assert data == b"Basic Content"
85
86    @pytest.mark.asyncio
87    async def test_request_with_timeout_success(self, server):
88        request = self.make_request()
89        response = await request(url=server.url + "/basic", method="GET", timeout=2)
90
91        assert response.status == http_client.OK
92        assert response.headers["x-test-header"] == "value"
93
94        data = await response.data.read(13)
95        assert data == b"Basic Content"
96
97    @pytest.mark.asyncio
98    async def test_request_with_timeout_failure(self, server):
99        request = self.make_request()
100
101        with pytest.raises(exceptions.TransportError):
102            await request(url=server.url + "/wait", method="GET", timeout=1)
103
104    @pytest.mark.asyncio
105    async def test_request_headers(self, server):
106        request = self.make_request()
107        response = await request(
108            url=server.url + "/basic",
109            method="GET",
110            headers={"x-test-header": "hello world"},
111        )
112
113        assert response.status == http_client.OK
114        assert response.headers["x-test-header"] == "hello world"
115
116        data = await response.data.read(13)
117        assert data == b"Basic Content"
118
119    @pytest.mark.asyncio
120    async def test_request_error(self, server):
121        request = self.make_request()
122
123        response = await request(url=server.url + "/server_error", method="GET")
124        assert response.status == http_client.INTERNAL_SERVER_ERROR
125        data = await response.data.read(5)
126        assert data == b"Error"
127
128    @pytest.mark.asyncio
129    async def test_connection_error(self):
130        request = self.make_request()
131
132        with pytest.raises(exceptions.TransportError):
133            await request(url="http://{}".format(compliance.NXDOMAIN), method="GET")
134