1# Copyright 2020 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import datetime 16import json 17 18import mock 19import pytest 20 21from google.auth import _jwt_async as jwt_async 22from google.auth import crypt 23from google.auth import exceptions 24from tests import test_jwt 25 26 27@pytest.fixture 28def signer(): 29 return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1") 30 31 32class TestCredentials(object): 33 SERVICE_ACCOUNT_EMAIL = "[email protected]" 34 SUBJECT = "subject" 35 AUDIENCE = "audience" 36 ADDITIONAL_CLAIMS = {"meta": "data"} 37 credentials = None 38 39 @pytest.fixture(autouse=True) 40 def credentials_fixture(self, signer): 41 self.credentials = jwt_async.Credentials( 42 signer, 43 self.SERVICE_ACCOUNT_EMAIL, 44 self.SERVICE_ACCOUNT_EMAIL, 45 self.AUDIENCE, 46 ) 47 48 def test_from_service_account_info(self): 49 with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 50 info = json.load(fh) 51 52 credentials = jwt_async.Credentials.from_service_account_info( 53 info, audience=self.AUDIENCE 54 ) 55 56 assert credentials._signer.key_id == info["private_key_id"] 57 assert credentials._issuer == info["client_email"] 58 assert credentials._subject == info["client_email"] 59 assert credentials._audience == self.AUDIENCE 60 61 def test_from_service_account_info_args(self): 62 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 63 64 credentials = jwt_async.Credentials.from_service_account_info( 65 info, 66 subject=self.SUBJECT, 67 audience=self.AUDIENCE, 68 additional_claims=self.ADDITIONAL_CLAIMS, 69 ) 70 71 assert credentials._signer.key_id == info["private_key_id"] 72 assert credentials._issuer == info["client_email"] 73 assert credentials._subject == self.SUBJECT 74 assert credentials._audience == self.AUDIENCE 75 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 76 77 def test_from_service_account_file(self): 78 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 79 80 credentials = jwt_async.Credentials.from_service_account_file( 81 test_jwt.SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE 82 ) 83 84 assert credentials._signer.key_id == info["private_key_id"] 85 assert credentials._issuer == info["client_email"] 86 assert credentials._subject == info["client_email"] 87 assert credentials._audience == self.AUDIENCE 88 89 def test_from_service_account_file_args(self): 90 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 91 92 credentials = jwt_async.Credentials.from_service_account_file( 93 test_jwt.SERVICE_ACCOUNT_JSON_FILE, 94 subject=self.SUBJECT, 95 audience=self.AUDIENCE, 96 additional_claims=self.ADDITIONAL_CLAIMS, 97 ) 98 99 assert credentials._signer.key_id == info["private_key_id"] 100 assert credentials._issuer == info["client_email"] 101 assert credentials._subject == self.SUBJECT 102 assert credentials._audience == self.AUDIENCE 103 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 104 105 def test_from_signing_credentials(self): 106 jwt_from_signing = self.credentials.from_signing_credentials( 107 self.credentials, audience=mock.sentinel.new_audience 108 ) 109 jwt_from_info = jwt_async.Credentials.from_service_account_info( 110 test_jwt.SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience 111 ) 112 113 assert isinstance(jwt_from_signing, jwt_async.Credentials) 114 assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id 115 assert jwt_from_signing._issuer == jwt_from_info._issuer 116 assert jwt_from_signing._subject == jwt_from_info._subject 117 assert jwt_from_signing._audience == jwt_from_info._audience 118 119 def test_default_state(self): 120 assert not self.credentials.valid 121 # Expiration hasn't been set yet 122 assert not self.credentials.expired 123 124 def test_with_claims(self): 125 new_audience = "new_audience" 126 new_credentials = self.credentials.with_claims(audience=new_audience) 127 128 assert new_credentials._signer == self.credentials._signer 129 assert new_credentials._issuer == self.credentials._issuer 130 assert new_credentials._subject == self.credentials._subject 131 assert new_credentials._audience == new_audience 132 assert new_credentials._additional_claims == self.credentials._additional_claims 133 assert new_credentials._quota_project_id == self.credentials._quota_project_id 134 135 def test_with_quota_project(self): 136 quota_project_id = "project-foo" 137 138 new_credentials = self.credentials.with_quota_project(quota_project_id) 139 assert new_credentials._signer == self.credentials._signer 140 assert new_credentials._issuer == self.credentials._issuer 141 assert new_credentials._subject == self.credentials._subject 142 assert new_credentials._audience == self.credentials._audience 143 assert new_credentials._additional_claims == self.credentials._additional_claims 144 assert new_credentials._quota_project_id == quota_project_id 145 146 def test_sign_bytes(self): 147 to_sign = b"123" 148 signature = self.credentials.sign_bytes(to_sign) 149 assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) 150 151 def test_signer(self): 152 assert isinstance(self.credentials.signer, crypt.RSASigner) 153 154 def test_signer_email(self): 155 assert ( 156 self.credentials.signer_email 157 == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] 158 ) 159 160 def _verify_token(self, token): 161 payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) 162 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 163 return payload 164 165 def test_refresh(self): 166 self.credentials.refresh(None) 167 assert self.credentials.valid 168 assert not self.credentials.expired 169 170 def test_expired(self): 171 assert not self.credentials.expired 172 173 self.credentials.refresh(None) 174 assert not self.credentials.expired 175 176 with mock.patch("google.auth._helpers.utcnow") as now: 177 one_day = datetime.timedelta(days=1) 178 now.return_value = self.credentials.expiry + one_day 179 assert self.credentials.expired 180 181 @pytest.mark.asyncio 182 async def test_before_request(self): 183 headers = {} 184 185 self.credentials.refresh(None) 186 await self.credentials.before_request( 187 None, "GET", "http://example.com?a=1#3", headers 188 ) 189 190 header_value = headers["authorization"] 191 _, token = header_value.split(" ") 192 193 # Since the audience is set, it should use the existing token. 194 assert token.encode("utf-8") == self.credentials.token 195 196 payload = self._verify_token(token) 197 assert payload["aud"] == self.AUDIENCE 198 199 @pytest.mark.asyncio 200 async def test_before_request_refreshes(self): 201 assert not self.credentials.valid 202 await self.credentials.before_request( 203 None, "GET", "http://example.com?a=1#3", {} 204 ) 205 assert self.credentials.valid 206 207 208class TestOnDemandCredentials(object): 209 SERVICE_ACCOUNT_EMAIL = "[email protected]" 210 SUBJECT = "subject" 211 ADDITIONAL_CLAIMS = {"meta": "data"} 212 credentials = None 213 214 @pytest.fixture(autouse=True) 215 def credentials_fixture(self, signer): 216 self.credentials = jwt_async.OnDemandCredentials( 217 signer, 218 self.SERVICE_ACCOUNT_EMAIL, 219 self.SERVICE_ACCOUNT_EMAIL, 220 max_cache_size=2, 221 ) 222 223 def test_from_service_account_info(self): 224 with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 225 info = json.load(fh) 226 227 credentials = jwt_async.OnDemandCredentials.from_service_account_info(info) 228 229 assert credentials._signer.key_id == info["private_key_id"] 230 assert credentials._issuer == info["client_email"] 231 assert credentials._subject == info["client_email"] 232 233 def test_from_service_account_info_args(self): 234 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 235 236 credentials = jwt_async.OnDemandCredentials.from_service_account_info( 237 info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS 238 ) 239 240 assert credentials._signer.key_id == info["private_key_id"] 241 assert credentials._issuer == info["client_email"] 242 assert credentials._subject == self.SUBJECT 243 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 244 245 def test_from_service_account_file(self): 246 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 247 248 credentials = jwt_async.OnDemandCredentials.from_service_account_file( 249 test_jwt.SERVICE_ACCOUNT_JSON_FILE 250 ) 251 252 assert credentials._signer.key_id == info["private_key_id"] 253 assert credentials._issuer == info["client_email"] 254 assert credentials._subject == info["client_email"] 255 256 def test_from_service_account_file_args(self): 257 info = test_jwt.SERVICE_ACCOUNT_INFO.copy() 258 259 credentials = jwt_async.OnDemandCredentials.from_service_account_file( 260 test_jwt.SERVICE_ACCOUNT_JSON_FILE, 261 subject=self.SUBJECT, 262 additional_claims=self.ADDITIONAL_CLAIMS, 263 ) 264 265 assert credentials._signer.key_id == info["private_key_id"] 266 assert credentials._issuer == info["client_email"] 267 assert credentials._subject == self.SUBJECT 268 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 269 270 def test_from_signing_credentials(self): 271 jwt_from_signing = self.credentials.from_signing_credentials(self.credentials) 272 jwt_from_info = jwt_async.OnDemandCredentials.from_service_account_info( 273 test_jwt.SERVICE_ACCOUNT_INFO 274 ) 275 276 assert isinstance(jwt_from_signing, jwt_async.OnDemandCredentials) 277 assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id 278 assert jwt_from_signing._issuer == jwt_from_info._issuer 279 assert jwt_from_signing._subject == jwt_from_info._subject 280 281 def test_default_state(self): 282 # Credentials are *always* valid. 283 assert self.credentials.valid 284 # Credentials *never* expire. 285 assert not self.credentials.expired 286 287 def test_with_claims(self): 288 new_claims = {"meep": "moop"} 289 new_credentials = self.credentials.with_claims(additional_claims=new_claims) 290 291 assert new_credentials._signer == self.credentials._signer 292 assert new_credentials._issuer == self.credentials._issuer 293 assert new_credentials._subject == self.credentials._subject 294 assert new_credentials._additional_claims == new_claims 295 296 def test_with_quota_project(self): 297 quota_project_id = "project-foo" 298 new_credentials = self.credentials.with_quota_project(quota_project_id) 299 300 assert new_credentials._signer == self.credentials._signer 301 assert new_credentials._issuer == self.credentials._issuer 302 assert new_credentials._subject == self.credentials._subject 303 assert new_credentials._additional_claims == self.credentials._additional_claims 304 assert new_credentials._quota_project_id == quota_project_id 305 306 def test_sign_bytes(self): 307 to_sign = b"123" 308 signature = self.credentials.sign_bytes(to_sign) 309 assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) 310 311 def test_signer(self): 312 assert isinstance(self.credentials.signer, crypt.RSASigner) 313 314 def test_signer_email(self): 315 assert ( 316 self.credentials.signer_email 317 == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] 318 ) 319 320 def _verify_token(self, token): 321 payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) 322 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 323 return payload 324 325 def test_refresh(self): 326 with pytest.raises(exceptions.RefreshError): 327 self.credentials.refresh(None) 328 329 def test_before_request(self): 330 headers = {} 331 332 self.credentials.before_request( 333 None, "GET", "http://example.com?a=1#3", headers 334 ) 335 336 _, token = headers["authorization"].split(" ") 337 payload = self._verify_token(token) 338 339 assert payload["aud"] == "http://example.com" 340 341 # Making another request should re-use the same token. 342 self.credentials.before_request(None, "GET", "http://example.com?b=2", headers) 343 344 _, new_token = headers["authorization"].split(" ") 345 346 assert new_token == token 347 348 def test_expired_token(self): 349 self.credentials._cache["audience"] = ( 350 mock.sentinel.token, 351 datetime.datetime.min, 352 ) 353 354 token = self.credentials._get_jwt_for_audience("audience") 355 356 assert token != mock.sentinel.token 357