xref: /aosp_15_r20/external/pytorch/test/jit/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport cmath
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
7*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_MACOS
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import execWrapper, JitTestCase
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
16*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestComplex(JitTestCase):
21*da0073e9SAndroid Build Coastguard Worker    def test_script(self):
22*da0073e9SAndroid Build Coastguard Worker        def fn(a: complex):
23*da0073e9SAndroid Build Coastguard Worker            return a
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (3 + 5j,))
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def test_complexlist(self):
28*da0073e9SAndroid Build Coastguard Worker        def fn(a: List[complex], idx: int):
29*da0073e9SAndroid Build Coastguard Worker            return a[idx]
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker        input = [1j, 2, 3 + 4j, -5, -7j]
32*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (input, 2))
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    def test_complexdict(self):
35*da0073e9SAndroid Build Coastguard Worker        def fn(a: Dict[complex, complex], key: complex) -> complex:
36*da0073e9SAndroid Build Coastguard Worker            return a[key]
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
39*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (input, -4.3 - 2j))
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self):
42*da0073e9SAndroid Build Coastguard Worker        class ComplexModule(torch.jit.ScriptModule):
43*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
44*da0073e9SAndroid Build Coastguard Worker                super().__init__()
45*da0073e9SAndroid Build Coastguard Worker                self.a = 3 + 5j
46*da0073e9SAndroid Build Coastguard Worker                self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
47*da0073e9SAndroid Build Coastguard Worker                self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
50*da0073e9SAndroid Build Coastguard Worker            def forward(self, b: int):
51*da0073e9SAndroid Build Coastguard Worker                return b + 2j
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        loaded = self.getExportImportCopy(ComplexModule())
54*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.a, 3 + 5j)
55*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(2), 2 + 2j)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_complex_parse(self):
60*da0073e9SAndroid Build Coastguard Worker        def fn(a: int, b: torch.Tensor, dim: int):
61*da0073e9SAndroid Build Coastguard Worker            # verifies `emitValueToTensor()` 's behavior
62*da0073e9SAndroid Build Coastguard Worker            b[dim] = 2.4 + 0.5j
63*da0073e9SAndroid Build Coastguard Worker            return (3 * 2j) + a + 5j - 7.4j - 4
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor(1)
66*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([0.4, 1.4j, 2.35])
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (t1, t2, 2))
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def test_complex_constants_and_ops(self):
71*da0073e9SAndroid Build Coastguard Worker        vals = (
72*da0073e9SAndroid Build Coastguard Worker            [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
73*da0073e9SAndroid Build Coastguard Worker            + [10.0**i for i in range(2)]
74*da0073e9SAndroid Build Coastguard Worker            + [-(10.0**i) for i in range(2)]
75*da0073e9SAndroid Build Coastguard Worker        )
76*da0073e9SAndroid Build Coastguard Worker        complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent(
79*da0073e9SAndroid Build Coastguard Worker            """
80*da0073e9SAndroid Build Coastguard Worker            def func(a: complex):
81*da0073e9SAndroid Build Coastguard Worker                return cmath.{func_or_const}(a)
82*da0073e9SAndroid Build Coastguard Worker            """
83*da0073e9SAndroid Build Coastguard Worker        )
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        def checkCmath(func_name, funcs_template=funcs_template):
86*da0073e9SAndroid Build Coastguard Worker            funcs_str = funcs_template.format(func_or_const=func_name)
87*da0073e9SAndroid Build Coastguard Worker            scope = {}
88*da0073e9SAndroid Build Coastguard Worker            execWrapper(funcs_str, globals(), scope)
89*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(funcs_str)
90*da0073e9SAndroid Build Coastguard Worker            f_script = cu.func
91*da0073e9SAndroid Build Coastguard Worker            f = scope["func"]
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            if func_name in ["isinf", "isnan", "isfinite"]:
94*da0073e9SAndroid Build Coastguard Worker                new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
95*da0073e9SAndroid Build Coastguard Worker                final_vals = tuple(
96*da0073e9SAndroid Build Coastguard Worker                    complex(x, y) for x, y in product(new_vals, new_vals)
97*da0073e9SAndroid Build Coastguard Worker                )
98*da0073e9SAndroid Build Coastguard Worker            else:
99*da0073e9SAndroid Build Coastguard Worker                final_vals = complex_vals
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker            for a in final_vals:
102*da0073e9SAndroid Build Coastguard Worker                res_python = None
103*da0073e9SAndroid Build Coastguard Worker                res_script = None
104*da0073e9SAndroid Build Coastguard Worker                try:
105*da0073e9SAndroid Build Coastguard Worker                    res_python = f(a)
106*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
107*da0073e9SAndroid Build Coastguard Worker                    res_python = e
108*da0073e9SAndroid Build Coastguard Worker                try:
109*da0073e9SAndroid Build Coastguard Worker                    res_script = f_script(a)
110*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
111*da0073e9SAndroid Build Coastguard Worker                    res_script = e
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker                if res_python != res_script:
114*da0073e9SAndroid Build Coastguard Worker                    if isinstance(res_python, Exception):
115*da0073e9SAndroid Build Coastguard Worker                        continue
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker                    msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
118*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res_python, res_script, msg=msg)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        unary_ops = [
121*da0073e9SAndroid Build Coastguard Worker            "log",
122*da0073e9SAndroid Build Coastguard Worker            "log10",
123*da0073e9SAndroid Build Coastguard Worker            "sqrt",
124*da0073e9SAndroid Build Coastguard Worker            "exp",
125*da0073e9SAndroid Build Coastguard Worker            "sin",
126*da0073e9SAndroid Build Coastguard Worker            "cos",
127*da0073e9SAndroid Build Coastguard Worker            "asin",
128*da0073e9SAndroid Build Coastguard Worker            "acos",
129*da0073e9SAndroid Build Coastguard Worker            "atan",
130*da0073e9SAndroid Build Coastguard Worker            "sinh",
131*da0073e9SAndroid Build Coastguard Worker            "cosh",
132*da0073e9SAndroid Build Coastguard Worker            "tanh",
133*da0073e9SAndroid Build Coastguard Worker            "asinh",
134*da0073e9SAndroid Build Coastguard Worker            "acosh",
135*da0073e9SAndroid Build Coastguard Worker            "atanh",
136*da0073e9SAndroid Build Coastguard Worker            "phase",
137*da0073e9SAndroid Build Coastguard Worker            "isinf",
138*da0073e9SAndroid Build Coastguard Worker            "isnan",
139*da0073e9SAndroid Build Coastguard Worker            "isfinite",
140*da0073e9SAndroid Build Coastguard Worker        ]
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        # --- Unary ops ---
143*da0073e9SAndroid Build Coastguard Worker        for op in unary_ops:
144*da0073e9SAndroid Build Coastguard Worker            checkCmath(op)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        def fn(x: complex):
147*da0073e9SAndroid Build Coastguard Worker            return abs(x)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        for val in complex_vals:
150*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn, (val,))
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        def pow_complex_float(x: complex, y: float):
153*da0073e9SAndroid Build Coastguard Worker            return pow(x, y)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        def pow_float_complex(x: float, y: complex):
156*da0073e9SAndroid Build Coastguard Worker            return pow(x, y)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        self.checkScript(pow_float_complex, (2, 3j))
159*da0073e9SAndroid Build Coastguard Worker        self.checkScript(pow_complex_float, (3j, 2))
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        def pow_complex_complex(x: complex, y: complex):
162*da0073e9SAndroid Build Coastguard Worker            return pow(x, y)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(complex_vals, complex_vals):
165*da0073e9SAndroid Build Coastguard Worker            # Reference: https://github.com/pytorch/pytorch/issues/54622
166*da0073e9SAndroid Build Coastguard Worker            if x == 0:
167*da0073e9SAndroid Build Coastguard Worker                continue
168*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pow_complex_complex, (x, y))
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        if not IS_MACOS:
171*da0073e9SAndroid Build Coastguard Worker            # --- Binary op ---
172*da0073e9SAndroid Build Coastguard Worker            def rect_fn(x: float, y: float):
173*da0073e9SAndroid Build Coastguard Worker                return cmath.rect(x, y)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker            for x, y in product(vals, vals):
176*da0073e9SAndroid Build Coastguard Worker                self.checkScript(
177*da0073e9SAndroid Build Coastguard Worker                    rect_fn,
178*da0073e9SAndroid Build Coastguard Worker                    (
179*da0073e9SAndroid Build Coastguard Worker                        x,
180*da0073e9SAndroid Build Coastguard Worker                        y,
181*da0073e9SAndroid Build Coastguard Worker                    ),
182*da0073e9SAndroid Build Coastguard Worker                )
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        func_constants_template = dedent(
185*da0073e9SAndroid Build Coastguard Worker            """
186*da0073e9SAndroid Build Coastguard Worker            def func():
187*da0073e9SAndroid Build Coastguard Worker                return cmath.{func_or_const}
188*da0073e9SAndroid Build Coastguard Worker            """
189*da0073e9SAndroid Build Coastguard Worker        )
190*da0073e9SAndroid Build Coastguard Worker        float_consts = ["pi", "e", "tau", "inf", "nan"]
191*da0073e9SAndroid Build Coastguard Worker        complex_consts = ["infj", "nanj"]
192*da0073e9SAndroid Build Coastguard Worker        for x in float_consts + complex_consts:
193*da0073e9SAndroid Build Coastguard Worker            checkCmath(x, funcs_template=func_constants_template)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_infj_nanj_pickle(self):
196*da0073e9SAndroid Build Coastguard Worker        class ComplexModule(torch.jit.ScriptModule):
197*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
198*da0073e9SAndroid Build Coastguard Worker                super().__init__()
199*da0073e9SAndroid Build Coastguard Worker                self.a = 3 + 5j
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
202*da0073e9SAndroid Build Coastguard Worker            def forward(self, infj: int, nanj: int):
203*da0073e9SAndroid Build Coastguard Worker                if infj == 2:
204*da0073e9SAndroid Build Coastguard Worker                    return infj + cmath.infj
205*da0073e9SAndroid Build Coastguard Worker                else:
206*da0073e9SAndroid Build Coastguard Worker                    return nanj + cmath.nanj
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        loaded = self.getExportImportCopy(ComplexModule())
209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(2, 3), 2 + cmath.infj)
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(3, 4), 4 + cmath.nanj)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    def test_complex_constructor(self):
213*da0073e9SAndroid Build Coastguard Worker        # Test all scalar types
214*da0073e9SAndroid Build Coastguard Worker        def fn_int(real: int, img: int):
215*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
218*da0073e9SAndroid Build Coastguard Worker            fn_int,
219*da0073e9SAndroid Build Coastguard Worker            (
220*da0073e9SAndroid Build Coastguard Worker                0,
221*da0073e9SAndroid Build Coastguard Worker                0,
222*da0073e9SAndroid Build Coastguard Worker            ),
223*da0073e9SAndroid Build Coastguard Worker        )
224*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
225*da0073e9SAndroid Build Coastguard Worker            fn_int,
226*da0073e9SAndroid Build Coastguard Worker            (
227*da0073e9SAndroid Build Coastguard Worker                -1234,
228*da0073e9SAndroid Build Coastguard Worker                0,
229*da0073e9SAndroid Build Coastguard Worker            ),
230*da0073e9SAndroid Build Coastguard Worker        )
231*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
232*da0073e9SAndroid Build Coastguard Worker            fn_int,
233*da0073e9SAndroid Build Coastguard Worker            (
234*da0073e9SAndroid Build Coastguard Worker                0,
235*da0073e9SAndroid Build Coastguard Worker                -1256,
236*da0073e9SAndroid Build Coastguard Worker            ),
237*da0073e9SAndroid Build Coastguard Worker        )
238*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
239*da0073e9SAndroid Build Coastguard Worker            fn_int,
240*da0073e9SAndroid Build Coastguard Worker            (
241*da0073e9SAndroid Build Coastguard Worker                -167,
242*da0073e9SAndroid Build Coastguard Worker                -1256,
243*da0073e9SAndroid Build Coastguard Worker            ),
244*da0073e9SAndroid Build Coastguard Worker        )
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        def fn_float(real: float, img: float):
247*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
250*da0073e9SAndroid Build Coastguard Worker            fn_float,
251*da0073e9SAndroid Build Coastguard Worker            (
252*da0073e9SAndroid Build Coastguard Worker                0.0,
253*da0073e9SAndroid Build Coastguard Worker                0.0,
254*da0073e9SAndroid Build Coastguard Worker            ),
255*da0073e9SAndroid Build Coastguard Worker        )
256*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
257*da0073e9SAndroid Build Coastguard Worker            fn_float,
258*da0073e9SAndroid Build Coastguard Worker            (
259*da0073e9SAndroid Build Coastguard Worker                -1234.78,
260*da0073e9SAndroid Build Coastguard Worker                0,
261*da0073e9SAndroid Build Coastguard Worker            ),
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
264*da0073e9SAndroid Build Coastguard Worker            fn_float,
265*da0073e9SAndroid Build Coastguard Worker            (
266*da0073e9SAndroid Build Coastguard Worker                0,
267*da0073e9SAndroid Build Coastguard Worker                56.18,
268*da0073e9SAndroid Build Coastguard Worker            ),
269*da0073e9SAndroid Build Coastguard Worker        )
270*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
271*da0073e9SAndroid Build Coastguard Worker            fn_float,
272*da0073e9SAndroid Build Coastguard Worker            (
273*da0073e9SAndroid Build Coastguard Worker                -1.9,
274*da0073e9SAndroid Build Coastguard Worker                -19.8,
275*da0073e9SAndroid Build Coastguard Worker            ),
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        def fn_bool(real: bool, img: bool):
279*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
282*da0073e9SAndroid Build Coastguard Worker            fn_bool,
283*da0073e9SAndroid Build Coastguard Worker            (
284*da0073e9SAndroid Build Coastguard Worker                True,
285*da0073e9SAndroid Build Coastguard Worker                True,
286*da0073e9SAndroid Build Coastguard Worker            ),
287*da0073e9SAndroid Build Coastguard Worker        )
288*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
289*da0073e9SAndroid Build Coastguard Worker            fn_bool,
290*da0073e9SAndroid Build Coastguard Worker            (
291*da0073e9SAndroid Build Coastguard Worker                False,
292*da0073e9SAndroid Build Coastguard Worker                False,
293*da0073e9SAndroid Build Coastguard Worker            ),
294*da0073e9SAndroid Build Coastguard Worker        )
295*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
296*da0073e9SAndroid Build Coastguard Worker            fn_bool,
297*da0073e9SAndroid Build Coastguard Worker            (
298*da0073e9SAndroid Build Coastguard Worker                False,
299*da0073e9SAndroid Build Coastguard Worker                True,
300*da0073e9SAndroid Build Coastguard Worker            ),
301*da0073e9SAndroid Build Coastguard Worker        )
302*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
303*da0073e9SAndroid Build Coastguard Worker            fn_bool,
304*da0073e9SAndroid Build Coastguard Worker            (
305*da0073e9SAndroid Build Coastguard Worker                True,
306*da0073e9SAndroid Build Coastguard Worker                False,
307*da0073e9SAndroid Build Coastguard Worker            ),
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        def fn_bool_int(real: bool, img: int):
311*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
314*da0073e9SAndroid Build Coastguard Worker            fn_bool_int,
315*da0073e9SAndroid Build Coastguard Worker            (
316*da0073e9SAndroid Build Coastguard Worker                True,
317*da0073e9SAndroid Build Coastguard Worker                0,
318*da0073e9SAndroid Build Coastguard Worker            ),
319*da0073e9SAndroid Build Coastguard Worker        )
320*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
321*da0073e9SAndroid Build Coastguard Worker            fn_bool_int,
322*da0073e9SAndroid Build Coastguard Worker            (
323*da0073e9SAndroid Build Coastguard Worker                False,
324*da0073e9SAndroid Build Coastguard Worker                0,
325*da0073e9SAndroid Build Coastguard Worker            ),
326*da0073e9SAndroid Build Coastguard Worker        )
327*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
328*da0073e9SAndroid Build Coastguard Worker            fn_bool_int,
329*da0073e9SAndroid Build Coastguard Worker            (
330*da0073e9SAndroid Build Coastguard Worker                False,
331*da0073e9SAndroid Build Coastguard Worker                -1,
332*da0073e9SAndroid Build Coastguard Worker            ),
333*da0073e9SAndroid Build Coastguard Worker        )
334*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
335*da0073e9SAndroid Build Coastguard Worker            fn_bool_int,
336*da0073e9SAndroid Build Coastguard Worker            (
337*da0073e9SAndroid Build Coastguard Worker                True,
338*da0073e9SAndroid Build Coastguard Worker                3,
339*da0073e9SAndroid Build Coastguard Worker            ),
340*da0073e9SAndroid Build Coastguard Worker        )
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        def fn_int_bool(real: int, img: bool):
343*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
346*da0073e9SAndroid Build Coastguard Worker            fn_int_bool,
347*da0073e9SAndroid Build Coastguard Worker            (
348*da0073e9SAndroid Build Coastguard Worker                0,
349*da0073e9SAndroid Build Coastguard Worker                True,
350*da0073e9SAndroid Build Coastguard Worker            ),
351*da0073e9SAndroid Build Coastguard Worker        )
352*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
353*da0073e9SAndroid Build Coastguard Worker            fn_int_bool,
354*da0073e9SAndroid Build Coastguard Worker            (
355*da0073e9SAndroid Build Coastguard Worker                0,
356*da0073e9SAndroid Build Coastguard Worker                False,
357*da0073e9SAndroid Build Coastguard Worker            ),
358*da0073e9SAndroid Build Coastguard Worker        )
359*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
360*da0073e9SAndroid Build Coastguard Worker            fn_int_bool,
361*da0073e9SAndroid Build Coastguard Worker            (
362*da0073e9SAndroid Build Coastguard Worker                -3,
363*da0073e9SAndroid Build Coastguard Worker                True,
364*da0073e9SAndroid Build Coastguard Worker            ),
365*da0073e9SAndroid Build Coastguard Worker        )
366*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
367*da0073e9SAndroid Build Coastguard Worker            fn_int_bool,
368*da0073e9SAndroid Build Coastguard Worker            (
369*da0073e9SAndroid Build Coastguard Worker                6,
370*da0073e9SAndroid Build Coastguard Worker                False,
371*da0073e9SAndroid Build Coastguard Worker            ),
372*da0073e9SAndroid Build Coastguard Worker        )
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        def fn_bool_float(real: bool, img: float):
375*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
378*da0073e9SAndroid Build Coastguard Worker            fn_bool_float,
379*da0073e9SAndroid Build Coastguard Worker            (
380*da0073e9SAndroid Build Coastguard Worker                True,
381*da0073e9SAndroid Build Coastguard Worker                0.0,
382*da0073e9SAndroid Build Coastguard Worker            ),
383*da0073e9SAndroid Build Coastguard Worker        )
384*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
385*da0073e9SAndroid Build Coastguard Worker            fn_bool_float,
386*da0073e9SAndroid Build Coastguard Worker            (
387*da0073e9SAndroid Build Coastguard Worker                False,
388*da0073e9SAndroid Build Coastguard Worker                0.0,
389*da0073e9SAndroid Build Coastguard Worker            ),
390*da0073e9SAndroid Build Coastguard Worker        )
391*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
392*da0073e9SAndroid Build Coastguard Worker            fn_bool_float,
393*da0073e9SAndroid Build Coastguard Worker            (
394*da0073e9SAndroid Build Coastguard Worker                False,
395*da0073e9SAndroid Build Coastguard Worker                -1.0,
396*da0073e9SAndroid Build Coastguard Worker            ),
397*da0073e9SAndroid Build Coastguard Worker        )
398*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
399*da0073e9SAndroid Build Coastguard Worker            fn_bool_float,
400*da0073e9SAndroid Build Coastguard Worker            (
401*da0073e9SAndroid Build Coastguard Worker                True,
402*da0073e9SAndroid Build Coastguard Worker                3.0,
403*da0073e9SAndroid Build Coastguard Worker            ),
404*da0073e9SAndroid Build Coastguard Worker        )
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker        def fn_float_bool(real: float, img: bool):
407*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
410*da0073e9SAndroid Build Coastguard Worker            fn_float_bool,
411*da0073e9SAndroid Build Coastguard Worker            (
412*da0073e9SAndroid Build Coastguard Worker                0.0,
413*da0073e9SAndroid Build Coastguard Worker                True,
414*da0073e9SAndroid Build Coastguard Worker            ),
415*da0073e9SAndroid Build Coastguard Worker        )
416*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
417*da0073e9SAndroid Build Coastguard Worker            fn_float_bool,
418*da0073e9SAndroid Build Coastguard Worker            (
419*da0073e9SAndroid Build Coastguard Worker                0.0,
420*da0073e9SAndroid Build Coastguard Worker                False,
421*da0073e9SAndroid Build Coastguard Worker            ),
422*da0073e9SAndroid Build Coastguard Worker        )
423*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
424*da0073e9SAndroid Build Coastguard Worker            fn_float_bool,
425*da0073e9SAndroid Build Coastguard Worker            (
426*da0073e9SAndroid Build Coastguard Worker                -3.0,
427*da0073e9SAndroid Build Coastguard Worker                True,
428*da0073e9SAndroid Build Coastguard Worker            ),
429*da0073e9SAndroid Build Coastguard Worker        )
430*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
431*da0073e9SAndroid Build Coastguard Worker            fn_float_bool,
432*da0073e9SAndroid Build Coastguard Worker            (
433*da0073e9SAndroid Build Coastguard Worker                6.0,
434*da0073e9SAndroid Build Coastguard Worker                False,
435*da0073e9SAndroid Build Coastguard Worker            ),
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        def fn_float_int(real: float, img: int):
439*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
442*da0073e9SAndroid Build Coastguard Worker            fn_float_int,
443*da0073e9SAndroid Build Coastguard Worker            (
444*da0073e9SAndroid Build Coastguard Worker                0.0,
445*da0073e9SAndroid Build Coastguard Worker                1,
446*da0073e9SAndroid Build Coastguard Worker            ),
447*da0073e9SAndroid Build Coastguard Worker        )
448*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
449*da0073e9SAndroid Build Coastguard Worker            fn_float_int,
450*da0073e9SAndroid Build Coastguard Worker            (
451*da0073e9SAndroid Build Coastguard Worker                0.0,
452*da0073e9SAndroid Build Coastguard Worker                -1,
453*da0073e9SAndroid Build Coastguard Worker            ),
454*da0073e9SAndroid Build Coastguard Worker        )
455*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
456*da0073e9SAndroid Build Coastguard Worker            fn_float_int,
457*da0073e9SAndroid Build Coastguard Worker            (
458*da0073e9SAndroid Build Coastguard Worker                1.8,
459*da0073e9SAndroid Build Coastguard Worker                -3,
460*da0073e9SAndroid Build Coastguard Worker            ),
461*da0073e9SAndroid Build Coastguard Worker        )
462*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
463*da0073e9SAndroid Build Coastguard Worker            fn_float_int,
464*da0073e9SAndroid Build Coastguard Worker            (
465*da0073e9SAndroid Build Coastguard Worker                2.7,
466*da0073e9SAndroid Build Coastguard Worker                8,
467*da0073e9SAndroid Build Coastguard Worker            ),
468*da0073e9SAndroid Build Coastguard Worker        )
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        def fn_int_float(real: int, img: float):
471*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
474*da0073e9SAndroid Build Coastguard Worker            fn_int_float,
475*da0073e9SAndroid Build Coastguard Worker            (
476*da0073e9SAndroid Build Coastguard Worker                1,
477*da0073e9SAndroid Build Coastguard Worker                0.0,
478*da0073e9SAndroid Build Coastguard Worker            ),
479*da0073e9SAndroid Build Coastguard Worker        )
480*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
481*da0073e9SAndroid Build Coastguard Worker            fn_int_float,
482*da0073e9SAndroid Build Coastguard Worker            (
483*da0073e9SAndroid Build Coastguard Worker                -1,
484*da0073e9SAndroid Build Coastguard Worker                1.7,
485*da0073e9SAndroid Build Coastguard Worker            ),
486*da0073e9SAndroid Build Coastguard Worker        )
487*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
488*da0073e9SAndroid Build Coastguard Worker            fn_int_float,
489*da0073e9SAndroid Build Coastguard Worker            (
490*da0073e9SAndroid Build Coastguard Worker                -3,
491*da0073e9SAndroid Build Coastguard Worker                0.0,
492*da0073e9SAndroid Build Coastguard Worker            ),
493*da0073e9SAndroid Build Coastguard Worker        )
494*da0073e9SAndroid Build Coastguard Worker        self.checkScript(
495*da0073e9SAndroid Build Coastguard Worker            fn_int_float,
496*da0073e9SAndroid Build Coastguard Worker            (
497*da0073e9SAndroid Build Coastguard Worker                2,
498*da0073e9SAndroid Build Coastguard Worker                -8.9,
499*da0073e9SAndroid Build Coastguard Worker            ),
500*da0073e9SAndroid Build Coastguard Worker        )
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    def test_torch_complex_constructor_with_tensor(self):
503*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        def fn_tensor_float(real, img: float):
506*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker        def fn_tensor_int(real, img: int):
509*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        def fn_tensor_bool(real, img: bool):
512*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        def fn_float_tensor(real: float, img):
515*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker        def fn_int_tensor(real: int, img):
518*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        def fn_bool_tensor(real: bool, img):
521*da0073e9SAndroid Build Coastguard Worker            return complex(real, img)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
524*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_tensor_float, (tensor, 1.2))
525*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_tensor_int, (tensor, 3))
526*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_tensor_bool, (tensor, True))
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_float_tensor, (1.2, tensor))
529*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_int_tensor, (3, tensor))
530*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn_bool_tensor, (True, tensor))
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        def fn_tensor_tensor(real, img):
533*da0073e9SAndroid Build Coastguard Worker            return complex(real, img) + complex(2)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        for x, y in product(tensors, tensors):
536*da0073e9SAndroid Build Coastguard Worker            self.checkScript(
537*da0073e9SAndroid Build Coastguard Worker                fn_tensor_tensor,
538*da0073e9SAndroid Build Coastguard Worker                (
539*da0073e9SAndroid Build Coastguard Worker                    x,
540*da0073e9SAndroid Build Coastguard Worker                    y,
541*da0073e9SAndroid Build Coastguard Worker                ),
542*da0073e9SAndroid Build Coastguard Worker            )
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    def test_comparison_ops(self):
545*da0073e9SAndroid Build Coastguard Worker        def fn1(a: complex, b: complex):
546*da0073e9SAndroid Build Coastguard Worker            return a == b
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        def fn2(a: complex, b: complex):
549*da0073e9SAndroid Build Coastguard Worker            return a != b
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker        def fn3(a: complex, b: float):
552*da0073e9SAndroid Build Coastguard Worker            return a == b
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        def fn4(a: complex, b: float):
555*da0073e9SAndroid Build Coastguard Worker            return a != b
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        x, y = 2 - 3j, 4j
558*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, x))
559*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, y))
560*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, (x, x))
561*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, (x, y))
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker        x1, y1 = 1 + 0j, 1.0
564*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, (x1, y1))
565*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn4, (x1, y1))
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    def test_div(self):
568*da0073e9SAndroid Build Coastguard Worker        def fn1(a: complex, b: complex):
569*da0073e9SAndroid Build Coastguard Worker            return a / b
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        x, y = 2 - 3j, 4j
572*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, y))
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker    def test_complex_list_sum(self):
575*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[complex]):
576*da0073e9SAndroid Build Coastguard Worker            return sum(x)
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def test_tensor_attributes(self):
581*da0073e9SAndroid Build Coastguard Worker        def tensor_real(x):
582*da0073e9SAndroid Build Coastguard Worker            return x.real
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        def tensor_imag(x):
585*da0073e9SAndroid Build Coastguard Worker            return x.imag
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, dtype=torch.cdouble)
588*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensor_real, (t,))
589*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensor_imag, (t,))
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_complex_tensor(self):
592*da0073e9SAndroid Build Coastguard Worker        def mul(x: complex, y: torch.Tensor):
593*da0073e9SAndroid Build Coastguard Worker            return x * y
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker        def add(x: complex, y: torch.Tensor):
596*da0073e9SAndroid Build Coastguard Worker            return x + y
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        def eq(x: complex, y: torch.Tensor):
599*da0073e9SAndroid Build Coastguard Worker            return x == y
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        def ne(x: complex, y: torch.Tensor):
602*da0073e9SAndroid Build Coastguard Worker            return x != y
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        def sub(x: complex, y: torch.Tensor):
605*da0073e9SAndroid Build Coastguard Worker            return x - y
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        def div(x: complex, y: torch.Tensor):
608*da0073e9SAndroid Build Coastguard Worker            return x - y
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        ops = [mul, add, eq, ne, sub, div]
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        for shape in [(1,), (2, 2)]:
613*da0073e9SAndroid Build Coastguard Worker            x = 0.71 + 0.71j
614*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(shape, dtype=torch.cfloat)
615*da0073e9SAndroid Build Coastguard Worker            for op in ops:
616*da0073e9SAndroid Build Coastguard Worker                eager_result = op(x, y)
617*da0073e9SAndroid Build Coastguard Worker                scripted = torch.jit.script(op)
618*da0073e9SAndroid Build Coastguard Worker                jit_result = scripted(x, y)
619*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(eager_result, jit_result)
620