xref: /aosp_15_r20/external/pytorch/test/dynamo/test_frame_init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import torch
4import torch._dynamo.test_case
5from torch._C._dynamo.eval_frame import set_eval_frame
6from torch._guards import CompileId
7
8
9def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
10    local = 1
11    return {
12        "local": local,
13        "arg1": arg1,
14        "positional_only_arg": positional_only_arg,
15        "keyword_only_arg": keyword_only_arg,
16        "kwargs": kwargs,
17    }
18
19
20def varkwargs_code1(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
21    # remove a local variable: local = 1
22    return {
23        "local": 1,
24        "arg1": arg1,
25        "positional_only_arg": positional_only_arg,
26        "keyword_only_arg": keyword_only_arg,
27        "kwargs": kwargs,
28    }
29
30
31def varkwargs_code2(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
32    # introduce a local variable
33    local1 = 0
34    local2 = 1
35    return {
36        "local": local1 + local2,
37        "arg1": arg1,
38        "positional_only_arg": positional_only_arg,
39        "keyword_only_arg": keyword_only_arg,
40        "kwargs": kwargs,
41    }
42
43
44def target_with_varargs(arg1, /, positional_only_arg, *varargs, **kwargs):
45    local = 1
46    return {
47        "local": local,
48        "arg1": arg1,
49        "positional_only_arg": positional_only_arg,
50        "varargs": varargs,
51        "kwargs": kwargs,
52    }
53
54
55def varargs_code1(arg1, /, positional_only_arg, *varargs, **kwargs):
56    # remove a local variable: local = 1
57    return {
58        "local": 1,
59        "arg1": arg1,
60        "positional_only_arg": positional_only_arg,
61        "varargs": varargs,
62        "kwargs": kwargs,
63    }
64
65
66def varargs_code2(arg1, /, positional_only_arg, *varargs, **kwargs):
67    # introduce a local variable
68    local1 = 0
69    local2 = 1
70    return {
71        "local": local1 + local2,
72        "arg1": arg1,
73        "positional_only_arg": positional_only_arg,
74        "varargs": varargs,
75        "kwargs": kwargs,
76    }
77
78
79class FrameInitTests(torch._dynamo.test_case.TestCase):
80    def test_frame_init(self):
81        code_map1 = {
82            target_with_varargs.__code__: varargs_code1.__code__,
83            target_with_varkwargs.__code__: varkwargs_code1.__code__,
84        }
85        code_map2 = {
86            target_with_varargs.__code__: varargs_code2.__code__,
87            target_with_varkwargs.__code__: varkwargs_code2.__code__,
88        }
89
90        def callback1(frame, cache_entry, frame_state):
91            if frame.f_code in code_map1:
92                transformed_code = code_map1[frame.f_code]
93                return torch._dynamo.types.GuardedCode(
94                    transformed_code, lambda f_locals: True, CompileId(0, 0)
95                )
96            return None
97
98        def callback2(frame, cache_entry, frame_state):
99            if frame.f_code in code_map2:
100                transformed_code = code_map2[frame.f_code]
101                return torch._dynamo.types.GuardedCode(
102                    transformed_code, lambda f_locals: True, CompileId(0, 0)
103                )
104            return None
105
106        for callback in [callback1, callback2]:
107            torch._dynamo.reset()
108            expected_varargs_output = target_with_varargs(
109                1, 2, 3, 4, name1=1, name2=2, name3=3
110            )
111            expected_kwargs_output = target_with_varkwargs(
112                1, 2, keyword_only_arg=1, name2=2, name3=3
113            )
114            original = set_eval_frame(callback1)
115            real_varargs_output = target_with_varargs(
116                1, 2, 3, 4, name1=1, name2=2, name3=3
117            )
118            real_kwargs_output = target_with_varkwargs(
119                1, 2, keyword_only_arg=1, name2=2, name3=3
120            )
121            self.assertEqual(real_varargs_output, expected_varargs_output)
122            self.assertEqual(real_kwargs_output, expected_kwargs_output)
123            set_eval_frame(original)
124
125
126if __name__ == "__main__":
127    from torch._dynamo.test_case import run_tests
128
129    run_tests()
130