1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import datetime 16import json 17 18import mock 19import pytest 20import six 21from six.moves import http_client 22from six.moves import urllib 23 24from google.auth import _helpers 25from google.auth import _jwt_async as jwt 26from google.auth import exceptions 27from google.oauth2 import _client as sync_client 28from google.oauth2 import _client_async as _client 29from tests.oauth2 import test__client as test_client 30 31 32def make_request(response_data, status=http_client.OK): 33 response = mock.AsyncMock(spec=["transport.Response"]) 34 response.status = status 35 data = json.dumps(response_data).encode("utf-8") 36 response.data = mock.AsyncMock(spec=["__call__", "read"]) 37 response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) 38 response.content = mock.AsyncMock(spec=["__call__"], return_value=data) 39 request = mock.AsyncMock(spec=["transport.Request"]) 40 request.return_value = response 41 return request 42 43 44@pytest.mark.asyncio 45async def test__token_endpoint_request(): 46 47 request = make_request({"test": "response"}) 48 49 result = await _client._token_endpoint_request( 50 request, "http://example.com", {"test": "params"} 51 ) 52 53 # Check request call 54 request.assert_called_with( 55 method="POST", 56 url="http://example.com", 57 headers={"Content-Type": "application/x-www-form-urlencoded"}, 58 body="test=params".encode("utf-8"), 59 ) 60 61 # Check result 62 assert result == {"test": "response"} 63 64 65@pytest.mark.asyncio 66async def test__token_endpoint_request_json(): 67 68 request = make_request({"test": "response"}) 69 access_token = "access_token" 70 71 result = await _client._token_endpoint_request( 72 request, 73 "http://example.com", 74 {"test": "params"}, 75 access_token=access_token, 76 use_json=True, 77 ) 78 79 # Check request call 80 request.assert_called_with( 81 method="POST", 82 url="http://example.com", 83 headers={ 84 "Content-Type": "application/json", 85 "Authorization": "Bearer access_token", 86 }, 87 body=b'{"test": "params"}', 88 ) 89 90 # Check result 91 assert result == {"test": "response"} 92 93 94@pytest.mark.asyncio 95async def test__token_endpoint_request_error(): 96 request = make_request({}, status=http_client.BAD_REQUEST) 97 98 with pytest.raises(exceptions.RefreshError): 99 await _client._token_endpoint_request(request, "http://example.com", {}) 100 101 102@pytest.mark.asyncio 103async def test__token_endpoint_request_internal_failure_error(): 104 request = make_request( 105 {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST 106 ) 107 108 with pytest.raises(exceptions.RefreshError): 109 await _client._token_endpoint_request( 110 request, "http://example.com", {"error_description": "internal_failure"} 111 ) 112 113 request = make_request( 114 {"error": "internal_failure"}, status=http_client.BAD_REQUEST 115 ) 116 117 with pytest.raises(exceptions.RefreshError): 118 await _client._token_endpoint_request( 119 request, "http://example.com", {"error": "internal_failure"} 120 ) 121 122 123def verify_request_params(request, params): 124 request_body = request.call_args[1]["body"].decode("utf-8") 125 request_params = urllib.parse.parse_qs(request_body) 126 127 for key, value in six.iteritems(params): 128 assert request_params[key][0] == value 129 130 131@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) 132@pytest.mark.asyncio 133async def test_jwt_grant(utcnow): 134 request = make_request( 135 {"access_token": "token", "expires_in": 500, "extra": "data"} 136 ) 137 138 token, expiry, extra_data = await _client.jwt_grant( 139 request, "http://example.com", "assertion_value" 140 ) 141 142 # Check request call 143 verify_request_params( 144 request, 145 {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, 146 ) 147 148 # Check result 149 assert token == "token" 150 assert expiry == utcnow() + datetime.timedelta(seconds=500) 151 assert extra_data["extra"] == "data" 152 153 154@pytest.mark.asyncio 155async def test_jwt_grant_no_access_token(): 156 request = make_request( 157 { 158 # No access token. 159 "expires_in": 500, 160 "extra": "data", 161 } 162 ) 163 164 with pytest.raises(exceptions.RefreshError): 165 await _client.jwt_grant(request, "http://example.com", "assertion_value") 166 167 168@pytest.mark.asyncio 169async def test_id_token_jwt_grant(): 170 now = _helpers.utcnow() 171 id_token_expiry = _helpers.datetime_to_secs(now) 172 id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8") 173 request = make_request({"id_token": id_token, "extra": "data"}) 174 175 token, expiry, extra_data = await _client.id_token_jwt_grant( 176 request, "http://example.com", "assertion_value" 177 ) 178 179 # Check request call 180 verify_request_params( 181 request, 182 {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, 183 ) 184 185 # Check result 186 assert token == id_token 187 # JWT does not store microseconds 188 now = now.replace(microsecond=0) 189 assert expiry == now 190 assert extra_data["extra"] == "data" 191 192 193@pytest.mark.asyncio 194async def test_id_token_jwt_grant_no_access_token(): 195 request = make_request( 196 { 197 # No access token. 198 "expires_in": 500, 199 "extra": "data", 200 } 201 ) 202 203 with pytest.raises(exceptions.RefreshError): 204 await _client.id_token_jwt_grant( 205 request, "http://example.com", "assertion_value" 206 ) 207 208 209@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) 210@pytest.mark.asyncio 211async def test_refresh_grant(unused_utcnow): 212 request = make_request( 213 { 214 "access_token": "token", 215 "refresh_token": "new_refresh_token", 216 "expires_in": 500, 217 "extra": "data", 218 } 219 ) 220 221 token, refresh_token, expiry, extra_data = await _client.refresh_grant( 222 request, 223 "http://example.com", 224 "refresh_token", 225 "client_id", 226 "client_secret", 227 rapt_token="rapt_token", 228 ) 229 230 # Check request call 231 verify_request_params( 232 request, 233 { 234 "grant_type": sync_client._REFRESH_GRANT_TYPE, 235 "refresh_token": "refresh_token", 236 "client_id": "client_id", 237 "client_secret": "client_secret", 238 "rapt": "rapt_token", 239 }, 240 ) 241 242 # Check result 243 assert token == "token" 244 assert refresh_token == "new_refresh_token" 245 assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) 246 assert extra_data["extra"] == "data" 247 248 249@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) 250@pytest.mark.asyncio 251async def test_refresh_grant_with_scopes(unused_utcnow): 252 request = make_request( 253 { 254 "access_token": "token", 255 "refresh_token": "new_refresh_token", 256 "expires_in": 500, 257 "extra": "data", 258 "scope": test_client.SCOPES_AS_STRING, 259 } 260 ) 261 262 token, refresh_token, expiry, extra_data = await _client.refresh_grant( 263 request, 264 "http://example.com", 265 "refresh_token", 266 "client_id", 267 "client_secret", 268 test_client.SCOPES_AS_LIST, 269 ) 270 271 # Check request call. 272 verify_request_params( 273 request, 274 { 275 "grant_type": sync_client._REFRESH_GRANT_TYPE, 276 "refresh_token": "refresh_token", 277 "client_id": "client_id", 278 "client_secret": "client_secret", 279 "scope": test_client.SCOPES_AS_STRING, 280 }, 281 ) 282 283 # Check result. 284 assert token == "token" 285 assert refresh_token == "new_refresh_token" 286 assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) 287 assert extra_data["extra"] == "data" 288 289 290@pytest.mark.asyncio 291async def test_refresh_grant_no_access_token(): 292 request = make_request( 293 { 294 # No access token. 295 "refresh_token": "new_refresh_token", 296 "expires_in": 500, 297 "extra": "data", 298 } 299 ) 300 301 with pytest.raises(exceptions.RefreshError): 302 await _client.refresh_grant( 303 request, "http://example.com", "refresh_token", "client_id", "client_secret" 304 ) 305