xref: /aosp_15_r20/external/pytorch/test/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: complex"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
6*da0073e9SAndroid Build Coastguard Worker    dtypes,
7*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
8*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
9*da0073e9SAndroid Build Coastguard Worker)
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import complex_types
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdevices = (torch.device("cpu"), torch.device("cuda:0"))
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass TestComplexTensor(TestCase):
18*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
19*da0073e9SAndroid Build Coastguard Worker    def test_to_list(self, device, dtype):
20*da0073e9SAndroid Build Coastguard Worker        # test that the complex float tensor has expected values and
21*da0073e9SAndroid Build Coastguard Worker        # there's no garbage value in the resultant list
22*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
23*da0073e9SAndroid Build Coastguard Worker            torch.zeros((2, 2), device=device, dtype=dtype).tolist(),
24*da0073e9SAndroid Build Coastguard Worker            [[0j, 0j], [0j, 0j]],
25*da0073e9SAndroid Build Coastguard Worker        )
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64, torch.float16)
28*da0073e9SAndroid Build Coastguard Worker    def test_dtype_inference(self, device, dtype):
29*da0073e9SAndroid Build Coastguard Worker        # issue: https://github.com/pytorch/pytorch/issues/36834
30*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(dtype):
31*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([3.0, 3.0 + 5.0j], device=device)
32*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
33*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dtype, torch.chalf)
34*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.float32:
35*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dtype, torch.cfloat)
36*da0073e9SAndroid Build Coastguard Worker        else:
37*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dtype, torch.cdouble)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
40*da0073e9SAndroid Build Coastguard Worker    def test_conj_copy(self, device, dtype):
41*da0073e9SAndroid Build Coastguard Worker        # issue: https://github.com/pytorch/pytorch/issues/106051
42*da0073e9SAndroid Build Coastguard Worker        x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
43*da0073e9SAndroid Build Coastguard Worker        xc1 = torch.conj(x1)
44*da0073e9SAndroid Build Coastguard Worker        x1.copy_(xc1)
45*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
48*da0073e9SAndroid Build Coastguard Worker    def test_all(self, device, dtype):
49*da0073e9SAndroid Build Coastguard Worker        # issue: https://github.com/pytorch/pytorch/issues/120875
50*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
51*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(x))
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
54*da0073e9SAndroid Build Coastguard Worker    def test_any(self, device, dtype):
55*da0073e9SAndroid Build Coastguard Worker        # issue: https://github.com/pytorch/pytorch/issues/120875
56*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
57*da0073e9SAndroid Build Coastguard Worker            [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
58*da0073e9SAndroid Build Coastguard Worker        )
59*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.any(x))
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
62*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
63*da0073e9SAndroid Build Coastguard Worker    def test_eq(self, device, dtype):
64*da0073e9SAndroid Build Coastguard Worker        "Test eq on complex types"
65*da0073e9SAndroid Build Coastguard Worker        nan = float("nan")
66*da0073e9SAndroid Build Coastguard Worker        # Non-vectorized operations
67*da0073e9SAndroid Build Coastguard Worker        for a, b in (
68*da0073e9SAndroid Build Coastguard Worker            (
69*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
70*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
71*da0073e9SAndroid Build Coastguard Worker            ),
72*da0073e9SAndroid Build Coastguard Worker            (
73*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
74*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
75*da0073e9SAndroid Build Coastguard Worker            ),
76*da0073e9SAndroid Build Coastguard Worker            (
77*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
78*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
79*da0073e9SAndroid Build Coastguard Worker            ),
80*da0073e9SAndroid Build Coastguard Worker        ):
81*da0073e9SAndroid Build Coastguard Worker            actual = torch.eq(a, b)
82*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([False], device=device, dtype=torch.bool)
83*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
84*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
85*da0073e9SAndroid Build Coastguard Worker            )
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker            actual = torch.eq(a, a)
88*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([True], device=device, dtype=torch.bool)
89*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
90*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
91*da0073e9SAndroid Build Coastguard Worker            )
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
94*da0073e9SAndroid Build Coastguard Worker            torch.eq(a, b, out=actual)
95*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
96*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
97*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
98*da0073e9SAndroid Build Coastguard Worker            )
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
101*da0073e9SAndroid Build Coastguard Worker            torch.eq(a, a, out=actual)
102*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
103*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
104*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
105*da0073e9SAndroid Build Coastguard Worker            )
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        # Vectorized operations
108*da0073e9SAndroid Build Coastguard Worker        for a, b in (
109*da0073e9SAndroid Build Coastguard Worker            (
110*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
111*da0073e9SAndroid Build Coastguard Worker                    [
112*da0073e9SAndroid Build Coastguard Worker                        -0.0610 - 2.1172j,
113*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 5.4775j,
114*da0073e9SAndroid Build Coastguard Worker                        complex(2.8871, nan),
115*da0073e9SAndroid Build Coastguard Worker                        -6.6545 - 3.7655j,
116*da0073e9SAndroid Build Coastguard Worker                        -2.7036 - 1.4470j,
117*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
118*da0073e9SAndroid Build Coastguard Worker                        -0.0610 - 2.1172j,
119*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 5.4775j,
120*da0073e9SAndroid Build Coastguard Worker                        complex(nan, -3.2650),
121*da0073e9SAndroid Build Coastguard Worker                        -6.6545 - 3.7655j,
122*da0073e9SAndroid Build Coastguard Worker                        -2.7036 - 1.4470j,
123*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
124*da0073e9SAndroid Build Coastguard Worker                    ],
125*da0073e9SAndroid Build Coastguard Worker                    device=device,
126*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
127*da0073e9SAndroid Build Coastguard Worker                ),
128*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
129*da0073e9SAndroid Build Coastguard Worker                    [
130*da0073e9SAndroid Build Coastguard Worker                        -6.1278 - 8.5019j,
131*da0073e9SAndroid Build Coastguard Worker                        0.5886 + 8.8816j,
132*da0073e9SAndroid Build Coastguard Worker                        complex(2.8871, nan),
133*da0073e9SAndroid Build Coastguard Worker                        6.3505 + 2.2683j,
134*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.9659j,
135*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
136*da0073e9SAndroid Build Coastguard Worker                        -6.1278 - 2.1172j,
137*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 8.8816j,
138*da0073e9SAndroid Build Coastguard Worker                        complex(nan, -3.2650),
139*da0073e9SAndroid Build Coastguard Worker                        6.3505 + 2.2683j,
140*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.9659j,
141*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
142*da0073e9SAndroid Build Coastguard Worker                    ],
143*da0073e9SAndroid Build Coastguard Worker                    device=device,
144*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
145*da0073e9SAndroid Build Coastguard Worker                ),
146*da0073e9SAndroid Build Coastguard Worker            ),
147*da0073e9SAndroid Build Coastguard Worker        ):
148*da0073e9SAndroid Build Coastguard Worker            actual = torch.eq(a, b)
149*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
150*da0073e9SAndroid Build Coastguard Worker                [
151*da0073e9SAndroid Build Coastguard Worker                    False,
152*da0073e9SAndroid Build Coastguard Worker                    False,
153*da0073e9SAndroid Build Coastguard Worker                    False,
154*da0073e9SAndroid Build Coastguard Worker                    False,
155*da0073e9SAndroid Build Coastguard Worker                    False,
156*da0073e9SAndroid Build Coastguard Worker                    True,
157*da0073e9SAndroid Build Coastguard Worker                    False,
158*da0073e9SAndroid Build Coastguard Worker                    False,
159*da0073e9SAndroid Build Coastguard Worker                    False,
160*da0073e9SAndroid Build Coastguard Worker                    False,
161*da0073e9SAndroid Build Coastguard Worker                    False,
162*da0073e9SAndroid Build Coastguard Worker                    True,
163*da0073e9SAndroid Build Coastguard Worker                ],
164*da0073e9SAndroid Build Coastguard Worker                device=device,
165*da0073e9SAndroid Build Coastguard Worker                dtype=torch.bool,
166*da0073e9SAndroid Build Coastguard Worker            )
167*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
168*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
169*da0073e9SAndroid Build Coastguard Worker            )
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker            actual = torch.eq(a, a)
172*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
173*da0073e9SAndroid Build Coastguard Worker                [
174*da0073e9SAndroid Build Coastguard Worker                    True,
175*da0073e9SAndroid Build Coastguard Worker                    True,
176*da0073e9SAndroid Build Coastguard Worker                    False,
177*da0073e9SAndroid Build Coastguard Worker                    True,
178*da0073e9SAndroid Build Coastguard Worker                    True,
179*da0073e9SAndroid Build Coastguard Worker                    True,
180*da0073e9SAndroid Build Coastguard Worker                    True,
181*da0073e9SAndroid Build Coastguard Worker                    True,
182*da0073e9SAndroid Build Coastguard Worker                    False,
183*da0073e9SAndroid Build Coastguard Worker                    True,
184*da0073e9SAndroid Build Coastguard Worker                    True,
185*da0073e9SAndroid Build Coastguard Worker                    True,
186*da0073e9SAndroid Build Coastguard Worker                ],
187*da0073e9SAndroid Build Coastguard Worker                device=device,
188*da0073e9SAndroid Build Coastguard Worker                dtype=torch.bool,
189*da0073e9SAndroid Build Coastguard Worker            )
190*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
191*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
192*da0073e9SAndroid Build Coastguard Worker            )
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
195*da0073e9SAndroid Build Coastguard Worker            torch.eq(a, b, out=actual)
196*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
197*da0073e9SAndroid Build Coastguard Worker                [
198*da0073e9SAndroid Build Coastguard Worker                    complex(0),
199*da0073e9SAndroid Build Coastguard Worker                    complex(0),
200*da0073e9SAndroid Build Coastguard Worker                    complex(0),
201*da0073e9SAndroid Build Coastguard Worker                    complex(0),
202*da0073e9SAndroid Build Coastguard Worker                    complex(0),
203*da0073e9SAndroid Build Coastguard Worker                    complex(1),
204*da0073e9SAndroid Build Coastguard Worker                    complex(0),
205*da0073e9SAndroid Build Coastguard Worker                    complex(0),
206*da0073e9SAndroid Build Coastguard Worker                    complex(0),
207*da0073e9SAndroid Build Coastguard Worker                    complex(0),
208*da0073e9SAndroid Build Coastguard Worker                    complex(0),
209*da0073e9SAndroid Build Coastguard Worker                    complex(1),
210*da0073e9SAndroid Build Coastguard Worker                ],
211*da0073e9SAndroid Build Coastguard Worker                device=device,
212*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
213*da0073e9SAndroid Build Coastguard Worker            )
214*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
215*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
216*da0073e9SAndroid Build Coastguard Worker            )
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
219*da0073e9SAndroid Build Coastguard Worker            torch.eq(a, a, out=actual)
220*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
221*da0073e9SAndroid Build Coastguard Worker                [
222*da0073e9SAndroid Build Coastguard Worker                    complex(1),
223*da0073e9SAndroid Build Coastguard Worker                    complex(1),
224*da0073e9SAndroid Build Coastguard Worker                    complex(0),
225*da0073e9SAndroid Build Coastguard Worker                    complex(1),
226*da0073e9SAndroid Build Coastguard Worker                    complex(1),
227*da0073e9SAndroid Build Coastguard Worker                    complex(1),
228*da0073e9SAndroid Build Coastguard Worker                    complex(1),
229*da0073e9SAndroid Build Coastguard Worker                    complex(1),
230*da0073e9SAndroid Build Coastguard Worker                    complex(0),
231*da0073e9SAndroid Build Coastguard Worker                    complex(1),
232*da0073e9SAndroid Build Coastguard Worker                    complex(1),
233*da0073e9SAndroid Build Coastguard Worker                    complex(1),
234*da0073e9SAndroid Build Coastguard Worker                ],
235*da0073e9SAndroid Build Coastguard Worker                device=device,
236*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
237*da0073e9SAndroid Build Coastguard Worker            )
238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
239*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
240*da0073e9SAndroid Build Coastguard Worker            )
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
243*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
244*da0073e9SAndroid Build Coastguard Worker    def test_ne(self, device, dtype):
245*da0073e9SAndroid Build Coastguard Worker        "Test ne on complex types"
246*da0073e9SAndroid Build Coastguard Worker        nan = float("nan")
247*da0073e9SAndroid Build Coastguard Worker        # Non-vectorized operations
248*da0073e9SAndroid Build Coastguard Worker        for a, b in (
249*da0073e9SAndroid Build Coastguard Worker            (
250*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
251*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
252*da0073e9SAndroid Build Coastguard Worker            ),
253*da0073e9SAndroid Build Coastguard Worker            (
254*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
255*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
256*da0073e9SAndroid Build Coastguard Worker            ),
257*da0073e9SAndroid Build Coastguard Worker            (
258*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
259*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
260*da0073e9SAndroid Build Coastguard Worker            ),
261*da0073e9SAndroid Build Coastguard Worker        ):
262*da0073e9SAndroid Build Coastguard Worker            actual = torch.ne(a, b)
263*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([True], device=device, dtype=torch.bool)
264*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
265*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
266*da0073e9SAndroid Build Coastguard Worker            )
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker            actual = torch.ne(a, a)
269*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([False], device=device, dtype=torch.bool)
270*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
271*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
272*da0073e9SAndroid Build Coastguard Worker            )
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
275*da0073e9SAndroid Build Coastguard Worker            torch.ne(a, b, out=actual)
276*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
277*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
278*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
279*da0073e9SAndroid Build Coastguard Worker            )
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
282*da0073e9SAndroid Build Coastguard Worker            torch.ne(a, a, out=actual)
283*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
284*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
285*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
286*da0073e9SAndroid Build Coastguard Worker            )
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        # Vectorized operations
289*da0073e9SAndroid Build Coastguard Worker        for a, b in (
290*da0073e9SAndroid Build Coastguard Worker            (
291*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
292*da0073e9SAndroid Build Coastguard Worker                    [
293*da0073e9SAndroid Build Coastguard Worker                        -0.0610 - 2.1172j,
294*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 5.4775j,
295*da0073e9SAndroid Build Coastguard Worker                        complex(2.8871, nan),
296*da0073e9SAndroid Build Coastguard Worker                        -6.6545 - 3.7655j,
297*da0073e9SAndroid Build Coastguard Worker                        -2.7036 - 1.4470j,
298*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
299*da0073e9SAndroid Build Coastguard Worker                        -0.0610 - 2.1172j,
300*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 5.4775j,
301*da0073e9SAndroid Build Coastguard Worker                        complex(nan, -3.2650),
302*da0073e9SAndroid Build Coastguard Worker                        -6.6545 - 3.7655j,
303*da0073e9SAndroid Build Coastguard Worker                        -2.7036 - 1.4470j,
304*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
305*da0073e9SAndroid Build Coastguard Worker                    ],
306*da0073e9SAndroid Build Coastguard Worker                    device=device,
307*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
308*da0073e9SAndroid Build Coastguard Worker                ),
309*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
310*da0073e9SAndroid Build Coastguard Worker                    [
311*da0073e9SAndroid Build Coastguard Worker                        -6.1278 - 8.5019j,
312*da0073e9SAndroid Build Coastguard Worker                        0.5886 + 8.8816j,
313*da0073e9SAndroid Build Coastguard Worker                        complex(2.8871, nan),
314*da0073e9SAndroid Build Coastguard Worker                        6.3505 + 2.2683j,
315*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.9659j,
316*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
317*da0073e9SAndroid Build Coastguard Worker                        -6.1278 - 2.1172j,
318*da0073e9SAndroid Build Coastguard Worker                        5.1576 + 8.8816j,
319*da0073e9SAndroid Build Coastguard Worker                        complex(nan, -3.2650),
320*da0073e9SAndroid Build Coastguard Worker                        6.3505 + 2.2683j,
321*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.9659j,
322*da0073e9SAndroid Build Coastguard Worker                        0.3712 + 7.989j,
323*da0073e9SAndroid Build Coastguard Worker                    ],
324*da0073e9SAndroid Build Coastguard Worker                    device=device,
325*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
326*da0073e9SAndroid Build Coastguard Worker                ),
327*da0073e9SAndroid Build Coastguard Worker            ),
328*da0073e9SAndroid Build Coastguard Worker        ):
329*da0073e9SAndroid Build Coastguard Worker            actual = torch.ne(a, b)
330*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
331*da0073e9SAndroid Build Coastguard Worker                [
332*da0073e9SAndroid Build Coastguard Worker                    True,
333*da0073e9SAndroid Build Coastguard Worker                    True,
334*da0073e9SAndroid Build Coastguard Worker                    True,
335*da0073e9SAndroid Build Coastguard Worker                    True,
336*da0073e9SAndroid Build Coastguard Worker                    True,
337*da0073e9SAndroid Build Coastguard Worker                    False,
338*da0073e9SAndroid Build Coastguard Worker                    True,
339*da0073e9SAndroid Build Coastguard Worker                    True,
340*da0073e9SAndroid Build Coastguard Worker                    True,
341*da0073e9SAndroid Build Coastguard Worker                    True,
342*da0073e9SAndroid Build Coastguard Worker                    True,
343*da0073e9SAndroid Build Coastguard Worker                    False,
344*da0073e9SAndroid Build Coastguard Worker                ],
345*da0073e9SAndroid Build Coastguard Worker                device=device,
346*da0073e9SAndroid Build Coastguard Worker                dtype=torch.bool,
347*da0073e9SAndroid Build Coastguard Worker            )
348*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
349*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
350*da0073e9SAndroid Build Coastguard Worker            )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker            actual = torch.ne(a, a)
353*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
354*da0073e9SAndroid Build Coastguard Worker                [
355*da0073e9SAndroid Build Coastguard Worker                    False,
356*da0073e9SAndroid Build Coastguard Worker                    False,
357*da0073e9SAndroid Build Coastguard Worker                    True,
358*da0073e9SAndroid Build Coastguard Worker                    False,
359*da0073e9SAndroid Build Coastguard Worker                    False,
360*da0073e9SAndroid Build Coastguard Worker                    False,
361*da0073e9SAndroid Build Coastguard Worker                    False,
362*da0073e9SAndroid Build Coastguard Worker                    False,
363*da0073e9SAndroid Build Coastguard Worker                    True,
364*da0073e9SAndroid Build Coastguard Worker                    False,
365*da0073e9SAndroid Build Coastguard Worker                    False,
366*da0073e9SAndroid Build Coastguard Worker                    False,
367*da0073e9SAndroid Build Coastguard Worker                ],
368*da0073e9SAndroid Build Coastguard Worker                device=device,
369*da0073e9SAndroid Build Coastguard Worker                dtype=torch.bool,
370*da0073e9SAndroid Build Coastguard Worker            )
371*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
372*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
373*da0073e9SAndroid Build Coastguard Worker            )
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
376*da0073e9SAndroid Build Coastguard Worker            torch.ne(a, b, out=actual)
377*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
378*da0073e9SAndroid Build Coastguard Worker                [
379*da0073e9SAndroid Build Coastguard Worker                    complex(1),
380*da0073e9SAndroid Build Coastguard Worker                    complex(1),
381*da0073e9SAndroid Build Coastguard Worker                    complex(1),
382*da0073e9SAndroid Build Coastguard Worker                    complex(1),
383*da0073e9SAndroid Build Coastguard Worker                    complex(1),
384*da0073e9SAndroid Build Coastguard Worker                    complex(0),
385*da0073e9SAndroid Build Coastguard Worker                    complex(1),
386*da0073e9SAndroid Build Coastguard Worker                    complex(1),
387*da0073e9SAndroid Build Coastguard Worker                    complex(1),
388*da0073e9SAndroid Build Coastguard Worker                    complex(1),
389*da0073e9SAndroid Build Coastguard Worker                    complex(1),
390*da0073e9SAndroid Build Coastguard Worker                    complex(0),
391*da0073e9SAndroid Build Coastguard Worker                ],
392*da0073e9SAndroid Build Coastguard Worker                device=device,
393*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
394*da0073e9SAndroid Build Coastguard Worker            )
395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
396*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
397*da0073e9SAndroid Build Coastguard Worker            )
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker            actual = torch.full_like(b, complex(2, 2))
400*da0073e9SAndroid Build Coastguard Worker            torch.ne(a, a, out=actual)
401*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
402*da0073e9SAndroid Build Coastguard Worker                [
403*da0073e9SAndroid Build Coastguard Worker                    complex(0),
404*da0073e9SAndroid Build Coastguard Worker                    complex(0),
405*da0073e9SAndroid Build Coastguard Worker                    complex(1),
406*da0073e9SAndroid Build Coastguard Worker                    complex(0),
407*da0073e9SAndroid Build Coastguard Worker                    complex(0),
408*da0073e9SAndroid Build Coastguard Worker                    complex(0),
409*da0073e9SAndroid Build Coastguard Worker                    complex(0),
410*da0073e9SAndroid Build Coastguard Worker                    complex(0),
411*da0073e9SAndroid Build Coastguard Worker                    complex(1),
412*da0073e9SAndroid Build Coastguard Worker                    complex(0),
413*da0073e9SAndroid Build Coastguard Worker                    complex(0),
414*da0073e9SAndroid Build Coastguard Worker                    complex(0),
415*da0073e9SAndroid Build Coastguard Worker                ],
416*da0073e9SAndroid Build Coastguard Worker                device=device,
417*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
418*da0073e9SAndroid Build Coastguard Worker            )
419*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
420*da0073e9SAndroid Build Coastguard Worker                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
421*da0073e9SAndroid Build Coastguard Worker            )
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestComplexTensor, globals())
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
427*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
428*da0073e9SAndroid Build Coastguard Worker    run_tests()
429