1 # Copyright 2014 Google Inc. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Tests for oauth2client.contrib.xsrfutil.""" 16 17 import base64 18 19 import mock 20 import unittest2 21 22 from oauth2client import _helpers 23 from oauth2client.contrib import xsrfutil 24 25 # Jan 17 2008, 5:40PM 26 TEST_KEY = b'test key' 27 # Jan. 17, 2008 22:40:32.081230 UTC 28 TEST_TIME = 1200609642081230 29 TEST_USER_ID_1 = 123832983 30 TEST_USER_ID_2 = 938297432 31 TEST_ACTION_ID_1 = b'some_action' 32 TEST_ACTION_ID_2 = b'some_other_action' 33 TEST_EXTRA_INFO_1 = b'extra_info_1' 34 TEST_EXTRA_INFO_2 = b'more_extra_info' 35 36 37 __author__ = 'jcgregorio@google.com (Joe Gregorio)' 38 39 40 class Test_generate_token(unittest2.TestCase): 41 42 def test_bad_positional(self): 43 # Need 2 positional arguments. 44 with self.assertRaises(TypeError): 45 xsrfutil.generate_token(None) 46 # At most 2 positional arguments. 47 with self.assertRaises(TypeError): 48 xsrfutil.generate_token(None, None, None) 49 50 def test_it(self): 51 digest = b'foobar' 52 digester = mock.MagicMock() 53 digester.digest = mock.MagicMock(name='digest', return_value=digest) 54 with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac: 55 hmac.new = mock.MagicMock(name='new', return_value=digester) 56 token = xsrfutil.generate_token(TEST_KEY, 57 TEST_USER_ID_1, 58 action_id=TEST_ACTION_ID_1, 59 when=TEST_TIME) 60 hmac.new.assert_called_once_with(TEST_KEY) 61 digester.digest.assert_called_once_with() 62 63 expected_digest_calls = [ 64 mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))), 65 mock.call.update(xsrfutil.DELIMITER), 66 mock.call.update(TEST_ACTION_ID_1), 67 mock.call.update(xsrfutil.DELIMITER), 68 mock.call.update(_helpers._to_bytes(str(TEST_TIME))), 69 ] 70 self.assertEqual(digester.method_calls, expected_digest_calls) 71 72 expected_token_as_bytes = (digest + xsrfutil.DELIMITER + 73 _helpers._to_bytes(str(TEST_TIME))) 74 expected_token = base64.urlsafe_b64encode( 75 expected_token_as_bytes) 76 self.assertEqual(token, expected_token) 77 78 def test_with_system_time(self): 79 digest = b'foobar' 80 curr_time = 1440449755.74 81 digester = mock.MagicMock() 82 digester.digest = mock.MagicMock(name='digest', return_value=digest) 83 with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac: 84 hmac.new = mock.MagicMock(name='new', return_value=digester) 85 86 with mock.patch('oauth2client.contrib.xsrfutil.time') as time: 87 time.time = mock.MagicMock(name='time', return_value=curr_time) 88 # when= is omitted 89 token = xsrfutil.generate_token(TEST_KEY, 90 TEST_USER_ID_1, 91 action_id=TEST_ACTION_ID_1) 92 93 hmac.new.assert_called_once_with(TEST_KEY) 94 time.time.assert_called_once_with() 95 digester.digest.assert_called_once_with() 96 97 expected_digest_calls = [ 98 mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))), 99 mock.call.update(xsrfutil.DELIMITER), 100 mock.call.update(TEST_ACTION_ID_1), 101 mock.call.update(xsrfutil.DELIMITER), 102 mock.call.update(_helpers._to_bytes(str(int(curr_time)))), 103 ] 104 self.assertEqual(digester.method_calls, expected_digest_calls) 105 106 expected_token_as_bytes = ( 107 digest + xsrfutil.DELIMITER + 108 _helpers._to_bytes(str(int(curr_time)))) 109 expected_token = base64.urlsafe_b64encode( 110 expected_token_as_bytes) 111 self.assertEqual(token, expected_token) 112 113 114 class Test_validate_token(unittest2.TestCase): 115 116 def test_bad_positional(self): 117 # Need 3 positional arguments. 118 with self.assertRaises(TypeError): 119 xsrfutil.validate_token(None, None) 120 # At most 3 positional arguments. 121 with self.assertRaises(TypeError): 122 xsrfutil.validate_token(None, None, None, None) 123 124 def test_no_token(self): 125 key = token = user_id = None 126 self.assertFalse(xsrfutil.validate_token(key, token, user_id)) 127 128 def test_token_not_valid_base64(self): 129 key = user_id = None 130 token = b'a' # Bad padding 131 self.assertFalse(xsrfutil.validate_token(key, token, user_id)) 132 133 def test_token_non_integer(self): 134 key = user_id = None 135 token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz') 136 self.assertFalse(xsrfutil.validate_token(key, token, user_id)) 137 138 def test_token_too_old_implicit_current_time(self): 139 token_time = 123456789 140 curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1 141 142 key = user_id = None 143 token = base64.b64encode(_helpers._to_bytes(str(token_time))) 144 with mock.patch('oauth2client.contrib.xsrfutil.time') as time: 145 time.time = mock.MagicMock(name='time', return_value=curr_time) 146 self.assertFalse(xsrfutil.validate_token(key, token, user_id)) 147 time.time.assert_called_once_with() 148 149 def test_token_too_old_explicit_current_time(self): 150 token_time = 123456789 151 curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1 152 153 key = user_id = None 154 token = base64.b64encode(_helpers._to_bytes(str(token_time))) 155 self.assertFalse(xsrfutil.validate_token(key, token, user_id, 156 current_time=curr_time)) 157 158 def test_token_length_differs_from_generated(self): 159 token_time = 123456789 160 # Make sure it isn't too old. 161 curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 162 163 key = object() 164 user_id = object() 165 action_id = object() 166 token = base64.b64encode(_helpers._to_bytes(str(token_time))) 167 generated_token = b'a' 168 # Make sure the token length comparison will fail. 169 self.assertNotEqual(len(token), len(generated_token)) 170 171 with mock.patch('oauth2client.contrib.xsrfutil.generate_token', 172 return_value=generated_token) as gen_tok: 173 self.assertFalse(xsrfutil.validate_token(key, token, user_id, 174 current_time=curr_time, 175 action_id=action_id)) 176 gen_tok.assert_called_once_with(key, user_id, action_id=action_id, 177 when=token_time) 178 179 def test_token_differs_from_generated_but_same_length(self): 180 token_time = 123456789 181 # Make sure it isn't too old. 182 curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 183 184 key = object() 185 user_id = object() 186 action_id = object() 187 token = base64.b64encode(_helpers._to_bytes(str(token_time))) 188 # It is encoded as b'MTIzNDU2Nzg5', which has length 12. 189 generated_token = b'M' * 12 190 # Make sure the token length comparison will succeed, but the token 191 # comparison will fail. 192 self.assertEqual(len(token), len(generated_token)) 193 self.assertNotEqual(token, generated_token) 194 195 with mock.patch('oauth2client.contrib.xsrfutil.generate_token', 196 return_value=generated_token) as gen_tok: 197 self.assertFalse(xsrfutil.validate_token(key, token, user_id, 198 current_time=curr_time, 199 action_id=action_id)) 200 gen_tok.assert_called_once_with(key, user_id, action_id=action_id, 201 when=token_time) 202 203 def test_success(self): 204 token_time = 123456789 205 # Make sure it isn't too old. 206 curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 207 208 key = object() 209 user_id = object() 210 action_id = object() 211 token = base64.b64encode(_helpers._to_bytes(str(token_time))) 212 with mock.patch('oauth2client.contrib.xsrfutil.generate_token', 213 return_value=token) as gen_tok: 214 self.assertTrue(xsrfutil.validate_token(key, token, user_id, 215 current_time=curr_time, 216 action_id=action_id)) 217 gen_tok.assert_called_once_with(key, user_id, action_id=action_id, 218 when=token_time) 219 220 221 class XsrfUtilTests(unittest2.TestCase): 222 """Test xsrfutil functions.""" 223 224 def testGenerateAndValidateToken(self): 225 """Test generating and validating a token.""" 226 token = xsrfutil.generate_token(TEST_KEY, 227 TEST_USER_ID_1, 228 action_id=TEST_ACTION_ID_1, 229 when=TEST_TIME) 230 231 # Check that the token is considered valid when it should be. 232 self.assertTrue(xsrfutil.validate_token(TEST_KEY, 233 token, 234 TEST_USER_ID_1, 235 action_id=TEST_ACTION_ID_1, 236 current_time=TEST_TIME)) 237 238 # Should still be valid 15 minutes later. 239 later15mins = TEST_TIME + 15 * 60 240 self.assertTrue(xsrfutil.validate_token(TEST_KEY, 241 token, 242 TEST_USER_ID_1, 243 action_id=TEST_ACTION_ID_1, 244 current_time=later15mins)) 245 246 # But not if beyond the timeout. 247 later2hours = TEST_TIME + 2 * 60 * 60 248 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 249 token, 250 TEST_USER_ID_1, 251 action_id=TEST_ACTION_ID_1, 252 current_time=later2hours)) 253 254 # Or if the key is different. 255 self.assertFalse(xsrfutil.validate_token('another key', 256 token, 257 TEST_USER_ID_1, 258 action_id=TEST_ACTION_ID_1, 259 current_time=later15mins)) 260 261 # Or the user ID.... 262 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 263 token, 264 TEST_USER_ID_2, 265 action_id=TEST_ACTION_ID_1, 266 current_time=later15mins)) 267 268 # Or the action ID... 269 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 270 token, 271 TEST_USER_ID_1, 272 action_id=TEST_ACTION_ID_2, 273 current_time=later15mins)) 274 275 # Invalid when truncated 276 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 277 token[:-1], 278 TEST_USER_ID_1, 279 action_id=TEST_ACTION_ID_1, 280 current_time=later15mins)) 281 282 # Invalid with extra garbage 283 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 284 token + b'x', 285 TEST_USER_ID_1, 286 action_id=TEST_ACTION_ID_1, 287 current_time=later15mins)) 288 289 # Invalid with token of None 290 self.assertFalse(xsrfutil.validate_token(TEST_KEY, 291 None, 292 TEST_USER_ID_1, 293 action_id=TEST_ACTION_ID_1)) 294