1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: complex"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 6*da0073e9SAndroid Build Coastguard Worker dtypes, 7*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 8*da0073e9SAndroid Build Coastguard Worker onlyCPU, 9*da0073e9SAndroid Build Coastguard Worker) 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import complex_types 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdevices = (torch.device("cpu"), torch.device("cuda:0")) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerclass TestComplexTensor(TestCase): 18*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 19*da0073e9SAndroid Build Coastguard Worker def test_to_list(self, device, dtype): 20*da0073e9SAndroid Build Coastguard Worker # test that the complex float tensor has expected values and 21*da0073e9SAndroid Build Coastguard Worker # there's no garbage value in the resultant list 22*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 23*da0073e9SAndroid Build Coastguard Worker torch.zeros((2, 2), device=device, dtype=dtype).tolist(), 24*da0073e9SAndroid Build Coastguard Worker [[0j, 0j], [0j, 0j]], 25*da0073e9SAndroid Build Coastguard Worker ) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64, torch.float16) 28*da0073e9SAndroid Build Coastguard Worker def test_dtype_inference(self, device, dtype): 29*da0073e9SAndroid Build Coastguard Worker # issue: https://github.com/pytorch/pytorch/issues/36834 30*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(dtype): 31*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([3.0, 3.0 + 5.0j], device=device) 32*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 33*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.chalf) 34*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.float32: 35*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.cfloat) 36*da0073e9SAndroid Build Coastguard Worker else: 37*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.cdouble) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 40*da0073e9SAndroid Build Coastguard Worker def test_conj_copy(self, device, dtype): 41*da0073e9SAndroid Build Coastguard Worker # issue: https://github.com/pytorch/pytorch/issues/106051 42*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype) 43*da0073e9SAndroid Build Coastguard Worker xc1 = torch.conj(x1) 44*da0073e9SAndroid Build Coastguard Worker x1.copy_(xc1) 45*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype)) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 48*da0073e9SAndroid Build Coastguard Worker def test_all(self, device, dtype): 49*da0073e9SAndroid Build Coastguard Worker # issue: https://github.com/pytorch/pytorch/issues/120875 50*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) 51*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(x)) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 54*da0073e9SAndroid Build Coastguard Worker def test_any(self, device, dtype): 55*da0073e9SAndroid Build Coastguard Worker # issue: https://github.com/pytorch/pytorch/issues/120875 56*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 57*da0073e9SAndroid Build Coastguard Worker [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.any(x)) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @onlyCPU 62*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 63*da0073e9SAndroid Build Coastguard Worker def test_eq(self, device, dtype): 64*da0073e9SAndroid Build Coastguard Worker "Test eq on complex types" 65*da0073e9SAndroid Build Coastguard Worker nan = float("nan") 66*da0073e9SAndroid Build Coastguard Worker # Non-vectorized operations 67*da0073e9SAndroid Build Coastguard Worker for a, b in ( 68*da0073e9SAndroid Build Coastguard Worker ( 69*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 70*da0073e9SAndroid Build Coastguard Worker torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), 71*da0073e9SAndroid Build Coastguard Worker ), 72*da0073e9SAndroid Build Coastguard Worker ( 73*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 74*da0073e9SAndroid Build Coastguard Worker torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), 75*da0073e9SAndroid Build Coastguard Worker ), 76*da0073e9SAndroid Build Coastguard Worker ( 77*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 78*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), 79*da0073e9SAndroid Build Coastguard Worker ), 80*da0073e9SAndroid Build Coastguard Worker ): 81*da0073e9SAndroid Build Coastguard Worker actual = torch.eq(a, b) 82*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([False], device=device, dtype=torch.bool) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 84*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker actual = torch.eq(a, a) 88*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([True], device=device, dtype=torch.bool) 89*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 90*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 94*da0073e9SAndroid Build Coastguard Worker torch.eq(a, b, out=actual) 95*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([complex(0)], device=device, dtype=dtype) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 97*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 101*da0073e9SAndroid Build Coastguard Worker torch.eq(a, a, out=actual) 102*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([complex(1)], device=device, dtype=dtype) 103*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 104*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker # Vectorized operations 108*da0073e9SAndroid Build Coastguard Worker for a, b in ( 109*da0073e9SAndroid Build Coastguard Worker ( 110*da0073e9SAndroid Build Coastguard Worker torch.tensor( 111*da0073e9SAndroid Build Coastguard Worker [ 112*da0073e9SAndroid Build Coastguard Worker -0.0610 - 2.1172j, 113*da0073e9SAndroid Build Coastguard Worker 5.1576 + 5.4775j, 114*da0073e9SAndroid Build Coastguard Worker complex(2.8871, nan), 115*da0073e9SAndroid Build Coastguard Worker -6.6545 - 3.7655j, 116*da0073e9SAndroid Build Coastguard Worker -2.7036 - 1.4470j, 117*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 118*da0073e9SAndroid Build Coastguard Worker -0.0610 - 2.1172j, 119*da0073e9SAndroid Build Coastguard Worker 5.1576 + 5.4775j, 120*da0073e9SAndroid Build Coastguard Worker complex(nan, -3.2650), 121*da0073e9SAndroid Build Coastguard Worker -6.6545 - 3.7655j, 122*da0073e9SAndroid Build Coastguard Worker -2.7036 - 1.4470j, 123*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 124*da0073e9SAndroid Build Coastguard Worker ], 125*da0073e9SAndroid Build Coastguard Worker device=device, 126*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 127*da0073e9SAndroid Build Coastguard Worker ), 128*da0073e9SAndroid Build Coastguard Worker torch.tensor( 129*da0073e9SAndroid Build Coastguard Worker [ 130*da0073e9SAndroid Build Coastguard Worker -6.1278 - 8.5019j, 131*da0073e9SAndroid Build Coastguard Worker 0.5886 + 8.8816j, 132*da0073e9SAndroid Build Coastguard Worker complex(2.8871, nan), 133*da0073e9SAndroid Build Coastguard Worker 6.3505 + 2.2683j, 134*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.9659j, 135*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 136*da0073e9SAndroid Build Coastguard Worker -6.1278 - 2.1172j, 137*da0073e9SAndroid Build Coastguard Worker 5.1576 + 8.8816j, 138*da0073e9SAndroid Build Coastguard Worker complex(nan, -3.2650), 139*da0073e9SAndroid Build Coastguard Worker 6.3505 + 2.2683j, 140*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.9659j, 141*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 142*da0073e9SAndroid Build Coastguard Worker ], 143*da0073e9SAndroid Build Coastguard Worker device=device, 144*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 145*da0073e9SAndroid Build Coastguard Worker ), 146*da0073e9SAndroid Build Coastguard Worker ), 147*da0073e9SAndroid Build Coastguard Worker ): 148*da0073e9SAndroid Build Coastguard Worker actual = torch.eq(a, b) 149*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 150*da0073e9SAndroid Build Coastguard Worker [ 151*da0073e9SAndroid Build Coastguard Worker False, 152*da0073e9SAndroid Build Coastguard Worker False, 153*da0073e9SAndroid Build Coastguard Worker False, 154*da0073e9SAndroid Build Coastguard Worker False, 155*da0073e9SAndroid Build Coastguard Worker False, 156*da0073e9SAndroid Build Coastguard Worker True, 157*da0073e9SAndroid Build Coastguard Worker False, 158*da0073e9SAndroid Build Coastguard Worker False, 159*da0073e9SAndroid Build Coastguard Worker False, 160*da0073e9SAndroid Build Coastguard Worker False, 161*da0073e9SAndroid Build Coastguard Worker False, 162*da0073e9SAndroid Build Coastguard Worker True, 163*da0073e9SAndroid Build Coastguard Worker ], 164*da0073e9SAndroid Build Coastguard Worker device=device, 165*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 168*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker actual = torch.eq(a, a) 172*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 173*da0073e9SAndroid Build Coastguard Worker [ 174*da0073e9SAndroid Build Coastguard Worker True, 175*da0073e9SAndroid Build Coastguard Worker True, 176*da0073e9SAndroid Build Coastguard Worker False, 177*da0073e9SAndroid Build Coastguard Worker True, 178*da0073e9SAndroid Build Coastguard Worker True, 179*da0073e9SAndroid Build Coastguard Worker True, 180*da0073e9SAndroid Build Coastguard Worker True, 181*da0073e9SAndroid Build Coastguard Worker True, 182*da0073e9SAndroid Build Coastguard Worker False, 183*da0073e9SAndroid Build Coastguard Worker True, 184*da0073e9SAndroid Build Coastguard Worker True, 185*da0073e9SAndroid Build Coastguard Worker True, 186*da0073e9SAndroid Build Coastguard Worker ], 187*da0073e9SAndroid Build Coastguard Worker device=device, 188*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 189*da0073e9SAndroid Build Coastguard Worker ) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 191*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 195*da0073e9SAndroid Build Coastguard Worker torch.eq(a, b, out=actual) 196*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 197*da0073e9SAndroid Build Coastguard Worker [ 198*da0073e9SAndroid Build Coastguard Worker complex(0), 199*da0073e9SAndroid Build Coastguard Worker complex(0), 200*da0073e9SAndroid Build Coastguard Worker complex(0), 201*da0073e9SAndroid Build Coastguard Worker complex(0), 202*da0073e9SAndroid Build Coastguard Worker complex(0), 203*da0073e9SAndroid Build Coastguard Worker complex(1), 204*da0073e9SAndroid Build Coastguard Worker complex(0), 205*da0073e9SAndroid Build Coastguard Worker complex(0), 206*da0073e9SAndroid Build Coastguard Worker complex(0), 207*da0073e9SAndroid Build Coastguard Worker complex(0), 208*da0073e9SAndroid Build Coastguard Worker complex(0), 209*da0073e9SAndroid Build Coastguard Worker complex(1), 210*da0073e9SAndroid Build Coastguard Worker ], 211*da0073e9SAndroid Build Coastguard Worker device=device, 212*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 215*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 219*da0073e9SAndroid Build Coastguard Worker torch.eq(a, a, out=actual) 220*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 221*da0073e9SAndroid Build Coastguard Worker [ 222*da0073e9SAndroid Build Coastguard Worker complex(1), 223*da0073e9SAndroid Build Coastguard Worker complex(1), 224*da0073e9SAndroid Build Coastguard Worker complex(0), 225*da0073e9SAndroid Build Coastguard Worker complex(1), 226*da0073e9SAndroid Build Coastguard Worker complex(1), 227*da0073e9SAndroid Build Coastguard Worker complex(1), 228*da0073e9SAndroid Build Coastguard Worker complex(1), 229*da0073e9SAndroid Build Coastguard Worker complex(1), 230*da0073e9SAndroid Build Coastguard Worker complex(0), 231*da0073e9SAndroid Build Coastguard Worker complex(1), 232*da0073e9SAndroid Build Coastguard Worker complex(1), 233*da0073e9SAndroid Build Coastguard Worker complex(1), 234*da0073e9SAndroid Build Coastguard Worker ], 235*da0073e9SAndroid Build Coastguard Worker device=device, 236*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 237*da0073e9SAndroid Build Coastguard Worker ) 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 239*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker @onlyCPU 243*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 244*da0073e9SAndroid Build Coastguard Worker def test_ne(self, device, dtype): 245*da0073e9SAndroid Build Coastguard Worker "Test ne on complex types" 246*da0073e9SAndroid Build Coastguard Worker nan = float("nan") 247*da0073e9SAndroid Build Coastguard Worker # Non-vectorized operations 248*da0073e9SAndroid Build Coastguard Worker for a, b in ( 249*da0073e9SAndroid Build Coastguard Worker ( 250*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 251*da0073e9SAndroid Build Coastguard Worker torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), 252*da0073e9SAndroid Build Coastguard Worker ), 253*da0073e9SAndroid Build Coastguard Worker ( 254*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 255*da0073e9SAndroid Build Coastguard Worker torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), 256*da0073e9SAndroid Build Coastguard Worker ), 257*da0073e9SAndroid Build Coastguard Worker ( 258*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 259*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), 260*da0073e9SAndroid Build Coastguard Worker ), 261*da0073e9SAndroid Build Coastguard Worker ): 262*da0073e9SAndroid Build Coastguard Worker actual = torch.ne(a, b) 263*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([True], device=device, dtype=torch.bool) 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 265*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 266*da0073e9SAndroid Build Coastguard Worker ) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker actual = torch.ne(a, a) 269*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([False], device=device, dtype=torch.bool) 270*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 271*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 272*da0073e9SAndroid Build Coastguard Worker ) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 275*da0073e9SAndroid Build Coastguard Worker torch.ne(a, b, out=actual) 276*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([complex(1)], device=device, dtype=dtype) 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 278*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 282*da0073e9SAndroid Build Coastguard Worker torch.ne(a, a, out=actual) 283*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([complex(0)], device=device, dtype=dtype) 284*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 285*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 286*da0073e9SAndroid Build Coastguard Worker ) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker # Vectorized operations 289*da0073e9SAndroid Build Coastguard Worker for a, b in ( 290*da0073e9SAndroid Build Coastguard Worker ( 291*da0073e9SAndroid Build Coastguard Worker torch.tensor( 292*da0073e9SAndroid Build Coastguard Worker [ 293*da0073e9SAndroid Build Coastguard Worker -0.0610 - 2.1172j, 294*da0073e9SAndroid Build Coastguard Worker 5.1576 + 5.4775j, 295*da0073e9SAndroid Build Coastguard Worker complex(2.8871, nan), 296*da0073e9SAndroid Build Coastguard Worker -6.6545 - 3.7655j, 297*da0073e9SAndroid Build Coastguard Worker -2.7036 - 1.4470j, 298*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 299*da0073e9SAndroid Build Coastguard Worker -0.0610 - 2.1172j, 300*da0073e9SAndroid Build Coastguard Worker 5.1576 + 5.4775j, 301*da0073e9SAndroid Build Coastguard Worker complex(nan, -3.2650), 302*da0073e9SAndroid Build Coastguard Worker -6.6545 - 3.7655j, 303*da0073e9SAndroid Build Coastguard Worker -2.7036 - 1.4470j, 304*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 305*da0073e9SAndroid Build Coastguard Worker ], 306*da0073e9SAndroid Build Coastguard Worker device=device, 307*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 308*da0073e9SAndroid Build Coastguard Worker ), 309*da0073e9SAndroid Build Coastguard Worker torch.tensor( 310*da0073e9SAndroid Build Coastguard Worker [ 311*da0073e9SAndroid Build Coastguard Worker -6.1278 - 8.5019j, 312*da0073e9SAndroid Build Coastguard Worker 0.5886 + 8.8816j, 313*da0073e9SAndroid Build Coastguard Worker complex(2.8871, nan), 314*da0073e9SAndroid Build Coastguard Worker 6.3505 + 2.2683j, 315*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.9659j, 316*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 317*da0073e9SAndroid Build Coastguard Worker -6.1278 - 2.1172j, 318*da0073e9SAndroid Build Coastguard Worker 5.1576 + 8.8816j, 319*da0073e9SAndroid Build Coastguard Worker complex(nan, -3.2650), 320*da0073e9SAndroid Build Coastguard Worker 6.3505 + 2.2683j, 321*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.9659j, 322*da0073e9SAndroid Build Coastguard Worker 0.3712 + 7.989j, 323*da0073e9SAndroid Build Coastguard Worker ], 324*da0073e9SAndroid Build Coastguard Worker device=device, 325*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 326*da0073e9SAndroid Build Coastguard Worker ), 327*da0073e9SAndroid Build Coastguard Worker ), 328*da0073e9SAndroid Build Coastguard Worker ): 329*da0073e9SAndroid Build Coastguard Worker actual = torch.ne(a, b) 330*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 331*da0073e9SAndroid Build Coastguard Worker [ 332*da0073e9SAndroid Build Coastguard Worker True, 333*da0073e9SAndroid Build Coastguard Worker True, 334*da0073e9SAndroid Build Coastguard Worker True, 335*da0073e9SAndroid Build Coastguard Worker True, 336*da0073e9SAndroid Build Coastguard Worker True, 337*da0073e9SAndroid Build Coastguard Worker False, 338*da0073e9SAndroid Build Coastguard Worker True, 339*da0073e9SAndroid Build Coastguard Worker True, 340*da0073e9SAndroid Build Coastguard Worker True, 341*da0073e9SAndroid Build Coastguard Worker True, 342*da0073e9SAndroid Build Coastguard Worker True, 343*da0073e9SAndroid Build Coastguard Worker False, 344*da0073e9SAndroid Build Coastguard Worker ], 345*da0073e9SAndroid Build Coastguard Worker device=device, 346*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 349*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker actual = torch.ne(a, a) 353*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 354*da0073e9SAndroid Build Coastguard Worker [ 355*da0073e9SAndroid Build Coastguard Worker False, 356*da0073e9SAndroid Build Coastguard Worker False, 357*da0073e9SAndroid Build Coastguard Worker True, 358*da0073e9SAndroid Build Coastguard Worker False, 359*da0073e9SAndroid Build Coastguard Worker False, 360*da0073e9SAndroid Build Coastguard Worker False, 361*da0073e9SAndroid Build Coastguard Worker False, 362*da0073e9SAndroid Build Coastguard Worker False, 363*da0073e9SAndroid Build Coastguard Worker True, 364*da0073e9SAndroid Build Coastguard Worker False, 365*da0073e9SAndroid Build Coastguard Worker False, 366*da0073e9SAndroid Build Coastguard Worker False, 367*da0073e9SAndroid Build Coastguard Worker ], 368*da0073e9SAndroid Build Coastguard Worker device=device, 369*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 372*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 373*da0073e9SAndroid Build Coastguard Worker ) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 376*da0073e9SAndroid Build Coastguard Worker torch.ne(a, b, out=actual) 377*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 378*da0073e9SAndroid Build Coastguard Worker [ 379*da0073e9SAndroid Build Coastguard Worker complex(1), 380*da0073e9SAndroid Build Coastguard Worker complex(1), 381*da0073e9SAndroid Build Coastguard Worker complex(1), 382*da0073e9SAndroid Build Coastguard Worker complex(1), 383*da0073e9SAndroid Build Coastguard Worker complex(1), 384*da0073e9SAndroid Build Coastguard Worker complex(0), 385*da0073e9SAndroid Build Coastguard Worker complex(1), 386*da0073e9SAndroid Build Coastguard Worker complex(1), 387*da0073e9SAndroid Build Coastguard Worker complex(1), 388*da0073e9SAndroid Build Coastguard Worker complex(1), 389*da0073e9SAndroid Build Coastguard Worker complex(1), 390*da0073e9SAndroid Build Coastguard Worker complex(0), 391*da0073e9SAndroid Build Coastguard Worker ], 392*da0073e9SAndroid Build Coastguard Worker device=device, 393*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 394*da0073e9SAndroid Build Coastguard Worker ) 395*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 396*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 397*da0073e9SAndroid Build Coastguard Worker ) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker actual = torch.full_like(b, complex(2, 2)) 400*da0073e9SAndroid Build Coastguard Worker torch.ne(a, a, out=actual) 401*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 402*da0073e9SAndroid Build Coastguard Worker [ 403*da0073e9SAndroid Build Coastguard Worker complex(0), 404*da0073e9SAndroid Build Coastguard Worker complex(0), 405*da0073e9SAndroid Build Coastguard Worker complex(1), 406*da0073e9SAndroid Build Coastguard Worker complex(0), 407*da0073e9SAndroid Build Coastguard Worker complex(0), 408*da0073e9SAndroid Build Coastguard Worker complex(0), 409*da0073e9SAndroid Build Coastguard Worker complex(0), 410*da0073e9SAndroid Build Coastguard Worker complex(0), 411*da0073e9SAndroid Build Coastguard Worker complex(1), 412*da0073e9SAndroid Build Coastguard Worker complex(0), 413*da0073e9SAndroid Build Coastguard Worker complex(0), 414*da0073e9SAndroid Build Coastguard Worker complex(0), 415*da0073e9SAndroid Build Coastguard Worker ], 416*da0073e9SAndroid Build Coastguard Worker device=device, 417*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 418*da0073e9SAndroid Build Coastguard Worker ) 419*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 420*da0073e9SAndroid Build Coastguard Worker actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 421*da0073e9SAndroid Build Coastguard Worker ) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestComplexTensor, globals()) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 427*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 428*da0073e9SAndroid Build Coastguard Worker run_tests() 429