1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import time 16 17import flask 18import pytest 19from pytest_localserver.http import WSGIServer 20from six.moves import http_client 21 22from google.auth import exceptions 23from tests.transport import compliance 24 25 26class RequestResponseTests(object): 27 @pytest.fixture(scope="module") 28 def server(self): 29 """Provides a test HTTP server. 30 31 The test server is automatically created before 32 a test and destroyed at the end. The server is serving a test 33 application that can be used to verify requests. 34 """ 35 app = flask.Flask(__name__) 36 app.debug = True 37 38 # pylint: disable=unused-variable 39 # (pylint thinks the flask routes are unusued.) 40 @app.route("/basic") 41 def index(): 42 header_value = flask.request.headers.get("x-test-header", "value") 43 headers = {"X-Test-Header": header_value} 44 return "Basic Content", http_client.OK, headers 45 46 @app.route("/server_error") 47 def server_error(): 48 return "Error", http_client.INTERNAL_SERVER_ERROR 49 50 @app.route("/wait") 51 def wait(): 52 time.sleep(3) 53 return "Waited" 54 55 # pylint: enable=unused-variable 56 57 server = WSGIServer(application=app.wsgi_app) 58 server.start() 59 yield server 60 server.stop() 61 62 @pytest.mark.asyncio 63 async def test_request_basic(self, server): 64 request = self.make_request() 65 response = await request(url=server.url + "/basic", method="GET") 66 assert response.status == http_client.OK 67 assert response.headers["x-test-header"] == "value" 68 69 # Use 13 as this is the length of the data written into the stream. 70 71 data = await response.data.read(13) 72 assert data == b"Basic Content" 73 74 @pytest.mark.asyncio 75 async def test_request_basic_with_http(self, server): 76 request = self.make_with_parameter_request() 77 response = await request(url=server.url + "/basic", method="GET") 78 assert response.status == http_client.OK 79 assert response.headers["x-test-header"] == "value" 80 81 # Use 13 as this is the length of the data written into the stream. 82 83 data = await response.data.read(13) 84 assert data == b"Basic Content" 85 86 @pytest.mark.asyncio 87 async def test_request_with_timeout_success(self, server): 88 request = self.make_request() 89 response = await request(url=server.url + "/basic", method="GET", timeout=2) 90 91 assert response.status == http_client.OK 92 assert response.headers["x-test-header"] == "value" 93 94 data = await response.data.read(13) 95 assert data == b"Basic Content" 96 97 @pytest.mark.asyncio 98 async def test_request_with_timeout_failure(self, server): 99 request = self.make_request() 100 101 with pytest.raises(exceptions.TransportError): 102 await request(url=server.url + "/wait", method="GET", timeout=1) 103 104 @pytest.mark.asyncio 105 async def test_request_headers(self, server): 106 request = self.make_request() 107 response = await request( 108 url=server.url + "/basic", 109 method="GET", 110 headers={"x-test-header": "hello world"}, 111 ) 112 113 assert response.status == http_client.OK 114 assert response.headers["x-test-header"] == "hello world" 115 116 data = await response.data.read(13) 117 assert data == b"Basic Content" 118 119 @pytest.mark.asyncio 120 async def test_request_error(self, server): 121 request = self.make_request() 122 123 response = await request(url=server.url + "/server_error", method="GET") 124 assert response.status == http_client.INTERNAL_SERVER_ERROR 125 data = await response.data.read(5) 126 assert data == b"Error" 127 128 @pytest.mark.asyncio 129 async def test_connection_error(self): 130 request = self.make_request() 131 132 with pytest.raises(exceptions.TransportError): 133 await request(url="http://{}".format(compliance.NXDOMAIN), method="GET") 134