1 # Copyright 2014 Google Inc. All rights reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 
15 """Tests for oauth2client.contrib.xsrfutil."""
16 
17 import base64
18 
19 import mock
20 import unittest2
21 
22 from oauth2client import _helpers
23 from oauth2client.contrib import xsrfutil
24 
25 # Jan 17 2008, 5:40PM
26 TEST_KEY = b'test key'
27 # Jan. 17, 2008 22:40:32.081230 UTC
28 TEST_TIME = 1200609642081230
29 TEST_USER_ID_1 = 123832983
30 TEST_USER_ID_2 = 938297432
31 TEST_ACTION_ID_1 = b'some_action'
32 TEST_ACTION_ID_2 = b'some_other_action'
33 TEST_EXTRA_INFO_1 = b'extra_info_1'
34 TEST_EXTRA_INFO_2 = b'more_extra_info'
35 
36 
37 __author__ = 'jcgregorio@google.com (Joe Gregorio)'
38 
39 
40 class Test_generate_token(unittest2.TestCase):
41 
42     def test_bad_positional(self):
43         # Need 2 positional arguments.
44         with self.assertRaises(TypeError):
45             xsrfutil.generate_token(None)
46         # At most 2 positional arguments.
47         with self.assertRaises(TypeError):
48             xsrfutil.generate_token(None, None, None)
49 
50     def test_it(self):
51         digest = b'foobar'
52         digester = mock.MagicMock()
53         digester.digest = mock.MagicMock(name='digest', return_value=digest)
54         with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac:
55             hmac.new = mock.MagicMock(name='new', return_value=digester)
56             token = xsrfutil.generate_token(TEST_KEY,
57                                             TEST_USER_ID_1,
58                                             action_id=TEST_ACTION_ID_1,
59                                             when=TEST_TIME)
60             hmac.new.assert_called_once_with(TEST_KEY)
61             digester.digest.assert_called_once_with()
62 
63             expected_digest_calls = [
64                 mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))),
65                 mock.call.update(xsrfutil.DELIMITER),
66                 mock.call.update(TEST_ACTION_ID_1),
67                 mock.call.update(xsrfutil.DELIMITER),
68                 mock.call.update(_helpers._to_bytes(str(TEST_TIME))),
69             ]
70             self.assertEqual(digester.method_calls, expected_digest_calls)
71 
72             expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
73                                        _helpers._to_bytes(str(TEST_TIME)))
74             expected_token = base64.urlsafe_b64encode(
75                 expected_token_as_bytes)
76             self.assertEqual(token, expected_token)
77 
78     def test_with_system_time(self):
79         digest = b'foobar'
80         curr_time = 1440449755.74
81         digester = mock.MagicMock()
82         digester.digest = mock.MagicMock(name='digest', return_value=digest)
83         with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac:
84             hmac.new = mock.MagicMock(name='new', return_value=digester)
85 
86             with mock.patch('oauth2client.contrib.xsrfutil.time') as time:
87                 time.time = mock.MagicMock(name='time', return_value=curr_time)
88                 # when= is omitted
89                 token = xsrfutil.generate_token(TEST_KEY,
90                                                 TEST_USER_ID_1,
91                                                 action_id=TEST_ACTION_ID_1)
92 
93                 hmac.new.assert_called_once_with(TEST_KEY)
94                 time.time.assert_called_once_with()
95                 digester.digest.assert_called_once_with()
96 
97                 expected_digest_calls = [
98                     mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))),
99                     mock.call.update(xsrfutil.DELIMITER),
100                     mock.call.update(TEST_ACTION_ID_1),
101                     mock.call.update(xsrfutil.DELIMITER),
102                     mock.call.update(_helpers._to_bytes(str(int(curr_time)))),
103                 ]
104                 self.assertEqual(digester.method_calls, expected_digest_calls)
105 
106                 expected_token_as_bytes = (
107                     digest + xsrfutil.DELIMITER +
108                     _helpers._to_bytes(str(int(curr_time))))
109                 expected_token = base64.urlsafe_b64encode(
110                     expected_token_as_bytes)
111                 self.assertEqual(token, expected_token)
112 
113 
114 class Test_validate_token(unittest2.TestCase):
115 
116     def test_bad_positional(self):
117         # Need 3 positional arguments.
118         with self.assertRaises(TypeError):
119             xsrfutil.validate_token(None, None)
120         # At most 3 positional arguments.
121         with self.assertRaises(TypeError):
122             xsrfutil.validate_token(None, None, None, None)
123 
124     def test_no_token(self):
125         key = token = user_id = None
126         self.assertFalse(xsrfutil.validate_token(key, token, user_id))
127 
128     def test_token_not_valid_base64(self):
129         key = user_id = None
130         token = b'a'  # Bad padding
131         self.assertFalse(xsrfutil.validate_token(key, token, user_id))
132 
133     def test_token_non_integer(self):
134         key = user_id = None
135         token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz')
136         self.assertFalse(xsrfutil.validate_token(key, token, user_id))
137 
138     def test_token_too_old_implicit_current_time(self):
139         token_time = 123456789
140         curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
141 
142         key = user_id = None
143         token = base64.b64encode(_helpers._to_bytes(str(token_time)))
144         with mock.patch('oauth2client.contrib.xsrfutil.time') as time:
145             time.time = mock.MagicMock(name='time', return_value=curr_time)
146             self.assertFalse(xsrfutil.validate_token(key, token, user_id))
147             time.time.assert_called_once_with()
148 
149     def test_token_too_old_explicit_current_time(self):
150         token_time = 123456789
151         curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
152 
153         key = user_id = None
154         token = base64.b64encode(_helpers._to_bytes(str(token_time)))
155         self.assertFalse(xsrfutil.validate_token(key, token, user_id,
156                                                  current_time=curr_time))
157 
158     def test_token_length_differs_from_generated(self):
159         token_time = 123456789
160         # Make sure it isn't too old.
161         curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
162 
163         key = object()
164         user_id = object()
165         action_id = object()
166         token = base64.b64encode(_helpers._to_bytes(str(token_time)))
167         generated_token = b'a'
168         # Make sure the token length comparison will fail.
169         self.assertNotEqual(len(token), len(generated_token))
170 
171         with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
172                         return_value=generated_token) as gen_tok:
173             self.assertFalse(xsrfutil.validate_token(key, token, user_id,
174                                                      current_time=curr_time,
175                                                      action_id=action_id))
176             gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
177                                             when=token_time)
178 
179     def test_token_differs_from_generated_but_same_length(self):
180         token_time = 123456789
181         # Make sure it isn't too old.
182         curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
183 
184         key = object()
185         user_id = object()
186         action_id = object()
187         token = base64.b64encode(_helpers._to_bytes(str(token_time)))
188         # It is encoded as b'MTIzNDU2Nzg5', which has length 12.
189         generated_token = b'M' * 12
190         # Make sure the token length comparison will succeed, but the token
191         # comparison will fail.
192         self.assertEqual(len(token), len(generated_token))
193         self.assertNotEqual(token, generated_token)
194 
195         with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
196                         return_value=generated_token) as gen_tok:
197             self.assertFalse(xsrfutil.validate_token(key, token, user_id,
198                                                      current_time=curr_time,
199                                                      action_id=action_id))
200             gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
201                                             when=token_time)
202 
203     def test_success(self):
204         token_time = 123456789
205         # Make sure it isn't too old.
206         curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
207 
208         key = object()
209         user_id = object()
210         action_id = object()
211         token = base64.b64encode(_helpers._to_bytes(str(token_time)))
212         with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
213                         return_value=token) as gen_tok:
214             self.assertTrue(xsrfutil.validate_token(key, token, user_id,
215                                                     current_time=curr_time,
216                                                     action_id=action_id))
217             gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
218                                             when=token_time)
219 
220 
221 class XsrfUtilTests(unittest2.TestCase):
222     """Test xsrfutil functions."""
223 
224     def testGenerateAndValidateToken(self):
225         """Test generating and validating a token."""
226         token = xsrfutil.generate_token(TEST_KEY,
227                                         TEST_USER_ID_1,
228                                         action_id=TEST_ACTION_ID_1,
229                                         when=TEST_TIME)
230 
231         # Check that the token is considered valid when it should be.
232         self.assertTrue(xsrfutil.validate_token(TEST_KEY,
233                                                 token,
234                                                 TEST_USER_ID_1,
235                                                 action_id=TEST_ACTION_ID_1,
236                                                 current_time=TEST_TIME))
237 
238         # Should still be valid 15 minutes later.
239         later15mins = TEST_TIME + 15 * 60
240         self.assertTrue(xsrfutil.validate_token(TEST_KEY,
241                                                 token,
242                                                 TEST_USER_ID_1,
243                                                 action_id=TEST_ACTION_ID_1,
244                                                 current_time=later15mins))
245 
246         # But not if beyond the timeout.
247         later2hours = TEST_TIME + 2 * 60 * 60
248         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
249                                                  token,
250                                                  TEST_USER_ID_1,
251                                                  action_id=TEST_ACTION_ID_1,
252                                                  current_time=later2hours))
253 
254         # Or if the key is different.
255         self.assertFalse(xsrfutil.validate_token('another key',
256                                                  token,
257                                                  TEST_USER_ID_1,
258                                                  action_id=TEST_ACTION_ID_1,
259                                                  current_time=later15mins))
260 
261         # Or the user ID....
262         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
263                                                  token,
264                                                  TEST_USER_ID_2,
265                                                  action_id=TEST_ACTION_ID_1,
266                                                  current_time=later15mins))
267 
268         # Or the action ID...
269         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
270                                                  token,
271                                                  TEST_USER_ID_1,
272                                                  action_id=TEST_ACTION_ID_2,
273                                                  current_time=later15mins))
274 
275         # Invalid when truncated
276         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
277                                                  token[:-1],
278                                                  TEST_USER_ID_1,
279                                                  action_id=TEST_ACTION_ID_1,
280                                                  current_time=later15mins))
281 
282         # Invalid with extra garbage
283         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
284                                                  token + b'x',
285                                                  TEST_USER_ID_1,
286                                                  action_id=TEST_ACTION_ID_1,
287                                                  current_time=later15mins))
288 
289         # Invalid with token of None
290         self.assertFalse(xsrfutil.validate_token(TEST_KEY,
291                                                  None,
292                                                  TEST_USER_ID_1,
293                                                  action_id=TEST_ACTION_ID_1))
294