1 """bytecode_helper - support tools for testing correct bytecode generation"""
2 
3 import unittest
4 import dis
5 import io
6 
7 _UNSPECIFIED = object()
8 
9 class BytecodeTestCase(unittest.TestCase):
10     """Custom assertion methods for inspecting bytecode."""
11 
12     def get_disassembly_as_string(self, co):
13         s = io.StringIO()
14         dis.dis(co, file=s)
15         return s.getvalue()
16 
17     def assertInBytecode(self, x, opname, argval=_UNSPECIFIED):
18         """Returns instr if opname is found, otherwise throws AssertionError"""
19         for instr in dis.get_instructions(x):
20             if instr.opname == opname:
21                 if argval is _UNSPECIFIED or instr.argval == argval:
22                     return instr
23         disassembly = self.get_disassembly_as_string(x)
24         if argval is _UNSPECIFIED:
25             msg = '%s not found in bytecode:\n%s' % (opname, disassembly)
26         else:
27             msg = '(%s,%r) not found in bytecode:\n%s'
28             msg = msg % (opname, argval, disassembly)
29         self.fail(msg)
30 
31     def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED):
32         """Throws AssertionError if opname is found"""
33         for instr in dis.get_instructions(x):
34             if instr.opname == opname:
35                 disassembly = self.get_disassembly_as_string(x)
36                 if argval is _UNSPECIFIED:
37                     msg = '%s occurs in bytecode:\n%s' % (opname, disassembly)
38                     self.fail(msg)
39                 elif instr.argval == argval:
40                     msg = '(%s,%r) occurs in bytecode:\n%s'
41                     msg = msg % (opname, argval, disassembly)
42                     self.fail(msg)
43