1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport cmath 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 7*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_MACOS 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import execWrapper, JitTestCase 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 16*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass TestComplex(JitTestCase): 21*da0073e9SAndroid Build Coastguard Worker def test_script(self): 22*da0073e9SAndroid Build Coastguard Worker def fn(a: complex): 23*da0073e9SAndroid Build Coastguard Worker return a 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (3 + 5j,)) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def test_complexlist(self): 28*da0073e9SAndroid Build Coastguard Worker def fn(a: List[complex], idx: int): 29*da0073e9SAndroid Build Coastguard Worker return a[idx] 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker input = [1j, 2, 3 + 4j, -5, -7j] 32*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (input, 2)) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def test_complexdict(self): 35*da0073e9SAndroid Build Coastguard Worker def fn(a: Dict[complex, complex], key: complex) -> complex: 36*da0073e9SAndroid Build Coastguard Worker return a[key] 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} 39*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (input, -4.3 - 2j)) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_pickle(self): 42*da0073e9SAndroid Build Coastguard Worker class ComplexModule(torch.jit.ScriptModule): 43*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 44*da0073e9SAndroid Build Coastguard Worker super().__init__() 45*da0073e9SAndroid Build Coastguard Worker self.a = 3 + 5j 46*da0073e9SAndroid Build Coastguard Worker self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j] 47*da0073e9SAndroid Build Coastguard Worker self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 50*da0073e9SAndroid Build Coastguard Worker def forward(self, b: int): 51*da0073e9SAndroid Build Coastguard Worker return b + 2j 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker loaded = self.getExportImportCopy(ComplexModule()) 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.a, 3 + 5j) 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4]) 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}) 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(2), 2 + 2j) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_complex_parse(self): 60*da0073e9SAndroid Build Coastguard Worker def fn(a: int, b: torch.Tensor, dim: int): 61*da0073e9SAndroid Build Coastguard Worker # verifies `emitValueToTensor()` 's behavior 62*da0073e9SAndroid Build Coastguard Worker b[dim] = 2.4 + 0.5j 63*da0073e9SAndroid Build Coastguard Worker return (3 * 2j) + a + 5j - 7.4j - 4 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor(1) 66*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([0.4, 1.4j, 2.35]) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (t1, t2, 2)) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def test_complex_constants_and_ops(self): 71*da0073e9SAndroid Build Coastguard Worker vals = ( 72*da0073e9SAndroid Build Coastguard Worker [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2] 73*da0073e9SAndroid Build Coastguard Worker + [10.0**i for i in range(2)] 74*da0073e9SAndroid Build Coastguard Worker + [-(10.0**i) for i in range(2)] 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker complex_vals = tuple(complex(x, y) for x, y in product(vals, vals)) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent( 79*da0073e9SAndroid Build Coastguard Worker """ 80*da0073e9SAndroid Build Coastguard Worker def func(a: complex): 81*da0073e9SAndroid Build Coastguard Worker return cmath.{func_or_const}(a) 82*da0073e9SAndroid Build Coastguard Worker """ 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def checkCmath(func_name, funcs_template=funcs_template): 86*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(func_or_const=func_name) 87*da0073e9SAndroid Build Coastguard Worker scope = {} 88*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 89*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 90*da0073e9SAndroid Build Coastguard Worker f_script = cu.func 91*da0073e9SAndroid Build Coastguard Worker f = scope["func"] 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker if func_name in ["isinf", "isnan", "isfinite"]: 94*da0073e9SAndroid Build Coastguard Worker new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")]) 95*da0073e9SAndroid Build Coastguard Worker final_vals = tuple( 96*da0073e9SAndroid Build Coastguard Worker complex(x, y) for x, y in product(new_vals, new_vals) 97*da0073e9SAndroid Build Coastguard Worker ) 98*da0073e9SAndroid Build Coastguard Worker else: 99*da0073e9SAndroid Build Coastguard Worker final_vals = complex_vals 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker for a in final_vals: 102*da0073e9SAndroid Build Coastguard Worker res_python = None 103*da0073e9SAndroid Build Coastguard Worker res_script = None 104*da0073e9SAndroid Build Coastguard Worker try: 105*da0073e9SAndroid Build Coastguard Worker res_python = f(a) 106*da0073e9SAndroid Build Coastguard Worker except Exception as e: 107*da0073e9SAndroid Build Coastguard Worker res_python = e 108*da0073e9SAndroid Build Coastguard Worker try: 109*da0073e9SAndroid Build Coastguard Worker res_script = f_script(a) 110*da0073e9SAndroid Build Coastguard Worker except Exception as e: 111*da0073e9SAndroid Build Coastguard Worker res_script = e 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker if res_python != res_script: 114*da0073e9SAndroid Build Coastguard Worker if isinstance(res_python, Exception): 115*da0073e9SAndroid Build Coastguard Worker continue 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}" 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_python, res_script, msg=msg) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker unary_ops = [ 121*da0073e9SAndroid Build Coastguard Worker "log", 122*da0073e9SAndroid Build Coastguard Worker "log10", 123*da0073e9SAndroid Build Coastguard Worker "sqrt", 124*da0073e9SAndroid Build Coastguard Worker "exp", 125*da0073e9SAndroid Build Coastguard Worker "sin", 126*da0073e9SAndroid Build Coastguard Worker "cos", 127*da0073e9SAndroid Build Coastguard Worker "asin", 128*da0073e9SAndroid Build Coastguard Worker "acos", 129*da0073e9SAndroid Build Coastguard Worker "atan", 130*da0073e9SAndroid Build Coastguard Worker "sinh", 131*da0073e9SAndroid Build Coastguard Worker "cosh", 132*da0073e9SAndroid Build Coastguard Worker "tanh", 133*da0073e9SAndroid Build Coastguard Worker "asinh", 134*da0073e9SAndroid Build Coastguard Worker "acosh", 135*da0073e9SAndroid Build Coastguard Worker "atanh", 136*da0073e9SAndroid Build Coastguard Worker "phase", 137*da0073e9SAndroid Build Coastguard Worker "isinf", 138*da0073e9SAndroid Build Coastguard Worker "isnan", 139*da0073e9SAndroid Build Coastguard Worker "isfinite", 140*da0073e9SAndroid Build Coastguard Worker ] 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker # --- Unary ops --- 143*da0073e9SAndroid Build Coastguard Worker for op in unary_ops: 144*da0073e9SAndroid Build Coastguard Worker checkCmath(op) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def fn(x: complex): 147*da0073e9SAndroid Build Coastguard Worker return abs(x) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker for val in complex_vals: 150*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (val,)) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def pow_complex_float(x: complex, y: float): 153*da0073e9SAndroid Build Coastguard Worker return pow(x, y) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def pow_float_complex(x: float, y: complex): 156*da0073e9SAndroid Build Coastguard Worker return pow(x, y) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker self.checkScript(pow_float_complex, (2, 3j)) 159*da0073e9SAndroid Build Coastguard Worker self.checkScript(pow_complex_float, (3j, 2)) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def pow_complex_complex(x: complex, y: complex): 162*da0073e9SAndroid Build Coastguard Worker return pow(x, y) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker for x, y in zip(complex_vals, complex_vals): 165*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/54622 166*da0073e9SAndroid Build Coastguard Worker if x == 0: 167*da0073e9SAndroid Build Coastguard Worker continue 168*da0073e9SAndroid Build Coastguard Worker self.checkScript(pow_complex_complex, (x, y)) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker if not IS_MACOS: 171*da0073e9SAndroid Build Coastguard Worker # --- Binary op --- 172*da0073e9SAndroid Build Coastguard Worker def rect_fn(x: float, y: float): 173*da0073e9SAndroid Build Coastguard Worker return cmath.rect(x, y) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker for x, y in product(vals, vals): 176*da0073e9SAndroid Build Coastguard Worker self.checkScript( 177*da0073e9SAndroid Build Coastguard Worker rect_fn, 178*da0073e9SAndroid Build Coastguard Worker ( 179*da0073e9SAndroid Build Coastguard Worker x, 180*da0073e9SAndroid Build Coastguard Worker y, 181*da0073e9SAndroid Build Coastguard Worker ), 182*da0073e9SAndroid Build Coastguard Worker ) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker func_constants_template = dedent( 185*da0073e9SAndroid Build Coastguard Worker """ 186*da0073e9SAndroid Build Coastguard Worker def func(): 187*da0073e9SAndroid Build Coastguard Worker return cmath.{func_or_const} 188*da0073e9SAndroid Build Coastguard Worker """ 189*da0073e9SAndroid Build Coastguard Worker ) 190*da0073e9SAndroid Build Coastguard Worker float_consts = ["pi", "e", "tau", "inf", "nan"] 191*da0073e9SAndroid Build Coastguard Worker complex_consts = ["infj", "nanj"] 192*da0073e9SAndroid Build Coastguard Worker for x in float_consts + complex_consts: 193*da0073e9SAndroid Build Coastguard Worker checkCmath(x, funcs_template=func_constants_template) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_infj_nanj_pickle(self): 196*da0073e9SAndroid Build Coastguard Worker class ComplexModule(torch.jit.ScriptModule): 197*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 198*da0073e9SAndroid Build Coastguard Worker super().__init__() 199*da0073e9SAndroid Build Coastguard Worker self.a = 3 + 5j 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 202*da0073e9SAndroid Build Coastguard Worker def forward(self, infj: int, nanj: int): 203*da0073e9SAndroid Build Coastguard Worker if infj == 2: 204*da0073e9SAndroid Build Coastguard Worker return infj + cmath.infj 205*da0073e9SAndroid Build Coastguard Worker else: 206*da0073e9SAndroid Build Coastguard Worker return nanj + cmath.nanj 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker loaded = self.getExportImportCopy(ComplexModule()) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(2, 3), 2 + cmath.infj) 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(3, 4), 4 + cmath.nanj) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker def test_complex_constructor(self): 213*da0073e9SAndroid Build Coastguard Worker # Test all scalar types 214*da0073e9SAndroid Build Coastguard Worker def fn_int(real: int, img: int): 215*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker self.checkScript( 218*da0073e9SAndroid Build Coastguard Worker fn_int, 219*da0073e9SAndroid Build Coastguard Worker ( 220*da0073e9SAndroid Build Coastguard Worker 0, 221*da0073e9SAndroid Build Coastguard Worker 0, 222*da0073e9SAndroid Build Coastguard Worker ), 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker self.checkScript( 225*da0073e9SAndroid Build Coastguard Worker fn_int, 226*da0073e9SAndroid Build Coastguard Worker ( 227*da0073e9SAndroid Build Coastguard Worker -1234, 228*da0073e9SAndroid Build Coastguard Worker 0, 229*da0073e9SAndroid Build Coastguard Worker ), 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker self.checkScript( 232*da0073e9SAndroid Build Coastguard Worker fn_int, 233*da0073e9SAndroid Build Coastguard Worker ( 234*da0073e9SAndroid Build Coastguard Worker 0, 235*da0073e9SAndroid Build Coastguard Worker -1256, 236*da0073e9SAndroid Build Coastguard Worker ), 237*da0073e9SAndroid Build Coastguard Worker ) 238*da0073e9SAndroid Build Coastguard Worker self.checkScript( 239*da0073e9SAndroid Build Coastguard Worker fn_int, 240*da0073e9SAndroid Build Coastguard Worker ( 241*da0073e9SAndroid Build Coastguard Worker -167, 242*da0073e9SAndroid Build Coastguard Worker -1256, 243*da0073e9SAndroid Build Coastguard Worker ), 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker def fn_float(real: float, img: float): 247*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker self.checkScript( 250*da0073e9SAndroid Build Coastguard Worker fn_float, 251*da0073e9SAndroid Build Coastguard Worker ( 252*da0073e9SAndroid Build Coastguard Worker 0.0, 253*da0073e9SAndroid Build Coastguard Worker 0.0, 254*da0073e9SAndroid Build Coastguard Worker ), 255*da0073e9SAndroid Build Coastguard Worker ) 256*da0073e9SAndroid Build Coastguard Worker self.checkScript( 257*da0073e9SAndroid Build Coastguard Worker fn_float, 258*da0073e9SAndroid Build Coastguard Worker ( 259*da0073e9SAndroid Build Coastguard Worker -1234.78, 260*da0073e9SAndroid Build Coastguard Worker 0, 261*da0073e9SAndroid Build Coastguard Worker ), 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker self.checkScript( 264*da0073e9SAndroid Build Coastguard Worker fn_float, 265*da0073e9SAndroid Build Coastguard Worker ( 266*da0073e9SAndroid Build Coastguard Worker 0, 267*da0073e9SAndroid Build Coastguard Worker 56.18, 268*da0073e9SAndroid Build Coastguard Worker ), 269*da0073e9SAndroid Build Coastguard Worker ) 270*da0073e9SAndroid Build Coastguard Worker self.checkScript( 271*da0073e9SAndroid Build Coastguard Worker fn_float, 272*da0073e9SAndroid Build Coastguard Worker ( 273*da0073e9SAndroid Build Coastguard Worker -1.9, 274*da0073e9SAndroid Build Coastguard Worker -19.8, 275*da0073e9SAndroid Build Coastguard Worker ), 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def fn_bool(real: bool, img: bool): 279*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker self.checkScript( 282*da0073e9SAndroid Build Coastguard Worker fn_bool, 283*da0073e9SAndroid Build Coastguard Worker ( 284*da0073e9SAndroid Build Coastguard Worker True, 285*da0073e9SAndroid Build Coastguard Worker True, 286*da0073e9SAndroid Build Coastguard Worker ), 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker self.checkScript( 289*da0073e9SAndroid Build Coastguard Worker fn_bool, 290*da0073e9SAndroid Build Coastguard Worker ( 291*da0073e9SAndroid Build Coastguard Worker False, 292*da0073e9SAndroid Build Coastguard Worker False, 293*da0073e9SAndroid Build Coastguard Worker ), 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker self.checkScript( 296*da0073e9SAndroid Build Coastguard Worker fn_bool, 297*da0073e9SAndroid Build Coastguard Worker ( 298*da0073e9SAndroid Build Coastguard Worker False, 299*da0073e9SAndroid Build Coastguard Worker True, 300*da0073e9SAndroid Build Coastguard Worker ), 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker self.checkScript( 303*da0073e9SAndroid Build Coastguard Worker fn_bool, 304*da0073e9SAndroid Build Coastguard Worker ( 305*da0073e9SAndroid Build Coastguard Worker True, 306*da0073e9SAndroid Build Coastguard Worker False, 307*da0073e9SAndroid Build Coastguard Worker ), 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def fn_bool_int(real: bool, img: int): 311*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker self.checkScript( 314*da0073e9SAndroid Build Coastguard Worker fn_bool_int, 315*da0073e9SAndroid Build Coastguard Worker ( 316*da0073e9SAndroid Build Coastguard Worker True, 317*da0073e9SAndroid Build Coastguard Worker 0, 318*da0073e9SAndroid Build Coastguard Worker ), 319*da0073e9SAndroid Build Coastguard Worker ) 320*da0073e9SAndroid Build Coastguard Worker self.checkScript( 321*da0073e9SAndroid Build Coastguard Worker fn_bool_int, 322*da0073e9SAndroid Build Coastguard Worker ( 323*da0073e9SAndroid Build Coastguard Worker False, 324*da0073e9SAndroid Build Coastguard Worker 0, 325*da0073e9SAndroid Build Coastguard Worker ), 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker self.checkScript( 328*da0073e9SAndroid Build Coastguard Worker fn_bool_int, 329*da0073e9SAndroid Build Coastguard Worker ( 330*da0073e9SAndroid Build Coastguard Worker False, 331*da0073e9SAndroid Build Coastguard Worker -1, 332*da0073e9SAndroid Build Coastguard Worker ), 333*da0073e9SAndroid Build Coastguard Worker ) 334*da0073e9SAndroid Build Coastguard Worker self.checkScript( 335*da0073e9SAndroid Build Coastguard Worker fn_bool_int, 336*da0073e9SAndroid Build Coastguard Worker ( 337*da0073e9SAndroid Build Coastguard Worker True, 338*da0073e9SAndroid Build Coastguard Worker 3, 339*da0073e9SAndroid Build Coastguard Worker ), 340*da0073e9SAndroid Build Coastguard Worker ) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker def fn_int_bool(real: int, img: bool): 343*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker self.checkScript( 346*da0073e9SAndroid Build Coastguard Worker fn_int_bool, 347*da0073e9SAndroid Build Coastguard Worker ( 348*da0073e9SAndroid Build Coastguard Worker 0, 349*da0073e9SAndroid Build Coastguard Worker True, 350*da0073e9SAndroid Build Coastguard Worker ), 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker self.checkScript( 353*da0073e9SAndroid Build Coastguard Worker fn_int_bool, 354*da0073e9SAndroid Build Coastguard Worker ( 355*da0073e9SAndroid Build Coastguard Worker 0, 356*da0073e9SAndroid Build Coastguard Worker False, 357*da0073e9SAndroid Build Coastguard Worker ), 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker self.checkScript( 360*da0073e9SAndroid Build Coastguard Worker fn_int_bool, 361*da0073e9SAndroid Build Coastguard Worker ( 362*da0073e9SAndroid Build Coastguard Worker -3, 363*da0073e9SAndroid Build Coastguard Worker True, 364*da0073e9SAndroid Build Coastguard Worker ), 365*da0073e9SAndroid Build Coastguard Worker ) 366*da0073e9SAndroid Build Coastguard Worker self.checkScript( 367*da0073e9SAndroid Build Coastguard Worker fn_int_bool, 368*da0073e9SAndroid Build Coastguard Worker ( 369*da0073e9SAndroid Build Coastguard Worker 6, 370*da0073e9SAndroid Build Coastguard Worker False, 371*da0073e9SAndroid Build Coastguard Worker ), 372*da0073e9SAndroid Build Coastguard Worker ) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def fn_bool_float(real: bool, img: float): 375*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker self.checkScript( 378*da0073e9SAndroid Build Coastguard Worker fn_bool_float, 379*da0073e9SAndroid Build Coastguard Worker ( 380*da0073e9SAndroid Build Coastguard Worker True, 381*da0073e9SAndroid Build Coastguard Worker 0.0, 382*da0073e9SAndroid Build Coastguard Worker ), 383*da0073e9SAndroid Build Coastguard Worker ) 384*da0073e9SAndroid Build Coastguard Worker self.checkScript( 385*da0073e9SAndroid Build Coastguard Worker fn_bool_float, 386*da0073e9SAndroid Build Coastguard Worker ( 387*da0073e9SAndroid Build Coastguard Worker False, 388*da0073e9SAndroid Build Coastguard Worker 0.0, 389*da0073e9SAndroid Build Coastguard Worker ), 390*da0073e9SAndroid Build Coastguard Worker ) 391*da0073e9SAndroid Build Coastguard Worker self.checkScript( 392*da0073e9SAndroid Build Coastguard Worker fn_bool_float, 393*da0073e9SAndroid Build Coastguard Worker ( 394*da0073e9SAndroid Build Coastguard Worker False, 395*da0073e9SAndroid Build Coastguard Worker -1.0, 396*da0073e9SAndroid Build Coastguard Worker ), 397*da0073e9SAndroid Build Coastguard Worker ) 398*da0073e9SAndroid Build Coastguard Worker self.checkScript( 399*da0073e9SAndroid Build Coastguard Worker fn_bool_float, 400*da0073e9SAndroid Build Coastguard Worker ( 401*da0073e9SAndroid Build Coastguard Worker True, 402*da0073e9SAndroid Build Coastguard Worker 3.0, 403*da0073e9SAndroid Build Coastguard Worker ), 404*da0073e9SAndroid Build Coastguard Worker ) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def fn_float_bool(real: float, img: bool): 407*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker self.checkScript( 410*da0073e9SAndroid Build Coastguard Worker fn_float_bool, 411*da0073e9SAndroid Build Coastguard Worker ( 412*da0073e9SAndroid Build Coastguard Worker 0.0, 413*da0073e9SAndroid Build Coastguard Worker True, 414*da0073e9SAndroid Build Coastguard Worker ), 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker self.checkScript( 417*da0073e9SAndroid Build Coastguard Worker fn_float_bool, 418*da0073e9SAndroid Build Coastguard Worker ( 419*da0073e9SAndroid Build Coastguard Worker 0.0, 420*da0073e9SAndroid Build Coastguard Worker False, 421*da0073e9SAndroid Build Coastguard Worker ), 422*da0073e9SAndroid Build Coastguard Worker ) 423*da0073e9SAndroid Build Coastguard Worker self.checkScript( 424*da0073e9SAndroid Build Coastguard Worker fn_float_bool, 425*da0073e9SAndroid Build Coastguard Worker ( 426*da0073e9SAndroid Build Coastguard Worker -3.0, 427*da0073e9SAndroid Build Coastguard Worker True, 428*da0073e9SAndroid Build Coastguard Worker ), 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker self.checkScript( 431*da0073e9SAndroid Build Coastguard Worker fn_float_bool, 432*da0073e9SAndroid Build Coastguard Worker ( 433*da0073e9SAndroid Build Coastguard Worker 6.0, 434*da0073e9SAndroid Build Coastguard Worker False, 435*da0073e9SAndroid Build Coastguard Worker ), 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def fn_float_int(real: float, img: int): 439*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker self.checkScript( 442*da0073e9SAndroid Build Coastguard Worker fn_float_int, 443*da0073e9SAndroid Build Coastguard Worker ( 444*da0073e9SAndroid Build Coastguard Worker 0.0, 445*da0073e9SAndroid Build Coastguard Worker 1, 446*da0073e9SAndroid Build Coastguard Worker ), 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker self.checkScript( 449*da0073e9SAndroid Build Coastguard Worker fn_float_int, 450*da0073e9SAndroid Build Coastguard Worker ( 451*da0073e9SAndroid Build Coastguard Worker 0.0, 452*da0073e9SAndroid Build Coastguard Worker -1, 453*da0073e9SAndroid Build Coastguard Worker ), 454*da0073e9SAndroid Build Coastguard Worker ) 455*da0073e9SAndroid Build Coastguard Worker self.checkScript( 456*da0073e9SAndroid Build Coastguard Worker fn_float_int, 457*da0073e9SAndroid Build Coastguard Worker ( 458*da0073e9SAndroid Build Coastguard Worker 1.8, 459*da0073e9SAndroid Build Coastguard Worker -3, 460*da0073e9SAndroid Build Coastguard Worker ), 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker self.checkScript( 463*da0073e9SAndroid Build Coastguard Worker fn_float_int, 464*da0073e9SAndroid Build Coastguard Worker ( 465*da0073e9SAndroid Build Coastguard Worker 2.7, 466*da0073e9SAndroid Build Coastguard Worker 8, 467*da0073e9SAndroid Build Coastguard Worker ), 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def fn_int_float(real: int, img: float): 471*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker self.checkScript( 474*da0073e9SAndroid Build Coastguard Worker fn_int_float, 475*da0073e9SAndroid Build Coastguard Worker ( 476*da0073e9SAndroid Build Coastguard Worker 1, 477*da0073e9SAndroid Build Coastguard Worker 0.0, 478*da0073e9SAndroid Build Coastguard Worker ), 479*da0073e9SAndroid Build Coastguard Worker ) 480*da0073e9SAndroid Build Coastguard Worker self.checkScript( 481*da0073e9SAndroid Build Coastguard Worker fn_int_float, 482*da0073e9SAndroid Build Coastguard Worker ( 483*da0073e9SAndroid Build Coastguard Worker -1, 484*da0073e9SAndroid Build Coastguard Worker 1.7, 485*da0073e9SAndroid Build Coastguard Worker ), 486*da0073e9SAndroid Build Coastguard Worker ) 487*da0073e9SAndroid Build Coastguard Worker self.checkScript( 488*da0073e9SAndroid Build Coastguard Worker fn_int_float, 489*da0073e9SAndroid Build Coastguard Worker ( 490*da0073e9SAndroid Build Coastguard Worker -3, 491*da0073e9SAndroid Build Coastguard Worker 0.0, 492*da0073e9SAndroid Build Coastguard Worker ), 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker self.checkScript( 495*da0073e9SAndroid Build Coastguard Worker fn_int_float, 496*da0073e9SAndroid Build Coastguard Worker ( 497*da0073e9SAndroid Build Coastguard Worker 2, 498*da0073e9SAndroid Build Coastguard Worker -8.9, 499*da0073e9SAndroid Build Coastguard Worker ), 500*da0073e9SAndroid Build Coastguard Worker ) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker def test_torch_complex_constructor_with_tensor(self): 503*da0073e9SAndroid Build Coastguard Worker tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])] 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker def fn_tensor_float(real, img: float): 506*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def fn_tensor_int(real, img: int): 509*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def fn_tensor_bool(real, img: bool): 512*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker def fn_float_tensor(real: float, img): 515*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker def fn_int_tensor(real: int, img): 518*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def fn_bool_tensor(real: bool, img): 521*da0073e9SAndroid Build Coastguard Worker return complex(real, img) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 524*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_tensor_float, (tensor, 1.2)) 525*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_tensor_int, (tensor, 3)) 526*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_tensor_bool, (tensor, True)) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_float_tensor, (1.2, tensor)) 529*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_int_tensor, (3, tensor)) 530*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_bool_tensor, (True, tensor)) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker def fn_tensor_tensor(real, img): 533*da0073e9SAndroid Build Coastguard Worker return complex(real, img) + complex(2) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker for x, y in product(tensors, tensors): 536*da0073e9SAndroid Build Coastguard Worker self.checkScript( 537*da0073e9SAndroid Build Coastguard Worker fn_tensor_tensor, 538*da0073e9SAndroid Build Coastguard Worker ( 539*da0073e9SAndroid Build Coastguard Worker x, 540*da0073e9SAndroid Build Coastguard Worker y, 541*da0073e9SAndroid Build Coastguard Worker ), 542*da0073e9SAndroid Build Coastguard Worker ) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker def test_comparison_ops(self): 545*da0073e9SAndroid Build Coastguard Worker def fn1(a: complex, b: complex): 546*da0073e9SAndroid Build Coastguard Worker return a == b 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker def fn2(a: complex, b: complex): 549*da0073e9SAndroid Build Coastguard Worker return a != b 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker def fn3(a: complex, b: float): 552*da0073e9SAndroid Build Coastguard Worker return a == b 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def fn4(a: complex, b: float): 555*da0073e9SAndroid Build Coastguard Worker return a != b 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker x, y = 2 - 3j, 4j 558*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, x)) 559*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, y)) 560*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, (x, x)) 561*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, (x, y)) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker x1, y1 = 1 + 0j, 1.0 564*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, (x1, y1)) 565*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn4, (x1, y1)) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker def test_div(self): 568*da0073e9SAndroid Build Coastguard Worker def fn1(a: complex, b: complex): 569*da0073e9SAndroid Build Coastguard Worker return a / b 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker x, y = 2 - 3j, 4j 572*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, y)) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker def test_complex_list_sum(self): 575*da0073e9SAndroid Build Coastguard Worker def fn(x: List[complex]): 576*da0073e9SAndroid Build Coastguard Worker return sum(x) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),)) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def test_tensor_attributes(self): 581*da0073e9SAndroid Build Coastguard Worker def tensor_real(x): 582*da0073e9SAndroid Build Coastguard Worker return x.real 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def tensor_imag(x): 585*da0073e9SAndroid Build Coastguard Worker return x.imag 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, dtype=torch.cdouble) 588*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensor_real, (t,)) 589*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensor_imag, (t,)) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker def test_binary_op_complex_tensor(self): 592*da0073e9SAndroid Build Coastguard Worker def mul(x: complex, y: torch.Tensor): 593*da0073e9SAndroid Build Coastguard Worker return x * y 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker def add(x: complex, y: torch.Tensor): 596*da0073e9SAndroid Build Coastguard Worker return x + y 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker def eq(x: complex, y: torch.Tensor): 599*da0073e9SAndroid Build Coastguard Worker return x == y 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker def ne(x: complex, y: torch.Tensor): 602*da0073e9SAndroid Build Coastguard Worker return x != y 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def sub(x: complex, y: torch.Tensor): 605*da0073e9SAndroid Build Coastguard Worker return x - y 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker def div(x: complex, y: torch.Tensor): 608*da0073e9SAndroid Build Coastguard Worker return x - y 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker ops = [mul, add, eq, ne, sub, div] 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker for shape in [(1,), (2, 2)]: 613*da0073e9SAndroid Build Coastguard Worker x = 0.71 + 0.71j 614*da0073e9SAndroid Build Coastguard Worker y = torch.randn(shape, dtype=torch.cfloat) 615*da0073e9SAndroid Build Coastguard Worker for op in ops: 616*da0073e9SAndroid Build Coastguard Worker eager_result = op(x, y) 617*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(op) 618*da0073e9SAndroid Build Coastguard Worker jit_result = scripted(x, y) 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, jit_result) 620