1# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
2#
3# Copyright (C) 2006-2007 Gerhard Häring <[email protected]>
4#
5# This file is part of pysqlite.
6#
7# This software is provided 'as-is', without any express or implied
8# warranty.  In no event will the authors be held liable for any damages
9# arising from the use of this software.
10#
11# Permission is granted to anyone to use this software for any purpose,
12# including commercial applications, and to alter it and redistribute it
13# freely, subject to the following restrictions:
14#
15# 1. The origin of this software must not be misrepresented; you must not
16#    claim that you wrote the original software. If you use this software
17#    in a product, an acknowledgment in the product documentation would be
18#    appreciated but is not required.
19# 2. Altered source versions must be plainly marked as such, and must not be
20#    misrepresented as being the original software.
21# 3. This notice may not be removed or altered from any source distribution.
22
23import contextlib
24import sqlite3 as sqlite
25import unittest
26
27from test.support.os_helper import TESTFN, unlink
28
29from test.test_sqlite3.test_dbapi import memory_database, cx_limit
30from test.test_sqlite3.test_userfunctions import with_tracebacks
31
32
33class CollationTests(unittest.TestCase):
34    def test_create_collation_not_string(self):
35        con = sqlite.connect(":memory:")
36        with self.assertRaises(TypeError):
37            con.create_collation(None, lambda x, y: (x > y) - (x < y))
38
39    def test_create_collation_not_callable(self):
40        con = sqlite.connect(":memory:")
41        with self.assertRaises(TypeError) as cm:
42            con.create_collation("X", 42)
43        self.assertEqual(str(cm.exception), 'parameter must be callable')
44
45    def test_create_collation_not_ascii(self):
46        con = sqlite.connect(":memory:")
47        con.create_collation("collä", lambda x, y: (x > y) - (x < y))
48
49    def test_create_collation_bad_upper(self):
50        class BadUpperStr(str):
51            def upper(self):
52                return None
53        con = sqlite.connect(":memory:")
54        mycoll = lambda x, y: -((x > y) - (x < y))
55        con.create_collation(BadUpperStr("mycoll"), mycoll)
56        result = con.execute("""
57            select x from (
58            select 'a' as x
59            union
60            select 'b' as x
61            ) order by x collate mycoll
62            """).fetchall()
63        self.assertEqual(result[0][0], 'b')
64        self.assertEqual(result[1][0], 'a')
65
66    def test_collation_is_used(self):
67        def mycoll(x, y):
68            # reverse order
69            return -((x > y) - (x < y))
70
71        con = sqlite.connect(":memory:")
72        con.create_collation("mycoll", mycoll)
73        sql = """
74            select x from (
75            select 'a' as x
76            union
77            select 'b' as x
78            union
79            select 'c' as x
80            ) order by x collate mycoll
81            """
82        result = con.execute(sql).fetchall()
83        self.assertEqual(result, [('c',), ('b',), ('a',)],
84                         msg='the expected order was not returned')
85
86        con.create_collation("mycoll", None)
87        with self.assertRaises(sqlite.OperationalError) as cm:
88            result = con.execute(sql).fetchall()
89        self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
90
91    def test_collation_returns_large_integer(self):
92        def mycoll(x, y):
93            # reverse order
94            return -((x > y) - (x < y)) * 2**32
95        con = sqlite.connect(":memory:")
96        con.create_collation("mycoll", mycoll)
97        sql = """
98            select x from (
99            select 'a' as x
100            union
101            select 'b' as x
102            union
103            select 'c' as x
104            ) order by x collate mycoll
105            """
106        result = con.execute(sql).fetchall()
107        self.assertEqual(result, [('c',), ('b',), ('a',)],
108                         msg="the expected order was not returned")
109
110    def test_collation_register_twice(self):
111        """
112        Register two different collation functions under the same name.
113        Verify that the last one is actually used.
114        """
115        con = sqlite.connect(":memory:")
116        con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
117        con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
118        result = con.execute("""
119            select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
120            """).fetchall()
121        self.assertEqual(result[0][0], 'b')
122        self.assertEqual(result[1][0], 'a')
123
124    def test_deregister_collation(self):
125        """
126        Register a collation, then deregister it. Make sure an error is raised if we try
127        to use it.
128        """
129        con = sqlite.connect(":memory:")
130        con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
131        con.create_collation("mycoll", None)
132        with self.assertRaises(sqlite.OperationalError) as cm:
133            con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
134        self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
135
136class ProgressTests(unittest.TestCase):
137    def test_progress_handler_used(self):
138        """
139        Test that the progress handler is invoked once it is set.
140        """
141        con = sqlite.connect(":memory:")
142        progress_calls = []
143        def progress():
144            progress_calls.append(None)
145            return 0
146        con.set_progress_handler(progress, 1)
147        con.execute("""
148            create table foo(a, b)
149            """)
150        self.assertTrue(progress_calls)
151
152    def test_opcode_count(self):
153        """
154        Test that the opcode argument is respected.
155        """
156        con = sqlite.connect(":memory:")
157        progress_calls = []
158        def progress():
159            progress_calls.append(None)
160            return 0
161        con.set_progress_handler(progress, 1)
162        curs = con.cursor()
163        curs.execute("""
164            create table foo (a, b)
165            """)
166        first_count = len(progress_calls)
167        progress_calls = []
168        con.set_progress_handler(progress, 2)
169        curs.execute("""
170            create table bar (a, b)
171            """)
172        second_count = len(progress_calls)
173        self.assertGreaterEqual(first_count, second_count)
174
175    def test_cancel_operation(self):
176        """
177        Test that returning a non-zero value stops the operation in progress.
178        """
179        con = sqlite.connect(":memory:")
180        def progress():
181            return 1
182        con.set_progress_handler(progress, 1)
183        curs = con.cursor()
184        self.assertRaises(
185            sqlite.OperationalError,
186            curs.execute,
187            "create table bar (a, b)")
188
189    def test_clear_handler(self):
190        """
191        Test that setting the progress handler to None clears the previously set handler.
192        """
193        con = sqlite.connect(":memory:")
194        action = 0
195        def progress():
196            nonlocal action
197            action = 1
198            return 0
199        con.set_progress_handler(progress, 1)
200        con.set_progress_handler(None, 1)
201        con.execute("select 1 union select 2 union select 3").fetchall()
202        self.assertEqual(action, 0, "progress handler was not cleared")
203
204    @with_tracebacks(ZeroDivisionError, name="bad_progress")
205    def test_error_in_progress_handler(self):
206        con = sqlite.connect(":memory:")
207        def bad_progress():
208            1 / 0
209        con.set_progress_handler(bad_progress, 1)
210        with self.assertRaises(sqlite.OperationalError):
211            con.execute("""
212                create table foo(a, b)
213                """)
214
215    @with_tracebacks(ZeroDivisionError, name="bad_progress")
216    def test_error_in_progress_handler_result(self):
217        con = sqlite.connect(":memory:")
218        class BadBool:
219            def __bool__(self):
220                1 / 0
221        def bad_progress():
222            return BadBool()
223        con.set_progress_handler(bad_progress, 1)
224        with self.assertRaises(sqlite.OperationalError):
225            con.execute("""
226                create table foo(a, b)
227                """)
228
229
230class TraceCallbackTests(unittest.TestCase):
231    @contextlib.contextmanager
232    def check_stmt_trace(self, cx, expected):
233        try:
234            traced = []
235            cx.set_trace_callback(lambda stmt: traced.append(stmt))
236            yield
237        finally:
238            self.assertEqual(traced, expected)
239            cx.set_trace_callback(None)
240
241    def test_trace_callback_used(self):
242        """
243        Test that the trace callback is invoked once it is set.
244        """
245        con = sqlite.connect(":memory:")
246        traced_statements = []
247        def trace(statement):
248            traced_statements.append(statement)
249        con.set_trace_callback(trace)
250        con.execute("create table foo(a, b)")
251        self.assertTrue(traced_statements)
252        self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
253
254    def test_clear_trace_callback(self):
255        """
256        Test that setting the trace callback to None clears the previously set callback.
257        """
258        con = sqlite.connect(":memory:")
259        traced_statements = []
260        def trace(statement):
261            traced_statements.append(statement)
262        con.set_trace_callback(trace)
263        con.set_trace_callback(None)
264        con.execute("create table foo(a, b)")
265        self.assertFalse(traced_statements, "trace callback was not cleared")
266
267    def test_unicode_content(self):
268        """
269        Test that the statement can contain unicode literals.
270        """
271        unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
272        con = sqlite.connect(":memory:")
273        traced_statements = []
274        def trace(statement):
275            traced_statements.append(statement)
276        con.set_trace_callback(trace)
277        con.execute("create table foo(x)")
278        con.execute("insert into foo(x) values ('%s')" % unicode_value)
279        con.commit()
280        self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
281                        "Unicode data %s garbled in trace callback: %s"
282                        % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
283
284    def test_trace_callback_content(self):
285        # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
286        traced_statements = []
287        def trace(statement):
288            traced_statements.append(statement)
289
290        queries = ["create table foo(x)",
291                   "insert into foo(x) values(1)"]
292        self.addCleanup(unlink, TESTFN)
293        con1 = sqlite.connect(TESTFN, isolation_level=None)
294        con2 = sqlite.connect(TESTFN)
295        try:
296            con1.set_trace_callback(trace)
297            cur = con1.cursor()
298            cur.execute(queries[0])
299            con2.execute("create table bar(x)")
300            cur.execute(queries[1])
301        finally:
302            con1.close()
303            con2.close()
304        self.assertEqual(traced_statements, queries)
305
306    def test_trace_expanded_sql(self):
307        expected = [
308            "create table t(t)",
309            "BEGIN ",
310            "insert into t values(0)",
311            "insert into t values(1)",
312            "insert into t values(2)",
313            "COMMIT",
314        ]
315        with memory_database() as cx, self.check_stmt_trace(cx, expected):
316            with cx:
317                cx.execute("create table t(t)")
318                cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
319
320    @with_tracebacks(
321        sqlite.DataError,
322        regex="Expanded SQL string exceeds the maximum string length"
323    )
324    def test_trace_too_much_expanded_sql(self):
325        # If the expanded string is too large, we'll fall back to the
326        # unexpanded SQL statement (for SQLite 3.14.0 and newer).
327        # The resulting string length is limited by the runtime limit
328        # SQLITE_LIMIT_LENGTH.
329        template = "select 1 as a where a="
330        category = sqlite.SQLITE_LIMIT_LENGTH
331        with memory_database() as cx, cx_limit(cx, category=category) as lim:
332            ok_param = "a"
333            bad_param = "a" * lim
334
335            unexpanded_query = template + "?"
336            expected = [unexpanded_query]
337            if sqlite.sqlite_version_info < (3, 14, 0):
338                expected = []
339            with self.check_stmt_trace(cx, expected):
340                cx.execute(unexpanded_query, (bad_param,))
341
342            expanded_query = f"{template}'{ok_param}'"
343            with self.check_stmt_trace(cx, [expanded_query]):
344                cx.execute(unexpanded_query, (ok_param,))
345
346    @with_tracebacks(ZeroDivisionError, regex="division by zero")
347    def test_trace_bad_handler(self):
348        with memory_database() as cx:
349            cx.set_trace_callback(lambda stmt: 5/0)
350            cx.execute("select 1")
351
352
353if __name__ == "__main__":
354    unittest.main()
355