1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17
18import mock
19import pytest
20import six
21from six.moves import http_client
22from six.moves import urllib
23
24from google.auth import _helpers
25from google.auth import _jwt_async as jwt
26from google.auth import exceptions
27from google.oauth2 import _client as sync_client
28from google.oauth2 import _client_async as _client
29from tests.oauth2 import test__client as test_client
30
31
32def make_request(response_data, status=http_client.OK):
33    response = mock.AsyncMock(spec=["transport.Response"])
34    response.status = status
35    data = json.dumps(response_data).encode("utf-8")
36    response.data = mock.AsyncMock(spec=["__call__", "read"])
37    response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data)
38    response.content = mock.AsyncMock(spec=["__call__"], return_value=data)
39    request = mock.AsyncMock(spec=["transport.Request"])
40    request.return_value = response
41    return request
42
43
44@pytest.mark.asyncio
45async def test__token_endpoint_request():
46
47    request = make_request({"test": "response"})
48
49    result = await _client._token_endpoint_request(
50        request, "http://example.com", {"test": "params"}
51    )
52
53    # Check request call
54    request.assert_called_with(
55        method="POST",
56        url="http://example.com",
57        headers={"Content-Type": "application/x-www-form-urlencoded"},
58        body="test=params".encode("utf-8"),
59    )
60
61    # Check result
62    assert result == {"test": "response"}
63
64
65@pytest.mark.asyncio
66async def test__token_endpoint_request_json():
67
68    request = make_request({"test": "response"})
69    access_token = "access_token"
70
71    result = await _client._token_endpoint_request(
72        request,
73        "http://example.com",
74        {"test": "params"},
75        access_token=access_token,
76        use_json=True,
77    )
78
79    # Check request call
80    request.assert_called_with(
81        method="POST",
82        url="http://example.com",
83        headers={
84            "Content-Type": "application/json",
85            "Authorization": "Bearer access_token",
86        },
87        body=b'{"test": "params"}',
88    )
89
90    # Check result
91    assert result == {"test": "response"}
92
93
94@pytest.mark.asyncio
95async def test__token_endpoint_request_error():
96    request = make_request({}, status=http_client.BAD_REQUEST)
97
98    with pytest.raises(exceptions.RefreshError):
99        await _client._token_endpoint_request(request, "http://example.com", {})
100
101
102@pytest.mark.asyncio
103async def test__token_endpoint_request_internal_failure_error():
104    request = make_request(
105        {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST
106    )
107
108    with pytest.raises(exceptions.RefreshError):
109        await _client._token_endpoint_request(
110            request, "http://example.com", {"error_description": "internal_failure"}
111        )
112
113    request = make_request(
114        {"error": "internal_failure"}, status=http_client.BAD_REQUEST
115    )
116
117    with pytest.raises(exceptions.RefreshError):
118        await _client._token_endpoint_request(
119            request, "http://example.com", {"error": "internal_failure"}
120        )
121
122
123def verify_request_params(request, params):
124    request_body = request.call_args[1]["body"].decode("utf-8")
125    request_params = urllib.parse.parse_qs(request_body)
126
127    for key, value in six.iteritems(params):
128        assert request_params[key][0] == value
129
130
131@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
132@pytest.mark.asyncio
133async def test_jwt_grant(utcnow):
134    request = make_request(
135        {"access_token": "token", "expires_in": 500, "extra": "data"}
136    )
137
138    token, expiry, extra_data = await _client.jwt_grant(
139        request, "http://example.com", "assertion_value"
140    )
141
142    # Check request call
143    verify_request_params(
144        request,
145        {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
146    )
147
148    # Check result
149    assert token == "token"
150    assert expiry == utcnow() + datetime.timedelta(seconds=500)
151    assert extra_data["extra"] == "data"
152
153
154@pytest.mark.asyncio
155async def test_jwt_grant_no_access_token():
156    request = make_request(
157        {
158            # No access token.
159            "expires_in": 500,
160            "extra": "data",
161        }
162    )
163
164    with pytest.raises(exceptions.RefreshError):
165        await _client.jwt_grant(request, "http://example.com", "assertion_value")
166
167
168@pytest.mark.asyncio
169async def test_id_token_jwt_grant():
170    now = _helpers.utcnow()
171    id_token_expiry = _helpers.datetime_to_secs(now)
172    id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8")
173    request = make_request({"id_token": id_token, "extra": "data"})
174
175    token, expiry, extra_data = await _client.id_token_jwt_grant(
176        request, "http://example.com", "assertion_value"
177    )
178
179    # Check request call
180    verify_request_params(
181        request,
182        {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
183    )
184
185    # Check result
186    assert token == id_token
187    # JWT does not store microseconds
188    now = now.replace(microsecond=0)
189    assert expiry == now
190    assert extra_data["extra"] == "data"
191
192
193@pytest.mark.asyncio
194async def test_id_token_jwt_grant_no_access_token():
195    request = make_request(
196        {
197            # No access token.
198            "expires_in": 500,
199            "extra": "data",
200        }
201    )
202
203    with pytest.raises(exceptions.RefreshError):
204        await _client.id_token_jwt_grant(
205            request, "http://example.com", "assertion_value"
206        )
207
208
209@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
210@pytest.mark.asyncio
211async def test_refresh_grant(unused_utcnow):
212    request = make_request(
213        {
214            "access_token": "token",
215            "refresh_token": "new_refresh_token",
216            "expires_in": 500,
217            "extra": "data",
218        }
219    )
220
221    token, refresh_token, expiry, extra_data = await _client.refresh_grant(
222        request,
223        "http://example.com",
224        "refresh_token",
225        "client_id",
226        "client_secret",
227        rapt_token="rapt_token",
228    )
229
230    # Check request call
231    verify_request_params(
232        request,
233        {
234            "grant_type": sync_client._REFRESH_GRANT_TYPE,
235            "refresh_token": "refresh_token",
236            "client_id": "client_id",
237            "client_secret": "client_secret",
238            "rapt": "rapt_token",
239        },
240    )
241
242    # Check result
243    assert token == "token"
244    assert refresh_token == "new_refresh_token"
245    assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
246    assert extra_data["extra"] == "data"
247
248
249@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
250@pytest.mark.asyncio
251async def test_refresh_grant_with_scopes(unused_utcnow):
252    request = make_request(
253        {
254            "access_token": "token",
255            "refresh_token": "new_refresh_token",
256            "expires_in": 500,
257            "extra": "data",
258            "scope": test_client.SCOPES_AS_STRING,
259        }
260    )
261
262    token, refresh_token, expiry, extra_data = await _client.refresh_grant(
263        request,
264        "http://example.com",
265        "refresh_token",
266        "client_id",
267        "client_secret",
268        test_client.SCOPES_AS_LIST,
269    )
270
271    # Check request call.
272    verify_request_params(
273        request,
274        {
275            "grant_type": sync_client._REFRESH_GRANT_TYPE,
276            "refresh_token": "refresh_token",
277            "client_id": "client_id",
278            "client_secret": "client_secret",
279            "scope": test_client.SCOPES_AS_STRING,
280        },
281    )
282
283    # Check result.
284    assert token == "token"
285    assert refresh_token == "new_refresh_token"
286    assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
287    assert extra_data["extra"] == "data"
288
289
290@pytest.mark.asyncio
291async def test_refresh_grant_no_access_token():
292    request = make_request(
293        {
294            # No access token.
295            "refresh_token": "new_refresh_token",
296            "expires_in": 500,
297            "extra": "data",
298        }
299    )
300
301    with pytest.raises(exceptions.RefreshError):
302        await _client.refresh_grant(
303            request, "http://example.com", "refresh_token", "client_id", "client_secret"
304        )
305