xref: /aosp_15_r20/external/pytorch/test/dynamo/test_exc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["module: dynamo"]
2 
3 import logging
4 import unittest
5 
6 import torch
7 import torch._dynamo
8 import torch._dynamo.config
9 import torch._dynamo.test_case
10 from torch._dynamo.comptime import comptime
11 from torch._dynamo.exc import Unsupported
12 from torch.testing._internal.common_device_type import skipIf
13 from torch.testing._internal.common_utils import (
14     IS_FBCODE,
15     munge_exc,
16     skipIfWindows,
17     TEST_Z3,
18 )
19 from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
20 
21 
22 class ExcTests(LoggingTestCase):
23     maxDiff = None
24 
25     def test_unsupported_real_stack(self):
26         # exercise Unsupported constructor and augment_exc_message
27         def fn002(x):
28             torch._dynamo.graph_break()
29 
30         def fn001(x):
31             x = x + 1
32             fn002(x)
33 
34         self.assertExpectedInlineMunged(
35             Unsupported,
36             lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
37                 torch.randn(1)
38             ),
39             """\
40 'skip function graph_break in file _dynamo/decorators.py'
41 
42 from user code:
43    File "test_exc.py", line N, in fn001
44     fn002(x)
45   File "test_exc.py", line N, in fn002
46     torch._dynamo.graph_break()""",
47         )
48 
49     @torch._dynamo.config.patch(verbose=True, suppress_errors=True)
50     @make_logging_test()
51     @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
52     def test_internal_error_suppress_errors(self, records):
53         def fn001(x):
54             def f(ctx):
55                 raise AssertionError
56 
57             comptime(f)
58 
59         torch.compile(fn001, backend="eager")(torch.randn(1))
60 
61         record = self.getRecord(records, "WON'T CONVERT")
62 
63         self.assertExpectedInline(
64             munge_exc(record.getMessage()),
65             """\
66 WON'T CONVERT fn001 test_exc.py line N
67 ========== TorchDynamo Stack Trace ==========
68 Traceback (most recent call last):
69   File "test_exc.py", line N, in f
70     raise AssertionError
71 AssertionError:
72 
73 from user code:
74    File "test_exc.py", line N, in fn001
75     comptime(f)
76 
77 
78 ========== The above exception occurred while processing the following code ==========
79 
80   File "test_exc.py", line N, in test_internal_error_suppress_errors
81     torch.compile(fn001, backend="eager")(torch.randn(1))
82   File "test_exc.py", line N, in fn001
83     comptime(f)
84 
85 ==========""",
86         )
87 
88     @make_logging_test()
89     def test_not_implemented_error(self, records):
90         def fn001(x):
91             def f(ctx):
92                 raise NotImplementedError
93 
94             # Ensure graph break is not possible
95             for i in range(3):
96                 comptime(f)
97 
98         torch.compile(fn001, backend="eager")(torch.randn(1))
99 
100         record = self.getRecord(records, "WON'T CONVERT")
101 
102         self.assertExpectedInline(
103             munge_exc(record.getMessage()),
104             """\
105 WON'T CONVERT fn001 test_exc.py line N
106 due to:
107 Traceback (most recent call last):
108   File "test_exc.py", line N, in f
109     raise NotImplementedError
110 torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError:
111 
112 from user code:
113    File "test_exc.py", line N, in fn001
114     comptime(f)""",
115         )
116 
117     @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
118     @make_logging_test(dynamo=logging.DEBUG)
119     def test_unsupported_error(self, records):
120         def fn001(x):
121             return {1, 2}
122 
123         torch.compile(fn001, backend="eager")(torch.randn(1))
124 
125         # TODO: There is no graph break log!  This is because the graph break
126         # logging is not in a centralized location; unsupported
127         # instruction bypasses it
128         self.getRecord(records, "Graph break:")
129 
130     @torch._dynamo.config.patch(suppress_errors=False)
131     def test_internal_error_no_suppress(self):
132         def fn001(x):
133             # NB: avoid decorator, as 3.11 changed the line number attributed
134             # in this situation
135             def f(ctx):
136                 raise AssertionError
137 
138             comptime(f)
139 
140         # NB: OK for user code to be truncated here, because the regular
141         # exception backtrace has the rest of the crumbs
142         self.assertExpectedInlineMunged(
143             AssertionError,
144             lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
145             """\
146 
147 
148 from user code:
149    File "test_exc.py", line N, in fn001
150     comptime(f)""",
151         )
152 
153     @make_logging_test(graph_breaks=True)
154     def test_graph_break_log(self, records):
155         def fn002(x):
156             x = x + 1
157             torch._dynamo.graph_break()
158             x = x + 1
159             return x
160 
161         def fn001(x):
162             return fn002(x)
163 
164         torch.compile(fn001, backend="eager")(torch.randn(1))
165 
166         record = self.getRecord(records, "Graph break:")
167 
168         # TODO: This should also report the enclosing frames; need to plumb
169         # frame object to it
170         self.assertExpectedInline(
171             munge_exc(record.getMessage()),
172             """\
173 Graph break: from user code at:
174   File "test_exc.py", line N, in fn001
175     return fn002(x)
176   File "test_exc.py", line N, in fn002
177     torch._dynamo.graph_break()
178 """,  # noqa: B950
179         )
180 
181     @torch._dynamo.config.patch(suppress_errors=False)
182     def test_backend_suppress_line(self):
183         def fn001(x):
184             x = torch.relu(x)
185             return x + 1
186 
187         # Do NOT let this get attributed to x + 1
188         self.assertExpectedInlineMunged(
189             torch._dynamo.exc.BackendCompilerFailed,
190             lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
191                 torch.randn(1)
192             ),
193             """\
194 backend='relu_compile_error_TESTING_ONLY' raised:
195 ReluCompileError:""",
196         )
197 
198     @skipIf(not TEST_Z3, "z3 not installed")
199     @torch._dynamo.config.patch(
200         assume_static_by_default=False,
201         suppress_errors=False,
202     )
203     @torch.fx.experimental._config.patch(
204         inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
205         translation_validation=True,
206         translation_validation_no_bisect=True,
207     )
208     @skipIfWindows(
209         msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n  ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])'  # noqa: PLR0133
210         != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n  ==> (<= (+ s1 s2) [483 chars][0])"'
211     )
212     def test_trigger_on_error(self):
213         from torch.fx.experimental.validator import ValidationException
214 
215         @torch.compile
216         def fn(x, shape):
217             return x.split(shape)
218 
219         self.assertExpectedInlineMunged(
220             ValidationException,
221             lambda: fn(torch.randn(20), (5, 10, 5)),
222             """\
223 translation validation failed.
224 
225 Model:
226   ==> L['shape'][0]: 0
227   ==> L['shape'][1]: 1
228   ==> L['shape'][2]: 1
229   ==> L['x'].size()[0]: 3
230   ==> L['x'].storage_offset(): 0
231   ==> L['x'].stride()[0]: 1
232   ==> s0: 3
233   ==> s1: 0
234   ==> s2: 1
235   ==> s3: 1
236 
237 Assertions:
238   ==> (== 0 L['x'].storage_offset())
239   ==> (== 1 L['x'].stride()[0])
240   ==> (== L['shape'][0] s1)
241   ==> (== L['shape'][1] s2)
242   ==> (== L['shape'][2] s3)
243   ==> (== L['x'].size()[0] s0)
244   ==> (> s0 1)
245   ==> (True)
246 
247 Target Expressions:
248   ==> (!= (+ s1 s2 s3) s0)
249   ==> (<= (+ s1 s2 s3) s0)
250   ==> (<= (+ s1 s2) (+ s0 (* -1 s3)))
251   ==> (<= (+ s1 s2) s0)
252   ==> (<= 0 s1)
253   ==> (<= 0 s2)
254   ==> (<= 0 s3)
255   ==> (<= 2 s0)
256   ==> (<= s1 (+ s0 (* -1 s2)))
257   ==> (== 0 L['x'].storage_offset())
258   ==> (== 1 L['x'].stride()[0])
259   ==> (== L['shape'][0] s1)
260   ==> (== L['shape'][1] s2)
261   ==> (== L['shape'][2] s3)
262   ==> (== L['x'].size()[0] s0)
263   ==> (> s0 0)
264   ==> (>= 0 s1)
265   ==> (And (<= (+ s1 s2) s0) (<= (* -1 s0) (+ s1 s2)))
266 
267 Failed Source Expressions:
268   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
269         )
270 
271     @skipIf(not TEST_Z3, "z3 not installed")
272     @torch._dynamo.config.patch(
273         assume_static_by_default=False,
274         suppress_errors=False,
275     )
276     @torch.fx.experimental._config.patch(
277         inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
278         translation_validation=True,
279     )
280     def test_trigger_bisect_on_error(self):
281         from torch.fx.experimental.validator import BisectValidationException
282 
283         @torch.compile
284         def fn(x, shape):
285             return x.split(shape)
286 
287         self.assertExpectedInlineMunged(
288             BisectValidationException,
289             lambda: fn(torch.randn(20), (5, 10, 5)),
290             """\
291 translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
292 
293 Failure occurred while running node:
294     %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
295 
296 Model:
297   ==> L['shape'][0]: 1
298   ==> L['shape'][1]: 1
299   ==> L['shape'][2]: 0
300   ==> L['x'].size()[0]: 3
301   ==> L['x'].storage_offset(): 0
302   ==> L['x'].stride()[0]: 1
303   ==> s0: 3
304   ==> s1: 1
305   ==> s2: 1
306   ==> s3: 0
307 
308 Assertions:
309   ==> (== 0 L['x'].storage_offset())
310   ==> (== 1 L['x'].stride()[0])
311   ==> (== L['shape'][0] s1)
312   ==> (== L['shape'][1] s2)
313   ==> (== L['shape'][2] s3)
314   ==> (== L['x'].size()[0] s0)
315   ==> (> s0 1)
316 
317 Target Expressions:
318   ==> (!= (+ s1 s2 s3) s0)
319   ==> (<= 0 s1)
320   ==> (<= 0 s2)
321   ==> (<= 0 s3)
322   ==> (<= 2 s0)
323   ==> (== 0 L['x'].storage_offset())
324   ==> (== 1 L['x'].stride()[0])
325   ==> (== L['shape'][0] s1)
326   ==> (== L['shape'][1] s2)
327   ==> (== L['shape'][2] s3)
328   ==> (== L['x'].size()[0] s0)
329   ==> (> s0 0)
330 
331 Failed Source Expressions:
332   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
333         )
334 
335 
336 if __name__ == "__main__":
337     from torch._dynamo.test_case import run_tests
338 
339     run_tests()
340