1# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 2# 3# Copyright (C) 2006-2007 Gerhard Häring <[email protected]> 4# 5# This file is part of pysqlite. 6# 7# This software is provided 'as-is', without any express or implied 8# warranty. In no event will the authors be held liable for any damages 9# arising from the use of this software. 10# 11# Permission is granted to anyone to use this software for any purpose, 12# including commercial applications, and to alter it and redistribute it 13# freely, subject to the following restrictions: 14# 15# 1. The origin of this software must not be misrepresented; you must not 16# claim that you wrote the original software. If you use this software 17# in a product, an acknowledgment in the product documentation would be 18# appreciated but is not required. 19# 2. Altered source versions must be plainly marked as such, and must not be 20# misrepresented as being the original software. 21# 3. This notice may not be removed or altered from any source distribution. 22 23import contextlib 24import sqlite3 as sqlite 25import unittest 26 27from test.support.os_helper import TESTFN, unlink 28 29from test.test_sqlite3.test_dbapi import memory_database, cx_limit 30from test.test_sqlite3.test_userfunctions import with_tracebacks 31 32 33class CollationTests(unittest.TestCase): 34 def test_create_collation_not_string(self): 35 con = sqlite.connect(":memory:") 36 with self.assertRaises(TypeError): 37 con.create_collation(None, lambda x, y: (x > y) - (x < y)) 38 39 def test_create_collation_not_callable(self): 40 con = sqlite.connect(":memory:") 41 with self.assertRaises(TypeError) as cm: 42 con.create_collation("X", 42) 43 self.assertEqual(str(cm.exception), 'parameter must be callable') 44 45 def test_create_collation_not_ascii(self): 46 con = sqlite.connect(":memory:") 47 con.create_collation("collä", lambda x, y: (x > y) - (x < y)) 48 49 def test_create_collation_bad_upper(self): 50 class BadUpperStr(str): 51 def upper(self): 52 return None 53 con = sqlite.connect(":memory:") 54 mycoll = lambda x, y: -((x > y) - (x < y)) 55 con.create_collation(BadUpperStr("mycoll"), mycoll) 56 result = con.execute(""" 57 select x from ( 58 select 'a' as x 59 union 60 select 'b' as x 61 ) order by x collate mycoll 62 """).fetchall() 63 self.assertEqual(result[0][0], 'b') 64 self.assertEqual(result[1][0], 'a') 65 66 def test_collation_is_used(self): 67 def mycoll(x, y): 68 # reverse order 69 return -((x > y) - (x < y)) 70 71 con = sqlite.connect(":memory:") 72 con.create_collation("mycoll", mycoll) 73 sql = """ 74 select x from ( 75 select 'a' as x 76 union 77 select 'b' as x 78 union 79 select 'c' as x 80 ) order by x collate mycoll 81 """ 82 result = con.execute(sql).fetchall() 83 self.assertEqual(result, [('c',), ('b',), ('a',)], 84 msg='the expected order was not returned') 85 86 con.create_collation("mycoll", None) 87 with self.assertRaises(sqlite.OperationalError) as cm: 88 result = con.execute(sql).fetchall() 89 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 90 91 def test_collation_returns_large_integer(self): 92 def mycoll(x, y): 93 # reverse order 94 return -((x > y) - (x < y)) * 2**32 95 con = sqlite.connect(":memory:") 96 con.create_collation("mycoll", mycoll) 97 sql = """ 98 select x from ( 99 select 'a' as x 100 union 101 select 'b' as x 102 union 103 select 'c' as x 104 ) order by x collate mycoll 105 """ 106 result = con.execute(sql).fetchall() 107 self.assertEqual(result, [('c',), ('b',), ('a',)], 108 msg="the expected order was not returned") 109 110 def test_collation_register_twice(self): 111 """ 112 Register two different collation functions under the same name. 113 Verify that the last one is actually used. 114 """ 115 con = sqlite.connect(":memory:") 116 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 117 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 118 result = con.execute(""" 119 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 120 """).fetchall() 121 self.assertEqual(result[0][0], 'b') 122 self.assertEqual(result[1][0], 'a') 123 124 def test_deregister_collation(self): 125 """ 126 Register a collation, then deregister it. Make sure an error is raised if we try 127 to use it. 128 """ 129 con = sqlite.connect(":memory:") 130 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 131 con.create_collation("mycoll", None) 132 with self.assertRaises(sqlite.OperationalError) as cm: 133 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 134 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 135 136class ProgressTests(unittest.TestCase): 137 def test_progress_handler_used(self): 138 """ 139 Test that the progress handler is invoked once it is set. 140 """ 141 con = sqlite.connect(":memory:") 142 progress_calls = [] 143 def progress(): 144 progress_calls.append(None) 145 return 0 146 con.set_progress_handler(progress, 1) 147 con.execute(""" 148 create table foo(a, b) 149 """) 150 self.assertTrue(progress_calls) 151 152 def test_opcode_count(self): 153 """ 154 Test that the opcode argument is respected. 155 """ 156 con = sqlite.connect(":memory:") 157 progress_calls = [] 158 def progress(): 159 progress_calls.append(None) 160 return 0 161 con.set_progress_handler(progress, 1) 162 curs = con.cursor() 163 curs.execute(""" 164 create table foo (a, b) 165 """) 166 first_count = len(progress_calls) 167 progress_calls = [] 168 con.set_progress_handler(progress, 2) 169 curs.execute(""" 170 create table bar (a, b) 171 """) 172 second_count = len(progress_calls) 173 self.assertGreaterEqual(first_count, second_count) 174 175 def test_cancel_operation(self): 176 """ 177 Test that returning a non-zero value stops the operation in progress. 178 """ 179 con = sqlite.connect(":memory:") 180 def progress(): 181 return 1 182 con.set_progress_handler(progress, 1) 183 curs = con.cursor() 184 self.assertRaises( 185 sqlite.OperationalError, 186 curs.execute, 187 "create table bar (a, b)") 188 189 def test_clear_handler(self): 190 """ 191 Test that setting the progress handler to None clears the previously set handler. 192 """ 193 con = sqlite.connect(":memory:") 194 action = 0 195 def progress(): 196 nonlocal action 197 action = 1 198 return 0 199 con.set_progress_handler(progress, 1) 200 con.set_progress_handler(None, 1) 201 con.execute("select 1 union select 2 union select 3").fetchall() 202 self.assertEqual(action, 0, "progress handler was not cleared") 203 204 @with_tracebacks(ZeroDivisionError, name="bad_progress") 205 def test_error_in_progress_handler(self): 206 con = sqlite.connect(":memory:") 207 def bad_progress(): 208 1 / 0 209 con.set_progress_handler(bad_progress, 1) 210 with self.assertRaises(sqlite.OperationalError): 211 con.execute(""" 212 create table foo(a, b) 213 """) 214 215 @with_tracebacks(ZeroDivisionError, name="bad_progress") 216 def test_error_in_progress_handler_result(self): 217 con = sqlite.connect(":memory:") 218 class BadBool: 219 def __bool__(self): 220 1 / 0 221 def bad_progress(): 222 return BadBool() 223 con.set_progress_handler(bad_progress, 1) 224 with self.assertRaises(sqlite.OperationalError): 225 con.execute(""" 226 create table foo(a, b) 227 """) 228 229 230class TraceCallbackTests(unittest.TestCase): 231 @contextlib.contextmanager 232 def check_stmt_trace(self, cx, expected): 233 try: 234 traced = [] 235 cx.set_trace_callback(lambda stmt: traced.append(stmt)) 236 yield 237 finally: 238 self.assertEqual(traced, expected) 239 cx.set_trace_callback(None) 240 241 def test_trace_callback_used(self): 242 """ 243 Test that the trace callback is invoked once it is set. 244 """ 245 con = sqlite.connect(":memory:") 246 traced_statements = [] 247 def trace(statement): 248 traced_statements.append(statement) 249 con.set_trace_callback(trace) 250 con.execute("create table foo(a, b)") 251 self.assertTrue(traced_statements) 252 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 253 254 def test_clear_trace_callback(self): 255 """ 256 Test that setting the trace callback to None clears the previously set callback. 257 """ 258 con = sqlite.connect(":memory:") 259 traced_statements = [] 260 def trace(statement): 261 traced_statements.append(statement) 262 con.set_trace_callback(trace) 263 con.set_trace_callback(None) 264 con.execute("create table foo(a, b)") 265 self.assertFalse(traced_statements, "trace callback was not cleared") 266 267 def test_unicode_content(self): 268 """ 269 Test that the statement can contain unicode literals. 270 """ 271 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 272 con = sqlite.connect(":memory:") 273 traced_statements = [] 274 def trace(statement): 275 traced_statements.append(statement) 276 con.set_trace_callback(trace) 277 con.execute("create table foo(x)") 278 con.execute("insert into foo(x) values ('%s')" % unicode_value) 279 con.commit() 280 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 281 "Unicode data %s garbled in trace callback: %s" 282 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 283 284 def test_trace_callback_content(self): 285 # set_trace_callback() shouldn't produce duplicate content (bpo-26187) 286 traced_statements = [] 287 def trace(statement): 288 traced_statements.append(statement) 289 290 queries = ["create table foo(x)", 291 "insert into foo(x) values(1)"] 292 self.addCleanup(unlink, TESTFN) 293 con1 = sqlite.connect(TESTFN, isolation_level=None) 294 con2 = sqlite.connect(TESTFN) 295 try: 296 con1.set_trace_callback(trace) 297 cur = con1.cursor() 298 cur.execute(queries[0]) 299 con2.execute("create table bar(x)") 300 cur.execute(queries[1]) 301 finally: 302 con1.close() 303 con2.close() 304 self.assertEqual(traced_statements, queries) 305 306 def test_trace_expanded_sql(self): 307 expected = [ 308 "create table t(t)", 309 "BEGIN ", 310 "insert into t values(0)", 311 "insert into t values(1)", 312 "insert into t values(2)", 313 "COMMIT", 314 ] 315 with memory_database() as cx, self.check_stmt_trace(cx, expected): 316 with cx: 317 cx.execute("create table t(t)") 318 cx.executemany("insert into t values(?)", ((v,) for v in range(3))) 319 320 @with_tracebacks( 321 sqlite.DataError, 322 regex="Expanded SQL string exceeds the maximum string length" 323 ) 324 def test_trace_too_much_expanded_sql(self): 325 # If the expanded string is too large, we'll fall back to the 326 # unexpanded SQL statement (for SQLite 3.14.0 and newer). 327 # The resulting string length is limited by the runtime limit 328 # SQLITE_LIMIT_LENGTH. 329 template = "select 1 as a where a=" 330 category = sqlite.SQLITE_LIMIT_LENGTH 331 with memory_database() as cx, cx_limit(cx, category=category) as lim: 332 ok_param = "a" 333 bad_param = "a" * lim 334 335 unexpanded_query = template + "?" 336 expected = [unexpanded_query] 337 if sqlite.sqlite_version_info < (3, 14, 0): 338 expected = [] 339 with self.check_stmt_trace(cx, expected): 340 cx.execute(unexpanded_query, (bad_param,)) 341 342 expanded_query = f"{template}'{ok_param}'" 343 with self.check_stmt_trace(cx, [expanded_query]): 344 cx.execute(unexpanded_query, (ok_param,)) 345 346 @with_tracebacks(ZeroDivisionError, regex="division by zero") 347 def test_trace_bad_handler(self): 348 with memory_database() as cx: 349 cx.set_trace_callback(lambda stmt: 5/0) 350 cx.execute("select 1") 351 352 353if __name__ == "__main__": 354 unittest.main() 355