1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import copy
16
17import mock
18import pytest
19
20from google.auth import exceptions
21from google.oauth2 import _reauth_async
22from google.oauth2 import reauth
23
24
25MOCK_REQUEST = mock.AsyncMock(spec=["transport.Request"])
26CHALLENGES_RESPONSE_TEMPLATE = {
27    "status": "CHALLENGE_REQUIRED",
28    "sessionId": "123",
29    "challenges": [
30        {
31            "status": "READY",
32            "challengeId": 1,
33            "challengeType": "PASSWORD",
34            "securityKey": {},
35        }
36    ],
37}
38CHALLENGES_RESPONSE_AUTHENTICATED = {
39    "status": "AUTHENTICATED",
40    "sessionId": "123",
41    "encodedProofOfReauthToken": "new_rapt_token",
42}
43
44
45class MockChallenge(object):
46    def __init__(self, name, locally_eligible, challenge_input):
47        self.name = name
48        self.is_locally_eligible = locally_eligible
49        self.challenge_input = challenge_input
50
51    def obtain_challenge_input(self, metadata):
52        return self.challenge_input
53
54
55@pytest.mark.asyncio
56async def test__get_challenges():
57    with mock.patch(
58        "google.oauth2._client_async._token_endpoint_request"
59    ) as mock_token_endpoint_request:
60        await _reauth_async._get_challenges(MOCK_REQUEST, ["SAML"], "token")
61        mock_token_endpoint_request.assert_called_with(
62            MOCK_REQUEST,
63            reauth._REAUTH_API + ":start",
64            {"supportedChallengeTypes": ["SAML"]},
65            access_token="token",
66            use_json=True,
67        )
68
69
70@pytest.mark.asyncio
71async def test__get_challenges_with_scopes():
72    with mock.patch(
73        "google.oauth2._client_async._token_endpoint_request"
74    ) as mock_token_endpoint_request:
75        await _reauth_async._get_challenges(
76            MOCK_REQUEST, ["SAML"], "token", requested_scopes=["scope"]
77        )
78        mock_token_endpoint_request.assert_called_with(
79            MOCK_REQUEST,
80            reauth._REAUTH_API + ":start",
81            {
82                "supportedChallengeTypes": ["SAML"],
83                "oauthScopesForDomainPolicyLookup": ["scope"],
84            },
85            access_token="token",
86            use_json=True,
87        )
88
89
90@pytest.mark.asyncio
91async def test__send_challenge_result():
92    with mock.patch(
93        "google.oauth2._client_async._token_endpoint_request"
94    ) as mock_token_endpoint_request:
95        await _reauth_async._send_challenge_result(
96            MOCK_REQUEST, "123", "1", {"credential": "password"}, "token"
97        )
98        mock_token_endpoint_request.assert_called_with(
99            MOCK_REQUEST,
100            reauth._REAUTH_API + "/123:continue",
101            {
102                "sessionId": "123",
103                "challengeId": "1",
104                "action": "RESPOND",
105                "proposalResponse": {"credential": "password"},
106            },
107            access_token="token",
108            use_json=True,
109        )
110
111
112@pytest.mark.asyncio
113async def test__run_next_challenge_not_ready():
114    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
115    challenges_response["challenges"][0]["status"] = "STATUS_UNSPECIFIED"
116    assert (
117        await _reauth_async._run_next_challenge(
118            challenges_response, MOCK_REQUEST, "token"
119        )
120        is None
121    )
122
123
124@pytest.mark.asyncio
125async def test__run_next_challenge_not_supported():
126    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
127    challenges_response["challenges"][0]["challengeType"] = "CHALLENGE_TYPE_UNSPECIFIED"
128    with pytest.raises(exceptions.ReauthFailError) as excinfo:
129        await _reauth_async._run_next_challenge(
130            challenges_response, MOCK_REQUEST, "token"
131        )
132    assert excinfo.match(r"Unsupported challenge type CHALLENGE_TYPE_UNSPECIFIED")
133
134
135@pytest.mark.asyncio
136async def test__run_next_challenge_not_locally_eligible():
137    mock_challenge = MockChallenge("PASSWORD", False, "challenge_input")
138    with mock.patch(
139        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
140    ):
141        with pytest.raises(exceptions.ReauthFailError) as excinfo:
142            await _reauth_async._run_next_challenge(
143                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
144            )
145        assert excinfo.match(r"Challenge PASSWORD is not locally eligible")
146
147
148@pytest.mark.asyncio
149async def test__run_next_challenge_no_challenge_input():
150    mock_challenge = MockChallenge("PASSWORD", True, None)
151    with mock.patch(
152        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
153    ):
154        assert (
155            await _reauth_async._run_next_challenge(
156                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
157            )
158            is None
159        )
160
161
162@pytest.mark.asyncio
163async def test__run_next_challenge_success():
164    mock_challenge = MockChallenge("PASSWORD", True, {"credential": "password"})
165    with mock.patch(
166        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
167    ):
168        with mock.patch(
169            "google.oauth2._reauth_async._send_challenge_result"
170        ) as mock_send_challenge_result:
171            await _reauth_async._run_next_challenge(
172                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
173            )
174            mock_send_challenge_result.assert_called_with(
175                MOCK_REQUEST, "123", 1, {"credential": "password"}, "token"
176            )
177
178
179@pytest.mark.asyncio
180async def test__obtain_rapt_authenticated():
181    with mock.patch(
182        "google.oauth2._reauth_async._get_challenges",
183        return_value=CHALLENGES_RESPONSE_AUTHENTICATED,
184    ):
185        new_rapt_token = await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
186        assert new_rapt_token == "new_rapt_token"
187
188
189@pytest.mark.asyncio
190async def test__obtain_rapt_authenticated_after_run_next_challenge():
191    with mock.patch(
192        "google.oauth2._reauth_async._get_challenges",
193        return_value=CHALLENGES_RESPONSE_TEMPLATE,
194    ):
195        with mock.patch(
196            "google.oauth2._reauth_async._run_next_challenge",
197            side_effect=[
198                CHALLENGES_RESPONSE_TEMPLATE,
199                CHALLENGES_RESPONSE_AUTHENTICATED,
200            ],
201        ):
202            with mock.patch("google.oauth2.reauth.is_interactive", return_value=True):
203                assert (
204                    await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
205                    == "new_rapt_token"
206                )
207
208
209@pytest.mark.asyncio
210async def test__obtain_rapt_unsupported_status():
211    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
212    challenges_response["status"] = "STATUS_UNSPECIFIED"
213    with mock.patch(
214        "google.oauth2._reauth_async._get_challenges", return_value=challenges_response
215    ):
216        with pytest.raises(exceptions.ReauthFailError) as excinfo:
217            await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
218        assert excinfo.match(r"API error: STATUS_UNSPECIFIED")
219
220
221@pytest.mark.asyncio
222async def test__obtain_rapt_not_interactive():
223    with mock.patch(
224        "google.oauth2._reauth_async._get_challenges",
225        return_value=CHALLENGES_RESPONSE_TEMPLATE,
226    ):
227        with mock.patch("google.oauth2.reauth.is_interactive", return_value=False):
228            with pytest.raises(exceptions.ReauthFailError) as excinfo:
229                await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
230            assert excinfo.match(r"not in an interactive session")
231
232
233@pytest.mark.asyncio
234async def test__obtain_rapt_not_authenticated():
235    with mock.patch(
236        "google.oauth2._reauth_async._get_challenges",
237        return_value=CHALLENGES_RESPONSE_TEMPLATE,
238    ):
239        with mock.patch("google.oauth2.reauth.RUN_CHALLENGE_RETRY_LIMIT", 0):
240            with pytest.raises(exceptions.ReauthFailError) as excinfo:
241                await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
242            assert excinfo.match(r"Reauthentication failed")
243
244
245@pytest.mark.asyncio
246async def test_get_rapt_token():
247    with mock.patch(
248        "google.oauth2._client_async.refresh_grant",
249        return_value=("token", None, None, None),
250    ) as mock_refresh_grant:
251        with mock.patch(
252            "google.oauth2._reauth_async._obtain_rapt", return_value="new_rapt_token"
253        ) as mock_obtain_rapt:
254            assert (
255                await _reauth_async.get_rapt_token(
256                    MOCK_REQUEST,
257                    "client_id",
258                    "client_secret",
259                    "refresh_token",
260                    "token_uri",
261                )
262                == "new_rapt_token"
263            )
264            mock_refresh_grant.assert_called_with(
265                request=MOCK_REQUEST,
266                client_id="client_id",
267                client_secret="client_secret",
268                refresh_token="refresh_token",
269                token_uri="token_uri",
270                scopes=[reauth._REAUTH_SCOPE],
271            )
272            mock_obtain_rapt.assert_called_with(
273                MOCK_REQUEST, "token", requested_scopes=None
274            )
275
276
277@pytest.mark.asyncio
278async def test_refresh_grant_failed():
279    with mock.patch(
280        "google.oauth2._client_async._token_endpoint_request_no_throw"
281    ) as mock_token_request:
282        mock_token_request.return_value = (False, {"error": "Bad request"})
283        with pytest.raises(exceptions.RefreshError) as excinfo:
284            await _reauth_async.refresh_grant(
285                MOCK_REQUEST,
286                "token_uri",
287                "refresh_token",
288                "client_id",
289                "client_secret",
290                scopes=["foo", "bar"],
291                rapt_token="rapt_token",
292            )
293        assert excinfo.match(r"Bad request")
294        mock_token_request.assert_called_with(
295            MOCK_REQUEST,
296            "token_uri",
297            {
298                "grant_type": "refresh_token",
299                "client_id": "client_id",
300                "client_secret": "client_secret",
301                "refresh_token": "refresh_token",
302                "scope": "foo bar",
303                "rapt": "rapt_token",
304            },
305        )
306
307
308@pytest.mark.asyncio
309async def test_refresh_grant_success():
310    with mock.patch(
311        "google.oauth2._client_async._token_endpoint_request_no_throw"
312    ) as mock_token_request:
313        mock_token_request.side_effect = [
314            (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
315            (True, {"access_token": "access_token"}),
316        ]
317        with mock.patch(
318            "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token"
319        ):
320            assert await _reauth_async.refresh_grant(
321                MOCK_REQUEST,
322                "token_uri",
323                "refresh_token",
324                "client_id",
325                "client_secret",
326                enable_reauth_refresh=True,
327            ) == (
328                "access_token",
329                "refresh_token",
330                None,
331                {"access_token": "access_token"},
332                "new_rapt_token",
333            )
334
335
336@pytest.mark.asyncio
337async def test_refresh_grant_reauth_refresh_disabled():
338    with mock.patch(
339        "google.oauth2._client_async._token_endpoint_request_no_throw"
340    ) as mock_token_request:
341        mock_token_request.side_effect = [
342            (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
343            (True, {"access_token": "access_token"}),
344        ]
345        with pytest.raises(exceptions.RefreshError) as excinfo:
346            assert await _reauth_async.refresh_grant(
347                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
348            )
349        assert excinfo.match(r"Reauthentication is needed")
350