1# Owner(s): ["module: dynamo"] 2 3import torch 4import torch._dynamo.test_case 5from torch._C._dynamo.eval_frame import set_eval_frame 6from torch._guards import CompileId 7 8 9def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): 10 local = 1 11 return { 12 "local": local, 13 "arg1": arg1, 14 "positional_only_arg": positional_only_arg, 15 "keyword_only_arg": keyword_only_arg, 16 "kwargs": kwargs, 17 } 18 19 20def varkwargs_code1(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): 21 # remove a local variable: local = 1 22 return { 23 "local": 1, 24 "arg1": arg1, 25 "positional_only_arg": positional_only_arg, 26 "keyword_only_arg": keyword_only_arg, 27 "kwargs": kwargs, 28 } 29 30 31def varkwargs_code2(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): 32 # introduce a local variable 33 local1 = 0 34 local2 = 1 35 return { 36 "local": local1 + local2, 37 "arg1": arg1, 38 "positional_only_arg": positional_only_arg, 39 "keyword_only_arg": keyword_only_arg, 40 "kwargs": kwargs, 41 } 42 43 44def target_with_varargs(arg1, /, positional_only_arg, *varargs, **kwargs): 45 local = 1 46 return { 47 "local": local, 48 "arg1": arg1, 49 "positional_only_arg": positional_only_arg, 50 "varargs": varargs, 51 "kwargs": kwargs, 52 } 53 54 55def varargs_code1(arg1, /, positional_only_arg, *varargs, **kwargs): 56 # remove a local variable: local = 1 57 return { 58 "local": 1, 59 "arg1": arg1, 60 "positional_only_arg": positional_only_arg, 61 "varargs": varargs, 62 "kwargs": kwargs, 63 } 64 65 66def varargs_code2(arg1, /, positional_only_arg, *varargs, **kwargs): 67 # introduce a local variable 68 local1 = 0 69 local2 = 1 70 return { 71 "local": local1 + local2, 72 "arg1": arg1, 73 "positional_only_arg": positional_only_arg, 74 "varargs": varargs, 75 "kwargs": kwargs, 76 } 77 78 79class FrameInitTests(torch._dynamo.test_case.TestCase): 80 def test_frame_init(self): 81 code_map1 = { 82 target_with_varargs.__code__: varargs_code1.__code__, 83 target_with_varkwargs.__code__: varkwargs_code1.__code__, 84 } 85 code_map2 = { 86 target_with_varargs.__code__: varargs_code2.__code__, 87 target_with_varkwargs.__code__: varkwargs_code2.__code__, 88 } 89 90 def callback1(frame, cache_entry, frame_state): 91 if frame.f_code in code_map1: 92 transformed_code = code_map1[frame.f_code] 93 return torch._dynamo.types.GuardedCode( 94 transformed_code, lambda f_locals: True, CompileId(0, 0) 95 ) 96 return None 97 98 def callback2(frame, cache_entry, frame_state): 99 if frame.f_code in code_map2: 100 transformed_code = code_map2[frame.f_code] 101 return torch._dynamo.types.GuardedCode( 102 transformed_code, lambda f_locals: True, CompileId(0, 0) 103 ) 104 return None 105 106 for callback in [callback1, callback2]: 107 torch._dynamo.reset() 108 expected_varargs_output = target_with_varargs( 109 1, 2, 3, 4, name1=1, name2=2, name3=3 110 ) 111 expected_kwargs_output = target_with_varkwargs( 112 1, 2, keyword_only_arg=1, name2=2, name3=3 113 ) 114 original = set_eval_frame(callback1) 115 real_varargs_output = target_with_varargs( 116 1, 2, 3, 4, name1=1, name2=2, name3=3 117 ) 118 real_kwargs_output = target_with_varkwargs( 119 1, 2, keyword_only_arg=1, name2=2, name3=3 120 ) 121 self.assertEqual(real_varargs_output, expected_varargs_output) 122 self.assertEqual(real_kwargs_output, expected_kwargs_output) 123 set_eval_frame(original) 124 125 126if __name__ == "__main__": 127 from torch._dynamo.test_case import run_tests 128 129 run_tests() 130