xref: /aosp_15_r20/external/pytorch/test/test_linalg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: linear algebra"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerimport itertools
8*da0073e9SAndroid Build Coastguard Workerimport warnings
9*da0073e9SAndroid Build Coastguard Workerimport math
10*da0073e9SAndroid Build Coastguard Workerfrom math import inf, nan, isnan
11*da0073e9SAndroid Build Coastguard Workerimport re
12*da0073e9SAndroid Build Coastguard Workerimport random
13*da0073e9SAndroid Build Coastguard Workerfrom random import randrange
14*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
15*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce, partial
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \
18*da0073e9SAndroid Build Coastguard Worker    (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
19*da0073e9SAndroid Build Coastguard Worker     TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
20*da0073e9SAndroid Build Coastguard Worker     make_fullrank_matrices_with_distinct_singular_values,
21*da0073e9SAndroid Build Coastguard Worker     freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
22*da0073e9SAndroid Build Coastguard Worker     setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \
24*da0073e9SAndroid Build Coastguard Worker    (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
25*da0073e9SAndroid Build Coastguard Worker     onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
26*da0073e9SAndroid Build Coastguard Worker     skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA,
27*da0073e9SAndroid Build Coastguard Worker     onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm,
28*da0073e9SAndroid Build Coastguard Worker     dtypesIfMPS, largeTensorTest)
29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
31*da0073e9SAndroid Build Coastguard Worker    all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
32*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types_and, floating_types_and, complex_types,
33*da0073e9SAndroid Build Coastguard Worker)
34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
35*da0073e9SAndroid Build Coastguard Worker    _get_torch_cuda_version, CDNA2OrLater
36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off
38*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.binomial import Binomial
39*da0073e9SAndroid Build Coastguard Workerimport torch.backends.opt_einsum as opt_einsum
40*da0073e9SAndroid Build Coastguard Workerimport operator
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype
43*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
46*da0073e9SAndroid Build Coastguard Worker    import scipy
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerdef blaslt_supported_device():
49*da0073e9SAndroid Build Coastguard Worker    if torch.cuda.is_available():
50*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip:
51*da0073e9SAndroid Build Coastguard Worker            for arch in ['gfx90a', 'gfx94']:
52*da0073e9SAndroid Build Coastguard Worker                if arch in torch.cuda.get_device_properties(0).gcnArchName:
53*da0073e9SAndroid Build Coastguard Worker                    return True
54*da0073e9SAndroid Build Coastguard Worker        else:
55*da0073e9SAndroid Build Coastguard Worker            return True
56*da0073e9SAndroid Build Coastguard Worker    return False
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef set_tunableop_defaults():
59*da0073e9SAndroid Build Coastguard Worker    if not torch.cuda.is_available():
60*da0073e9SAndroid Build Coastguard Worker        # TunableOp not supported on CPU at this time.
61*da0073e9SAndroid Build Coastguard Worker        return
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    # disable TunableOp and restore to default values
64*da0073e9SAndroid Build Coastguard Worker    ordinal = torch.cuda.current_device()
65*da0073e9SAndroid Build Coastguard Worker    filename = f"tunableop_results{ordinal}.csv"
66*da0073e9SAndroid Build Coastguard Worker    torch.cuda.tunable.enable(False)
67*da0073e9SAndroid Build Coastguard Worker    torch.cuda.tunable.tuning_enable(True)
68*da0073e9SAndroid Build Coastguard Worker    torch.cuda.tunable.set_filename(filename)  # reset back to default filename for next unit test
69*da0073e9SAndroid Build Coastguard Worker    torch.cuda.tunable.set_max_tuning_duration(30)
70*da0073e9SAndroid Build Coastguard Worker    torch.cuda.tunable.set_max_tuning_iterations(100)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Workerclass TestLinalg(TestCase):
74*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
75*da0073e9SAndroid Build Coastguard Worker        super(self.__class__, self).setUp()
76*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = False
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
79*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = True
80*da0073e9SAndroid Build Coastguard Worker        super(self.__class__, self).tearDown()
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.cfloat)
85*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06})
86*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(5e-3)
87*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(5e-3)
88*da0073e9SAndroid Build Coastguard Worker    def test_inner(self, device, dtype):
89*da0073e9SAndroid Build Coastguard Worker        def check(a_sizes_, b_sizes_):
90*da0073e9SAndroid Build Coastguard Worker            for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)):
91*da0073e9SAndroid Build Coastguard Worker                a = torch.randn(a_sizes, dtype=dtype, device=device)
92*da0073e9SAndroid Build Coastguard Worker                b = torch.randn(b_sizes, dtype=dtype, device=device)
93*da0073e9SAndroid Build Coastguard Worker                res = torch.inner(a, b)
94*da0073e9SAndroid Build Coastguard Worker                ref = np.inner(a.cpu().numpy(), b.cpu().numpy())
95*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref)))
96*da0073e9SAndroid Build Coastguard Worker                out = torch.zeros_like(res)
97*da0073e9SAndroid Build Coastguard Worker                torch.inner(a, b, out=out)
98*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, out)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        check([], [])                       # scalar x scalar
101*da0073e9SAndroid Build Coastguard Worker        check([], [0])                      # scalar x empty
102*da0073e9SAndroid Build Coastguard Worker        check([], [3])                      # scalar x 1D
103*da0073e9SAndroid Build Coastguard Worker        check([], [2, 3, 4])                # scalar x 3D
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        check([0], [0])                     # empty x empty
106*da0073e9SAndroid Build Coastguard Worker        check([0], [2, 0])                  # empty x 2D
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        check([2], [2])                     # 1D x 1D
109*da0073e9SAndroid Build Coastguard Worker        check([2], [3, 1, 2])               # 1D x 3D
110*da0073e9SAndroid Build Coastguard Worker        check([2], [3, 0, 2])               # 1D x 3D empty
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        check([1, 2], [3, 2])               # 2D x 2D
113*da0073e9SAndroid Build Coastguard Worker        check([1, 2], [3, 4, 2])            # 2D x 3D
114*da0073e9SAndroid Build Coastguard Worker        check([2, 1, 3, 2], [1, 3, 2, 2])   # 4D x 4D
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        # Test error message
117*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
118*da0073e9SAndroid Build Coastguard Worker                                    r"inner\(\) the last dimension must match on both "
119*da0073e9SAndroid Build Coastguard Worker                                    r"input tensors but got shapes \[2, 3\] and \[2, 2\]"):
120*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    # Tests torch.outer, and its alias, torch.ger, vs. NumPy
123*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-1})
124*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
125*da0073e9SAndroid Build Coastguard Worker    def test_outer(self, device, dtype):
126*da0073e9SAndroid Build Coastguard Worker        def run_test_case(a, b):
127*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
128*da0073e9SAndroid Build Coastguard Worker                a_np = a.to(torch.double).cpu().numpy()
129*da0073e9SAndroid Build Coastguard Worker                b_np = b.to(torch.double).cpu().numpy()
130*da0073e9SAndroid Build Coastguard Worker                exact_dtype = False
131*da0073e9SAndroid Build Coastguard Worker            else:
132*da0073e9SAndroid Build Coastguard Worker                a_np = a.cpu().numpy()
133*da0073e9SAndroid Build Coastguard Worker                b_np = b.cpu().numpy()
134*da0073e9SAndroid Build Coastguard Worker                exact_dtype = True
135*da0073e9SAndroid Build Coastguard Worker            expected = np.outer(a_np, b_np)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.outer(a, b), expected, exact_dtype=False)
138*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ger(a, b), expected, exact_dtype=False)
141*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker            # test out variant
144*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
145*da0073e9SAndroid Build Coastguard Worker            torch.outer(a, b, out=out)
146*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, expected, exact_dtype=False)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
149*da0073e9SAndroid Build Coastguard Worker            torch.ger(a, b, out=out)
150*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, expected, exact_dtype=False)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(50).to(device=device, dtype=dtype)
153*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(50).to(device=device, dtype=dtype)
154*da0073e9SAndroid Build Coastguard Worker        run_test_case(a, b)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        # test 0 strided tensor
157*da0073e9SAndroid Build Coastguard Worker        zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50)
158*da0073e9SAndroid Build Coastguard Worker        run_test_case(zero_strided, b)
159*da0073e9SAndroid Build Coastguard Worker        run_test_case(a, zero_strided)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_removed_error(self, device):
162*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(5, 5, device=device, dtype=torch.float32)
163*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
164*da0073e9SAndroid Build Coastguard Worker            torch.matrix_rank(a)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def test_solve_removed_error(self, device):
167*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(5, 5, device=device, dtype=torch.float32)
168*da0073e9SAndroid Build Coastguard Worker        b = make_tensor(5, 1, device=device, dtype=torch.float32)
169*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
170*da0073e9SAndroid Build Coastguard Worker            torch.solve(b, a)
171*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
172*da0073e9SAndroid Build Coastguard Worker            b.solve(a)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_eig_removed_error(self, device):
175*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(5, 5, device=device, dtype=torch.float32)
176*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
177*da0073e9SAndroid Build Coastguard Worker            torch.eig(a)
178*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
179*da0073e9SAndroid Build Coastguard Worker            a.eig()
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    def test_symeig_removed_error(self, device):
182*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(5, 5, device=device, dtype=torch.float32)
183*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
184*da0073e9SAndroid Build Coastguard Worker            torch.symeig(a)
185*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
186*da0073e9SAndroid Build Coastguard Worker            a.symeig()
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_lstsq_removed_error(self, device):
189*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(5, 5, device=device, dtype=torch.float32)
190*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
191*da0073e9SAndroid Build Coastguard Worker            torch.lstsq(a, a)
192*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
193*da0073e9SAndroid Build Coastguard Worker            a.lstsq(a)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
196*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
197*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("flaky, needs investigation")
198*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
199*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lstsq(self, device, dtype):
200*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_well_conditioned_matrix
201*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
202*da0073e9SAndroid Build Coastguard Worker            drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None)
203*da0073e9SAndroid Build Coastguard Worker        else:
204*da0073e9SAndroid Build Coastguard Worker            drivers = ('gels', None)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        def check_solution_correctness(a, b, sol):
207*da0073e9SAndroid Build Coastguard Worker            sol2 = a.pinverse() @ b
208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        def check_correctness_ref(a, b, res, ref, driver="default"):
211*da0073e9SAndroid Build Coastguard Worker            def apply_if_not_empty(t, f):
212*da0073e9SAndroid Build Coastguard Worker                if t.numel():
213*da0073e9SAndroid Build Coastguard Worker                    return f(t)
214*da0073e9SAndroid Build Coastguard Worker                else:
215*da0073e9SAndroid Build Coastguard Worker                    return t
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker            def select_if_not_empty(t, i):
218*da0073e9SAndroid Build Coastguard Worker                selected = apply_if_not_empty(t, lambda x: x.select(0, i))
219*da0073e9SAndroid Build Coastguard Worker                return selected
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker            m = a.size(-2)
222*da0073e9SAndroid Build Coastguard Worker            n = a.size(-1)
223*da0073e9SAndroid Build Coastguard Worker            nrhs = b.size(-1)
224*da0073e9SAndroid Build Coastguard Worker            batch_size = int(np.prod(a.shape[:-2]))
225*da0073e9SAndroid Build Coastguard Worker            if batch_size == 0:
226*da0073e9SAndroid Build Coastguard Worker                batch_size = 1
227*da0073e9SAndroid Build Coastguard Worker            a_3d = a.view(batch_size, m, n)
228*da0073e9SAndroid Build Coastguard Worker            b_3d = b.view(batch_size, m, nrhs)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker            solution_3d = res.solution.view(batch_size, n, nrhs)
231*da0073e9SAndroid Build Coastguard Worker            residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs))
232*da0073e9SAndroid Build Coastguard Worker            rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
233*da0073e9SAndroid Build Coastguard Worker            singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            if a.numel() > 0:
236*da0073e9SAndroid Build Coastguard Worker                for i in range(batch_size):
237*da0073e9SAndroid Build Coastguard Worker                    sol, residuals, rank, singular_values = ref(
238*da0073e9SAndroid Build Coastguard Worker                        a_3d.select(0, i).numpy(),
239*da0073e9SAndroid Build Coastguard Worker                        b_3d.select(0, i).numpy()
240*da0073e9SAndroid Build Coastguard Worker                    )
241*da0073e9SAndroid Build Coastguard Worker                    # Singular values are None when lapack_driver='gelsy' in SciPy
242*da0073e9SAndroid Build Coastguard Worker                    if singular_values is None:
243*da0073e9SAndroid Build Coastguard Worker                        singular_values = []
244*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
245*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
246*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker                    # SciPy and NumPy operate only on non-batched input and
249*da0073e9SAndroid Build Coastguard Worker                    # return an empty array with shape (0,) if rank(a) != n
250*da0073e9SAndroid Build Coastguard Worker                    # in PyTorch the batched inputs are supported and
251*da0073e9SAndroid Build Coastguard Worker                    # matrices in the batched input can have different ranks
252*da0073e9SAndroid Build Coastguard Worker                    # we compute residuals only if all matrices have rank == n
253*da0073e9SAndroid Build Coastguard Worker                    # see https://github.com/pytorch/pytorch/issues/56483
254*da0073e9SAndroid Build Coastguard Worker                    if m > n:
255*da0073e9SAndroid Build Coastguard Worker                        if torch.all(rank_1d == n):
256*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(
257*da0073e9SAndroid Build Coastguard Worker                                residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False
258*da0073e9SAndroid Build Coastguard Worker                            )
259*da0073e9SAndroid Build Coastguard Worker                        else:
260*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(residuals_2d.numel() == 0)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker            else:
263*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
264*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.rank.shape, a.shape[:-2])
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker                # residuals are not always computed (and have non-zero shape)
267*da0073e9SAndroid Build Coastguard Worker                if m > n and driver != "gelsy":
268*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
269*da0073e9SAndroid Build Coastguard Worker                else:
270*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.residuals.shape, (0, ))
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker                # singular_values are not always computed (and have non-zero shape)
273*da0073e9SAndroid Build Coastguard Worker                if driver == "default" or driver == "gelsd" or driver == "gelss":
274*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
275*da0073e9SAndroid Build Coastguard Worker                else:
276*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.singular_values.shape, (0, ))
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        def check_correctness_scipy(a, b, res, driver, cond):
279*da0073e9SAndroid Build Coastguard Worker            # SciPy provides 3 driver options: gelsd, gelss, gelsy
280*da0073e9SAndroid Build Coastguard Worker            if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
281*da0073e9SAndroid Build Coastguard Worker                import scipy.linalg
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker                def scipy_ref(a, b):
284*da0073e9SAndroid Build Coastguard Worker                    return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
285*da0073e9SAndroid Build Coastguard Worker                check_correctness_ref(a, b, res, scipy_ref, driver=driver)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        def check_correctness_numpy(a, b, res, driver, rcond):
288*da0073e9SAndroid Build Coastguard Worker            # NumPy uses only gelsd routine
289*da0073e9SAndroid Build Coastguard Worker            if driver == 'gelsd':
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker                def numpy_ref(a, b):
292*da0073e9SAndroid Build Coastguard Worker                    return np.linalg.lstsq(a, b, rcond=rcond)
293*da0073e9SAndroid Build Coastguard Worker                check_correctness_ref(a, b, res, numpy_ref)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        ms = [2 ** i for i in range(5)]
296*da0073e9SAndroid Build Coastguard Worker        m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
297*da0073e9SAndroid Build Coastguard Worker        # cases m < n are only supported on CPU and for cuSOLVER path on CUDA
298*da0073e9SAndroid Build Coastguard Worker        m_l_n_sizes = [(m // 2, m) for m in ms]
299*da0073e9SAndroid Build Coastguard Worker        include_m_l_n_case = (has_cusolver() or device == 'cpu')
300*da0073e9SAndroid Build Coastguard Worker        matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else [])
301*da0073e9SAndroid Build Coastguard Worker        batches = [(), (2,), (2, 2), (2, 2, 2)]
302*da0073e9SAndroid Build Coastguard Worker        # we generate matrices with singular values sampled from a normal distribution,
303*da0073e9SAndroid Build Coastguard Worker        # that is why we use `cond=1.0`, the mean to cut roughly half of all
304*da0073e9SAndroid Build Coastguard Worker        # the singular values and compare whether torch.linalg.lstsq agrees with
305*da0073e9SAndroid Build Coastguard Worker        # SciPy and NumPy.
306*da0073e9SAndroid Build Coastguard Worker        # if rcond is True then set value for it based on the used algorithm
307*da0073e9SAndroid Build Coastguard Worker        # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
308*da0073e9SAndroid Build Coastguard Worker        rconds = (None, True, -1)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
311*da0073e9SAndroid Build Coastguard Worker            # keep the rcond value if it is None or -1, set the driver specific value if it is True
312*da0073e9SAndroid Build Coastguard Worker            if rcond and rcond != -1:
313*da0073e9SAndroid Build Coastguard Worker                if driver in ('gelss', 'gelsd'):
314*da0073e9SAndroid Build Coastguard Worker                    # SVD based algorithm; set to zero roughly half of all the singular values
315*da0073e9SAndroid Build Coastguard Worker                    rcond = 1.0
316*da0073e9SAndroid Build Coastguard Worker                else:
317*da0073e9SAndroid Build Coastguard Worker                    # driver == 'gelsy'
318*da0073e9SAndroid Build Coastguard Worker                    # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
319*da0073e9SAndroid Build Coastguard Worker                    # so we skip this case
320*da0073e9SAndroid Build Coastguard Worker                    continue
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker            # specifying rcond value has no effect for gels driver so no need to run the tests again
323*da0073e9SAndroid Build Coastguard Worker            if driver == 'gels' and rcond is not None:
324*da0073e9SAndroid Build Coastguard Worker                continue
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker            shape = batch + matrix_size
327*da0073e9SAndroid Build Coastguard Worker            a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
328*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(*shape, dtype=dtype, device=device)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker            m = a.size(-2)
331*da0073e9SAndroid Build Coastguard Worker            n = a.size(-1)
332*da0073e9SAndroid Build Coastguard Worker            res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
333*da0073e9SAndroid Build Coastguard Worker            sol = res.solution
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker            # Only checks gelsd, gelss, gelsy drivers
336*da0073e9SAndroid Build Coastguard Worker            check_correctness_scipy(a, b, res, driver, rcond)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker            # Only checks gelsd driver
339*da0073e9SAndroid Build Coastguard Worker            check_correctness_numpy(a, b, res, driver, rcond)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker            # gels driver is not checked by comparing to NumPy or SciPy implementation
342*da0073e9SAndroid Build Coastguard Worker            # because NumPy and SciPy do not implement this driver
343*da0073e9SAndroid Build Coastguard Worker            if driver == 'gels' and rcond is None:
344*da0073e9SAndroid Build Coastguard Worker                check_solution_correctness(a, b, sol)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
347*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
348*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
349*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lstsq_batch_broadcasting(self, device, dtype):
350*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_well_conditioned_matrix
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        def check_correctness(a, b):
353*da0073e9SAndroid Build Coastguard Worker            sol = torch.linalg.lstsq(a, b).solution
354*da0073e9SAndroid Build Coastguard Worker            sol2 = a.pinverse() @ b
355*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        ms = [2 ** i for i in range(5)]
358*da0073e9SAndroid Build Coastguard Worker        batches = [(), (0,), (2,), (2, 2), (2, 2, 2)]
359*da0073e9SAndroid Build Coastguard Worker        # the case when a single matrix is batch-broadcasted over the rhs
360*da0073e9SAndroid Build Coastguard Worker        for m, batch in itertools.product(ms, batches):
361*da0073e9SAndroid Build Coastguard Worker            a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m)
362*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device)
363*da0073e9SAndroid Build Coastguard Worker            check_correctness(a, b)
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        # cases with broadcastable shapes
366*da0073e9SAndroid Build Coastguard Worker        for m in ms:
367*da0073e9SAndroid Build Coastguard Worker            a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device)
368*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device)
369*da0073e9SAndroid Build Coastguard Worker            check_correctness(a, b)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker            # rhs are vectors, not matrices in this test
372*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device)
373*da0073e9SAndroid Build Coastguard Worker            # unsqueeze for b because `check_correctness` checks against
374*da0073e9SAndroid Build Coastguard Worker            # a.pinverse() @ b, which requires b to be a matrix
375*da0073e9SAndroid Build Coastguard Worker            check_correctness(a, b.unsqueeze(-1))
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker            a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device)
378*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device)
379*da0073e9SAndroid Build Coastguard Worker            check_correctness(a, b)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker            # rhs are vectors, not matrices in this test
382*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device)
383*da0073e9SAndroid Build Coastguard Worker            check_correctness(a, b.unsqueeze(-1))
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
386*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
387*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
388*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lstsq_input_checks(self, device, dtype):
389*da0073e9SAndroid Build Coastguard Worker        # check empty inputs
390*da0073e9SAndroid Build Coastguard Worker        # empty batches
391*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device)
392*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device)
393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
394*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)[0],
395*da0073e9SAndroid Build Coastguard Worker            torch.zeros(0, 0, 3, 2, dtype=dtype, device=device)
396*da0073e9SAndroid Build Coastguard Worker        )
397*da0073e9SAndroid Build Coastguard Worker        # empty a and b
398*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
399*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
401*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)[0],
402*da0073e9SAndroid Build Coastguard Worker            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
403*da0073e9SAndroid Build Coastguard Worker        )
404*da0073e9SAndroid Build Coastguard Worker        # empty a and b
405*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
406*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
408*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)[0],
409*da0073e9SAndroid Build Coastguard Worker            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
410*da0073e9SAndroid Build Coastguard Worker        )
411*da0073e9SAndroid Build Coastguard Worker        # empty a but not b
412*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
413*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device)
414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
415*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)[0],
416*da0073e9SAndroid Build Coastguard Worker            torch.zeros(2, 2, 0, 2, dtype=dtype, device=device)
417*da0073e9SAndroid Build Coastguard Worker        )
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        # empty a and b
420*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cpu':
421*da0073e9SAndroid Build Coastguard Worker            # only CPU since CUDA does not support overdetermined systems
422*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
423*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
424*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
425*da0073e9SAndroid Build Coastguard Worker                torch.linalg.lstsq(a, b)[0],
426*da0073e9SAndroid Build Coastguard Worker                torch.zeros(2, 2, 3, 3, dtype=dtype, device=device)
427*da0073e9SAndroid Build Coastguard Worker            )
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 3, dtype=dtype, device=device)
430*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(3, dtype=dtype, device=device)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
433*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(b, b)
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
436*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
439*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
442*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b.unsqueeze(-1))
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 1, 1, dtype=dtype, device=device)
445*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 1, dtype=dtype, device=device)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
448*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        def complement_device(device):
451*da0073e9SAndroid Build Coastguard Worker            if device == 'cpu' and torch.cuda.is_available():
452*da0073e9SAndroid Build Coastguard Worker                return 'cuda'
453*da0073e9SAndroid Build Coastguard Worker            else:
454*da0073e9SAndroid Build Coastguard Worker                return 'cpu'
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
457*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
458*da0073e9SAndroid Build Coastguard Worker        if a.device != b.device:
459*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
460*da0073e9SAndroid Build Coastguard Worker                torch.linalg.lstsq(a, b)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
463*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
464*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lstsq(a, b)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
467*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 2, dtype=dtype, device=device)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        if device != 'cpu':
470*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'):
471*da0073e9SAndroid Build Coastguard Worker                torch.linalg.lstsq(a, b, driver='fictitious_driver')
472*da0073e9SAndroid Build Coastguard Worker        # if on cpu
473*da0073e9SAndroid Build Coastguard Worker        else:
474*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'):
475*da0073e9SAndroid Build Coastguard Worker                torch.linalg.lstsq(a, b, driver='fictitious_driver')
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker        # cuSOLVER path supports underdetermined systems
478*da0073e9SAndroid Build Coastguard Worker        version = torch.testing._internal.common_cuda._get_torch_cuda_version()
479*da0073e9SAndroid Build Coastguard Worker        cusolver_not_available = (version < (10, 1))
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker        if device != 'cpu' and cusolver_not_available:
482*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, dtype=dtype, device=device)
483*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 1, dtype=dtype, device=device)
484*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
485*da0073e9SAndroid Build Coastguard Worker                torch.linalg.lstsq(a, b)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
488*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
489*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
490*da0073e9SAndroid Build Coastguard Worker    def test_cholesky(self, device, dtype):
491*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, contiguous):
494*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
495*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0 and not contiguous:
496*da0073e9SAndroid Build Coastguard Worker                A = A.mT
497*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(A.is_contiguous())
498*da0073e9SAndroid Build Coastguard Worker            expected_L = np.linalg.cholesky(A.cpu().numpy())
499*da0073e9SAndroid Build Coastguard Worker            actual_L = torch.linalg.cholesky(A)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
502*da0073e9SAndroid Build Coastguard Worker            # Let's compare the norms of matrices instead
503*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
504*da0073e9SAndroid Build Coastguard Worker                # axis is specified to calculate matrix norm for batched input
505*da0073e9SAndroid Build Coastguard Worker                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
506*da0073e9SAndroid Build Coastguard Worker                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
507*da0073e9SAndroid Build Coastguard Worker                # Compare the norms with standard tolerances
508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_norm, expected_norm)
509*da0073e9SAndroid Build Coastguard Worker                # and individual values with a higher tolerance
510*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
511*da0073e9SAndroid Build Coastguard Worker            else:
512*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_L, expected_L)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        shapes = (0, 3, 5)
515*da0073e9SAndroid Build Coastguard Worker        batches = ((), (3, ), (2, 2))
516*da0073e9SAndroid Build Coastguard Worker        larger_input_case = [(100, (5, ), True)]
517*da0073e9SAndroid Build Coastguard Worker        for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case:
518*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, contiguous)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        # check the out= variant
521*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device)
522*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(A)
523*da0073e9SAndroid Build Coastguard Worker        ans = torch.linalg.cholesky(A, out=out)
524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ans, out)
525*da0073e9SAndroid Build Coastguard Worker        expected = torch.linalg.cholesky(A)
526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, out)
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        # check the upper= variant
529*da0073e9SAndroid Build Coastguard Worker        expected = torch.linalg.cholesky(A).mH
530*da0073e9SAndroid Build Coastguard Worker        actual = torch.linalg.cholesky(A, upper=True)
531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
534*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
535*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
536*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_errors_and_warnings(self, device, dtype):
537*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # cholesky requires the input to be a square matrix or batch of square matrices
540*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(2, 3, device=device, dtype=dtype)
541*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
542*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A)
543*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(2, 2, 3, device=device, dtype=dtype)
544*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
545*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A)
546*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
547*da0073e9SAndroid Build Coastguard Worker            np.linalg.cholesky(A.cpu().numpy())
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        # cholesky requires the input to be at least 2 dimensional tensor
550*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(2, device=device, dtype=dtype)
551*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
552*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A)
553*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(np.linalg.LinAlgError,
554*da0073e9SAndroid Build Coastguard Worker                                    r'1-dimensional array given\. Array must be at least two-dimensional'):
555*da0073e9SAndroid Build Coastguard Worker            np.linalg.cholesky(A.cpu().numpy())
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        # if the input matrix is not positive definite, an error should be raised
558*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
559*da0073e9SAndroid Build Coastguard Worker        A[-1, -1] = 0  # Now A is not positive definite
560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
561*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A)
562*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
563*da0073e9SAndroid Build Coastguard Worker            np.linalg.cholesky(A.cpu().numpy())
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        # if at least one matrix in the batch is singular, an error should be raised
566*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
567*da0073e9SAndroid Build Coastguard Worker        A = A.reshape((1, 3, 3))
568*da0073e9SAndroid Build Coastguard Worker        A = A.repeat(5, 1, 1)
569*da0073e9SAndroid Build Coastguard Worker        A[4, -1, -1] = 0  # Now A[4] is not positive definite
570*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
571*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
574*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(3, dtype=dtype, device=device)
575*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(2, 3, dtype=dtype, device=device)
576*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
577*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
578*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A, out=out)
579*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
580*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
581*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
584*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(*A.shape, dtype=torch.int, device=device)
585*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
586*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky(A, out=out)
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        # device should match
589*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
590*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
591*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=wrong_device, dtype=dtype)
592*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
593*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cholesky(A, out=out)
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker    # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py
596*da0073e9SAndroid Build Coastguard Worker    @slowTest
597*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
598*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
599*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
600*da0073e9SAndroid Build Coastguard Worker    def test_old_cholesky_batched_many_batches(self, device, dtype):
601*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_symmetric_pd_matrix
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker        def cholesky_test_helper(n, batchsize, device, upper):
604*da0073e9SAndroid Build Coastguard Worker            A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device)
605*da0073e9SAndroid Build Coastguard Worker            chol_fact = torch.cholesky(A, upper=upper)
606*da0073e9SAndroid Build Coastguard Worker            if upper:
607*da0073e9SAndroid Build Coastguard Worker                # Correctness check
608*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A, chol_fact.mT.matmul(chol_fact))
609*da0073e9SAndroid Build Coastguard Worker                # Upper triangular check
610*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(chol_fact, chol_fact.triu())
611*da0073e9SAndroid Build Coastguard Worker            else:
612*da0073e9SAndroid Build Coastguard Worker                # Correctness check
613*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A, chol_fact.matmul(chol_fact.mT))
614*da0073e9SAndroid Build Coastguard Worker                # Lower triangular check
615*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(chol_fact, chol_fact.tril())
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker        for upper, batchsize in itertools.product([True, False], [262144, 524288]):
618*da0073e9SAndroid Build Coastguard Worker            cholesky_test_helper(2, batchsize, device, upper)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
621*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
622*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
623*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
624*da0073e9SAndroid Build Coastguard Worker    def test_old_cholesky_batched(self, device, dtype):
625*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker        def cholesky_test_helper(n, batch_dims, upper):
628*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
629*da0073e9SAndroid Build Coastguard Worker            cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
630*da0073e9SAndroid Build Coastguard Worker            cholesky_exp = cholesky_exp.reshape_as(A)
631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]):
634*da0073e9SAndroid Build Coastguard Worker            cholesky_test_helper(3, batchsize, upper)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
637*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
638*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
639*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
640*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.01)
641*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.01)
642*da0073e9SAndroid Build Coastguard Worker    def test_old_cholesky(self, device, dtype):
643*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker        # default Case
648*da0073e9SAndroid Build Coastguard Worker        C = torch.cholesky(A)
649*da0073e9SAndroid Build Coastguard Worker        B = torch.mm(C, C.t().conj())
650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, B, atol=1e-14, rtol=0)
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        # test Upper Triangular
653*da0073e9SAndroid Build Coastguard Worker        U = torch.cholesky(A, True)
654*da0073e9SAndroid Build Coastguard Worker        B = torch.mm(U.t().conj(), U)
655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix')
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        # test Lower Triangular
658*da0073e9SAndroid Build Coastguard Worker        L = torch.cholesky(A, False)
659*da0073e9SAndroid Build Coastguard Worker        B = torch.mm(L, L.t().conj())
660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
663*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
664*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
665*da0073e9SAndroid Build Coastguard Worker    def test_old_cholesky_empty(self, device, dtype):
666*da0073e9SAndroid Build Coastguard Worker        def run_test(upper):
667*da0073e9SAndroid Build Coastguard Worker            A = torch.empty(0, 0, dtype=dtype, device=device)
668*da0073e9SAndroid Build Coastguard Worker            chol = torch.cholesky(A, upper)
669*da0073e9SAndroid Build Coastguard Worker            chol_A = torch.matmul(chol, chol.t().conj())
670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A, chol_A)
671*da0073e9SAndroid Build Coastguard Worker        for upper in [True, False]:
672*da0073e9SAndroid Build Coastguard Worker            run_test(upper)
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker    # Test for issue
675*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/57032
676*da0073e9SAndroid Build Coastguard Worker    # torch.cholesky with upper=True for batched CUDA inputs was wrong
677*da0073e9SAndroid Build Coastguard Worker    # it was using the lower triangular part instead of the upper one
678*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
679*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
680*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
681*da0073e9SAndroid Build Coastguard Worker    def test_old_cholesky_batched_upper(self, device, dtype):
682*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker        batchsize = 2
685*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device)
686*da0073e9SAndroid Build Coastguard Worker        A_triu = A.triu()  # fill the lower triangular part with zero
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        U = torch.cholesky(A_triu, upper=True)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        reconstruct_A = U.mH @ U
691*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, reconstruct_A)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
694*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
695*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
696*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_ex(self, device, dtype):
697*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        def run_test(n, batch):
700*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
701*da0073e9SAndroid Build Coastguard Worker            expected_L = np.linalg.cholesky(A.cpu().numpy())
702*da0073e9SAndroid Build Coastguard Worker            expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
703*da0073e9SAndroid Build Coastguard Worker            actual_L, actual_info = torch.linalg.cholesky_ex(A)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
706*da0073e9SAndroid Build Coastguard Worker            # Let's compare the norms of matrices instead
707*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
708*da0073e9SAndroid Build Coastguard Worker                # axis is specified to calculate matrix norm for batched input
709*da0073e9SAndroid Build Coastguard Worker                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
710*da0073e9SAndroid Build Coastguard Worker                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
711*da0073e9SAndroid Build Coastguard Worker                # Compare the norms with standard tolerances
712*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_norm, expected_norm)
713*da0073e9SAndroid Build Coastguard Worker                # and individual values with a higher tolerance
714*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
715*da0073e9SAndroid Build Coastguard Worker            else:
716*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_L, expected_L)
717*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_info, expected_info)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        ns = (0, 3, 5)
720*da0073e9SAndroid Build Coastguard Worker        batches = ((), (2, ), (2, 1))
721*da0073e9SAndroid Build Coastguard Worker        for n, batch in itertools.product(ns, batches):
722*da0073e9SAndroid Build Coastguard Worker            run_test(n, batch)
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
725*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
726*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
727*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_ex_non_pd(self, device, dtype):
728*da0073e9SAndroid Build Coastguard Worker        # if the input matrix is not positive definite, info with positive integer is returned
729*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
730*da0073e9SAndroid Build Coastguard Worker        A[-1, -1] = 0  # Now A is singular
731*da0073e9SAndroid Build Coastguard Worker        _, info = torch.linalg.cholesky_ex(A)
732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(info, 3)
733*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
734*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky_ex(A, check_errors=True)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker        # if at least one matrix in the batch is not positive definite,
737*da0073e9SAndroid Build Coastguard Worker        # batched info with positive integer for the corresponding matrix is returned
738*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
739*da0073e9SAndroid Build Coastguard Worker        A = A.reshape((1, 3, 3))
740*da0073e9SAndroid Build Coastguard Worker        A = A.repeat(5, 1, 1)
741*da0073e9SAndroid Build Coastguard Worker        A[3, -2, -2] = 0  # Now A[3] is singular
742*da0073e9SAndroid Build Coastguard Worker        _, info = torch.linalg.cholesky_ex(A)
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
745*da0073e9SAndroid Build Coastguard Worker        expected_info[3] = 2
746*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(info, expected_info)
747*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
748*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cholesky_ex(A, check_errors=True)
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker    def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1):
751*da0073e9SAndroid Build Coastguard Worker        def check(m, a, b, beta, alpha):
752*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
753*da0073e9SAndroid Build Coastguard Worker                a_np = a.to(torch.double).cpu().numpy()
754*da0073e9SAndroid Build Coastguard Worker                b_np = b.to(torch.double).cpu().numpy()
755*da0073e9SAndroid Build Coastguard Worker                m_np = m.to(torch.double).cpu().numpy()
756*da0073e9SAndroid Build Coastguard Worker                exact_dtype = False
757*da0073e9SAndroid Build Coastguard Worker            else:
758*da0073e9SAndroid Build Coastguard Worker                a_np = a.cpu().numpy()
759*da0073e9SAndroid Build Coastguard Worker                b_np = b.cpu().numpy()
760*da0073e9SAndroid Build Coastguard Worker                m_np = m.cpu().numpy()
761*da0073e9SAndroid Build Coastguard Worker                exact_dtype = True
762*da0073e9SAndroid Build Coastguard Worker            if beta == 0:
763*da0073e9SAndroid Build Coastguard Worker                expected = alpha * np.outer(a_np, b_np)
764*da0073e9SAndroid Build Coastguard Worker            else:
765*da0073e9SAndroid Build Coastguard Worker                expected = beta * m_np + alpha * np.outer(a_np, b_np)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker            res = torch.addr(m, a, b, beta=beta, alpha=alpha)
768*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected, exact_dtype=exact_dtype)
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker            # Test out variant
771*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(res)
772*da0073e9SAndroid Build Coastguard Worker            torch.addr(m, a, b, beta=beta, alpha=alpha, out=out)
773*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, expected, exact_dtype=exact_dtype)
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker        m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2)
776*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
777*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        check(m, a, b, beta, alpha)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        # test transpose
782*da0073e9SAndroid Build Coastguard Worker        m_transpose = torch.transpose(m, 0, 1)
783*da0073e9SAndroid Build Coastguard Worker        check(m_transpose, a, b, beta, alpha)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        # test 0 strided tensor
786*da0073e9SAndroid Build Coastguard Worker        zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50)
787*da0073e9SAndroid Build Coastguard Worker        check(m, zero_strided, b, beta, alpha)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        # test scalar
790*da0073e9SAndroid Build Coastguard Worker        m_scalar = torch.tensor(1, device=device, dtype=dtype)
791*da0073e9SAndroid Build Coastguard Worker        check(m_scalar, a, b, beta, alpha)
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        # test nans and infs are not propagated to the output when beta == 0
794*da0073e9SAndroid Build Coastguard Worker        float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16)
795*da0073e9SAndroid Build Coastguard Worker        if beta == 0 and dtype in float_and_complex_dtypes:
796*da0073e9SAndroid Build Coastguard Worker            m[0][10] = m[10][10] = m[20][20] = float('inf')
797*da0073e9SAndroid Build Coastguard Worker            m[1][10] = m[11][10] = m[21][20] = float('nan')
798*da0073e9SAndroid Build Coastguard Worker        check(m, a, b, 0, alpha)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bool)
801*da0073e9SAndroid Build Coastguard Worker    def test_addr_bool(self, device, dtype):
802*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False)
803*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True)
804*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False)
805*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker    @dtypes(*integral_types())
808*da0073e9SAndroid Build Coastguard Worker    def test_addr_integral(self, device, dtype):
809*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
810*da0073e9SAndroid Build Coastguard Worker                                    'argument beta must not be a floating point number.'):
811*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1)
812*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
813*da0073e9SAndroid Build Coastguard Worker                                    'argument alpha must not be a floating point number.'):
814*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.)
815*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
816*da0073e9SAndroid Build Coastguard Worker                                    'Boolean beta only supported for Boolean results.'):
817*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
818*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
819*da0073e9SAndroid Build Coastguard Worker                                    'Boolean alpha only supported for Boolean results.'):
820*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        # when beta is zero
823*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2)
824*da0073e9SAndroid Build Coastguard Worker        # when beta is not zero
825*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2)
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-1})
828*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
829*da0073e9SAndroid Build Coastguard Worker    def test_addr_float_and_complex(self, device, dtype):
830*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
831*da0073e9SAndroid Build Coastguard Worker                                    'Boolean beta only supported for Boolean results.'):
832*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
833*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
834*da0073e9SAndroid Build Coastguard Worker                                    'Boolean alpha only supported for Boolean results.'):
835*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        # when beta is zero
838*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2)
839*da0073e9SAndroid Build Coastguard Worker        # when beta is not zero
840*da0073e9SAndroid Build Coastguard Worker        self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2)
841*da0073e9SAndroid Build Coastguard Worker        if dtype in complex_types():
842*da0073e9SAndroid Build Coastguard Worker            self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j))
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
845*da0073e9SAndroid Build Coastguard Worker                               all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
846*da0073e9SAndroid Build Coastguard Worker    def test_outer_type_promotion(self, device, dtypes):
847*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5).to(device=device, dtype=dtypes[0])
848*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5).to(device=device, dtype=dtypes[1])
849*da0073e9SAndroid Build Coastguard Worker        for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger):
850*da0073e9SAndroid Build Coastguard Worker            result = op(a, b)
851*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, torch.result_type(a, b))
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    # don't use @dtypes decorator to avoid generating ~1700 tests per device
854*da0073e9SAndroid Build Coastguard Worker    def test_addr_type_promotion(self, device):
855*da0073e9SAndroid Build Coastguard Worker        for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3):
856*da0073e9SAndroid Build Coastguard Worker            a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2)
857*da0073e9SAndroid Build Coastguard Worker            b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2)
858*da0073e9SAndroid Build Coastguard Worker            m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker            desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1),
861*da0073e9SAndroid Build Coastguard Worker                                                dtypes2)
862*da0073e9SAndroid Build Coastguard Worker            for op in (torch.addr, torch.Tensor.addr):
863*da0073e9SAndroid Build Coastguard Worker                result = op(m, a, b)
864*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dtype, desired_dtype)
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker    # Tests migrated from test_torch.py
867*da0073e9SAndroid Build Coastguard Worker    # 1) test the shape of the result tensor when there is empty input tensor
868*da0073e9SAndroid Build Coastguard Worker    # 2) test the Runtime Exception when there is scalar input tensor
869*da0073e9SAndroid Build Coastguard Worker    def test_outer_ger_addr_legacy_tests(self, device):
870*da0073e9SAndroid Build Coastguard Worker        for size in ((0, 0), (0, 5), (5, 0)):
871*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(size[0], device=device)
872*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(size[1], device=device)
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.outer(a, b).shape, size)
875*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ger(a, b).shape, size)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker            m = torch.empty(size, device=device)
878*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addr(m, a, b).shape, size)
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker        m = torch.randn(5, 6, device=device)
881*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, device=device)
882*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(6, device=device)
883*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.outer(a, b))
884*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.outer(b, a))
885*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.ger(a, b))
886*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.ger(b, a))
887*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b))
888*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a))
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    # Tests torch.det and its alias, torch.linalg.det, vs. NumPy
891*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
892*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
893*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
894*da0073e9SAndroid Build Coastguard Worker    def test_det(self, device, dtype):
895*da0073e9SAndroid Build Coastguard Worker        tensors = (
896*da0073e9SAndroid Build Coastguard Worker            torch.randn((2, 2), device=device, dtype=dtype),
897*da0073e9SAndroid Build Coastguard Worker            torch.randn((129, 129), device=device, dtype=dtype),
898*da0073e9SAndroid Build Coastguard Worker            torch.randn((3, 52, 52), device=device, dtype=dtype),
899*da0073e9SAndroid Build Coastguard Worker            torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker        ops = (torch.det, torch.Tensor.det,
903*da0073e9SAndroid Build Coastguard Worker               torch.linalg.det)
904*da0073e9SAndroid Build Coastguard Worker        for t in tensors:
905*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.det(t.cpu().numpy())
906*da0073e9SAndroid Build Coastguard Worker            for op in ops:
907*da0073e9SAndroid Build Coastguard Worker                actual = op(t)
908*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
909*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(op, np.linalg.det, t)
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker        # NOTE: det requires a 2D+ tensor
912*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(1, device=device, dtype=dtype)
913*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
914*da0073e9SAndroid Build Coastguard Worker            op(t)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
917*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
918*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
919*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
920*da0073e9SAndroid Build Coastguard Worker    def test_eigh(self, device, dtype):
921*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_matrix
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, uplo):
924*da0073e9SAndroid Build Coastguard Worker            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
925*da0073e9SAndroid Build Coastguard Worker            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
926*da0073e9SAndroid Build Coastguard Worker            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
927*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_w, expected_w)
928*da0073e9SAndroid Build Coastguard Worker            # sign of eigenvectors is not unique and therefore absolute values are compared
929*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(abs(actual_v), abs(expected_v))
930*da0073e9SAndroid Build Coastguard Worker            # additionally we can multiply the eigenvector with a phase factor e^{i\phi} and then compare the values
931*da0073e9SAndroid Build Coastguard Worker            # let's choose the convention that the first element of the eigenvectors from torch and numpy be the same
932*da0073e9SAndroid Build Coastguard Worker            # for real inputs, this phase factor is plus or minus one
933*da0073e9SAndroid Build Coastguard Worker            if matrix.numel() > 0:
934*da0073e9SAndroid Build Coastguard Worker                phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :])
935*da0073e9SAndroid Build Coastguard Worker                actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v)
936*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_v_rotated, expected_v)
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
939*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty_like(actual_w)
940*da0073e9SAndroid Build Coastguard Worker            out_v = torch.empty_like(actual_v)
941*da0073e9SAndroid Build Coastguard Worker            ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v))
942*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans_w, out_w)
943*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans_v, out_v)
944*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans_w, actual_w)
945*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(abs(ans_v), abs(actual_v))
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        shapes = (0, 3, 5)
948*da0073e9SAndroid Build Coastguard Worker        batches = ((), (3, ), (2, 2))
949*da0073e9SAndroid Build Coastguard Worker        uplos = ["U", "L"]
950*da0073e9SAndroid Build Coastguard Worker        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
951*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, uplo)
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
954*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
955*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
956*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
957*da0073e9SAndroid Build Coastguard Worker    def test_eigh_lower_uplo(self, device, dtype):
958*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, uplo):
959*da0073e9SAndroid Build Coastguard Worker            # check lower case uplo
960*da0073e9SAndroid Build Coastguard Worker            # use non-symmetric input to check whether uplo argument is working as intended
961*da0073e9SAndroid Build Coastguard Worker            matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device)
962*da0073e9SAndroid Build Coastguard Worker            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
963*da0073e9SAndroid Build Coastguard Worker            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
964*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_w, expected_w)
965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(abs(actual_v), abs(expected_v))
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        uplos = ["u", "l"]
968*da0073e9SAndroid Build Coastguard Worker        for uplo in uplos:
969*da0073e9SAndroid Build Coastguard Worker            run_test(3, (2, 2), uplo)
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
972*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
973*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
974*da0073e9SAndroid Build Coastguard Worker    def test_eigh_errors_and_warnings(self, device, dtype):
975*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_matrix
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker        # eigh requires a square matrix
978*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, device=device, dtype=dtype)
979*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
980*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigh(t)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        # eigh requires 'uplo' parameter to be 'U' or 'L'
983*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device=device, dtype=dtype)
984*da0073e9SAndroid Build Coastguard Worker        for uplo in ["a", "wrong"]:
985*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
986*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigh(t, UPLO=uplo)
987*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
988*da0073e9SAndroid Build Coastguard Worker                np.linalg.eigh(t.cpu().numpy(), UPLO=uplo)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
991*da0073e9SAndroid Build Coastguard Worker        a = random_hermitian_matrix(3, dtype=dtype, device=device)
992*da0073e9SAndroid Build Coastguard Worker        real_dtype = a.real.dtype if dtype.is_complex else dtype
993*da0073e9SAndroid Build Coastguard Worker        out_w = torch.empty(7, 7, dtype=real_dtype, device=device)
994*da0073e9SAndroid Build Coastguard Worker        out_v = torch.empty(7, 7, dtype=dtype, device=device)
995*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
996*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
997*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigh(a, out=(out_w, out_v))
998*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
999*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
1000*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
1001*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
1004*da0073e9SAndroid Build Coastguard Worker        out_w = torch.empty(0, dtype=real_dtype, device=device)
1005*da0073e9SAndroid Build Coastguard Worker        out_v = torch.empty(0, dtype=torch.int, device=device)
1006*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1007*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigh(a, out=(out_w, out_v))
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker        out_w = torch.empty(0, dtype=torch.int, device=device)
1010*da0073e9SAndroid Build Coastguard Worker        out_v = torch.empty(0, dtype=dtype, device=device)
1011*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1012*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigh(a, out=(out_w, out_v))
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker        # device should match
1015*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1016*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1017*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty(0, device=wrong_device, dtype=dtype)
1018*da0073e9SAndroid Build Coastguard Worker            out_v = torch.empty(0, device=device, dtype=dtype)
1019*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1020*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigh(a, out=(out_w, out_v))
1021*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty(0, device=device, dtype=dtype)
1022*da0073e9SAndroid Build Coastguard Worker            out_v = torch.empty(0, device=wrong_device, dtype=dtype)
1023*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1024*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigh(a, out=(out_w, out_v))
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1027*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1028*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.")
1029*da0073e9SAndroid Build Coastguard Worker    def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype):
1030*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359
1031*da0073e9SAndroid Build Coastguard Worker        # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8,
1032*da0073e9SAndroid Build Coastguard Worker        # but passes on cuda 12.1 update 1 or later.
1033*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(512, 512, dtype=dtype, device=device)
1034*da0073e9SAndroid Build Coastguard Worker        a[0, 0] = 1.0e-5
1035*da0073e9SAndroid Build Coastguard Worker        a[-1, -1] = 1.0e5
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        eigh_out = torch.linalg.eigh(a)
1038*da0073e9SAndroid Build Coastguard Worker        svd_out = torch.linalg.svd(a)
1039*da0073e9SAndroid Build Coastguard Worker
1040*da0073e9SAndroid Build Coastguard Worker        # Matrix input a is too ill-conditioned.
1041*da0073e9SAndroid Build Coastguard Worker        # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0
1042*da0073e9SAndroid Build Coastguard Worker        # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge
1043*da0073e9SAndroid Build Coastguard Worker        # to exact values.
1044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1048*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1049*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1050*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
1051*da0073e9SAndroid Build Coastguard Worker    def test_eigvalsh(self, device, dtype):
1052*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_matrix
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, uplo):
1055*da0073e9SAndroid Build Coastguard Worker            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
1056*da0073e9SAndroid Build Coastguard Worker            expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo)
1057*da0073e9SAndroid Build Coastguard Worker            actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo)
1058*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_w, expected_w)
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
1061*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(actual_w)
1062*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out)
1063*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
1064*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, actual_w)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        shapes = (0, 3, 5)
1067*da0073e9SAndroid Build Coastguard Worker        batches = ((), (3, ), (2, 2))
1068*da0073e9SAndroid Build Coastguard Worker        uplos = ["U", "L"]
1069*da0073e9SAndroid Build Coastguard Worker        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
1070*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, uplo)
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1073*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1074*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1075*da0073e9SAndroid Build Coastguard Worker    def test_eigvalsh_errors_and_warnings(self, device, dtype):
1076*da0073e9SAndroid Build Coastguard Worker        # eigvalsh requires a square matrix
1077*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, device=device, dtype=dtype)
1078*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1079*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvalsh(t)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker        # eigvalsh requires 'uplo' parameter to be 'U' or 'L'
1082*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device=device, dtype=dtype)
1083*da0073e9SAndroid Build Coastguard Worker        for uplo in ["a", "wrong"]:
1084*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
1085*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigvalsh(t, UPLO=uplo)
1086*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
1087*da0073e9SAndroid Build Coastguard Worker                np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo)
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
1090*da0073e9SAndroid Build Coastguard Worker        real_dtype = t.real.dtype if dtype.is_complex else dtype
1091*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(t).to(real_dtype)
1092*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1093*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
1094*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvalsh(t, out=out)
1095*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
1096*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
1097*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
1100*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.int, device=device)
1101*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1102*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvalsh(t, out=out)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker        # device should match
1105*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1106*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1107*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=wrong_device, dtype=dtype)
1108*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1109*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigvalsh(t, out=out)
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1112*da0073e9SAndroid Build Coastguard Worker    def test_kron(self, device, dtype):
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker        def run_test_case(a_shape, b_shape):
1115*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(a_shape, dtype=dtype, device=device)
1116*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(b_shape, dtype=dtype, device=device)
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1119*da0073e9SAndroid Build Coastguard Worker            result = torch.kron(a, b)
1120*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
1123*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(result)
1124*da0073e9SAndroid Build Coastguard Worker            ans = torch.kron(a, b, out=out)
1125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
1126*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, result)
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker        shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
1129*da0073e9SAndroid Build Coastguard Worker        for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
1130*da0073e9SAndroid Build Coastguard Worker            run_test_case(a_shape, b_shape)
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1133*da0073e9SAndroid Build Coastguard Worker    def test_kron_empty(self, device, dtype):
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker        def run_test_case(empty_shape):
1136*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(3, dtype=dtype, device=device)
1137*da0073e9SAndroid Build Coastguard Worker            b = torch.empty(empty_shape, dtype=dtype, device=device)
1138*da0073e9SAndroid Build Coastguard Worker            result = torch.kron(a, b)
1139*da0073e9SAndroid Build Coastguard Worker            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker            # NumPy doesn't work if the first argument is empty
1143*da0073e9SAndroid Build Coastguard Worker            result = torch.kron(b, a)
1144*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, expected.shape)
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        empty_shapes = [(0,), (2, 0), (1, 0, 3)]
1147*da0073e9SAndroid Build Coastguard Worker        for empty_shape in empty_shapes:
1148*da0073e9SAndroid Build Coastguard Worker            run_test_case(empty_shape)
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1151*da0073e9SAndroid Build Coastguard Worker    def test_kron_errors_and_warnings(self, device, dtype):
1152*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
1153*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(3, dtype=dtype, device=device)
1154*da0073e9SAndroid Build Coastguard Worker        b = torch.ones((2, 2), dtype=dtype, device=device)
1155*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(a)
1156*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1157*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
1158*da0073e9SAndroid Build Coastguard Worker            torch.kron(a, b, out=out)
1159*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
1160*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
1161*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker        # dtypes should match
1164*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(a).to(torch.int)
1165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
1166*da0073e9SAndroid Build Coastguard Worker            torch.kron(a, b, out=out)
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker    # This test confirms that torch.linalg.norm's dtype argument works
1169*da0073e9SAndroid Build Coastguard Worker    # as expected, according to the function's documentation
1170*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1171*da0073e9SAndroid Build Coastguard Worker    def test_norm_dtype(self, device, dtype):
1172*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
1173*da0073e9SAndroid Build Coastguard Worker
1174*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input_size, ord, keepdim, to_dtype):
1175*da0073e9SAndroid Build Coastguard Worker            msg = (
1176*da0073e9SAndroid Build Coastguard Worker                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1177*da0073e9SAndroid Build Coastguard Worker                f'dtype={dtype}, to_dtype={to_dtype}')
1178*da0073e9SAndroid Build Coastguard Worker            input = make_arg(input_size)
1179*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1180*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, input.real.dtype, msg=msg)
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker            result_out = torch.empty((0), dtype=result.dtype, device=device)
1183*da0073e9SAndroid Build Coastguard Worker            torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out)
1184*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_out, msg=msg)
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim)
1187*da0073e9SAndroid Build Coastguard Worker            result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype)
1188*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_with_dtype, msg=msg)
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker            result_out_with_dtype = torch.empty_like(result_with_dtype)
1191*da0073e9SAndroid Build Coastguard Worker            torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype)
1192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg)
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker        # In these orders we are computing the 10-th power and 10-th root of numbers.
1197*da0073e9SAndroid Build Coastguard Worker        # We avoid them for half-precision types as it makes the tests above too badly conditioned
1198*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.float16 and dtype != torch.bfloat16:
1199*da0073e9SAndroid Build Coastguard Worker            ord_vector.extend([0.1, -0.1])
1200*da0073e9SAndroid Build Coastguard Worker        ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
1201*da0073e9SAndroid Build Coastguard Worker        S = 10
1202*da0073e9SAndroid Build Coastguard Worker
1203*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.cfloat:
1204*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (torch.cfloat, torch.cdouble)
1205*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.cdouble:
1206*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (torch.cdouble,)
1207*da0073e9SAndroid Build Coastguard Worker        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1208*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (torch.float, torch.double)
1209*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.double:
1210*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (torch.double,)
1211*da0073e9SAndroid Build Coastguard Worker        else:
1212*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Unsupported dtype")
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker        for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes):
1215*da0073e9SAndroid Build Coastguard Worker            run_test_case((S,) , ord, keepdim, norm_dtype)
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes):
1218*da0073e9SAndroid Build Coastguard Worker            if ord in [2, -2, 'nuc']:
1219*da0073e9SAndroid Build Coastguard Worker                # We need torch.svdvals
1220*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.float16 or dtype == torch.bfloat16:
1221*da0073e9SAndroid Build Coastguard Worker                    continue
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker                # We need LAPACK or equivalent
1224*da0073e9SAndroid Build Coastguard Worker                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1225*da0073e9SAndroid Build Coastguard Worker                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1226*da0073e9SAndroid Build Coastguard Worker                    continue
1227*da0073e9SAndroid Build Coastguard Worker            run_test_case((S, S) , ord, keepdim, norm_dtype)
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker    # This test confirms torch.linalg.norm bfloat16 and half get right result.
1230*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float16)
1231*da0073e9SAndroid Build Coastguard Worker    def test_norm_bfloat16_and_half(self, device, dtype):
1232*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input_size, ord, keepdim):
1235*da0073e9SAndroid Build Coastguard Worker            msg = (
1236*da0073e9SAndroid Build Coastguard Worker                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1237*da0073e9SAndroid Build Coastguard Worker                f'dtype={dtype}')
1238*da0073e9SAndroid Build Coastguard Worker            input = make_arg(input_size).fill_(1)
1239*da0073e9SAndroid Build Coastguard Worker            result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype)
1240*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1241*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_ref, result, msg=msg)
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1244*da0073e9SAndroid Build Coastguard Worker        for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)):
1245*da0073e9SAndroid Build Coastguard Worker            run_test_case((S,) , ord, keepdim, )
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1248*da0073e9SAndroid Build Coastguard Worker    def test_vector_norm(self, device, dtype):
1249*da0073e9SAndroid Build Coastguard Worker        if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
1250*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
1251*da0073e9SAndroid Build Coastguard Worker        # have to use torch.randn(...).to(bfloat16) instead of
1252*da0073e9SAndroid Build Coastguard Worker        # This test compares torch.linalg.vector_norm's output with
1253*da0073e9SAndroid Build Coastguard Worker        # torch.linalg.norm given a flattened tensor
1254*da0073e9SAndroid Build Coastguard Worker        ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1255*da0073e9SAndroid Build Coastguard Worker        input_sizes = [
1256*da0073e9SAndroid Build Coastguard Worker            (1, ),
1257*da0073e9SAndroid Build Coastguard Worker            (10, ),
1258*da0073e9SAndroid Build Coastguard Worker            (4, 5),
1259*da0073e9SAndroid Build Coastguard Worker            (3, 4, 5),
1260*da0073e9SAndroid Build Coastguard Worker            (0, ),
1261*da0073e9SAndroid Build Coastguard Worker            (0, 10),
1262*da0073e9SAndroid Build Coastguard Worker            (0, 0),
1263*da0073e9SAndroid Build Coastguard Worker            (10, 0, 10),
1264*da0073e9SAndroid Build Coastguard Worker        ]
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker        def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None):
1267*da0073e9SAndroid Build Coastguard Worker            if dim is None:
1268*da0073e9SAndroid Build Coastguard Worker                input_maybe_flat = input.flatten(0, -1)
1269*da0073e9SAndroid Build Coastguard Worker            else:
1270*da0073e9SAndroid Build Coastguard Worker                input_maybe_flat = input
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype)
1273*da0073e9SAndroid Build Coastguard Worker            if keepdim and dim is None:
1274*da0073e9SAndroid Build Coastguard Worker                result = result.reshape([1] * input.dim())
1275*da0073e9SAndroid Build Coastguard Worker            return result
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, ord, dim, keepdim, norm_dtype):
1278*da0073e9SAndroid Build Coastguard Worker            if (input.numel() == 0 and
1279*da0073e9SAndroid Build Coastguard Worker                (ord < 0. or ord == inf) and
1280*da0073e9SAndroid Build Coastguard Worker               (dim is None or input.shape[dim] == 0)):
1281*da0073e9SAndroid Build Coastguard Worker                # The operation does not have an identity.
1282*da0073e9SAndroid Build Coastguard Worker                error_msg = "linalg.vector_norm cannot compute"
1283*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, error_msg):
1284*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim)
1285*da0073e9SAndroid Build Coastguard Worker            else:
1286*da0073e9SAndroid Build Coastguard Worker                msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, '
1287*da0073e9SAndroid Build Coastguard Worker                       f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}')
1288*da0073e9SAndroid Build Coastguard Worker                result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1289*da0073e9SAndroid Build Coastguard Worker                result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1290*da0073e9SAndroid Build Coastguard Worker                if dtype.is_complex:
1291*da0073e9SAndroid Build Coastguard Worker                    result_dtype_reference = result_dtype_reference.real
1292*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result_dtype, result_dtype_reference, msg=msg)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker                if norm_dtype is not None:
1295*da0073e9SAndroid Build Coastguard Worker                    ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim)
1296*da0073e9SAndroid Build Coastguard Worker                    actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1297*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ref, actual, msg=msg)
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.cfloat:
1300*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (None, torch.cfloat, torch.cdouble)
1301*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.cdouble:
1302*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (None, torch.cdouble)
1303*da0073e9SAndroid Build Coastguard Worker        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1304*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (None, torch.float, torch.double)
1305*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.double:
1306*da0073e9SAndroid Build Coastguard Worker            norm_dtypes = (None, torch.double)
1307*da0073e9SAndroid Build Coastguard Worker        else:
1308*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Unsupported dtype")
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker        for amp in [False, True]:
1311*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type=device, enabled=amp):
1312*da0073e9SAndroid Build Coastguard Worker                for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
1313*da0073e9SAndroid Build Coastguard Worker                    input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1314*da0073e9SAndroid Build Coastguard Worker                    for dim in [None, random.randint(0, len(input_size) - 1)]:
1315*da0073e9SAndroid Build Coastguard Worker                        run_test_case(
1316*da0073e9SAndroid Build Coastguard Worker                            input,
1317*da0073e9SAndroid Build Coastguard Worker                            ord,
1318*da0073e9SAndroid Build Coastguard Worker                            dim,
1319*da0073e9SAndroid Build Coastguard Worker                            keepdim,
1320*da0073e9SAndroid Build Coastguard Worker                            norm_dtype)
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker    def test_vector_norm_dim_tuple_arg(self, device):
1323*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1324*da0073e9SAndroid Build Coastguard Worker            # input size, dim, error, error message
1325*da0073e9SAndroid Build Coastguard Worker            ((4, ), (0, ), None, None),
1326*da0073e9SAndroid Build Coastguard Worker            ((4, ), (1, ), IndexError, r'Dimension out of range'),
1327*da0073e9SAndroid Build Coastguard Worker            ((4, ), (-2, ), IndexError, r'Dimension out of range'),
1328*da0073e9SAndroid Build Coastguard Worker            ((4, 3), (0, -1), None, None),
1329*da0073e9SAndroid Build Coastguard Worker            ((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1330*da0073e9SAndroid Build Coastguard Worker            ((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1331*da0073e9SAndroid Build Coastguard Worker            ((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"),
1332*da0073e9SAndroid Build Coastguard Worker            ((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"),
1333*da0073e9SAndroid Build Coastguard Worker        ]
1334*da0073e9SAndroid Build Coastguard Worker        for input_size, dim_tuple, error, error_msg in test_cases:
1335*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(input_size, device=device)
1336*da0073e9SAndroid Build Coastguard Worker            # vector_norm should accept a tuple or a list for dim arg
1337*da0073e9SAndroid Build Coastguard Worker            for dim in [dim_tuple, list(dim_tuple)]:
1338*da0073e9SAndroid Build Coastguard Worker                if error is None:
1339*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.vector_norm(input, dim=dim)
1340*da0073e9SAndroid Build Coastguard Worker                else:
1341*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(error):
1342*da0073e9SAndroid Build Coastguard Worker                        torch.linalg.vector_norm(input, dim=dim)
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
1345*da0073e9SAndroid Build Coastguard Worker    # their vector norm results match
1346*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1347*da0073e9SAndroid Build Coastguard Worker    def test_norm_vector(self, device, dtype):
1348*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, p, dim, keepdim):
1349*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input, ord, dim, keepdim)
1350*da0073e9SAndroid Build Coastguard Worker            input_numpy = input.cpu().numpy()
1351*da0073e9SAndroid Build Coastguard Worker            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1354*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_numpy, msg=msg)
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker            result_out = torch.empty_like(result)
1357*da0073e9SAndroid Build Coastguard Worker            torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
1358*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_out, msg=msg)
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf]
1361*da0073e9SAndroid Build Coastguard Worker        S = 10
1362*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1363*da0073e9SAndroid Build Coastguard Worker            # input size, p settings, dim
1364*da0073e9SAndroid Build Coastguard Worker            ((S, ), ord_vector, None),
1365*da0073e9SAndroid Build Coastguard Worker            ((S, ), ord_vector, 0),
1366*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), ord_vector, 0),
1367*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), ord_vector, 1),
1368*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), ord_vector, 2),
1369*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), ord_vector, -1),
1370*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), ord_vector, -2),
1371*da0073e9SAndroid Build Coastguard Worker        ]
1372*da0073e9SAndroid Build Coastguard Worker        L = 1_000_000
1373*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.double:
1374*da0073e9SAndroid Build Coastguard Worker            test_cases.append(((L, ), ord_vector, None))
1375*da0073e9SAndroid Build Coastguard Worker        for keepdim in [True, False]:
1376*da0073e9SAndroid Build Coastguard Worker            for input_size, ord_settings, dim in test_cases:
1377*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(*input_size, dtype=dtype, device=device)
1378*da0073e9SAndroid Build Coastguard Worker                for ord in ord_settings:
1379*da0073e9SAndroid Build Coastguard Worker                    run_test_case(input, ord, dim, keepdim)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker    # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to
1382*da0073e9SAndroid Build Coastguard Worker    # ensure that their matrix norm results match.
1383*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # https://github.com/pytorch/pytorch/issues/54082
1384*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1385*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1386*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 2e-4})
1387*da0073e9SAndroid Build Coastguard Worker    def test_norm_matrix(self, device, dtype):
1388*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, ord, dim, keepdim):
1391*da0073e9SAndroid Build Coastguard Worker            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1392*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input, ord, dim, keepdim)
1393*da0073e9SAndroid Build Coastguard Worker            input_numpy = input.cpu().numpy()
1394*da0073e9SAndroid Build Coastguard Worker            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.norm(input, ord, dim, keepdim)
1397*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_numpy, msg=msg)
1398*da0073e9SAndroid Build Coastguard Worker            if ord is not None and dim is not None:
1399*da0073e9SAndroid Build Coastguard Worker                result = torch.linalg.matrix_norm(input, ord, dim, keepdim)
1400*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_numpy, msg=msg)
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro']
1403*da0073e9SAndroid Build Coastguard Worker        S = 10
1404*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1405*da0073e9SAndroid Build Coastguard Worker            # input size, dim
1406*da0073e9SAndroid Build Coastguard Worker            ((S, S), None),
1407*da0073e9SAndroid Build Coastguard Worker            ((S, S), (0, 1)),
1408*da0073e9SAndroid Build Coastguard Worker            ((S, S), (1, 0)),
1409*da0073e9SAndroid Build Coastguard Worker            ((S, S, S, S), (2, 0)),
1410*da0073e9SAndroid Build Coastguard Worker            ((S, S, S, S), (-1, -2)),
1411*da0073e9SAndroid Build Coastguard Worker            ((S, S, S, S), (-1, -3)),
1412*da0073e9SAndroid Build Coastguard Worker            ((S, S, S, S), (-3, 2)),
1413*da0073e9SAndroid Build Coastguard Worker        ]
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker        for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix):
1416*da0073e9SAndroid Build Coastguard Worker            if ord in [2, -2, 'nuc']:
1417*da0073e9SAndroid Build Coastguard Worker                # We need torch.svdvals
1418*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.float16 or dtype == torch.bfloat16:
1419*da0073e9SAndroid Build Coastguard Worker                    continue
1420*da0073e9SAndroid Build Coastguard Worker                # We need LAPACK or equivalent
1421*da0073e9SAndroid Build Coastguard Worker                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1422*da0073e9SAndroid Build Coastguard Worker                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1423*da0073e9SAndroid Build Coastguard Worker                    continue
1424*da0073e9SAndroid Build Coastguard Worker            run_test_case(make_arg(shape), ord, dim, keepdim)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1428*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float16)
1429*da0073e9SAndroid Build Coastguard Worker    def test_norm_fused_type_promotion(self, device, dtype):
1430*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, device=device, dtype=dtype)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        def profile_and_check(fn, x, kwargs):
1433*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
1434*da0073e9SAndroid Build Coastguard Worker                fn(x, **kwargs, dtype=torch.float)
1435*da0073e9SAndroid Build Coastguard Worker            # smoke check that profiler returned some events
1436*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events()))
1437*da0073e9SAndroid Build Coastguard Worker            # test that there was no explicit copy
1438*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("aten::to" in (e.name for e in p.events()))
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})):
1441*da0073e9SAndroid Build Coastguard Worker            profile_and_check(f, x, kwargs)
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1444*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1445*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1446*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1447*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3})
1448*da0073e9SAndroid Build Coastguard Worker    def test_cond(self, device, dtype):
1449*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, p):
1450*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.cond(input, p)
1451*da0073e9SAndroid Build Coastguard Worker            result_numpy = np.linalg.cond(input.cpu().numpy(), p)
1452*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False)
1453*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, result_numpy.shape)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker            # test out= variant
1456*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(result)
1457*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.cond(input, p, out=out)
1458*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
1459*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, result)
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1462*da0073e9SAndroid Build Coastguard Worker        input_sizes = [(32, 32), (2, 3, 3, 3)]
1463*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1464*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*input_size, dtype=dtype, device=device)
1465*da0073e9SAndroid Build Coastguard Worker            for p in norm_types:
1466*da0073e9SAndroid Build Coastguard Worker                run_test_case(input, p)
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker        # test empty batch sizes
1469*da0073e9SAndroid Build Coastguard Worker        input_sizes = [(0, 3, 3), (0, 2, 5, 5)]
1470*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1471*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*input_size, dtype=dtype, device=device)
1472*da0073e9SAndroid Build Coastguard Worker            for p in norm_types:
1473*da0073e9SAndroid Build Coastguard Worker                run_test_case(input, p)
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        # test non-square input
1476*da0073e9SAndroid Build Coastguard Worker        input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)]
1477*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1478*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*input_size, dtype=dtype, device=device)
1479*da0073e9SAndroid Build Coastguard Worker            for p in [2, -2, None]:
1480*da0073e9SAndroid Build Coastguard Worker                run_test_case(input, p)
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        # test for singular input
1483*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(3, dtype=dtype, device=device)
1484*da0073e9SAndroid Build Coastguard Worker        a[-1, -1] = 0  # make 'a' singular
1485*da0073e9SAndroid Build Coastguard Worker        for p in norm_types:
1486*da0073e9SAndroid Build Coastguard Worker            try:
1487*da0073e9SAndroid Build Coastguard Worker                run_test_case(a, p)
1488*da0073e9SAndroid Build Coastguard Worker            except np.linalg.LinAlgError:
1489*da0073e9SAndroid Build Coastguard Worker                # Numpy may fail to converge for some BLAS backends (although this is very rare)
1490*da0073e9SAndroid Build Coastguard Worker                # See the discussion in https://github.com/pytorch/pytorch/issues/67675
1491*da0073e9SAndroid Build Coastguard Worker                pass
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker        # test for 0x0 matrices. NumPy doesn't work for such input, we return 0
1494*da0073e9SAndroid Build Coastguard Worker        input_sizes = [(0, 0), (2, 5, 0, 0)]
1495*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1496*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*input_size, dtype=dtype, device=device)
1497*da0073e9SAndroid Build Coastguard Worker            for p in ['fro', 2]:
1498*da0073e9SAndroid Build Coastguard Worker                expected_dtype = a.real.dtype if dtype.is_complex else dtype
1499*da0073e9SAndroid Build Coastguard Worker                expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device)
1500*da0073e9SAndroid Build Coastguard Worker                actual = torch.linalg.cond(input, p)
1501*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1504*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1505*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1506*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1507*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3})
1508*da0073e9SAndroid Build Coastguard Worker    def test_cond_errors_and_warnings(self, device, dtype):
1509*da0073e9SAndroid Build Coastguard Worker        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        # cond expects the input to be at least 2-dimensional
1512*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(3, dtype=dtype, device=device)
1513*da0073e9SAndroid Build Coastguard Worker        for p in norm_types:
1514*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
1515*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cond(a, p)
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        # for some norm types cond expects the input to be square
1518*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(3, 2, dtype=dtype, device=device)
1519*da0073e9SAndroid Build Coastguard Worker        norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
1520*da0073e9SAndroid Build Coastguard Worker        for p in norm_types:
1521*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
1522*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cond(a, p)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
1525*da0073e9SAndroid Build Coastguard Worker        a = torch.ones((2, 2), dtype=dtype, device=device)
1526*da0073e9SAndroid Build Coastguard Worker        for p in ['fro', 2]:
1527*da0073e9SAndroid Build Coastguard Worker            real_dtype = a.real.dtype if dtype.is_complex else dtype
1528*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(a.shape, dtype=real_dtype, device=device)
1529*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
1530*da0073e9SAndroid Build Coastguard Worker                # Trigger warning
1531*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cond(a, p, out=out)
1532*da0073e9SAndroid Build Coastguard Worker                # Check warning occurs
1533*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 1)
1534*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
1537*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.int, device=device)
1538*da0073e9SAndroid Build Coastguard Worker        for p in ['fro', 2]:
1539*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
1540*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cond(a, p, out=out)
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker        # device should match
1543*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1544*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1545*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=wrong_device)
1546*da0073e9SAndroid Build Coastguard Worker            for p in ['fro', 2]:
1547*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1548*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.cond(a, p, out=out)
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        # for batched input if at least one matrix in the batch is not invertible,
1551*da0073e9SAndroid Build Coastguard Worker        # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop.
1552*da0073e9SAndroid Build Coastguard Worker        # this should change when at::inverse works with silent errors
1553*da0073e9SAndroid Build Coastguard Worker        # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results
1554*da0073e9SAndroid Build Coastguard Worker        # possibly filled with NANs
1555*da0073e9SAndroid Build Coastguard Worker        batch_dim = 3
1556*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(3, 3, dtype=dtype, device=device)
1557*da0073e9SAndroid Build Coastguard Worker        a = a.reshape((1, 3, 3))
1558*da0073e9SAndroid Build Coastguard Worker        a = a.repeat(batch_dim, 1, 1)
1559*da0073e9SAndroid Build Coastguard Worker        a[1, -1, -1] = 0  # now a[1] is singular
1560*da0073e9SAndroid Build Coastguard Worker        for p in [1, -1, inf, -inf, 'fro', 'nuc']:
1561*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.cond(a, p)
1562*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result[1], float('inf'))
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker        # check invalid norm type
1565*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(3, 3, dtype=dtype, device=device)
1566*da0073e9SAndroid Build Coastguard Worker        for p in ['wrong_norm', 5]:
1567*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
1568*da0073e9SAndroid Build Coastguard Worker                torch.linalg.cond(a, p)
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker    # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments
1571*da0073e9SAndroid Build Coastguard Worker    # to ensure that they both throw errors
1572*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1573*da0073e9SAndroid Build Coastguard Worker    def test_norm_errors(self, device, dtype):
1574*da0073e9SAndroid Build Coastguard Worker        def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex):
1575*da0073e9SAndroid Build Coastguard Worker            test_case_info = (
1576*da0073e9SAndroid Build Coastguard Worker                f'test case input.size()={input.size()}, ord={ord}, dim={dim}, '
1577*da0073e9SAndroid Build Coastguard Worker                f'keepdim={keepdim}, dtype={dtype}')
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info):
1580*da0073e9SAndroid Build Coastguard Worker                torch.linalg.norm(input, ord, dim, keepdim)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker            input_numpy = input.cpu().numpy()
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker            msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"'
1585*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(Exception, msg=test_case_info):
1586*da0073e9SAndroid Build Coastguard Worker                np.linalg.norm(input_numpy, ord, dim, keepdim)
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker        S = 10
1589*da0073e9SAndroid Build Coastguard Worker        error_test_cases = [
1590*da0073e9SAndroid Build Coastguard Worker            # input size, p settings, dim, error type, error regex
1591*da0073e9SAndroid Build Coastguard Worker            ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'),
1592*da0073e9SAndroid Build Coastguard Worker            ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'),
1593*da0073e9SAndroid Build Coastguard Worker            ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'),
1594*da0073e9SAndroid Build Coastguard Worker            ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'),
1595*da0073e9SAndroid Build Coastguard Worker            ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'),
1596*da0073e9SAndroid Build Coastguard Worker            ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'),
1597*da0073e9SAndroid Build Coastguard Worker            ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'),
1598*da0073e9SAndroid Build Coastguard Worker            ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'),
1599*da0073e9SAndroid Build Coastguard Worker            ((S, ), [0], (4, ), IndexError, r'Dimension out of range'),
1600*da0073e9SAndroid Build Coastguard Worker            ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'),
1601*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."),
1602*da0073e9SAndroid Build Coastguard Worker            ((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"),
1603*da0073e9SAndroid Build Coastguard Worker        ]
1604*da0073e9SAndroid Build Coastguard Worker        for keepdim in [True, False]:
1605*da0073e9SAndroid Build Coastguard Worker            for input_size, ord_settings, dim, error_type, error_regex in error_test_cases:
1606*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(*input_size, dtype=dtype, device=device)
1607*da0073e9SAndroid Build Coastguard Worker                for ord in ord_settings:
1608*da0073e9SAndroid Build Coastguard Worker                    run_error_test_case(input, ord, dim, keepdim, error_type, error_regex)
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker    # Test complex number inputs for linalg.norm
1611*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1612*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1613*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cfloat, torch.cdouble)
1614*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.cfloat: 5e-4})
1615*da0073e9SAndroid Build Coastguard Worker    def test_norm_complex(self, device, dtype):
1616*da0073e9SAndroid Build Coastguard Worker        def gen_error_message(input_size, ord, keepdim, dim=None):
1617*da0073e9SAndroid Build Coastguard Worker            return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}"
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf]
1620*da0073e9SAndroid Build Coastguard Worker        matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf]
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker        # Test supported ords
1623*da0073e9SAndroid Build Coastguard Worker        for keepdim in [False, True]:
1624*da0073e9SAndroid Build Coastguard Worker            # vector norm
1625*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, device=device, dtype=dtype)
1626*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
1627*da0073e9SAndroid Build Coastguard Worker            for ord in vector_ords:
1628*da0073e9SAndroid Build Coastguard Worker                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1629*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1630*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), ord, keepdim)
1631*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.shape, expected.shape, msg=msg)
1632*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker                res_out = torch.tensor([], device=device, dtype=res.dtype)
1635*da0073e9SAndroid Build Coastguard Worker                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1636*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1637*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_out, expected, msg=msg)
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker            # matrix norm
1640*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, 25, device=device, dtype=dtype)
1641*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
1642*da0073e9SAndroid Build Coastguard Worker            for ord in matrix_ords:
1643*da0073e9SAndroid Build Coastguard Worker                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1644*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1645*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), ord, keepdim)
1646*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.shape, expected.shape, msg=msg)
1647*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker                res_out = torch.tensor([], device=device, dtype=res.dtype)
1650*da0073e9SAndroid Build Coastguard Worker                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1651*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1652*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_out, expected, msg=msg)
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker    # Test that linal.vector_norm gives the same result as numpy when inputs
1655*da0073e9SAndroid Build Coastguard Worker    # contain extreme values (inf, -inf, nan)
1656*da0073e9SAndroid Build Coastguard Worker    def test_vector_norm_extreme_values(self, device):
1657*da0073e9SAndroid Build Coastguard Worker        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1658*da0073e9SAndroid Build Coastguard Worker        vectors = []
1659*da0073e9SAndroid Build Coastguard Worker        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1660*da0073e9SAndroid Build Coastguard Worker            vectors.append(list(pair))
1661*da0073e9SAndroid Build Coastguard Worker        for vector in vectors:
1662*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(vector, device=device)
1663*da0073e9SAndroid Build Coastguard Worker            x_n = x.cpu().numpy()
1664*da0073e9SAndroid Build Coastguard Worker            for ord in vector_ords:
1665*da0073e9SAndroid Build Coastguard Worker                msg = f'ord={ord}, vector={vector}'
1666*da0073e9SAndroid Build Coastguard Worker                result = torch.linalg.vector_norm(x, ord=ord)
1667*da0073e9SAndroid Build Coastguard Worker                result_n = np.linalg.norm(x_n, ord=ord)
1668*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_n, msg=msg)
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1671*da0073e9SAndroid Build Coastguard Worker    def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
1672*da0073e9SAndroid Build Coastguard Worker        input_sizes_and_dims = [
1673*da0073e9SAndroid Build Coastguard Worker            ((6, 1), -1),
1674*da0073e9SAndroid Build Coastguard Worker            ((3, 1, 2, 1), (1, 3)),
1675*da0073e9SAndroid Build Coastguard Worker            ((1,), None),
1676*da0073e9SAndroid Build Coastguard Worker        ]
1677*da0073e9SAndroid Build Coastguard Worker        orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
1678*da0073e9SAndroid Build Coastguard Worker        keepdims = [True, False]
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker        for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
1681*da0073e9SAndroid Build Coastguard Worker            input_size = input_size_and_dim[0]
1682*da0073e9SAndroid Build Coastguard Worker            dim = input_size_and_dim[1]
1683*da0073e9SAndroid Build Coastguard Worker            if type(dim) is tuple and ord == 0:
1684*da0073e9SAndroid Build Coastguard Worker                # skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.'
1685*da0073e9SAndroid Build Coastguard Worker                continue
1686*da0073e9SAndroid Build Coastguard Worker            input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1687*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.vector_norm(input, ord, dim, keepdim)
1688*da0073e9SAndroid Build Coastguard Worker            result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1691*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_numpy, msg=msg)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
1694*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1695*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1696*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 2e-5})
1697*da0073e9SAndroid Build Coastguard Worker    def test_matrix_norm(self, device, dtype):
1698*da0073e9SAndroid Build Coastguard Worker        # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm
1699*da0073e9SAndroid Build Coastguard Worker        A = make_tensor((2, 2, 2), dtype=dtype, device=device)
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'):
1702*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device))
1703*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'):
1704*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_norm(A, dim=(0,))
1705*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1706*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_norm(A, ord=0)
1707*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1708*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_norm(A, ord=3.0)
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker        # Test dim=None behavior
1711*da0073e9SAndroid Build Coastguard Worker        ref = torch.linalg.norm(A, dim=(-2, -1))
1712*da0073e9SAndroid Build Coastguard Worker        res = torch.linalg.matrix_norm(A)
1713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker    # Test that linal.norm gives the same result as numpy when inputs
1716*da0073e9SAndroid Build Coastguard Worker    # contain extreme values (inf, -inf, nan)
1717*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
1718*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipped on MacOS!")
1719*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1720*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1721*da0073e9SAndroid Build Coastguard Worker    def test_norm_extreme_values(self, device):
1722*da0073e9SAndroid Build Coastguard Worker        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1723*da0073e9SAndroid Build Coastguard Worker        # matrix_ords 'nuc', 2, -2 are skipped currently
1724*da0073e9SAndroid Build Coastguard Worker        # See issue https://github.com/pytorch/pytorch/issues/71911
1725*da0073e9SAndroid Build Coastguard Worker        matrix_ords = ['fro', 1, inf, -1, -inf]
1726*da0073e9SAndroid Build Coastguard Worker        vectors = []
1727*da0073e9SAndroid Build Coastguard Worker        matrices = []
1728*da0073e9SAndroid Build Coastguard Worker        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1729*da0073e9SAndroid Build Coastguard Worker            vectors.append(list(pair))
1730*da0073e9SAndroid Build Coastguard Worker            matrices.append([[pair[0], pair[1]]])
1731*da0073e9SAndroid Build Coastguard Worker            matrices.append([[pair[0]], [pair[1]]])
1732*da0073e9SAndroid Build Coastguard Worker        for vector in vectors:
1733*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(vector).to(device)
1734*da0073e9SAndroid Build Coastguard Worker            x_n = x.cpu().numpy()
1735*da0073e9SAndroid Build Coastguard Worker            for ord in vector_ords:
1736*da0073e9SAndroid Build Coastguard Worker                msg = f'ord={ord}, vector={vector}'
1737*da0073e9SAndroid Build Coastguard Worker                result = torch.linalg.norm(x, ord=ord)
1738*da0073e9SAndroid Build Coastguard Worker                result_n = np.linalg.norm(x_n, ord=ord)
1739*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_n, msg=msg)
1740*da0073e9SAndroid Build Coastguard Worker
1741*da0073e9SAndroid Build Coastguard Worker        # TODO: Remove this function once the broken cases are fixed
1742*da0073e9SAndroid Build Coastguard Worker        def is_broken_matrix_norm_case(ord, x):
1743*da0073e9SAndroid Build Coastguard Worker            if self.device_type == 'cuda':
1744*da0073e9SAndroid Build Coastguard Worker                if x.size() == torch.Size([1, 2]):
1745*da0073e9SAndroid Build Coastguard Worker                    if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1:
1746*da0073e9SAndroid Build Coastguard Worker                        # These cases are broken because of an issue with svd
1747*da0073e9SAndroid Build Coastguard Worker                        # https://github.com/pytorch/pytorch/issues/43567
1748*da0073e9SAndroid Build Coastguard Worker                        return True
1749*da0073e9SAndroid Build Coastguard Worker                if ord in ['nuc', 2, -2]:
1750*da0073e9SAndroid Build Coastguard Worker                    # These cases are broken because of another issue with svd
1751*da0073e9SAndroid Build Coastguard Worker                    # https://github.com/pytorch/pytorch/issues/52633
1752*da0073e9SAndroid Build Coastguard Worker                    return True
1753*da0073e9SAndroid Build Coastguard Worker            return False
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        for matrix in matrices:
1756*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(matrix).to(device)
1757*da0073e9SAndroid Build Coastguard Worker            x_n = x.cpu().numpy()
1758*da0073e9SAndroid Build Coastguard Worker            for ord in matrix_ords:
1759*da0073e9SAndroid Build Coastguard Worker                msg = f'ord={ord}, matrix={matrix}'
1760*da0073e9SAndroid Build Coastguard Worker                if is_broken_matrix_norm_case(ord, x):
1761*da0073e9SAndroid Build Coastguard Worker                    continue
1762*da0073e9SAndroid Build Coastguard Worker                else:
1763*da0073e9SAndroid Build Coastguard Worker                    result_n = np.linalg.norm(x_n, ord=ord)
1764*da0073e9SAndroid Build Coastguard Worker                    result = torch.linalg.norm(x, ord=ord)
1765*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result, result_n, msg=msg)
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker    # Test degenerate shape results match numpy for linalg.norm vector norms
1768*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1769*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1770*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1771*da0073e9SAndroid Build Coastguard Worker    def test_norm_vector_degenerate_shapes(self, device, dtype):
1772*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, ord, dim, keepdim):
1773*da0073e9SAndroid Build Coastguard Worker            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1774*da0073e9SAndroid Build Coastguard Worker            if (input.numel() == 0 and
1775*da0073e9SAndroid Build Coastguard Worker                (ord < 0. or ord == inf) and
1776*da0073e9SAndroid Build Coastguard Worker               (dim is None or input.shape[dim] == 0)):
1777*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
1778*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.norm(input, ord, dim, keepdim)
1779*da0073e9SAndroid Build Coastguard Worker            else:
1780*da0073e9SAndroid Build Coastguard Worker                input_numpy = input.cpu().numpy()
1781*da0073e9SAndroid Build Coastguard Worker                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1782*da0073e9SAndroid Build Coastguard Worker                result = torch.linalg.norm(input, ord, dim, keepdim)
1783*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_numpy, msg=msg)
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker        ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1786*da0073e9SAndroid Build Coastguard Worker        S = 10
1787*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1788*da0073e9SAndroid Build Coastguard Worker            # input size, dim
1789*da0073e9SAndroid Build Coastguard Worker            ((0, ), None),
1790*da0073e9SAndroid Build Coastguard Worker            ((0, S), 0),
1791*da0073e9SAndroid Build Coastguard Worker            ((0, S), 1),
1792*da0073e9SAndroid Build Coastguard Worker            ((S, 0), 0),
1793*da0073e9SAndroid Build Coastguard Worker            ((S, 0), 1),
1794*da0073e9SAndroid Build Coastguard Worker        ]
1795*da0073e9SAndroid Build Coastguard Worker        for keepdim in [True, False]:
1796*da0073e9SAndroid Build Coastguard Worker            for input_size, dim in test_cases:
1797*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(*input_size, dtype=dtype, device=device)
1798*da0073e9SAndroid Build Coastguard Worker                for ord in ord_vector:
1799*da0073e9SAndroid Build Coastguard Worker                    run_test_case(input, ord, dim, keepdim)
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker    # Test degenerate shape results match numpy for linalg.norm matrix norms
1802*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1803*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1804*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1805*da0073e9SAndroid Build Coastguard Worker    def test_norm_matrix_degenerate_shapes(self, device, dtype):
1806*da0073e9SAndroid Build Coastguard Worker        def run_test_case(input, ord, dim, keepdim, should_error):
1807*da0073e9SAndroid Build Coastguard Worker            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1808*da0073e9SAndroid Build Coastguard Worker            input_numpy = input.cpu().numpy()
1809*da0073e9SAndroid Build Coastguard Worker            ops = [torch.linalg.norm]
1810*da0073e9SAndroid Build Coastguard Worker
1811*da0073e9SAndroid Build Coastguard Worker            if ord is not None and dim is not None:
1812*da0073e9SAndroid Build Coastguard Worker                ops.append(torch.linalg.matrix_norm)
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker            if should_error:
1815*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(ValueError):
1816*da0073e9SAndroid Build Coastguard Worker                    np.linalg.norm(input_numpy, ord, dim, keepdim)
1817*da0073e9SAndroid Build Coastguard Worker                for op in ops:
1818*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(IndexError):
1819*da0073e9SAndroid Build Coastguard Worker                        op(input, ord, dim, keepdim)
1820*da0073e9SAndroid Build Coastguard Worker            else:
1821*da0073e9SAndroid Build Coastguard Worker                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1822*da0073e9SAndroid Build Coastguard Worker                for op in ops:
1823*da0073e9SAndroid Build Coastguard Worker                    result = op(input, ord, dim, keepdim)
1824*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result, result_numpy, msg=msg)
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker        ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None]
1827*da0073e9SAndroid Build Coastguard Worker        S = 10
1828*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1829*da0073e9SAndroid Build Coastguard Worker            # input size, p settings that cause error, dim
1830*da0073e9SAndroid Build Coastguard Worker            ((0, 0), [1, 2, inf, -1, -2, -inf], None),
1831*da0073e9SAndroid Build Coastguard Worker            ((0, S), [2, inf, -2, -inf], None),
1832*da0073e9SAndroid Build Coastguard Worker            ((S, 0), [1, 2, -1, -2], None),
1833*da0073e9SAndroid Build Coastguard Worker            ((S, S, 0), [], (0, 1)),
1834*da0073e9SAndroid Build Coastguard Worker            ((1, S, 0), [], (0, 1)),
1835*da0073e9SAndroid Build Coastguard Worker            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)),
1836*da0073e9SAndroid Build Coastguard Worker            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)),
1837*da0073e9SAndroid Build Coastguard Worker        ]
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker        for keepdim in [True, False]:
1840*da0073e9SAndroid Build Coastguard Worker            for input_size, error_ords, dim in test_cases:
1841*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(*input_size, dtype=dtype, device=device)
1842*da0073e9SAndroid Build Coastguard Worker                for ord in ord_matrix:
1843*da0073e9SAndroid Build Coastguard Worker                    run_test_case(input, ord, dim, keepdim, ord in error_ords)
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker    def test_norm_fastpaths(self, device):
1846*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 5, device=device)
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker        # slow path
1849*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.norm(x, 4.5, 1)
1850*da0073e9SAndroid Build Coastguard Worker        expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5)
1851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1852*da0073e9SAndroid Build Coastguard Worker
1853*da0073e9SAndroid Build Coastguard Worker        # fast 0-norm
1854*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.norm(x, 0, 1)
1855*da0073e9SAndroid Build Coastguard Worker        expected = (x != 0).type_as(x).sum(1)
1856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker        # fast 1-norm
1859*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.norm(x, 1, 1)
1860*da0073e9SAndroid Build Coastguard Worker        expected = x.abs().sum(1)
1861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker        # fast 2-norm
1864*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.norm(x, 2, 1)
1865*da0073e9SAndroid Build Coastguard Worker        expected = torch.sqrt(x.pow(2).sum(1))
1866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker        # fast 3-norm
1869*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.norm(x, 3, 1)
1870*da0073e9SAndroid Build Coastguard Worker        expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
1871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1874*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1875*da0073e9SAndroid Build Coastguard Worker    # NumPy computes only in float64 and complex128 precisions
1876*da0073e9SAndroid Build Coastguard Worker    # for float32 or complex64 results might be very different from float64 or complex128
1877*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.complex128)
1878*da0073e9SAndroid Build Coastguard Worker    def test_eig_numpy(self, device, dtype):
1879*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, *, symmetric=False):
1880*da0073e9SAndroid Build Coastguard Worker            from torch.testing._internal.common_utils import random_symmetric_matrix
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex and symmetric:
1883*da0073e9SAndroid Build Coastguard Worker                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1884*da0073e9SAndroid Build Coastguard Worker                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
1885*da0073e9SAndroid Build Coastguard Worker                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1886*da0073e9SAndroid Build Coastguard Worker            else:
1887*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, dtype=dtype, device=device)
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.eig(a)
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Worker            # compare with NumPy
1892*da0073e9SAndroid Build Coastguard Worker            # the eigenvalues are not necessarily ordered
1893*da0073e9SAndroid Build Coastguard Worker            # so order of NumPy and PyTorch can be different
1894*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.eig(a.cpu().numpy())
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker            # sort NumPy output
1897*da0073e9SAndroid Build Coastguard Worker            ind = np.argsort(expected[0], axis=-1)[::-1]
1898*da0073e9SAndroid Build Coastguard Worker            expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1))
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker            # sort PyTorch output
1901*da0073e9SAndroid Build Coastguard Worker            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
1902*da0073e9SAndroid Build Coastguard Worker            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
1903*da0073e9SAndroid Build Coastguard Worker            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
1904*da0073e9SAndroid Build Coastguard Worker            ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1]
1905*da0073e9SAndroid Build Coastguard Worker            actual_np = [x.cpu().numpy() for x in actual]
1906*da0073e9SAndroid Build Coastguard Worker            sorted_actual = (
1907*da0073e9SAndroid Build Coastguard Worker                np.take_along_axis(actual_np[0], ind, axis=-1),
1908*da0073e9SAndroid Build Coastguard Worker                np.take_along_axis(actual_np[1], ind[:, None], axis=-1))
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False)
1911*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False)
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        shapes = [(0, 0),  # Empty matrix
1914*da0073e9SAndroid Build Coastguard Worker                  (5, 5),  # Single matrix
1915*da0073e9SAndroid Build Coastguard Worker                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1916*da0073e9SAndroid Build Coastguard Worker                  (2, 5, 5),  # 3-dim tensors
1917*da0073e9SAndroid Build Coastguard Worker                  (2, 1, 5, 5)]  # 4-dim tensors
1918*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
1919*da0073e9SAndroid Build Coastguard Worker            run_test(shape)
1920*da0073e9SAndroid Build Coastguard Worker            run_test(shape, symmetric=True)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1923*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1924*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1925*da0073e9SAndroid Build Coastguard Worker    def test_eig_compare_backends(self, device, dtype):
1926*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, *, symmetric=False):
1927*da0073e9SAndroid Build Coastguard Worker            from torch.testing._internal.common_utils import random_symmetric_matrix
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex and symmetric:
1930*da0073e9SAndroid Build Coastguard Worker                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1931*da0073e9SAndroid Build Coastguard Worker                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1932*da0073e9SAndroid Build Coastguard Worker            else:
1933*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, dtype=dtype, device=device)
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.eig(a)
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker            complementary_device = 'cpu'
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker            # compare with CPU
1940*da0073e9SAndroid Build Coastguard Worker            expected = torch.linalg.eig(a.to(complementary_device))
1941*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected[0], actual[0])
1942*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected[1], actual[1])
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        shapes = [(0, 0),  # Empty matrix
1945*da0073e9SAndroid Build Coastguard Worker                  (5, 5),  # Single matrix
1946*da0073e9SAndroid Build Coastguard Worker                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1947*da0073e9SAndroid Build Coastguard Worker                  (2, 5, 5),  # 3-dim tensors
1948*da0073e9SAndroid Build Coastguard Worker                  (2, 1, 5, 5)]  # 4-dim tensors
1949*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
1950*da0073e9SAndroid Build Coastguard Worker            run_test(shape)
1951*da0073e9SAndroid Build Coastguard Worker            run_test(shape, symmetric=True)
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker    @slowTest
1954*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1955*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1956*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
1957*da0073e9SAndroid Build Coastguard Worker    def test_eig_check_magma(self, device, dtype):
1958*da0073e9SAndroid Build Coastguard Worker        # For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library
1959*da0073e9SAndroid Build Coastguard Worker        shape = (2049, 2049)
1960*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(shape, dtype=dtype, device=device)
1961*da0073e9SAndroid Build Coastguard Worker        w, v = torch.linalg.eig(a)
1962*da0073e9SAndroid Build Coastguard Worker        # check correctness using eigendecomposition identity
1963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3)
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
1966*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
1967*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
1968*da0073e9SAndroid Build Coastguard Worker    def test_eig_errors_and_warnings(self, device, dtype):
1969*da0073e9SAndroid Build Coastguard Worker        # eig requires the input to be at least 2 dimensional tensor
1970*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(2, dtype=dtype, device=device)
1971*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
1972*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eig(a)
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Worker        # eig requires a square matrix
1975*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((2, 3), dtype=dtype, device=device)
1976*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1977*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eig(a)
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        # if out tensor with floating dtype is passed for complex output an error is thrown
1980*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_complex:
1981*da0073e9SAndroid Build Coastguard Worker            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
1982*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
1983*da0073e9SAndroid Build Coastguard Worker            out0 = torch.empty(0, device=device, dtype=dtype)
1984*da0073e9SAndroid Build Coastguard Worker            out1 = torch.empty(0, device=device, dtype=dtype)
1985*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
1986*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eig(a, out=(out0, out1))
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker            out0 = torch.empty(0, device=device, dtype=torch.complex128)
1989*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"):
1990*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eig(a, out=(out0, out1))
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
1993*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((3, 3), dtype=dtype, device=device)
1994*da0073e9SAndroid Build Coastguard Worker        out0 = torch.empty(0, dtype=torch.int, device=device)
1995*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(0, dtype=torch.int, device=device)
1996*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
1997*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eig(a, out=(out0, out1))
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker        out0 = torch.empty(0, dtype=torch.complex128, device=device)
2000*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"):
2001*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eig(a, out=(out0, out1))
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
2004*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((3, 3), dtype=dtype, device=device)
2005*da0073e9SAndroid Build Coastguard Worker        out0 = torch.empty(1, device=device, dtype=torch.complex128)
2006*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(1, device=device, dtype=torch.complex128)
2007*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
2008*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
2009*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eig(a, out=(out0, out1))
2010*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
2011*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
2012*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2013*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
2014*da0073e9SAndroid Build Coastguard Worker
2015*da0073e9SAndroid Build Coastguard Worker        # device should match
2016*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2017*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2018*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2019*da0073e9SAndroid Build Coastguard Worker            out_v = torch.empty(0, device=device, dtype=torch.complex128)
2020*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2021*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eig(a, out=(out_w, out_v))
2022*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty(0, device=device, dtype=torch.complex128)
2023*da0073e9SAndroid Build Coastguard Worker            out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2024*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2025*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eig(a, out=(out_w, out_v))
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2028*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2029*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2030*da0073e9SAndroid Build Coastguard Worker    def test_eig_with_nan(self, device, dtype):
2031*da0073e9SAndroid Build Coastguard Worker        for val in [np.inf, np.nan]:
2032*da0073e9SAndroid Build Coastguard Worker            for batch_dim in [(), (10,)]:
2033*da0073e9SAndroid Build Coastguard Worker                a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype)
2034*da0073e9SAndroid Build Coastguard Worker                a[..., -1, -1] = val
2035*da0073e9SAndroid Build Coastguard Worker
2036*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"):
2037*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.eig(a)
2038*da0073e9SAndroid Build Coastguard Worker
2039*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2040*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2041*da0073e9SAndroid Build Coastguard Worker    # NumPy computes only in float64 and complex128 precisions
2042*da0073e9SAndroid Build Coastguard Worker    # for float32 or complex64 results might be very different from float64 or complex128
2043*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.complex128)
2044*da0073e9SAndroid Build Coastguard Worker    def test_eigvals_numpy(self, device, dtype):
2045*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, *, symmetric=False):
2046*da0073e9SAndroid Build Coastguard Worker            from torch.testing._internal.common_utils import random_symmetric_matrix
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex and symmetric:
2049*da0073e9SAndroid Build Coastguard Worker                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2050*da0073e9SAndroid Build Coastguard Worker                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
2051*da0073e9SAndroid Build Coastguard Worker                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2052*da0073e9SAndroid Build Coastguard Worker            else:
2053*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, dtype=dtype, device=device)
2054*da0073e9SAndroid Build Coastguard Worker
2055*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.eigvals(a)
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker            # compare with NumPy
2058*da0073e9SAndroid Build Coastguard Worker            # the eigenvalues are not necessarily ordered
2059*da0073e9SAndroid Build Coastguard Worker            # so order of NumPy and PyTorch can be different
2060*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.eigvals(a.cpu().numpy())
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker            # sort NumPy output
2063*da0073e9SAndroid Build Coastguard Worker            ind = np.argsort(expected, axis=-1)[::-1]
2064*da0073e9SAndroid Build Coastguard Worker            expected = np.take_along_axis(expected, ind, axis=-1)
2065*da0073e9SAndroid Build Coastguard Worker
2066*da0073e9SAndroid Build Coastguard Worker            # sort PyTorch output
2067*da0073e9SAndroid Build Coastguard Worker            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
2068*da0073e9SAndroid Build Coastguard Worker            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
2069*da0073e9SAndroid Build Coastguard Worker            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
2070*da0073e9SAndroid Build Coastguard Worker            ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1]
2071*da0073e9SAndroid Build Coastguard Worker            actual_np = actual.cpu().numpy()
2072*da0073e9SAndroid Build Coastguard Worker            sorted_actual = np.take_along_axis(actual_np, ind, axis=-1)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, sorted_actual, exact_dtype=False)
2075*da0073e9SAndroid Build Coastguard Worker
2076*da0073e9SAndroid Build Coastguard Worker        shapes = [(0, 0),  # Empty matrix
2077*da0073e9SAndroid Build Coastguard Worker                  (5, 5),  # Single matrix
2078*da0073e9SAndroid Build Coastguard Worker                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2079*da0073e9SAndroid Build Coastguard Worker                  (2, 5, 5),  # 3-dim tensors
2080*da0073e9SAndroid Build Coastguard Worker                  (2, 1, 5, 5)]  # 4-dim tensors
2081*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2082*da0073e9SAndroid Build Coastguard Worker            run_test(shape)
2083*da0073e9SAndroid Build Coastguard Worker            run_test(shape, symmetric=True)
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2086*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2087*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2088*da0073e9SAndroid Build Coastguard Worker    def test_eigvals_compare_backends(self, device, dtype):
2089*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, *, symmetric=False):
2090*da0073e9SAndroid Build Coastguard Worker            from torch.testing._internal.common_utils import random_symmetric_matrix
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex and symmetric:
2093*da0073e9SAndroid Build Coastguard Worker                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2094*da0073e9SAndroid Build Coastguard Worker                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2095*da0073e9SAndroid Build Coastguard Worker            else:
2096*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, dtype=dtype, device=device)
2097*da0073e9SAndroid Build Coastguard Worker
2098*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.eigvals(a)
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker            complementary_device = 'cpu'
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker            # compare with CPU
2103*da0073e9SAndroid Build Coastguard Worker            expected = torch.linalg.eigvals(a.to(complementary_device))
2104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker            # check out= variant
2107*da0073e9SAndroid Build Coastguard Worker            complex_dtype = dtype
2108*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex:
2109*da0073e9SAndroid Build Coastguard Worker                complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64
2110*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=complex_dtype, device=device)
2111*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.eigvals(a, out=out)
2112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
2113*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.to(complex_dtype), out)
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Worker            # check non-contiguous out
2116*da0073e9SAndroid Build Coastguard Worker            if a.numel() > 0:
2117*da0073e9SAndroid Build Coastguard Worker                out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2]
2118*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(out.is_contiguous())
2119*da0073e9SAndroid Build Coastguard Worker                ans = torch.linalg.eigvals(a, out=out)
2120*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ans, out)
2121*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected.to(complex_dtype), out)
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker        shapes = [(0, 0),  # Empty matrix
2124*da0073e9SAndroid Build Coastguard Worker                  (5, 5),  # Single matrix
2125*da0073e9SAndroid Build Coastguard Worker                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2126*da0073e9SAndroid Build Coastguard Worker                  (2, 5, 5),  # 3-dim tensors
2127*da0073e9SAndroid Build Coastguard Worker                  (2, 1, 5, 5)]  # 4-dim tensors
2128*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2129*da0073e9SAndroid Build Coastguard Worker            run_test(shape)
2130*da0073e9SAndroid Build Coastguard Worker            run_test(shape, symmetric=True)
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2133*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2134*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2135*da0073e9SAndroid Build Coastguard Worker    def test_eigvals_errors_and_warnings(self, device, dtype):
2136*da0073e9SAndroid Build Coastguard Worker        # eig requires the input to be at least 2 dimensional tensor
2137*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(2, dtype=dtype, device=device)
2138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
2139*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvals(a)
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker        # eig requires a square matrix
2142*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((2, 3), dtype=dtype, device=device)
2143*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2144*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvals(a)
2145*da0073e9SAndroid Build Coastguard Worker
2146*da0073e9SAndroid Build Coastguard Worker        # if out tensor with floating dtype is passed for complex output an error is thrown
2147*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_complex:
2148*da0073e9SAndroid Build Coastguard Worker            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
2149*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
2150*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=device, dtype=dtype)
2151*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
2152*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigvals(a, out=out)
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
2155*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((3, 3), dtype=dtype, device=device)
2156*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.int, device=device)
2157*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
2158*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvals(a, out=out)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
2161*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(1, device=device, dtype=torch.complex128)
2162*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
2163*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
2164*da0073e9SAndroid Build Coastguard Worker            torch.linalg.eigvals(a, out=out)
2165*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
2166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
2167*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker        # device should match
2170*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2171*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2172*da0073e9SAndroid Build Coastguard Worker            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2173*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2174*da0073e9SAndroid Build Coastguard Worker                torch.linalg.eigvals(a, out=out_w)
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2177*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2178*da0073e9SAndroid Build Coastguard Worker    def test_norm_old(self, device):
2179*da0073e9SAndroid Build Coastguard Worker        def gen_error_message(input_size, p, keepdim, dim=None):
2180*da0073e9SAndroid Build Coastguard Worker            return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2181*da0073e9SAndroid Build Coastguard Worker
2182*da0073e9SAndroid Build Coastguard Worker        # 'nuc' norm uses SVD, and thus its precsion is much lower than other norms.
2183*da0073e9SAndroid Build Coastguard Worker        # test_svd takes @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}),
2184*da0073e9SAndroid Build Coastguard Worker        # and here we are doing the same thing for nuc norm.
2185*da0073e9SAndroid Build Coastguard Worker        class PrecisionContext:
2186*da0073e9SAndroid Build Coastguard Worker            def __init__(self, test, norm):
2187*da0073e9SAndroid Build Coastguard Worker                self.norm = norm
2188*da0073e9SAndroid Build Coastguard Worker                self.saved_overrides = getattr(test, 'precision_overrides', None)
2189*da0073e9SAndroid Build Coastguard Worker                self.target_test = test
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
2192*da0073e9SAndroid Build Coastguard Worker                if 'nuc' != self.norm:
2193*da0073e9SAndroid Build Coastguard Worker                    return None
2194*da0073e9SAndroid Build Coastguard Worker                self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4}
2195*da0073e9SAndroid Build Coastguard Worker                return self.target_test.precision_overrides
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type, value, tb) -> bool:
2198*da0073e9SAndroid Build Coastguard Worker                if 'nuc' != self.norm:
2199*da0073e9SAndroid Build Coastguard Worker                    return True
2200*da0073e9SAndroid Build Coastguard Worker                if self.saved_overrides is None:
2201*da0073e9SAndroid Build Coastguard Worker                    delattr(self.target_test, 'precision_overrides')
2202*da0073e9SAndroid Build Coastguard Worker                else:
2203*da0073e9SAndroid Build Coastguard Worker                    self.target_test.precision_overrides = self.saved_overrides
2204*da0073e9SAndroid Build Coastguard Worker                return True
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker        for keepdim in [False, True]:
2207*da0073e9SAndroid Build Coastguard Worker            # full reduction
2208*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, device=device)
2209*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2210*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]:
2211*da0073e9SAndroid Build Coastguard Worker                res = x.norm(p, keepdim=keepdim).cpu()
2212*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2213*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim))
2214*da0073e9SAndroid Build Coastguard Worker
2215*da0073e9SAndroid Build Coastguard Worker            # one dimension
2216*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, 25, device=device)
2217*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2218*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]:
2219*da0073e9SAndroid Build Coastguard Worker                dim = 1
2220*da0073e9SAndroid Build Coastguard Worker                res = x.norm(p, dim, keepdim=keepdim).cpu()
2221*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, p, dim, keepdims=keepdim)
2222*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), p, keepdim, dim)
2223*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.shape, expected.shape, msg=msg)
2224*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, msg=msg)
2225*da0073e9SAndroid Build Coastguard Worker
2226*da0073e9SAndroid Build Coastguard Worker            # matrix norm
2227*da0073e9SAndroid Build Coastguard Worker            for p in ['fro', 'nuc']:
2228*da0073e9SAndroid Build Coastguard Worker                res = x.norm(p, keepdim=keepdim).cpu()
2229*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2230*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), p, keepdim)
2231*da0073e9SAndroid Build Coastguard Worker                with PrecisionContext(self, p):
2232*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.shape, expected.shape, msg=msg)
2233*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res, expected, msg=msg)
2234*da0073e9SAndroid Build Coastguard Worker
2235*da0073e9SAndroid Build Coastguard Worker            # zero dimensions
2236*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((), device=device)
2237*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2238*da0073e9SAndroid Build Coastguard Worker            res = x.norm(keepdim=keepdim).cpu()
2239*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.norm(xn, keepdims=keepdim)
2240*da0073e9SAndroid Build Coastguard Worker            msg = gen_error_message(x.size(), None, keepdim)
2241*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, expected.shape, msg=msg)
2242*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected, msg=msg)
2243*da0073e9SAndroid Build Coastguard Worker
2244*da0073e9SAndroid Build Coastguard Worker            # larger tensor sanity check
2245*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2246*da0073e9SAndroid Build Coastguard Worker                2 * torch.norm(torch.ones(10000), keepdim=keepdim),
2247*da0073e9SAndroid Build Coastguard Worker                torch.norm(torch.ones(40000), keepdim=keepdim))
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker            # matrix norm with non-square >2-D tensors, all combinations of reduction dims
2250*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 6, 7, 8, device=device)
2251*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2252*da0073e9SAndroid Build Coastguard Worker            for p in ['fro', 'nuc']:
2253*da0073e9SAndroid Build Coastguard Worker                for dim in itertools.product(*[list(range(4))] * 2):
2254*da0073e9SAndroid Build Coastguard Worker                    if dim[0] == dim[1]:
2255*da0073e9SAndroid Build Coastguard Worker                        continue
2256*da0073e9SAndroid Build Coastguard Worker                    res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu()
2257*da0073e9SAndroid Build Coastguard Worker                    expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim)
2258*da0073e9SAndroid Build Coastguard Worker                    msg = gen_error_message(x.size(), p, keepdim, dim)
2259*da0073e9SAndroid Build Coastguard Worker                    with PrecisionContext(self, p):
2260*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(res.shape, expected.shape, msg=msg)
2261*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(res, expected, msg=msg)
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker    # Test that torch.norm with p=+/-inf propagates NaN
2264*da0073e9SAndroid Build Coastguard Worker    def test_norm_old_nan_propagation(self, device):
2265*da0073e9SAndroid Build Coastguard Worker        ords = [inf, -inf]
2266*da0073e9SAndroid Build Coastguard Worker        for pair in itertools.product([0.0, nan, 1.0], repeat=2):
2267*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(list(pair), device=device)
2268*da0073e9SAndroid Build Coastguard Worker            for ord in ords:
2269*da0073e9SAndroid Build Coastguard Worker                result = torch.norm(x, p=ord)
2270*da0073e9SAndroid Build Coastguard Worker                result_check = torch.linalg.norm(x, ord=ord)
2271*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_check)
2272*da0073e9SAndroid Build Coastguard Worker
2273*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2274*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2275*da0073e9SAndroid Build Coastguard Worker    def test_norm_complex_old(self, device):
2276*da0073e9SAndroid Build Coastguard Worker        def gen_error_message(input_size, p, keepdim, dim=None):
2277*da0073e9SAndroid Build Coastguard Worker            return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2278*da0073e9SAndroid Build Coastguard Worker
2279*da0073e9SAndroid Build Coastguard Worker        for keepdim in [False, True]:
2280*da0073e9SAndroid Build Coastguard Worker            # vector norm
2281*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device)
2282*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2283*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]:
2284*da0073e9SAndroid Build Coastguard Worker                res = x.norm(p, keepdim=keepdim).cpu()
2285*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2286*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), p, keepdim)
2287*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.shape, expected.shape, msg=msg)
2288*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, msg=msg)
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker            # matrix norm
2291*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device)
2292*da0073e9SAndroid Build Coastguard Worker            xn = x.cpu().numpy()
2293*da0073e9SAndroid Build Coastguard Worker            for p in ['nuc', 'fro']:
2294*da0073e9SAndroid Build Coastguard Worker                res = x.norm(p, keepdim=keepdim).cpu()
2295*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2296*da0073e9SAndroid Build Coastguard Worker                msg = gen_error_message(x.size(), p, keepdim)
2297*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.shape, expected.shape, msg=msg)
2298*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4)
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker    # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations
2301*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
2302*da0073e9SAndroid Build Coastguard Worker    def test_norm_fro_2_equivalence_old(self, device, dtype):
2303*da0073e9SAndroid Build Coastguard Worker        input_sizes = [
2304*da0073e9SAndroid Build Coastguard Worker            (0,),
2305*da0073e9SAndroid Build Coastguard Worker            (10,),
2306*da0073e9SAndroid Build Coastguard Worker            (0, 0),
2307*da0073e9SAndroid Build Coastguard Worker            (4, 30),
2308*da0073e9SAndroid Build Coastguard Worker            (0, 45),
2309*da0073e9SAndroid Build Coastguard Worker            (100, 0),
2310*da0073e9SAndroid Build Coastguard Worker            (45, 10, 23),
2311*da0073e9SAndroid Build Coastguard Worker            (0, 23, 59),
2312*da0073e9SAndroid Build Coastguard Worker            (23, 0, 37),
2313*da0073e9SAndroid Build Coastguard Worker            (34, 58, 0),
2314*da0073e9SAndroid Build Coastguard Worker            (0, 0, 348),
2315*da0073e9SAndroid Build Coastguard Worker            (0, 3434, 0),
2316*da0073e9SAndroid Build Coastguard Worker            (0, 0, 0),
2317*da0073e9SAndroid Build Coastguard Worker            (5, 3, 8, 1, 3, 5)]
2318*da0073e9SAndroid Build Coastguard Worker
2319*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
2320*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker            # Try full reduction
2323*da0073e9SAndroid Build Coastguard Worker            dim_settings = [None]
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker            # Try all possible 1-D reductions
2326*da0073e9SAndroid Build Coastguard Worker            dim_settings += list(range(-a.dim(), a.dim()))
2327*da0073e9SAndroid Build Coastguard Worker
2328*da0073e9SAndroid Build Coastguard Worker            def wrap_dim(dim, ndims):
2329*da0073e9SAndroid Build Coastguard Worker                assert (dim < ndims) and (dim >= -ndims)
2330*da0073e9SAndroid Build Coastguard Worker                if dim >= 0:
2331*da0073e9SAndroid Build Coastguard Worker                    return dim
2332*da0073e9SAndroid Build Coastguard Worker                else:
2333*da0073e9SAndroid Build Coastguard Worker                    return dim + ndims
2334*da0073e9SAndroid Build Coastguard Worker
2335*da0073e9SAndroid Build Coastguard Worker            # Try all possible 2-D reductions
2336*da0073e9SAndroid Build Coastguard Worker            dim_settings += [
2337*da0073e9SAndroid Build Coastguard Worker                (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2)
2338*da0073e9SAndroid Build Coastguard Worker                if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())]
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker            for dim in dim_settings:
2341*da0073e9SAndroid Build Coastguard Worker                for keepdim in [True, False]:
2342*da0073e9SAndroid Build Coastguard Worker                    a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim)
2343*da0073e9SAndroid Build Coastguard Worker                    a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim)
2344*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a_norm_fro, a_norm_2)
2345*da0073e9SAndroid Build Coastguard Worker
2346*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2347*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2348*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2349*da0073e9SAndroid Build Coastguard Worker    def test_nuclear_norm_axes_small_brute_force_old(self, device):
2350*da0073e9SAndroid Build Coastguard Worker        def check_single_nuclear_norm(x, axes):
2351*da0073e9SAndroid Build Coastguard Worker            if self.device_type != 'cpu' and randrange(100) < 95:
2352*da0073e9SAndroid Build Coastguard Worker                return  # too many cpu <==> device copies
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker            a = np.array(x.cpu(), copy=False)
2355*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.norm(a, "nuc", axis=axes)
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker            ans = torch.norm(x, "nuc", dim=axes)
2358*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ans.is_contiguous())
2359*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans.shape, expected.shape)
2360*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker            out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
2363*da0073e9SAndroid Build Coastguard Worker            ans = torch.norm(x, "nuc", dim=axes, out=out)
2364*da0073e9SAndroid Build Coastguard Worker            self.assertIs(ans, out)
2365*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ans.is_contiguous())
2366*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans.shape, expected.shape)
2367*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        for n in range(1, 3):
2370*da0073e9SAndroid Build Coastguard Worker            for m in range(1, 3):
2371*da0073e9SAndroid Build Coastguard Worker                for axes in itertools.permutations([0, 1], 2):
2372*da0073e9SAndroid Build Coastguard Worker                    # 2d, inner dimensions C
2373*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(n, m, device=device)
2374*da0073e9SAndroid Build Coastguard Worker                    check_single_nuclear_norm(x, axes)
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker                    # 2d, inner dimensions Fortran
2377*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(m, n, device=device).mT
2378*da0073e9SAndroid Build Coastguard Worker                    check_single_nuclear_norm(x, axes)
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker                    # 2d, inner dimensions non-contiguous
2381*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(n, 2 * m, device=device)[:, ::2]
2382*da0073e9SAndroid Build Coastguard Worker                    check_single_nuclear_norm(x, axes)
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker                    # 2d, all dimensions non-contiguous
2385*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
2386*da0073e9SAndroid Build Coastguard Worker                    check_single_nuclear_norm(x, axes)
2387*da0073e9SAndroid Build Coastguard Worker
2388*da0073e9SAndroid Build Coastguard Worker                for o in range(1, 3):
2389*da0073e9SAndroid Build Coastguard Worker                    for axes in itertools.permutations([0, 1, 2], 2):
2390*da0073e9SAndroid Build Coastguard Worker                        # 3d, inner dimensions C
2391*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(o, n, m, device=device)
2392*da0073e9SAndroid Build Coastguard Worker                        check_single_nuclear_norm(x, axes)
2393*da0073e9SAndroid Build Coastguard Worker
2394*da0073e9SAndroid Build Coastguard Worker                        # 3d, inner dimensions Fortran
2395*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(o, m, n, device=device).mT
2396*da0073e9SAndroid Build Coastguard Worker                        check_single_nuclear_norm(x, axes)
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker                        # 3d, inner dimensions non-contiguous
2399*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
2400*da0073e9SAndroid Build Coastguard Worker                        check_single_nuclear_norm(x, axes)
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker                        # 3d, all dimensions non-contiguous
2403*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
2404*da0073e9SAndroid Build Coastguard Worker                        check_single_nuclear_norm(x, axes)
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker                    for r in range(1, 3):
2407*da0073e9SAndroid Build Coastguard Worker                        for axes in itertools.permutations([0, 1, 2, 3], 2):
2408*da0073e9SAndroid Build Coastguard Worker                            # 4d, inner dimensions C
2409*da0073e9SAndroid Build Coastguard Worker                            x = torch.randn(r, o, n, m, device=device)
2410*da0073e9SAndroid Build Coastguard Worker                            check_single_nuclear_norm(x, axes)
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker                            # 4d, inner dimensions Fortran
2413*da0073e9SAndroid Build Coastguard Worker                            x = torch.randn(r, o, n, m, device=device).mT
2414*da0073e9SAndroid Build Coastguard Worker                            check_single_nuclear_norm(x, axes)
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker                            # 4d, inner dimensions non-contiguous
2417*da0073e9SAndroid Build Coastguard Worker                            x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
2418*da0073e9SAndroid Build Coastguard Worker                            check_single_nuclear_norm(x, axes)
2419*da0073e9SAndroid Build Coastguard Worker
2420*da0073e9SAndroid Build Coastguard Worker                            # 4d, all dimensions non-contiguous
2421*da0073e9SAndroid Build Coastguard Worker                            x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
2422*da0073e9SAndroid Build Coastguard Worker                            check_single_nuclear_norm(x, axes)
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2425*da0073e9SAndroid Build Coastguard Worker    def test_nuclear_norm_exceptions_old(self, device):
2426*da0073e9SAndroid Build Coastguard Worker        for lst in [], [1], [1, 2]:
2427*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(lst, dtype=torch.double, device=device)
2428*da0073e9SAndroid Build Coastguard Worker            for axes in (), (0,):
2429*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
2430*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1))
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
2433*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0))
2434*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
2437*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2438*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2439*da0073e9SAndroid Build Coastguard Worker    def test_svd_lowrank(self, device, dtype):
2440*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker        def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options):
2443*da0073e9SAndroid Build Coastguard Worker            density = options.pop('density', 1)
2444*da0073e9SAndroid Build Coastguard Worker            if isinstance(matrix_size, int):
2445*da0073e9SAndroid Build Coastguard Worker                rows = columns = matrix_size
2446*da0073e9SAndroid Build Coastguard Worker            else:
2447*da0073e9SAndroid Build Coastguard Worker                rows, columns = matrix_size
2448*da0073e9SAndroid Build Coastguard Worker            if density == 1:
2449*da0073e9SAndroid Build Coastguard Worker                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
2450*da0073e9SAndroid Build Coastguard Worker                a = a_input
2451*da0073e9SAndroid Build Coastguard Worker            else:
2452*da0073e9SAndroid Build Coastguard Worker                assert batches == ()
2453*da0073e9SAndroid Build Coastguard Worker                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
2454*da0073e9SAndroid Build Coastguard Worker                a = a_input.to_dense()
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker            q = min(*size)
2457*da0073e9SAndroid Build Coastguard Worker            u, s, v = svd_lowrank(a_input, q=q, **options)
2458*da0073e9SAndroid Build Coastguard Worker
2459*da0073e9SAndroid Build Coastguard Worker            # check if u, s, v is a SVD
2460*da0073e9SAndroid Build Coastguard Worker            u, s, v = u[..., :q], s[..., :q], v[..., :q]
2461*da0073e9SAndroid Build Coastguard Worker            A = (u * s.unsqueeze(-2)).matmul(v.mH)
2462*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A, a, rtol=1e-7, atol=2e-7)
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker            # check if svd_lowrank produces same singular values as linalg.svdvals
2465*da0073e9SAndroid Build Coastguard Worker            U, S, Vh = torch.linalg.svd(a, full_matrices=False)
2466*da0073e9SAndroid Build Coastguard Worker            V = Vh.mH
2467*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s, S)
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker            if density == 1:
2470*da0073e9SAndroid Build Coastguard Worker                # actual_rank is known only for dense inputs
2471*da0073e9SAndroid Build Coastguard Worker                #
2472*da0073e9SAndroid Build Coastguard Worker                # check if pairs (u, U) and (v, V) span the same
2473*da0073e9SAndroid Build Coastguard Worker                # subspaces, respectively
2474*da0073e9SAndroid Build Coastguard Worker                u, v = u[..., :actual_rank], v[..., :actual_rank]
2475*da0073e9SAndroid Build Coastguard Worker                U, V = U[..., :actual_rank], V[..., :actual_rank]
2476*da0073e9SAndroid Build Coastguard Worker                expected_ones = u.mH.matmul(U).det().abs()
2477*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_ones, torch.ones_like(expected_ones))
2478*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones))
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker        all_batches = [(), (1,), (3,), (2, 3)]
2481*da0073e9SAndroid Build Coastguard Worker        for actual_rank, size, all_batches in [  # noqa: B020
2482*da0073e9SAndroid Build Coastguard Worker                (2, (17, 4), all_batches),
2483*da0073e9SAndroid Build Coastguard Worker                (4, (17, 4), all_batches),
2484*da0073e9SAndroid Build Coastguard Worker                (4, (17, 17), all_batches),
2485*da0073e9SAndroid Build Coastguard Worker                (10, (100, 40), all_batches),
2486*da0073e9SAndroid Build Coastguard Worker                (7, (1000, 1000), [()]),
2487*da0073e9SAndroid Build Coastguard Worker        ]:
2488*da0073e9SAndroid Build Coastguard Worker            # dense input
2489*da0073e9SAndroid Build Coastguard Worker            for batches in all_batches:
2490*da0073e9SAndroid Build Coastguard Worker                run_subtest(actual_rank, size, batches, device, torch.svd_lowrank)
2491*da0073e9SAndroid Build Coastguard Worker                if size != size[::-1]:
2492*da0073e9SAndroid Build Coastguard Worker                    run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank)
2493*da0073e9SAndroid Build Coastguard Worker
2494*da0073e9SAndroid Build Coastguard Worker        # sparse input
2495*da0073e9SAndroid Build Coastguard Worker        for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]:
2496*da0073e9SAndroid Build Coastguard Worker            for density in [0.005, 0.1]:
2497*da0073e9SAndroid Build Coastguard Worker                run_subtest(None, size, (), device, torch.svd_lowrank, density=density)
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker        # jitting support
2500*da0073e9SAndroid Build Coastguard Worker        jitted = torch.jit.script(torch.svd_lowrank)
2501*da0073e9SAndroid Build Coastguard Worker        actual_rank, size, batches = 2, (17, 4), ()
2502*da0073e9SAndroid Build Coastguard Worker        run_subtest(actual_rank, size, batches, device, jitted)
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2505*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2506*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4})
2507*da0073e9SAndroid Build Coastguard Worker    @setLinalgBackendsToDefaultFinally
2508*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2509*da0073e9SAndroid Build Coastguard Worker    @serialTest()
2510*da0073e9SAndroid Build Coastguard Worker    def test_svd(self, device, dtype):
2511*da0073e9SAndroid Build Coastguard Worker        # tests linalg.svd, svd, linalg.svdvals
2512*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
2513*da0073e9SAndroid Build Coastguard Worker
2514*da0073e9SAndroid Build Coastguard Worker        backends = ["default"]
2515*da0073e9SAndroid Build Coastguard Worker
2516*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
2517*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.has_magma:
2518*da0073e9SAndroid Build Coastguard Worker                backends.append("magma")
2519*da0073e9SAndroid Build Coastguard Worker            if has_cusolver() or has_hipsolver():
2520*da0073e9SAndroid Build Coastguard Worker                backends.append("cusolver")
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker        ns = (12, 4, 2, 0)
2523*da0073e9SAndroid Build Coastguard Worker        batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
2524*da0073e9SAndroid Build Coastguard Worker        drivers = (None, 'gesvd', 'gesvdj', 'gesvda')
2525*da0073e9SAndroid Build Coastguard Worker
2526*da0073e9SAndroid Build Coastguard Worker        for backend in backends:
2527*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.preferred_linalg_library(backend)
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker            for batch, m, n, driver in product(batches, ns, ns, drivers):
2530*da0073e9SAndroid Build Coastguard Worker                if not (backend == 'cusolver' or driver is None):
2531*da0073e9SAndroid Build Coastguard Worker                    # only test cases below and skip otherwise:
2532*da0073e9SAndroid Build Coastguard Worker                    # - backend == 'cusolver' (driver can be anything)
2533*da0073e9SAndroid Build Coastguard Worker                    # - backend != 'cusolver' (driver should only be None)
2534*da0073e9SAndroid Build Coastguard Worker                    continue
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Worker                shape = batch + (m, n)
2537*da0073e9SAndroid Build Coastguard Worker                k = min(m, n)
2538*da0073e9SAndroid Build Coastguard Worker                A = make_arg(shape)
2539*da0073e9SAndroid Build Coastguard Worker                U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver)
2540*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A)
2541*da0073e9SAndroid Build Coastguard Worker
2542*da0073e9SAndroid Build Coastguard Worker                U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver)
2543*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_f, S)
2544*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A)
2545*da0073e9SAndroid Build Coastguard Worker
2546*da0073e9SAndroid Build Coastguard Worker                S_s = torch.linalg.svdvals(A, driver=driver)
2547*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_s, S)
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker                U, S, V = torch.svd(A, some=True)
2550*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A)
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker                U_f, S_f, V_f = torch.svd(A, some=False)
2553*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_f, S)
2554*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A)
2555*da0073e9SAndroid Build Coastguard Worker
2556*da0073e9SAndroid Build Coastguard Worker                S_s = torch.svd(A, compute_uv=False).S
2557*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_s, S)
2558*da0073e9SAndroid Build Coastguard Worker
2559*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2560*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2561*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex128)
2562*da0073e9SAndroid Build Coastguard Worker    def test_invariance_error_spectral_decompositions(self, device, dtype):
2563*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
2564*da0073e9SAndroid Build Coastguard Worker        A = make_arg((3, 3))
2565*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2566*da0073e9SAndroid Build Coastguard Worker            U, _, Vh = torch.linalg.svd(A, full_matrices=False)
2567*da0073e9SAndroid Build Coastguard Worker            (U + Vh).sum().abs().backward()
2568*da0073e9SAndroid Build Coastguard Worker
2569*da0073e9SAndroid Build Coastguard Worker        A = make_arg((3, 3))
2570*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2571*da0073e9SAndroid Build Coastguard Worker            V = torch.linalg.eig(A).eigenvectors
2572*da0073e9SAndroid Build Coastguard Worker            V.sum().abs().backward()
2573*da0073e9SAndroid Build Coastguard Worker
2574*da0073e9SAndroid Build Coastguard Worker        A = make_arg((3, 3))
2575*da0073e9SAndroid Build Coastguard Worker        A = A + A.mH
2576*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2577*da0073e9SAndroid Build Coastguard Worker            Q = torch.linalg.eigh(A).eigenvectors
2578*da0073e9SAndroid Build Coastguard Worker            Q.sum().abs().backward()
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver  # MAGMA backend doesn't work in this case
2581*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
2582*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2583*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2584*da0073e9SAndroid Build Coastguard Worker    def test_svd_memory_allocation(self, device, dtype):
2585*da0073e9SAndroid Build Coastguard Worker        # test for https://github.com/pytorch/pytorch/issues/61949
2586*da0073e9SAndroid Build Coastguard Worker        # the problem was that tensors of incorrect size were allocated and then narrowed
2587*da0073e9SAndroid Build Coastguard Worker        m = 3
2588*da0073e9SAndroid Build Coastguard Worker        n = 2**20
2589*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((m, n), dtype=dtype, device=device)
2590*da0073e9SAndroid Build Coastguard Worker        # the following should run without errors
2591*da0073e9SAndroid Build Coastguard Worker        S = torch.linalg.svdvals(a)
2592*da0073e9SAndroid Build Coastguard Worker        result = torch.linalg.svd(a, full_matrices=False)
2593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.S, S)
2594*da0073e9SAndroid Build Coastguard Worker
2595*da0073e9SAndroid Build Coastguard Worker    def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
2596*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(*b_dims, dtype=dtype, device=device)
2599*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device)
2600*da0073e9SAndroid Build Coastguard Worker        L = torch.cholesky(A, upper=upper)
2601*da0073e9SAndroid Build Coastguard Worker        return b, A, L
2602*da0073e9SAndroid Build Coastguard Worker
2603*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2604*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2605*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2606*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2607*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
2608*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve(self, device, dtype):
2609*da0073e9SAndroid Build Coastguard Worker        for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]):
2610*da0073e9SAndroid Build Coastguard Worker            b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype)
2611*da0073e9SAndroid Build Coastguard Worker            x = torch.cholesky_solve(b, L, upper=upper)
2612*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2615*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2616*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2617*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2618*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
2619*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve_batched(self, device, dtype):
2620*da0073e9SAndroid Build Coastguard Worker        def cholesky_solve_batch_helper(A_dims, b_dims, upper):
2621*da0073e9SAndroid Build Coastguard Worker            b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2622*da0073e9SAndroid Build Coastguard Worker            x_exp_list = []
2623*da0073e9SAndroid Build Coastguard Worker            for i in range(b_dims[0]):
2624*da0073e9SAndroid Build Coastguard Worker                x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper))
2625*da0073e9SAndroid Build Coastguard Worker            x_exp = torch.stack(x_exp_list)  # Stacked output
2626*da0073e9SAndroid Build Coastguard Worker            x_act = torch.cholesky_solve(b, L, upper=upper)  # Actual output
2627*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_act, x_exp)  # Equality check
2628*da0073e9SAndroid Build Coastguard Worker            Ax = np.matmul(A.cpu(), x_act.cpu())
2629*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b, Ax)  # Correctness check
2630*da0073e9SAndroid Build Coastguard Worker
2631*da0073e9SAndroid Build Coastguard Worker        for upper, batchsize in itertools.product([True, False], [1, 3, 4]):
2632*da0073e9SAndroid Build Coastguard Worker            cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper)
2633*da0073e9SAndroid Build Coastguard Worker
2634*da0073e9SAndroid Build Coastguard Worker    @slowTest
2635*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2636*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2637*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2638*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2639*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
2640*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve_batched_many_batches(self, device, dtype):
2641*da0073e9SAndroid Build Coastguard Worker        for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]):
2642*da0073e9SAndroid Build Coastguard Worker            for upper in [True, False]:
2643*da0073e9SAndroid Build Coastguard Worker                b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2644*da0073e9SAndroid Build Coastguard Worker                x = torch.cholesky_solve(b, L, upper)
2645*da0073e9SAndroid Build Coastguard Worker                Ax = torch.matmul(A, x)
2646*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(Ax, b.expand_as(Ax))
2647*da0073e9SAndroid Build Coastguard Worker
2648*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2649*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2650*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2651*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2652*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
2653*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve_batched_broadcasting(self, device, dtype):
2654*da0073e9SAndroid Build Coastguard Worker        from numpy.linalg import solve
2655*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, b_dims, upper):
2658*da0073e9SAndroid Build Coastguard Worker            A_matrix_size = A_dims[-1]
2659*da0073e9SAndroid Build Coastguard Worker            A_batch_dims = A_dims[:-2]
2660*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims,
2661*da0073e9SAndroid Build Coastguard Worker                                           dtype=dtype, device='cpu')
2662*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(*b_dims, dtype=dtype, device='cpu')
2663*da0073e9SAndroid Build Coastguard Worker            x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device)
2664*da0073e9SAndroid Build Coastguard Worker            A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device)
2665*da0073e9SAndroid Build Coastguard Worker            L = torch.linalg.cholesky(A, upper=upper)
2666*da0073e9SAndroid Build Coastguard Worker            x = torch.cholesky_solve(b, L, upper=upper)
2667*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_exp)
2668*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/42695
2669*da0073e9SAndroid Build Coastguard Worker            x = torch.cholesky_solve(b, L, upper=upper, out=x)
2670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_exp)
2671*da0073e9SAndroid Build Coastguard Worker
2672*da0073e9SAndroid Build Coastguard Worker        # test against numpy.linalg.solve
2673*da0073e9SAndroid Build Coastguard Worker        for upper in [True, False]:
2674*da0073e9SAndroid Build Coastguard Worker            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper)  # no broadcasting
2675*da0073e9SAndroid Build Coastguard Worker            run_test((2, 1, 3, 4, 4), (4, 6), upper)  # broadcasting b
2676*da0073e9SAndroid Build Coastguard Worker            run_test((4, 4), (2, 1, 3, 4, 2), upper)  # broadcasting A
2677*da0073e9SAndroid Build Coastguard Worker            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper)  # broadcasting A & b
2678*da0073e9SAndroid Build Coastguard Worker
2679*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2680*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2681*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2682*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
2683*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
2684*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(2, dtype=dtype, device=device)
2685*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 1, dtype=dtype, device=device)
2686*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.int, device=device)
2687*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2688*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_solve(b, a, out=out)
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker        # device should match
2691*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2692*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2693*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=wrong_device)
2694*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2695*da0073e9SAndroid Build Coastguard Worker                torch.cholesky_solve(b, a, out=out)
2696*da0073e9SAndroid Build Coastguard Worker
2697*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
2698*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
2699*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(1, dtype=dtype, device=device)
2700*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
2701*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_solve(b, a, out=out)
2702*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
2703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
2704*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2705*da0073e9SAndroid Build Coastguard Worker
2706*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2707*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2708*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
2709*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_solve_backward(self, device, dtype):
2710*da0073e9SAndroid Build Coastguard Worker        b_dims = (5, 2)
2711*da0073e9SAndroid Build Coastguard Worker        L_dims = (5, 5)
2712*da0073e9SAndroid Build Coastguard Worker
2713*da0073e9SAndroid Build Coastguard Worker        for test_L_grad in (False, True):
2714*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True)
2715*da0073e9SAndroid Build Coastguard Worker            L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad)
2716*da0073e9SAndroid Build Coastguard Worker            if test_L_grad:
2717*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L))
2718*da0073e9SAndroid Build Coastguard Worker            else:
2719*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,))
2720*da0073e9SAndroid Build Coastguard Worker
2721*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2722*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2723*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2724*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2725*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
2726*da0073e9SAndroid Build Coastguard Worker    def test_inverse(self, device, dtype):
2727*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2728*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker        def run_test(torch_inverse, matrix, batches, n):
2731*da0073e9SAndroid Build Coastguard Worker            matrix_inverse = torch_inverse(matrix)
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker            # Compare against NumPy output
2734*da0073e9SAndroid Build Coastguard Worker            # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I
2735*da0073e9SAndroid Build Coastguard Worker            # But in PyTorch 'gertf' + 'getrs' is used. As such, there may be some element-wise differences
2736*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.inv(matrix.cpu().numpy())
2737*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision)
2738*da0073e9SAndroid Build Coastguard Worker
2739*da0073e9SAndroid Build Coastguard Worker            # Additional correctness tests, check matrix*matrix_inverse == identity
2740*da0073e9SAndroid Build Coastguard Worker            identity = torch.eye(n, dtype=dtype, device=device)
2741*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu()))
2742*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu()))
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
2745*da0073e9SAndroid Build Coastguard Worker            # prepare the expected out tensor
2746*da0073e9SAndroid Build Coastguard Worker            matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device)
2747*da0073e9SAndroid Build Coastguard Worker            matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format)
2748*da0073e9SAndroid Build Coastguard Worker            matrix_inverse_out = matrix_inverse_out_t.mT
2749*da0073e9SAndroid Build Coastguard Worker            ans = torch_inverse(matrix, out=matrix_inverse_out)
2750*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0)
2751*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0)
2752*da0073e9SAndroid Build Coastguard Worker
2753*da0073e9SAndroid Build Coastguard Worker            # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix
2754*da0073e9SAndroid Build Coastguard Worker            if matrix.ndim > 2 and batches[0] != 0:
2755*da0073e9SAndroid Build Coastguard Worker                expected_inv_list = []
2756*da0073e9SAndroid Build Coastguard Worker                p = int(np.prod(batches))  # use `p` instead of -1, so that the test works for empty input as well
2757*da0073e9SAndroid Build Coastguard Worker                for mat in matrix.contiguous().view(p, n, n):
2758*da0073e9SAndroid Build Coastguard Worker                    expected_inv_list.append(torch_inverse(mat))
2759*da0073e9SAndroid Build Coastguard Worker                expected_inv = torch.stack(expected_inv_list).view(*batches, n, n)
2760*da0073e9SAndroid Build Coastguard Worker                if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]:
2761*da0073e9SAndroid Build Coastguard Worker                    # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA
2762*da0073e9SAndroid Build Coastguard Worker                    # individual values can be significantly different for fp32, hence rather high rtol is used
2763*da0073e9SAndroid Build Coastguard Worker                    # the important thing is that torch_inverse passes above checks with identity
2764*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2)
2765*da0073e9SAndroid Build Coastguard Worker                else:
2766*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(matrix_inverse, expected_inv)
2767*da0073e9SAndroid Build Coastguard Worker
2768*da0073e9SAndroid Build Coastguard Worker        # helper function for testing torch.linalg.inv_ex
2769*da0073e9SAndroid Build Coastguard Worker        def test_inv_ex(input, out=None):
2770*da0073e9SAndroid Build Coastguard Worker            if out is not None:
2771*da0073e9SAndroid Build Coastguard Worker                info = torch.empty(0, dtype=torch.int32, device=device)
2772*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.inv_ex(input, out=(out, info)).inverse
2773*da0073e9SAndroid Build Coastguard Worker            return torch.linalg.inv_ex(input).inverse
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker        for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]:
2776*da0073e9SAndroid Build Coastguard Worker            for batches, n in itertools.product(
2777*da0073e9SAndroid Build Coastguard Worker                [[], [0], [2], [2, 1]],
2778*da0073e9SAndroid Build Coastguard Worker                [0, 5]
2779*da0073e9SAndroid Build Coastguard Worker            ):
2780*da0073e9SAndroid Build Coastguard Worker                matrices = make_arg(*batches, n, n)
2781*da0073e9SAndroid Build Coastguard Worker                run_test(torch_inverse, matrices, batches, n)
2782*da0073e9SAndroid Build Coastguard Worker
2783*da0073e9SAndroid Build Coastguard Worker                # test non-contiguous input
2784*da0073e9SAndroid Build Coastguard Worker                run_test(torch_inverse, matrices.mT, batches, n)
2785*da0073e9SAndroid Build Coastguard Worker                if n > 0:
2786*da0073e9SAndroid Build Coastguard Worker                    run_test(
2787*da0073e9SAndroid Build Coastguard Worker                        torch_inverse,
2788*da0073e9SAndroid Build Coastguard Worker                        make_arg(*batches, 2 * n, 2 * n)
2789*da0073e9SAndroid Build Coastguard Worker                        .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
2790*da0073e9SAndroid Build Coastguard Worker                        batches, n
2791*da0073e9SAndroid Build Coastguard Worker                    )
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2794*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2795*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2796*da0073e9SAndroid Build Coastguard Worker    def test_inv_ex_info_device(self, device, dtype):
2797*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
2798*da0073e9SAndroid Build Coastguard Worker        info = torch.linalg.inv_ex(A).info
2799*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(info.device == A.device)
2800*da0073e9SAndroid Build Coastguard Worker
2801*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2802*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2803*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2804*da0073e9SAndroid Build Coastguard Worker    def test_inv_ex_singular(self, device, dtype):
2805*da0073e9SAndroid Build Coastguard Worker        # if the input matrix is not invertible, info with positive integer is returned
2806*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
2807*da0073e9SAndroid Build Coastguard Worker        A[-1, -1] = 0  # Now A is singular
2808*da0073e9SAndroid Build Coastguard Worker        info = torch.linalg.inv_ex(A).info
2809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(info, 3)
2810*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError,
2811*da0073e9SAndroid Build Coastguard Worker                                    r'diagonal element 3 is zero, the inversion could not be completed'):
2812*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv_ex(A, check_errors=True)
2813*da0073e9SAndroid Build Coastguard Worker
2814*da0073e9SAndroid Build Coastguard Worker        # if at least one matrix in the batch is not positive definite,
2815*da0073e9SAndroid Build Coastguard Worker        # batched info with positive integer for the corresponding matrix is returned
2816*da0073e9SAndroid Build Coastguard Worker        A = torch.eye(3, 3, dtype=dtype, device=device)
2817*da0073e9SAndroid Build Coastguard Worker        A = A.reshape((1, 3, 3))
2818*da0073e9SAndroid Build Coastguard Worker        A = A.repeat(5, 1, 1)
2819*da0073e9SAndroid Build Coastguard Worker        A[3, -2, -2] = 0  # Now A[3] is singular
2820*da0073e9SAndroid Build Coastguard Worker        info = torch.linalg.inv_ex(A).info
2821*da0073e9SAndroid Build Coastguard Worker
2822*da0073e9SAndroid Build Coastguard Worker        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
2823*da0073e9SAndroid Build Coastguard Worker        expected_info[3] = 2
2824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(info, expected_info)
2825*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
2826*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv_ex(A, check_errors=True)
2827*da0073e9SAndroid Build Coastguard Worker
2828*da0073e9SAndroid Build Coastguard Worker    @slowTest
2829*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2830*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2831*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2832*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2833*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-5, torch.complex128: 1e-5})
2834*da0073e9SAndroid Build Coastguard Worker    def test_inverse_many_batches(self, device, dtype):
2835*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2836*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker        def test_inverse_many_batches_helper(torch_inverse, b, n):
2839*da0073e9SAndroid Build Coastguard Worker            matrices = make_arg(b, n, n)
2840*da0073e9SAndroid Build Coastguard Worker            matrices_inverse = torch_inverse(matrices)
2841*da0073e9SAndroid Build Coastguard Worker
2842*da0073e9SAndroid Build Coastguard Worker            # Compare against NumPy output
2843*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.inv(matrices.cpu().numpy())
2844*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3)
2845*da0073e9SAndroid Build Coastguard Worker
2846*da0073e9SAndroid Build Coastguard Worker        for torch_inverse in [torch.inverse, torch.linalg.inv]:
2847*da0073e9SAndroid Build Coastguard Worker            test_inverse_many_batches_helper(torch_inverse, 5, 256)
2848*da0073e9SAndroid Build Coastguard Worker            test_inverse_many_batches_helper(torch_inverse, 3, 512)
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2851*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2852*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2853*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2854*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
2855*da0073e9SAndroid Build Coastguard Worker    def test_inverse_errors(self, device, dtype):
2856*da0073e9SAndroid Build Coastguard Worker        # inverse expects batches of square matrices as input
2857*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2858*da0073e9SAndroid Build Coastguard Worker            torch.inverse(torch.randn(2, 3, 4, 3))
2859*da0073e9SAndroid Build Coastguard Worker
2860*da0073e9SAndroid Build Coastguard Worker        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
2861*da0073e9SAndroid Build Coastguard Worker        def run_test_singular_input(batch_dim, n):
2862*da0073e9SAndroid Build Coastguard Worker            x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
2863*da0073e9SAndroid Build Coastguard Worker            x[n, -1, -1] = 0
2864*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
2865*da0073e9SAndroid Build Coastguard Worker                torch.inverse(x)
2866*da0073e9SAndroid Build Coastguard Worker
2867*da0073e9SAndroid Build Coastguard Worker        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
2868*da0073e9SAndroid Build Coastguard Worker            run_test_singular_input(*params)
2869*da0073e9SAndroid Build Coastguard Worker
2870*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
2871*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2872*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2873*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2874*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2875*da0073e9SAndroid Build Coastguard Worker    def test_inverse_errors_large(self, device, dtype):
2876*da0073e9SAndroid Build Coastguard Worker        # Test batched inverse of singular matrices reports errors without crashing (gh-51930)
2877*da0073e9SAndroid Build Coastguard Worker        x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
2878*da0073e9SAndroid Build Coastguard Worker        x[:] = torch.eye(616, dtype=dtype, device=device)
2879*da0073e9SAndroid Build Coastguard Worker        x[..., 10, 10] = 0
2880*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
2881*da0073e9SAndroid Build Coastguard Worker            torch.inverse(x)
2882*da0073e9SAndroid Build Coastguard Worker
2883*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
2884*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2885*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2886*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2887*da0073e9SAndroid Build Coastguard Worker    def test_pinv(self, device, dtype):
2888*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2889*da0073e9SAndroid Build Coastguard Worker
2890*da0073e9SAndroid Build Coastguard Worker        def run_test_main(A, hermitian):
2891*da0073e9SAndroid Build Coastguard Worker            # Testing against definition for pseudo-inverses
2892*da0073e9SAndroid Build Coastguard Worker            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
2893*da0073e9SAndroid Build Coastguard Worker            np_A = A.cpu().numpy()
2894*da0073e9SAndroid Build Coastguard Worker            np_A_pinv = A_pinv.cpu().numpy()
2895*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0:
2896*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision)
2897*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision)
2898*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1))
2899*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1))
2900*da0073e9SAndroid Build Coastguard Worker            else:
2901*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
2902*da0073e9SAndroid Build Coastguard Worker
2903*da0073e9SAndroid Build Coastguard Worker            # Check out= variant
2904*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(A_pinv)
2905*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
2906*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
2907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, A_pinv)
2908*da0073e9SAndroid Build Coastguard Worker
2909*da0073e9SAndroid Build Coastguard Worker        def run_test_numpy(A, hermitian):
2910*da0073e9SAndroid Build Coastguard Worker            # Check against NumPy output
2911*da0073e9SAndroid Build Coastguard Worker            # Test float rcond, and specific value for each matrix
2912*da0073e9SAndroid Build Coastguard Worker            rconds = [float(torch.rand(1)), ]
2913*da0073e9SAndroid Build Coastguard Worker            # Test different types of rcond tensor
2914*da0073e9SAndroid Build Coastguard Worker            for rcond_type in all_types():
2915*da0073e9SAndroid Build Coastguard Worker                rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type))
2916*da0073e9SAndroid Build Coastguard Worker            # Test broadcasting of rcond
2917*da0073e9SAndroid Build Coastguard Worker            if A.ndim > 2:
2918*da0073e9SAndroid Build Coastguard Worker                rconds.append(torch.rand(A.shape[-3], device=device))
2919*da0073e9SAndroid Build Coastguard Worker            for rcond in rconds:
2920*da0073e9SAndroid Build Coastguard Worker                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
2921*da0073e9SAndroid Build Coastguard Worker                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
2922*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, torch_rtol)
2923*da0073e9SAndroid Build Coastguard Worker                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
2924*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
2925*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5)
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2928*da0073e9SAndroid Build Coastguard Worker                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
2929*da0073e9SAndroid Build Coastguard Worker                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
2930*da0073e9SAndroid Build Coastguard Worker                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
2931*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(*sizes, dtype=dtype, device=device)
2932*da0073e9SAndroid Build Coastguard Worker            hermitian = False
2933*da0073e9SAndroid Build Coastguard Worker            run_test_main(A, hermitian)
2934*da0073e9SAndroid Build Coastguard Worker            run_test_numpy(A, hermitian)
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker        # Check hermitian = True
2937*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2938*da0073e9SAndroid Build Coastguard Worker                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
2939*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
2940*da0073e9SAndroid Build Coastguard Worker            hermitian = True
2941*da0073e9SAndroid Build Coastguard Worker            run_test_main(A, hermitian)
2942*da0073e9SAndroid Build Coastguard Worker            run_test_numpy(A, hermitian)
2943*da0073e9SAndroid Build Coastguard Worker
2944*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
2945*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2946*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2947*da0073e9SAndroid Build Coastguard Worker    def test_pinv_errors_and_warnings(self, device, dtype):
2948*da0073e9SAndroid Build Coastguard Worker        # pinv requires at least 2D tensor
2949*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, device=device, dtype=dtype)
2950*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"):
2951*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a)
2952*da0073e9SAndroid Build Coastguard Worker
2953*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
2954*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3, dtype=dtype, device=device)
2955*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(7, 7, dtype=dtype, device=device)
2956*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
2957*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
2958*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a, out=out)
2959*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
2960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
2961*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2962*da0073e9SAndroid Build Coastguard Worker
2963*da0073e9SAndroid Build Coastguard Worker        # dtypes of out and input should be safely castable
2964*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(a).to(torch.int)
2965*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2966*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a, out=out)
2967*da0073e9SAndroid Build Coastguard Worker
2968*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2969*da0073e9SAndroid Build Coastguard Worker            # device of out and input should match
2970*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2971*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(a).to(wrong_device)
2972*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"):
2973*da0073e9SAndroid Build Coastguard Worker                torch.linalg.pinv(a, out=out)
2974*da0073e9SAndroid Build Coastguard Worker
2975*da0073e9SAndroid Build Coastguard Worker            # device of rcond and input should match
2976*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2977*da0073e9SAndroid Build Coastguard Worker            rcond = torch.full((), 1e-2, device=wrong_device)
2978*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
2979*da0073e9SAndroid Build Coastguard Worker                torch.linalg.pinv(a, rcond=rcond)
2980*da0073e9SAndroid Build Coastguard Worker
2981*da0073e9SAndroid Build Coastguard Worker        # rcond can't be complex
2982*da0073e9SAndroid Build Coastguard Worker        rcond = torch.full((), 1j, device=device)
2983*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"):
2984*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a, rcond=rcond)
2985*da0073e9SAndroid Build Coastguard Worker
2986*da0073e9SAndroid Build Coastguard Worker        # atol can't be complex
2987*da0073e9SAndroid Build Coastguard Worker        atol = torch.full((), 1j, device=device)
2988*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"):
2989*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a, atol=atol)
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker        # rtol can't be complex
2992*da0073e9SAndroid Build Coastguard Worker        rtol = torch.full((), 1j, device=device)
2993*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"):
2994*da0073e9SAndroid Build Coastguard Worker            torch.linalg.pinv(a, rtol=rtol)
2995*da0073e9SAndroid Build Coastguard Worker
2996*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
2997*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
2998*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2999*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
3000*da0073e9SAndroid Build Coastguard Worker    def test_inv_errors_and_warnings(self, device, dtype):
3001*da0073e9SAndroid Build Coastguard Worker        # inv expects batches of square matrices as input
3002*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
3003*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
3004*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv(a)
3005*da0073e9SAndroid Build Coastguard Worker
3006*da0073e9SAndroid Build Coastguard Worker        # inv requires the input to be at least 2 dimensional tensor
3007*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, device=device, dtype=dtype)
3008*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
3009*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv(a)
3010*da0073e9SAndroid Build Coastguard Worker
3011*da0073e9SAndroid Build Coastguard Worker        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
3012*da0073e9SAndroid Build Coastguard Worker        def run_test_singular_input(batch_dim, n):
3013*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
3014*da0073e9SAndroid Build Coastguard Worker            a[n, -1, -1] = 0
3015*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
3016*da0073e9SAndroid Build Coastguard Worker                torch.linalg.inv(a)
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
3019*da0073e9SAndroid Build Coastguard Worker            run_test_singular_input(*params)
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker        # dtypes should match
3022*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(2, dtype=dtype, device=device)
3023*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.int, device=device)
3024*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
3025*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv(a, out=out)
3026*da0073e9SAndroid Build Coastguard Worker
3027*da0073e9SAndroid Build Coastguard Worker        # device should match
3028*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3029*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3030*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=wrong_device, dtype=dtype)
3031*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3032*da0073e9SAndroid Build Coastguard Worker                torch.linalg.inv(a, out=out)
3033*da0073e9SAndroid Build Coastguard Worker
3034*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
3035*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
3036*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(2, dtype=dtype, device=device)
3037*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(1, dtype=dtype, device=device)
3038*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
3039*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv(a, out=out)
3040*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
3041*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
3042*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3043*da0073e9SAndroid Build Coastguard Worker
3044*da0073e9SAndroid Build Coastguard Worker        # if out tensor in batched column major format but with wrong a warning is given
3045*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
3046*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(2, dtype=dtype, device=device)
3047*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(3, 3, dtype=dtype, device=device)
3048*da0073e9SAndroid Build Coastguard Worker            out = out.mT.clone(memory_format=torch.contiguous_format)
3049*da0073e9SAndroid Build Coastguard Worker            out = out.mT
3050*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.mT.is_contiguous())
3051*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
3052*da0073e9SAndroid Build Coastguard Worker            torch.linalg.inv(a, out=out)
3053*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
3054*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
3055*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3056*da0073e9SAndroid Build Coastguard Worker
3057*da0073e9SAndroid Build Coastguard Worker    def solve_test_helper(self, A_dims, b_dims, device, dtype):
3058*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3059*da0073e9SAndroid Build Coastguard Worker        make_A = partial(make_fullrank, device=device, dtype=dtype)
3060*da0073e9SAndroid Build Coastguard Worker
3061*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(*b_dims, dtype=dtype, device=device)
3062*da0073e9SAndroid Build Coastguard Worker        A = make_A(*A_dims)
3063*da0073e9SAndroid Build Coastguard Worker        return b, A
3064*da0073e9SAndroid Build Coastguard Worker
3065*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3066*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3067*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3068*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
3069*da0073e9SAndroid Build Coastguard Worker    def test_solve(self, device, dtype):
3070*da0073e9SAndroid Build Coastguard Worker        def run_test(n, batch, rhs):
3071*da0073e9SAndroid Build Coastguard Worker            A_dims = (*batch, n, n)
3072*da0073e9SAndroid Build Coastguard Worker            b_dims = (*batch, n, *rhs)
3073*da0073e9SAndroid Build Coastguard Worker            b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
3074*da0073e9SAndroid Build Coastguard Worker
3075*da0073e9SAndroid Build Coastguard Worker            # Correctness test
3076*da0073e9SAndroid Build Coastguard Worker            x = torch.linalg.solve(A, b)
3077*da0073e9SAndroid Build Coastguard Worker            if rhs == ():
3078*da0073e9SAndroid Build Coastguard Worker                Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu())
3079*da0073e9SAndroid Build Coastguard Worker                Ax.squeeze_(-1)
3080*da0073e9SAndroid Build Coastguard Worker            else:
3081*da0073e9SAndroid Build Coastguard Worker                Ax = np.matmul(A.cpu(), x.cpu())
3082*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.expand_as(Ax), Ax)
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker            # Check against NumPy
3085*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
3086*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, expected)
3087*da0073e9SAndroid Build Coastguard Worker
3088*da0073e9SAndroid Build Coastguard Worker        batches = [(), (0, ), (3, ), (2, 3)]
3089*da0073e9SAndroid Build Coastguard Worker        ns = [0, 5, 32]
3090*da0073e9SAndroid Build Coastguard Worker        nrhs = [(), (1, ), (5, )]
3091*da0073e9SAndroid Build Coastguard Worker        for n, batch, rhs in itertools.product(ns, batches, nrhs):
3092*da0073e9SAndroid Build Coastguard Worker            run_test(n, batch, rhs)
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
3095*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3096*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3097*da0073e9SAndroid Build Coastguard Worker    def test_solve_batched_broadcasting(self, device, dtype):
3098*da0073e9SAndroid Build Coastguard Worker        from numpy.linalg import solve
3099*da0073e9SAndroid Build Coastguard Worker
3100*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, B_dims):
3101*da0073e9SAndroid Build Coastguard Worker            A_matrix_size = A_dims[-1]
3102*da0073e9SAndroid Build Coastguard Worker            A_batch_dims = A_dims[:-2]
3103*da0073e9SAndroid Build Coastguard Worker            B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype)
3104*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.solve(A, B)
3105*da0073e9SAndroid Build Coastguard Worker            expected = solve(A.cpu().numpy(), B.cpu().numpy())
3106*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
3107*da0073e9SAndroid Build Coastguard Worker
3108*da0073e9SAndroid Build Coastguard Worker        # test against numpy.linalg.solve
3109*da0073e9SAndroid Build Coastguard Worker        run_test((5, 5), (2, 0, 5, 3))  # broadcasting with 0 batch dim
3110*da0073e9SAndroid Build Coastguard Worker        run_test((2, 0, 5, 5), (5, 3))  # broadcasting with 0 batch dim
3111*da0073e9SAndroid Build Coastguard Worker        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting B
3112*da0073e9SAndroid Build Coastguard Worker        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
3113*da0073e9SAndroid Build Coastguard Worker        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & B
3114*da0073e9SAndroid Build Coastguard Worker
3115*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3116*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3117*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3118*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
3119*da0073e9SAndroid Build Coastguard Worker    def test_tensorsolve(self, device, dtype):
3120*da0073e9SAndroid Build Coastguard Worker        def run_test(a_shape, dims):
3121*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(a_shape, dtype=dtype, device=device)
3122*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(a_shape[:2], dtype=dtype, device=device)
3123*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.tensorsolve(a, b, dims=dims)
3124*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
3125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
3126*da0073e9SAndroid Build Coastguard Worker
3127*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
3128*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(result)
3129*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
3130*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
3131*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, result)
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker        a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
3134*da0073e9SAndroid Build Coastguard Worker        dims = [None, (0, 2)]
3135*da0073e9SAndroid Build Coastguard Worker        for a_shape, d in itertools.product(a_shapes, dims):
3136*da0073e9SAndroid Build Coastguard Worker            run_test(a_shape, d)
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3139*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3140*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3141*da0073e9SAndroid Build Coastguard Worker    def test_tensorsolve_empty(self, device, dtype):
3142*da0073e9SAndroid Build Coastguard Worker        # Check for empty inputs. NumPy does not work for these cases.
3143*da0073e9SAndroid Build Coastguard Worker        a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3144*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(a.shape[:2], dtype=dtype, device=device)
3145*da0073e9SAndroid Build Coastguard Worker        x = torch.linalg.tensorsolve(a, b)
3146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b)
3147*da0073e9SAndroid Build Coastguard Worker
3148*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3149*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3150*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
3151*da0073e9SAndroid Build Coastguard Worker    def test_tensorsolve_errors_and_warnings(self, device, dtype):
3152*da0073e9SAndroid Build Coastguard Worker        # tensorsolve expects the input that can be reshaped to a square matrix
3153*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4))
3154*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(8, 4, dtype=dtype, device=device)
3155*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape))
3156*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'):
3157*da0073e9SAndroid Build Coastguard Worker            torch.linalg.tensorsolve(a, b)
3158*da0073e9SAndroid Build Coastguard Worker
3159*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
3160*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(a)
3161*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(6, 4, dtype=dtype, device=device)
3162*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
3163*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
3164*da0073e9SAndroid Build Coastguard Worker            torch.linalg.tensorsolve(a, b, out=out)
3165*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
3166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
3167*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3168*da0073e9SAndroid Build Coastguard Worker
3169*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
3170*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(a).to(torch.int)
3171*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3172*da0073e9SAndroid Build Coastguard Worker            torch.linalg.tensorsolve(a, b, out=out)
3173*da0073e9SAndroid Build Coastguard Worker
3174*da0073e9SAndroid Build Coastguard Worker        # device should match
3175*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3176*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3177*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=wrong_device)
3178*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3179*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorsolve(a, b, out=out)
3180*da0073e9SAndroid Build Coastguard Worker
3181*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3182*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3183*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3184*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
3185*da0073e9SAndroid Build Coastguard Worker    def test_tensorinv(self, device, dtype):
3186*da0073e9SAndroid Build Coastguard Worker
3187*da0073e9SAndroid Build Coastguard Worker        def run_test(a_shape, ind):
3188*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(a_shape, dtype=dtype, device=device)
3189*da0073e9SAndroid Build Coastguard Worker            a_numpy = a.cpu().numpy()
3190*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.tensorinv(a, ind=ind)
3191*da0073e9SAndroid Build Coastguard Worker            expected = np.linalg.tensorinv(a_numpy, ind=ind)
3192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
3193*da0073e9SAndroid Build Coastguard Worker
3194*da0073e9SAndroid Build Coastguard Worker            # check the out= variant
3195*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(result)
3196*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.tensorinv(a, ind=ind, out=out)
3197*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
3198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, result)
3199*da0073e9SAndroid Build Coastguard Worker
3200*da0073e9SAndroid Build Coastguard Worker        # compare to NumPy output
3201*da0073e9SAndroid Build Coastguard Worker        run_test((12, 3, 4), ind=1)
3202*da0073e9SAndroid Build Coastguard Worker        run_test((3, 8, 24), ind=2)
3203*da0073e9SAndroid Build Coastguard Worker        run_test((18, 3, 3, 2), ind=1)
3204*da0073e9SAndroid Build Coastguard Worker        run_test((1, 4, 2, 2), ind=2)
3205*da0073e9SAndroid Build Coastguard Worker        run_test((2, 3, 5, 30), ind=3)
3206*da0073e9SAndroid Build Coastguard Worker        run_test((24, 2, 2, 3, 2), ind=1)
3207*da0073e9SAndroid Build Coastguard Worker        run_test((3, 4, 2, 3, 2), ind=2)
3208*da0073e9SAndroid Build Coastguard Worker        run_test((1, 2, 3, 2, 3), ind=3)
3209*da0073e9SAndroid Build Coastguard Worker        run_test((3, 2, 1, 2, 12), ind=4)
3210*da0073e9SAndroid Build Coastguard Worker
3211*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3212*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3213*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3214*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3215*da0073e9SAndroid Build Coastguard Worker    def test_tensorinv_empty(self, device, dtype):
3216*da0073e9SAndroid Build Coastguard Worker        for ind in range(1, 4):
3217*da0073e9SAndroid Build Coastguard Worker            # Check for empty inputs. NumPy does not work for these cases.
3218*da0073e9SAndroid Build Coastguard Worker            a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3219*da0073e9SAndroid Build Coastguard Worker            a_inv = torch.linalg.tensorinv(a, ind=ind)
3220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind])
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3223*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3224*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3225*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3226*da0073e9SAndroid Build Coastguard Worker    def test_tensorinv_errors_and_warnings(self, device, dtype):
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker        def check_shape(a_shape, ind):
3229*da0073e9SAndroid Build Coastguard Worker            # tensorinv requires the input to satisfy
3230*da0073e9SAndroid Build Coastguard Worker            # prod(a.shape[ind:]) == prod(a.shape[:ind])
3231*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(a_shape, dtype=dtype, device=device)
3232*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"):
3233*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorinv(a, ind=ind)
3234*da0073e9SAndroid Build Coastguard Worker
3235*da0073e9SAndroid Build Coastguard Worker        def check_ind(a_shape, ind):
3236*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(a_shape, dtype=dtype, device=device)
3237*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"):
3238*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorinv(a, ind=ind)
3239*da0073e9SAndroid Build Coastguard Worker
3240*da0073e9SAndroid Build Coastguard Worker        def check_out(a_shape, ind):
3241*da0073e9SAndroid Build Coastguard Worker            # if non-empty out tensor with wrong shape is passed a warning is given
3242*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(a_shape, dtype=dtype, device=device)
3243*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(a)
3244*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
3245*da0073e9SAndroid Build Coastguard Worker                # Trigger warning
3246*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorinv(a, ind=ind, out=out)
3247*da0073e9SAndroid Build Coastguard Worker                # Check warning occurs
3248*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 1)
3249*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker            # dtypes should be safely castable
3252*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=torch.int, device=device)
3253*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3254*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorinv(a, ind=ind, out=out)
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker            # device should match
3257*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
3258*da0073e9SAndroid Build Coastguard Worker                wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3259*da0073e9SAndroid Build Coastguard Worker                out = torch.empty(0, dtype=dtype, device=wrong_device)
3260*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3261*da0073e9SAndroid Build Coastguard Worker                    torch.linalg.tensorinv(a, ind=ind, out=out)
3262*da0073e9SAndroid Build Coastguard Worker
3263*da0073e9SAndroid Build Coastguard Worker        # test for invalid shape
3264*da0073e9SAndroid Build Coastguard Worker        check_shape((2, 3, 4), ind=1)
3265*da0073e9SAndroid Build Coastguard Worker        check_shape((1, 2, 3, 4), ind=3)
3266*da0073e9SAndroid Build Coastguard Worker
3267*da0073e9SAndroid Build Coastguard Worker        # test for invalid ind
3268*da0073e9SAndroid Build Coastguard Worker        check_ind((12, 3, 4), ind=-1)
3269*da0073e9SAndroid Build Coastguard Worker        check_ind((18, 3, 3, 2), ind=0)
3270*da0073e9SAndroid Build Coastguard Worker
3271*da0073e9SAndroid Build Coastguard Worker        # test for invalid out tensor
3272*da0073e9SAndroid Build Coastguard Worker        check_out((12, 3, 4), ind=1)
3273*da0073e9SAndroid Build Coastguard Worker        check_out((3, 8, 24), ind=2)
3274*da0073e9SAndroid Build Coastguard Worker
3275*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3276*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3277*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3278*da0073e9SAndroid Build Coastguard Worker    def test_tensorinv_singular_input(self, device, dtype):
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker        def check_singular_input(a_shape, ind):
3281*da0073e9SAndroid Build Coastguard Worker            prod_ind_end = np.prod(a_shape[ind:])
3282*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(prod_ind_end, dtype=dtype, device=device)
3283*da0073e9SAndroid Build Coastguard Worker            a[-1, -1] = 0   # Now `a` is singular
3284*da0073e9SAndroid Build Coastguard Worker            a = a.reshape(a_shape)
3285*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
3286*da0073e9SAndroid Build Coastguard Worker                torch.linalg.tensorinv(a, ind=ind)
3287*da0073e9SAndroid Build Coastguard Worker
3288*da0073e9SAndroid Build Coastguard Worker        # test for non-invertible input
3289*da0073e9SAndroid Build Coastguard Worker        check_singular_input((12, 3, 4), ind=1)
3290*da0073e9SAndroid Build Coastguard Worker        check_singular_input((3, 6, 18), ind=2)
3291*da0073e9SAndroid Build Coastguard Worker
3292*da0073e9SAndroid Build Coastguard Worker    def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
3293*da0073e9SAndroid Build Coastguard Worker        def check(x, y):
3294*da0073e9SAndroid Build Coastguard Worker            # Compare with numpy
3295*da0073e9SAndroid Build Coastguard Worker            res = torch_fn(x, y)
3296*da0073e9SAndroid Build Coastguard Worker            if x.dtype == torch.bfloat16:
3297*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy())))
3298*da0073e9SAndroid Build Coastguard Worker            else:
3299*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy())))
3300*da0073e9SAndroid Build Coastguard Worker            if res.dtype == torch.bfloat16:
3301*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu(), ref.bfloat16())
3302*da0073e9SAndroid Build Coastguard Worker            else:
3303*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu(), ref)
3304*da0073e9SAndroid Build Coastguard Worker
3305*da0073e9SAndroid Build Coastguard Worker            # Test out variant
3306*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(res)
3307*da0073e9SAndroid Build Coastguard Worker            torch_fn(x, y, out=out)
3308*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, res)
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        # Empty
3311*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([], dtype=dtype, device=device)
3312*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([], dtype=dtype, device=device)
3313*da0073e9SAndroid Build Coastguard Worker        check(x, y)
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker        # Contiguous
3316*da0073e9SAndroid Build Coastguard Worker        x = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3317*da0073e9SAndroid Build Coastguard Worker        y = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3318*da0073e9SAndroid Build Coastguard Worker        check(x, y)
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Worker        # 0 strided
3321*da0073e9SAndroid Build Coastguard Worker        y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000)
3322*da0073e9SAndroid Build Coastguard Worker        check(x, y)
3323*da0073e9SAndroid Build Coastguard Worker
3324*da0073e9SAndroid Build Coastguard Worker        # 2 strided
3325*da0073e9SAndroid Build Coastguard Worker        check(x[::2], y[::2])
3326*da0073e9SAndroid Build Coastguard Worker
3327*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16)
3328*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.cfloat)
3329*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0})
3330*da0073e9SAndroid Build Coastguard Worker    def test_dot_vs_numpy(self, device, dtype):
3331*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)
3332*da0073e9SAndroid Build Coastguard Worker
3333*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.cfloat)
3334*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
3335*da0073e9SAndroid Build Coastguard Worker    def test_vdot_vs_numpy(self, device, dtype):
3336*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)
3337*da0073e9SAndroid Build Coastguard Worker
3338*da0073e9SAndroid Build Coastguard Worker    def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
3339*da0073e9SAndroid Build Coastguard Worker        def check(x, y, regex):
3340*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, regex):
3341*da0073e9SAndroid Build Coastguard Worker                torch_fn(x, y)
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker        if complex_dtypes:
3344*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, dtype=torch.cfloat, device=device)
3345*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, dtype=torch.cdouble, device=device)
3346*da0073e9SAndroid Build Coastguard Worker        else:
3347*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, dtype=torch.float, device=device)
3348*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, dtype=torch.double, device=device)
3349*da0073e9SAndroid Build Coastguard Worker
3350*da0073e9SAndroid Build Coastguard Worker        check(x, y, 'dot : expected both vectors to have same dtype')
3351*da0073e9SAndroid Build Coastguard Worker        check(x.reshape(1, 1), y, '1D tensors expected')
3352*da0073e9SAndroid Build Coastguard Worker        check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size')
3353*da0073e9SAndroid Build Coastguard Worker
3354*da0073e9SAndroid Build Coastguard Worker        if self.device_type != 'cpu':
3355*da0073e9SAndroid Build Coastguard Worker            x_cpu = x.expand(3).cpu()
3356*da0073e9SAndroid Build Coastguard Worker            check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device')
3357*da0073e9SAndroid Build Coastguard Worker
3358*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3359*da0073e9SAndroid Build Coastguard Worker    def test_vdot_invalid_args(self, device):
3360*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_invalid_args(device, torch.vdot)
3361*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3364*da0073e9SAndroid Build Coastguard Worker    def test_dot_invalid_args(self, device):
3365*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_invalid_args(device, torch.dot)
3366*da0073e9SAndroid Build Coastguard Worker        self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)
3367*da0073e9SAndroid Build Coastguard Worker
3368*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3369*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3370*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3371*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank(self, device, dtype):
3372*da0073e9SAndroid Build Coastguard Worker        matrix_rank = torch.linalg.matrix_rank
3373*da0073e9SAndroid Build Coastguard Worker
3374*da0073e9SAndroid Build Coastguard Worker        def run_test(shape0, shape1, batch):
3375*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3376*da0073e9SAndroid Build Coastguard Worker            rank_a = matrix_rank(a)
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, matrix_rank(a.mH))
3379*da0073e9SAndroid Build Coastguard Worker            aaH = torch.matmul(a, a.mH)
3380*da0073e9SAndroid Build Coastguard Worker            rank_aaH = matrix_rank(aaH)
3381*da0073e9SAndroid Build Coastguard Worker            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3382*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3383*da0073e9SAndroid Build Coastguard Worker            aHa = torch.matmul(a.mH, a)
3384*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3385*da0073e9SAndroid Build Coastguard Worker
3386*da0073e9SAndroid Build Coastguard Worker            # check against NumPy
3387*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
3388*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
3389*da0073e9SAndroid Build Coastguard Worker
3390*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
3391*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
3392*da0073e9SAndroid Build Coastguard Worker
3393*da0073e9SAndroid Build Coastguard Worker            # hermitian flag for NumPy was added in 1.14.0
3394*da0073e9SAndroid Build Coastguard Worker            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
3395*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rank_aaH_hermitian,
3396*da0073e9SAndroid Build Coastguard Worker                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
3397*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matrix_rank(aaH, 0.01, True),
3398*da0073e9SAndroid Build Coastguard Worker                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker            # check out= variant
3401*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
3402*da0073e9SAndroid Build Coastguard Worker            ans = matrix_rank(a, out=out)
3403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
3404*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, rank_a)
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker        shapes = (3, 13)
3407*da0073e9SAndroid Build Coastguard Worker        batches = ((), (0, ), (4, ), (3, 5, ))
3408*da0073e9SAndroid Build Coastguard Worker        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3409*da0073e9SAndroid Build Coastguard Worker            run_test(shape0, shape1, batch)
3410*da0073e9SAndroid Build Coastguard Worker
3411*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3412*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3413*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3414*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_atol(self, device, dtype):
3415*da0073e9SAndroid Build Coastguard Worker
3416*da0073e9SAndroid Build Coastguard Worker        def run_test_atol(shape0, shape1, batch):
3417*da0073e9SAndroid Build Coastguard Worker            a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device)
3418*da0073e9SAndroid Build Coastguard Worker            # Check against NumPy output
3419*da0073e9SAndroid Build Coastguard Worker            # Test float tol, and specific value for each matrix
3420*da0073e9SAndroid Build Coastguard Worker            tolerances = [float(torch.rand(1)), ]
3421*da0073e9SAndroid Build Coastguard Worker            # Test different types of tol tensor
3422*da0073e9SAndroid Build Coastguard Worker            for tol_type in all_types():
3423*da0073e9SAndroid Build Coastguard Worker                tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0))
3424*da0073e9SAndroid Build Coastguard Worker            # Test broadcasting of tol
3425*da0073e9SAndroid Build Coastguard Worker            if a.ndim > 2:
3426*da0073e9SAndroid Build Coastguard Worker                tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0))
3427*da0073e9SAndroid Build Coastguard Worker            for tol in tolerances:
3428*da0073e9SAndroid Build Coastguard Worker                actual = torch.linalg.matrix_rank(a, atol=tol)
3429*da0073e9SAndroid Build Coastguard Worker                actual_tol = torch.linalg.matrix_rank(a, tol=tol)
3430*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, actual_tol)
3431*da0073e9SAndroid Build Coastguard Worker                numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy()
3432*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol)
3433*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
3434*da0073e9SAndroid Build Coastguard Worker
3435*da0073e9SAndroid Build Coastguard Worker        shapes = (3, 13)
3436*da0073e9SAndroid Build Coastguard Worker        batches = ((), (0, ), (4, ), (3, 5, ))
3437*da0073e9SAndroid Build Coastguard Worker        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3438*da0073e9SAndroid Build Coastguard Worker            run_test_atol(shape0, shape1, batch)
3439*da0073e9SAndroid Build Coastguard Worker
3440*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3441*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3442*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64)
3443*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_atol_rtol(self, device, dtype):
3444*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3445*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_fullrank, device=device, dtype=dtype)
3446*da0073e9SAndroid Build Coastguard Worker
3447*da0073e9SAndroid Build Coastguard Worker        # creates a matrix with singular values rank=n and singular values in range [2/3, 3/2]
3448*da0073e9SAndroid Build Coastguard Worker        # the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ...
3449*da0073e9SAndroid Build Coastguard Worker        n = 9
3450*da0073e9SAndroid Build Coastguard Worker        a = make_arg(n, n)
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker        # test float and tensor variants
3453*da0073e9SAndroid Build Coastguard Worker        for tol_value in [0.81, torch.tensor(0.81, device=device)]:
3454*da0073e9SAndroid Build Coastguard Worker            # using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case)
3455*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.matrix_rank(a, rtol=tol_value)
3456*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, 2)  # there are 2 singular values above 1.5*0.81 = 1.215
3457*da0073e9SAndroid Build Coastguard Worker
3458*da0073e9SAndroid Build Coastguard Worker            # atol is used directly to compare with singular values
3459*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.matrix_rank(a, atol=tol_value)
3460*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, 7)  # there are 7 singular values above 0.81
3461*da0073e9SAndroid Build Coastguard Worker
3462*da0073e9SAndroid Build Coastguard Worker            # when both are specified the maximum tolerance is used
3463*da0073e9SAndroid Build Coastguard Worker            result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
3464*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, 2)  # there are 2 singular values above max(0.81, 1.5*0.81)
3465*da0073e9SAndroid Build Coastguard Worker
3466*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3467*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3468*da0073e9SAndroid Build Coastguard Worker    @skipCUDAVersionIn([(11, 6), (11, 7)])  # https://github.com/pytorch/pytorch/issues/75391
3469*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3470*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_empty(self, device, dtype):
3471*da0073e9SAndroid Build Coastguard Worker        matrix_rank = torch.linalg.matrix_rank
3472*da0073e9SAndroid Build Coastguard Worker
3473*da0073e9SAndroid Build Coastguard Worker        # NumPy doesn't work for input with no elements
3474*da0073e9SAndroid Build Coastguard Worker        def run_test(shape0, shape1, batch):
3475*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3476*da0073e9SAndroid Build Coastguard Worker            rank_a = matrix_rank(a)
3477*da0073e9SAndroid Build Coastguard Worker            expected = torch.zeros(batch, dtype=torch.int64, device=device)
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, matrix_rank(a.mH))
3480*da0073e9SAndroid Build Coastguard Worker
3481*da0073e9SAndroid Build Coastguard Worker            aaH = torch.matmul(a, a.mH)
3482*da0073e9SAndroid Build Coastguard Worker            rank_aaH = matrix_rank(aaH)
3483*da0073e9SAndroid Build Coastguard Worker            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3485*da0073e9SAndroid Build Coastguard Worker
3486*da0073e9SAndroid Build Coastguard Worker            aHa = torch.matmul(a.mH, a)
3487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, expected)
3490*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(a, 0.01), expected)
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, expected)
3493*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aaH, 0.01), expected)
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH_hermitian, expected)
3496*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aaH, 0.01, True), expected)
3497*da0073e9SAndroid Build Coastguard Worker
3498*da0073e9SAndroid Build Coastguard Worker        batches = ((), (4, ), (3, 5, ))
3499*da0073e9SAndroid Build Coastguard Worker        for batch in batches:
3500*da0073e9SAndroid Build Coastguard Worker            run_test(0, 0, batch)
3501*da0073e9SAndroid Build Coastguard Worker            run_test(0, 3, batch)
3502*da0073e9SAndroid Build Coastguard Worker            run_test(3, 0, batch)
3503*da0073e9SAndroid Build Coastguard Worker
3504*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3505*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3506*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3507*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
3508*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
3509*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(2, dtype=dtype, device=device)
3510*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, dtype=torch.bool, device=device)
3511*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"):
3512*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_rank(a, out=out)
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker        # device should match
3515*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3516*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3517*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=wrong_device)
3518*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3519*da0073e9SAndroid Build Coastguard Worker                torch.linalg.matrix_rank(a, out=out)
3520*da0073e9SAndroid Build Coastguard Worker
3521*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
3522*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
3523*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(3, dtype=dtype, device=device)
3524*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
3525*da0073e9SAndroid Build Coastguard Worker            torch.linalg.matrix_rank(a, out=out)
3526*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
3527*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
3528*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3529*da0073e9SAndroid Build Coastguard Worker
3530*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
3531*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3532*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3533*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank_basic(self, device, dtype):
3534*da0073e9SAndroid Build Coastguard Worker        matrix_rank = torch.linalg.matrix_rank
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(10, dtype=dtype, device=device)
3537*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matrix_rank(a).item(), 10)
3538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matrix_rank(a, hermitian=True).item(), 10)
3539*da0073e9SAndroid Build Coastguard Worker
3540*da0073e9SAndroid Build Coastguard Worker        a[5, 5] = 0
3541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matrix_rank(a).item(), 9)
3542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matrix_rank(a, hermitian=True).item(), 9)
3543*da0073e9SAndroid Build Coastguard Worker
3544*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3545*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
3546*da0073e9SAndroid Build Coastguard Worker    # This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for.
3547*da0073e9SAndroid Build Coastguard Worker    def test_chain_matmul(self, device, dtype):
3548*da0073e9SAndroid Build Coastguard Worker        # chain_matmul accepts a single input tensor while multi_dot does not
3549*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((2, 2), dtype=dtype, device=device)
3550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.chain_matmul(t))
3551*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
3552*da0073e9SAndroid Build Coastguard Worker            torch.chain_matmul()
3553*da0073e9SAndroid Build Coastguard Worker
3554*da0073e9SAndroid Build Coastguard Worker        # chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
3555*da0073e9SAndroid Build Coastguard Worker        # be either 1D or 2D
3556*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
3557*da0073e9SAndroid Build Coastguard Worker            torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device))
3558*da0073e9SAndroid Build Coastguard Worker
3559*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3560*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3561*da0073e9SAndroid Build Coastguard Worker    def test_multi_dot(self, device, dtype):
3562*da0073e9SAndroid Build Coastguard Worker        def check(*shapes):
3563*da0073e9SAndroid Build Coastguard Worker            tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes]
3564*da0073e9SAndroid Build Coastguard Worker            np_arrays = [tensor.cpu().numpy() for tensor in tensors]
3565*da0073e9SAndroid Build Coastguard Worker            res = torch.linalg.multi_dot(tensors).cpu()
3566*da0073e9SAndroid Build Coastguard Worker            ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
3567*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, ref)
3568*da0073e9SAndroid Build Coastguard Worker
3569*da0073e9SAndroid Build Coastguard Worker        # test for inputs with empty dimensions
3570*da0073e9SAndroid Build Coastguard Worker        check([0], [0])
3571*da0073e9SAndroid Build Coastguard Worker        check([2], [2, 0])
3572*da0073e9SAndroid Build Coastguard Worker        check([1, 0], [0])
3573*da0073e9SAndroid Build Coastguard Worker        check([0, 2], [2, 1])
3574*da0073e9SAndroid Build Coastguard Worker        check([2, 2], [2, 0])
3575*da0073e9SAndroid Build Coastguard Worker        check([2, 0], [0, 3])
3576*da0073e9SAndroid Build Coastguard Worker        check([0, 0], [0, 1])
3577*da0073e9SAndroid Build Coastguard Worker        check([4, 2], [2, 0], [0, 3], [3, 2])
3578*da0073e9SAndroid Build Coastguard Worker
3579*da0073e9SAndroid Build Coastguard Worker        # test variable output shapes
3580*da0073e9SAndroid Build Coastguard Worker        check([2], [2])
3581*da0073e9SAndroid Build Coastguard Worker        check([1, 2], [2])
3582*da0073e9SAndroid Build Coastguard Worker        check([2], [2, 1])
3583*da0073e9SAndroid Build Coastguard Worker        check([1, 2], [2, 1])
3584*da0073e9SAndroid Build Coastguard Worker        check([3, 2], [2, 4])
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker        # test multiple input tensors
3587*da0073e9SAndroid Build Coastguard Worker        check([3], [3, 4], [4, 2], [2, 5], [5])
3588*da0073e9SAndroid Build Coastguard Worker        check([1, 2], [2, 2], [2, 3], [3, 1])
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker        # test large tensors
3591*da0073e9SAndroid Build Coastguard Worker        check([10, 100], [100, 5], [5, 50])
3592*da0073e9SAndroid Build Coastguard Worker        check([10, 20], [20, 30], [30, 5])
3593*da0073e9SAndroid Build Coastguard Worker
3594*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3595*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3596*da0073e9SAndroid Build Coastguard Worker    def test_multi_dot_errors(self, device, dtype):
3597*da0073e9SAndroid Build Coastguard Worker        def check(tensors, out, msg):
3598*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
3599*da0073e9SAndroid Build Coastguard Worker                torch.linalg.multi_dot(tensors, out=out)
3600*da0073e9SAndroid Build Coastguard Worker
3601*da0073e9SAndroid Build Coastguard Worker        a = make_tensor(2, dtype=dtype, device=device)
3602*da0073e9SAndroid Build Coastguard Worker
3603*da0073e9SAndroid Build Coastguard Worker        check([], None, "expected at least 2 tensors")
3604*da0073e9SAndroid Build Coastguard Worker        check([a], None, "expected at least 2 tensors")
3605*da0073e9SAndroid Build Coastguard Worker
3606*da0073e9SAndroid Build Coastguard Worker        check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
3607*da0073e9SAndroid Build Coastguard Worker        check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
3608*da0073e9SAndroid Build Coastguard Worker
3609*da0073e9SAndroid Build Coastguard Worker        check([a, a, a], None, "tensor 1 must be 2D")
3610*da0073e9SAndroid Build Coastguard Worker        check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D")
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker        check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype")
3613*da0073e9SAndroid Build Coastguard Worker        check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
3614*da0073e9SAndroid Build Coastguard Worker
3615*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
3616*da0073e9SAndroid Build Coastguard Worker            check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device")
3617*da0073e9SAndroid Build Coastguard Worker            check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
3618*da0073e9SAndroid Build Coastguard Worker
3619*da0073e9SAndroid Build Coastguard Worker        check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied")
3620*da0073e9SAndroid Build Coastguard Worker        check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied")
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
3623*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
3624*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3625*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3626*da0073e9SAndroid Build Coastguard Worker    def test_qr(self, device, dtype):
3627*da0073e9SAndroid Build Coastguard Worker        def run_test(tensor_dims, some):
3628*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(*tensor_dims, dtype=dtype, device=device)
3629*da0073e9SAndroid Build Coastguard Worker            Q, R = torch.qr(A, some=some)
3630*da0073e9SAndroid Build Coastguard Worker
3631*da0073e9SAndroid Build Coastguard Worker            # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n)
3632*da0073e9SAndroid Build Coastguard Worker            m, n = tensor_dims[-2:]
3633*da0073e9SAndroid Build Coastguard Worker            n_columns = m if (not some) and m > n else min(m, n)
3634*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Q.size(-2), m)
3635*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(R.size(-1), n)
3636*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Q.size(-1), n_columns)
3637*da0073e9SAndroid Build Coastguard Worker
3638*da0073e9SAndroid Build Coastguard Worker            A_ = A.cpu().numpy()
3639*da0073e9SAndroid Build Coastguard Worker            Q_ = Q.cpu().numpy()
3640*da0073e9SAndroid Build Coastguard Worker            R_ = R.cpu().numpy()
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker            # Check1: A = QR
3643*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A_, np.matmul(Q_, R_))
3644*da0073e9SAndroid Build Coastguard Worker
3645*da0073e9SAndroid Build Coastguard Worker            # Check2: A = QR (with out)
3646*da0073e9SAndroid Build Coastguard Worker            Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan)
3647*da0073e9SAndroid Build Coastguard Worker            torch.qr(A, some=some, out=(Q_out, R_out))
3648*da0073e9SAndroid Build Coastguard Worker            Q_out_ = Q_out.cpu().numpy()
3649*da0073e9SAndroid Build Coastguard Worker            R_out_ = R_out.cpu().numpy()
3650*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A_, np.matmul(Q_out_, R_out_))
3651*da0073e9SAndroid Build Coastguard Worker
3652*da0073e9SAndroid Build Coastguard Worker            # Check3: Q == Q_out, R == R_out
3653*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Q_, Q_out_)
3654*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(R_, R_out_)
3655*da0073e9SAndroid Build Coastguard Worker
3656*da0073e9SAndroid Build Coastguard Worker            # Check4: Q^{T}Q = I, triu(R) = R
3657*da0073e9SAndroid Build Coastguard Worker            eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy()
3658*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye)
3659*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(R.triu(), R)
3660*da0073e9SAndroid Build Coastguard Worker
3661*da0073e9SAndroid Build Coastguard Worker        tensor_dims_list = [(0, 5), (0, 0), (5, 0),  # Empty Tensors
3662*da0073e9SAndroid Build Coastguard Worker                            (2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5),  # Batched empty Tensors
3663*da0073e9SAndroid Build Coastguard Worker                            (3, 5), (5, 5), (5, 3),  # Single matrix
3664*da0073e9SAndroid Build Coastguard Worker                            (7, 3, 5), (7, 5, 5), (7, 5, 3),  # 3-dim Tensors
3665*da0073e9SAndroid Build Coastguard Worker                            (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)]  # 4-dim Tensors
3666*da0073e9SAndroid Build Coastguard Worker        for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]):
3667*da0073e9SAndroid Build Coastguard Worker            run_test(tensor_dims, some)
3668*da0073e9SAndroid Build Coastguard Worker
3669*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
3670*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3671*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3672*da0073e9SAndroid Build Coastguard Worker    def test_qr_vs_numpy(self, device, dtype):
3673*da0073e9SAndroid Build Coastguard Worker        """
3674*da0073e9SAndroid Build Coastguard Worker        test torch.linalg.qr vs numpy.linalg.qr
3675*da0073e9SAndroid Build Coastguard Worker        """
3676*da0073e9SAndroid Build Coastguard Worker        sizes_to_test = [
3677*da0073e9SAndroid Build Coastguard Worker            (7, 5),
3678*da0073e9SAndroid Build Coastguard Worker            (5, 7),
3679*da0073e9SAndroid Build Coastguard Worker            (5, 0),    # empty
3680*da0073e9SAndroid Build Coastguard Worker            (0, 5),    # empty
3681*da0073e9SAndroid Build Coastguard Worker        ]
3682*da0073e9SAndroid Build Coastguard Worker        for size in sizes_to_test:
3683*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(size, device=device, dtype=dtype)
3684*da0073e9SAndroid Build Coastguard Worker            np_t = t.cpu().numpy()
3685*da0073e9SAndroid Build Coastguard Worker            for mode in ['reduced', 'complete']:
3686*da0073e9SAndroid Build Coastguard Worker                exp_q, exp_r = np.linalg.qr(np_t, mode=mode)
3687*da0073e9SAndroid Build Coastguard Worker                q, r = torch.linalg.qr(t, mode=mode)
3688*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(q, exp_q)
3689*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r, exp_r)
3690*da0073e9SAndroid Build Coastguard Worker            #
3691*da0073e9SAndroid Build Coastguard Worker            # for mode='r' we need a special logic because numpy returns only r
3692*da0073e9SAndroid Build Coastguard Worker            exp_r = np.linalg.qr(np_t, mode='r')
3693*da0073e9SAndroid Build Coastguard Worker            q, r = torch.linalg.qr(t, mode='r')
3694*da0073e9SAndroid Build Coastguard Worker            # check that q is empty
3695*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q.shape, (0,))
3696*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q.dtype, t.dtype)
3697*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q.device, t.device)
3698*da0073e9SAndroid Build Coastguard Worker            # check r
3699*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, exp_r)
3700*da0073e9SAndroid Build Coastguard Worker
3701*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
3702*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3703*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3704*da0073e9SAndroid Build Coastguard Worker    def test_linalg_qr_autograd_errors(self, device, dtype):
3705*da0073e9SAndroid Build Coastguard Worker        # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
3706*da0073e9SAndroid Build Coastguard Worker        # without 'q' you cannot compute the backward pass. Check that
3707*da0073e9SAndroid Build Coastguard Worker        # linalg_qr_backward complains cleanly in that case.
3708*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
3709*da0073e9SAndroid Build Coastguard Worker        q, r = torch.linalg.qr(inp, mode='r')
3710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q.shape, (0,))  # empty tensor
3711*da0073e9SAndroid Build Coastguard Worker        b = torch.sum(r)
3712*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
3713*da0073e9SAndroid Build Coastguard Worker                                    "The derivative of linalg.qr depends on Q"):
3714*da0073e9SAndroid Build Coastguard Worker            b.backward()
3715*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
3716*da0073e9SAndroid Build Coastguard Worker        q, r = torch.linalg.qr(inp, mode='complete')
3717*da0073e9SAndroid Build Coastguard Worker        b = torch.sum(r)
3718*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
3719*da0073e9SAndroid Build Coastguard Worker                                    "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"):
3720*da0073e9SAndroid Build Coastguard Worker            b.backward()
3721*da0073e9SAndroid Build Coastguard Worker
3722*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
3723*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3724*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3725*da0073e9SAndroid Build Coastguard Worker    def test_qr_batched(self, device, dtype):
3726*da0073e9SAndroid Build Coastguard Worker        """
3727*da0073e9SAndroid Build Coastguard Worker        test torch.linalg.qr vs numpy.linalg.qr. We need some special logic
3728*da0073e9SAndroid Build Coastguard Worker        because numpy does not support batched qr
3729*da0073e9SAndroid Build Coastguard Worker        """
3730*da0073e9SAndroid Build Coastguard Worker        def np_qr_batched(a, mode):
3731*da0073e9SAndroid Build Coastguard Worker            """poor's man batched version of np.linalg.qr"""
3732*da0073e9SAndroid Build Coastguard Worker            all_q = []
3733*da0073e9SAndroid Build Coastguard Worker            all_r = []
3734*da0073e9SAndroid Build Coastguard Worker            for matrix in a:
3735*da0073e9SAndroid Build Coastguard Worker                result = np.linalg.qr(matrix, mode=mode)
3736*da0073e9SAndroid Build Coastguard Worker                if mode == 'r':
3737*da0073e9SAndroid Build Coastguard Worker                    all_r.append(result)
3738*da0073e9SAndroid Build Coastguard Worker                else:
3739*da0073e9SAndroid Build Coastguard Worker                    q, r = result
3740*da0073e9SAndroid Build Coastguard Worker                    all_q.append(q)
3741*da0073e9SAndroid Build Coastguard Worker                    all_r.append(r)
3742*da0073e9SAndroid Build Coastguard Worker            if mode == 'r':
3743*da0073e9SAndroid Build Coastguard Worker                return np.array(all_r)
3744*da0073e9SAndroid Build Coastguard Worker            else:
3745*da0073e9SAndroid Build Coastguard Worker                return np.array(all_q), np.array(all_r)
3746*da0073e9SAndroid Build Coastguard Worker
3747*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((3, 7, 5), device=device, dtype=dtype)
3748*da0073e9SAndroid Build Coastguard Worker        np_t = t.cpu().numpy()
3749*da0073e9SAndroid Build Coastguard Worker        for mode in ['reduced', 'complete']:
3750*da0073e9SAndroid Build Coastguard Worker            exp_q, exp_r = np_qr_batched(np_t, mode=mode)
3751*da0073e9SAndroid Build Coastguard Worker            q, r = torch.linalg.qr(t, mode=mode)
3752*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q, exp_q)
3753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, exp_r)
3754*da0073e9SAndroid Build Coastguard Worker        # for mode='r' we need a special logic because numpy returns only r
3755*da0073e9SAndroid Build Coastguard Worker        exp_r = np_qr_batched(np_t, mode='r')
3756*da0073e9SAndroid Build Coastguard Worker        q, r = torch.linalg.qr(t, mode='r')
3757*da0073e9SAndroid Build Coastguard Worker        # check that q is empty
3758*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q.shape, (0,))
3759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q.dtype, t.dtype)
3760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q.device, t.device)
3761*da0073e9SAndroid Build Coastguard Worker        # check r
3762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, exp_r)
3763*da0073e9SAndroid Build Coastguard Worker
3764*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
3765*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
3766*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3767*da0073e9SAndroid Build Coastguard Worker    def test_qr_error_cases(self, device, dtype):
3768*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(5, device=device, dtype=dtype)
3769*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'):
3770*da0073e9SAndroid Build Coastguard Worker            torch.linalg.qr(t1)
3771*da0073e9SAndroid Build Coastguard Worker        t2 = torch.randn((5, 7), device=device, dtype=dtype)
3772*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
3773*da0073e9SAndroid Build Coastguard Worker            torch.linalg.qr(t2, mode='hello')
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker    def _check_einsum(self, *args, np_args=None):
3776*da0073e9SAndroid Build Coastguard Worker        if np_args is None:
3777*da0073e9SAndroid Build Coastguard Worker            np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
3778*da0073e9SAndroid Build Coastguard Worker        ref = np.einsum(*np_args)
3779*da0073e9SAndroid Build Coastguard Worker        res = torch.einsum(*args)
3780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3781*da0073e9SAndroid Build Coastguard Worker
3782*da0073e9SAndroid Build Coastguard Worker        # Check that the other variations for opt_einsum work too
3783*da0073e9SAndroid Build Coastguard Worker        if TEST_OPT_EINSUM:
3784*da0073e9SAndroid Build Coastguard Worker            with opt_einsum.flags(enabled=False):
3785*da0073e9SAndroid Build Coastguard Worker                res = torch.einsum(*args)
3786*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, res)
3787*da0073e9SAndroid Build Coastguard Worker
3788*da0073e9SAndroid Build Coastguard Worker            with opt_einsum.flags(enabled=True, strategy='greedy'):
3789*da0073e9SAndroid Build Coastguard Worker                res = torch.einsum(*args)
3790*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, res)
3791*da0073e9SAndroid Build Coastguard Worker
3792*da0073e9SAndroid Build Coastguard Worker            with opt_einsum.flags(enabled=True, strategy='optimal'):
3793*da0073e9SAndroid Build Coastguard Worker                res = torch.einsum(*args)
3794*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, res)
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3797*da0073e9SAndroid Build Coastguard Worker    def test_einsum(self, device, dtype):
3798*da0073e9SAndroid Build Coastguard Worker        # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
3799*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
3800*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((7,), dtype=dtype, device=device)
3801*da0073e9SAndroid Build Coastguard Worker        A = make_tensor((3, 5), dtype=dtype, device=device)
3802*da0073e9SAndroid Build Coastguard Worker        B = make_tensor((2, 5), dtype=dtype, device=device)
3803*da0073e9SAndroid Build Coastguard Worker        C = make_tensor((2, 3, 5), dtype=dtype, device=device)
3804*da0073e9SAndroid Build Coastguard Worker        D = make_tensor((2, 5, 7), dtype=dtype, device=device)
3805*da0073e9SAndroid Build Coastguard Worker        E = make_tensor((7, 9), dtype=dtype, device=device)
3806*da0073e9SAndroid Build Coastguard Worker        F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device)
3807*da0073e9SAndroid Build Coastguard Worker        G = make_tensor((5, 4, 6), dtype=dtype, device=device)
3808*da0073e9SAndroid Build Coastguard Worker        H = make_tensor((4, 4), dtype=dtype, device=device)
3809*da0073e9SAndroid Build Coastguard Worker        I = make_tensor((2, 3, 2), dtype=dtype, device=device)
3810*da0073e9SAndroid Build Coastguard Worker
3811*da0073e9SAndroid Build Coastguard Worker        # Vector operations
3812*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('i->', x)                     # sum
3813*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('i,i->', x, x)                # dot
3814*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('i,i->i', x, x)               # vector element-wisem mul
3815*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('i,j->ij', x, y)              # outer
3816*da0073e9SAndroid Build Coastguard Worker
3817*da0073e9SAndroid Build Coastguard Worker        # Matrix operations
3818*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij->ji", A)                  # transpose
3819*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij->j", A)                   # row sum
3820*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij->i", A)                   # col sum
3821*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij,ij->ij", A, A)            # matrix element-wise mul
3822*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij,j->i", A, x)              # matrix vector multiplication
3823*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij,kj->ik", A, B)            # matmul
3824*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ij,ab->ijab", A, E)          # matrix outer product
3825*da0073e9SAndroid Build Coastguard Worker
3826*da0073e9SAndroid Build Coastguard Worker        # Tensor operations
3827*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("Aij,Ajk->Aik", C, D)         # batch matmul
3828*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ijk,jk->i", C, A)            # tensor matrix contraction
3829*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("aij,jk->aik", D, E)          # tensor matrix contraction
3830*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("abCd,dFg->abCFg", F, G)      # tensor tensor contraction
3831*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ijk,jk->ik", C, A)           # tensor matrix contraction with double indices
3832*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ijk,jk->ij", C, A)           # tensor matrix contraction with double indices
3833*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ijk,ik->j", C, B)            # non contiguous
3834*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ijk,ik->jk", C, B)           # non contiguous with double indices
3835*da0073e9SAndroid Build Coastguard Worker
3836*da0073e9SAndroid Build Coastguard Worker        # Test diagonals
3837*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ii", H)                      # trace
3838*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ii->i", H)                   # diagonal
3839*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('iji->j', I)                  # non-contiguous trace
3840*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device))
3841*da0073e9SAndroid Build Coastguard Worker
3842*da0073e9SAndroid Build Coastguard Worker        # Test ellipsis
3843*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("i...->...", H)
3844*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("ki,...k->i...", A.t(), B)
3845*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("k...,jk->...", A.t(), B)
3846*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('...ik, ...j -> ...ij', C, x)
3847*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device))
3848*da0073e9SAndroid Build Coastguard Worker        self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device))
3849*da0073e9SAndroid Build Coastguard Worker
3850*da0073e9SAndroid Build Coastguard Worker        # torch.bilinear with noncontiguous tensors
3851*da0073e9SAndroid Build Coastguard Worker        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3852*da0073e9SAndroid Build Coastguard Worker        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3853*da0073e9SAndroid Build Coastguard Worker        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3854*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("bn,anm,bm->ba", l, w, r)
3855*da0073e9SAndroid Build Coastguard Worker
3856*da0073e9SAndroid Build Coastguard Worker        # with strided tensors
3857*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2])
3858*da0073e9SAndroid Build Coastguard Worker
3859*da0073e9SAndroid Build Coastguard Worker        # test multiple inputs
3860*da0073e9SAndroid Build Coastguard Worker        self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F)
3861*da0073e9SAndroid Build Coastguard Worker
3862*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3863*da0073e9SAndroid Build Coastguard Worker    def test_einsum_sublist_format(self, device, dtype):
3864*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
3865*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((7,), dtype=dtype, device=device)
3866*da0073e9SAndroid Build Coastguard Worker        A = make_tensor((3, 5), dtype=dtype, device=device)
3867*da0073e9SAndroid Build Coastguard Worker        B = make_tensor((2, 5), dtype=dtype, device=device)
3868*da0073e9SAndroid Build Coastguard Worker        C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)
3869*da0073e9SAndroid Build Coastguard Worker
3870*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(x, [0])
3871*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(x, [0], [])
3872*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(x, [0], y, [1], [0, 1])
3873*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A, [0, 1], [1, 0])
3874*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A, [0, 1], x, [1], [0])
3875*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A, [0, 1], B, [2, 1])
3876*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A, [0, 1], B, [2, 1], [0, 2])
3877*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis])
3878*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0])
3879*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis])
3880*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis])
3881*da0073e9SAndroid Build Coastguard Worker
3882*da0073e9SAndroid Build Coastguard Worker        # torch.bilinear with noncontiguous tensors
3883*da0073e9SAndroid Build Coastguard Worker        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3884*da0073e9SAndroid Build Coastguard Worker        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3885*da0073e9SAndroid Build Coastguard Worker        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3886*da0073e9SAndroid Build Coastguard Worker        self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2])
3887*da0073e9SAndroid Build Coastguard Worker
3888*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3889*da0073e9SAndroid Build Coastguard Worker    def test_einsum_random(self, device, dtype):
3890*da0073e9SAndroid Build Coastguard Worker        def convert_label(label):
3891*da0073e9SAndroid Build Coastguard Worker            if label == ...:
3892*da0073e9SAndroid Build Coastguard Worker                return '...'
3893*da0073e9SAndroid Build Coastguard Worker            elif label < 26:
3894*da0073e9SAndroid Build Coastguard Worker                return chr(ord('A') + label)
3895*da0073e9SAndroid Build Coastguard Worker            else:
3896*da0073e9SAndroid Build Coastguard Worker                return chr(ord('a') + label - 26)
3897*da0073e9SAndroid Build Coastguard Worker
3898*da0073e9SAndroid Build Coastguard Worker        def convert_sublist(sublist):
3899*da0073e9SAndroid Build Coastguard Worker            return ''.join(convert_label(label) for label in sublist)
3900*da0073e9SAndroid Build Coastguard Worker
3901*da0073e9SAndroid Build Coastguard Worker        def test(n=10,                       # how many tests to generate
3902*da0073e9SAndroid Build Coastguard Worker                 n_labels=5,                 # how many labels available
3903*da0073e9SAndroid Build Coastguard Worker                 min_ops=1, max_ops=4,       # min and max number of operands per test
3904*da0073e9SAndroid Build Coastguard Worker                 min_dims=1, max_dims=3,     # min and max number of dimensions per operand
3905*da0073e9SAndroid Build Coastguard Worker                 min_size=1, max_size=8,     # min and max size of each dimension
3906*da0073e9SAndroid Build Coastguard Worker                 max_out_dim=3,              # max number of dimensions for the output
3907*da0073e9SAndroid Build Coastguard Worker                 enable_diagonals=True,      # controls if labels can be repeated for diagonals
3908*da0073e9SAndroid Build Coastguard Worker                 ellipsis_prob=0.5,          # probability of including ellipsis in operand
3909*da0073e9SAndroid Build Coastguard Worker                 broadcasting_prob=0.1):     # probability of turning some dim sizes 1 for broadcasting
3910*da0073e9SAndroid Build Coastguard Worker
3911*da0073e9SAndroid Build Coastguard Worker            all_labels = torch.arange(52)
3912*da0073e9SAndroid Build Coastguard Worker
3913*da0073e9SAndroid Build Coastguard Worker            assert 0 <= n
3914*da0073e9SAndroid Build Coastguard Worker            assert 0 <= n_labels < len(all_labels)
3915*da0073e9SAndroid Build Coastguard Worker            assert 0 < min_ops <= max_ops
3916*da0073e9SAndroid Build Coastguard Worker            assert 0 <= min_dims <= max_dims
3917*da0073e9SAndroid Build Coastguard Worker            assert 0 <= min_size <= max_size
3918*da0073e9SAndroid Build Coastguard Worker            assert 0 <= max_out_dim
3919*da0073e9SAndroid Build Coastguard Worker            assert enable_diagonals or max_dims <= n_labels
3920*da0073e9SAndroid Build Coastguard Worker
3921*da0073e9SAndroid Build Coastguard Worker            for _ in range(n):
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Worker                # Select a subset of labels for this test and give them random sizes
3924*da0073e9SAndroid Build Coastguard Worker                possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]]
3925*da0073e9SAndroid Build Coastguard Worker                labels_size = torch.randint_like(all_labels, min_size, max_size + 1)
3926*da0073e9SAndroid Build Coastguard Worker                ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,))
3927*da0073e9SAndroid Build Coastguard Worker
3928*da0073e9SAndroid Build Coastguard Worker                operands = []
3929*da0073e9SAndroid Build Coastguard Worker                sublists = []
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker                ell_size = 0
3932*da0073e9SAndroid Build Coastguard Worker                valid_labels = set()
3933*da0073e9SAndroid Build Coastguard Worker
3934*da0073e9SAndroid Build Coastguard Worker                # create random input operands
3935*da0073e9SAndroid Build Coastguard Worker                for _ in range(random.randint(min_ops, max_ops)):
3936*da0073e9SAndroid Build Coastguard Worker                    n_dim = random.randint(min_dims, max_dims)
3937*da0073e9SAndroid Build Coastguard Worker                    labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals)
3938*da0073e9SAndroid Build Coastguard Worker                    labels = possible_labels[labels_idx]
3939*da0073e9SAndroid Build Coastguard Worker                    valid_labels.update(labels.tolist())
3940*da0073e9SAndroid Build Coastguard Worker                    shape = labels_size[labels]
3941*da0073e9SAndroid Build Coastguard Worker
3942*da0073e9SAndroid Build Coastguard Worker                    # turn some dimensions to size 1 for testing broadcasting
3943*da0073e9SAndroid Build Coastguard Worker                    mask = Binomial(probs=broadcasting_prob).sample((n_dim,))
3944*da0073e9SAndroid Build Coastguard Worker                    broadcast_labels = torch.unique(labels[mask == 1])
3945*da0073e9SAndroid Build Coastguard Worker                    shape[(labels[..., None] == broadcast_labels).any(-1)] = 1
3946*da0073e9SAndroid Build Coastguard Worker
3947*da0073e9SAndroid Build Coastguard Worker                    labels = labels.tolist()
3948*da0073e9SAndroid Build Coastguard Worker                    shape = shape.tolist()
3949*da0073e9SAndroid Build Coastguard Worker
3950*da0073e9SAndroid Build Coastguard Worker                    # include ellipsis if not all dimensions were assigned a label already
3951*da0073e9SAndroid Build Coastguard Worker                    if n_dim < max_dims and torch.rand(1) < ellipsis_prob:
3952*da0073e9SAndroid Build Coastguard Worker                        ell_num_dim = random.randint(1, max_dims - n_dim)
3953*da0073e9SAndroid Build Coastguard Worker                        ell_size = max(ell_size, ell_num_dim)
3954*da0073e9SAndroid Build Coastguard Worker                        ell_shape = ellipsis_shape[-ell_num_dim:]
3955*da0073e9SAndroid Build Coastguard Worker                        # again, turn some dimensions to size 1 for broadcasting
3956*da0073e9SAndroid Build Coastguard Worker                        mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,))
3957*da0073e9SAndroid Build Coastguard Worker                        ell_shape[mask == 1] = 1
3958*da0073e9SAndroid Build Coastguard Worker                        ell_index = random.randint(0, n_dim)
3959*da0073e9SAndroid Build Coastguard Worker                        shape[ell_index:ell_index] = ell_shape
3960*da0073e9SAndroid Build Coastguard Worker                        labels.insert(ell_index, ...)
3961*da0073e9SAndroid Build Coastguard Worker
3962*da0073e9SAndroid Build Coastguard Worker                    operands.append(make_tensor(shape, dtype=dtype, device=device))
3963*da0073e9SAndroid Build Coastguard Worker                    sublists.append(labels)
3964*da0073e9SAndroid Build Coastguard Worker
3965*da0073e9SAndroid Build Coastguard Worker                # NumPy has a bug with the sublist format so for now we compare PyTorch sublist
3966*da0073e9SAndroid Build Coastguard Worker                # implementation against the equation format implementation of NumPy
3967*da0073e9SAndroid Build Coastguard Worker                # see https://github.com/numpy/numpy/issues/10926
3968*da0073e9SAndroid Build Coastguard Worker                np_operands = [op.cpu().numpy() for op in operands]
3969*da0073e9SAndroid Build Coastguard Worker
3970*da0073e9SAndroid Build Coastguard Worker                # test equation format
3971*da0073e9SAndroid Build Coastguard Worker                equation = ','.join(convert_sublist(l) for l in sublists)
3972*da0073e9SAndroid Build Coastguard Worker                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3973*da0073e9SAndroid Build Coastguard Worker
3974*da0073e9SAndroid Build Coastguard Worker                # test sublist format
3975*da0073e9SAndroid Build Coastguard Worker                args = list(itertools.chain.from_iterable(zip(operands, sublists)))
3976*da0073e9SAndroid Build Coastguard Worker                self._check_einsum(*args, np_args=(equation, *np_operands))
3977*da0073e9SAndroid Build Coastguard Worker
3978*da0073e9SAndroid Build Coastguard Worker                # generate an explicit output
3979*da0073e9SAndroid Build Coastguard Worker                out_sublist = []
3980*da0073e9SAndroid Build Coastguard Worker                num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size)
3981*da0073e9SAndroid Build Coastguard Worker                if num_out_labels > 0:
3982*da0073e9SAndroid Build Coastguard Worker                    out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels)
3983*da0073e9SAndroid Build Coastguard Worker                    out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist()
3984*da0073e9SAndroid Build Coastguard Worker                out_sublist.insert(random.randint(0, num_out_labels), ...)
3985*da0073e9SAndroid Build Coastguard Worker
3986*da0073e9SAndroid Build Coastguard Worker                # test equation format with explicit output
3987*da0073e9SAndroid Build Coastguard Worker                equation += '->' + convert_sublist(out_sublist)
3988*da0073e9SAndroid Build Coastguard Worker                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker                # test sublist format with explicit output
3991*da0073e9SAndroid Build Coastguard Worker                args.append(out_sublist)
3992*da0073e9SAndroid Build Coastguard Worker                self._check_einsum(*args, np_args=(equation, *np_operands))
3993*da0073e9SAndroid Build Coastguard Worker
3994*da0073e9SAndroid Build Coastguard Worker        test(500)
3995*da0073e9SAndroid Build Coastguard Worker
3996*da0073e9SAndroid Build Coastguard Worker    def test_einsum_corner_cases(self, device):
3997*da0073e9SAndroid Build Coastguard Worker        def check(equation, *operands, expected_output):
3998*da0073e9SAndroid Build Coastguard Worker            tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
3999*da0073e9SAndroid Build Coastguard Worker                       else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands]
4000*da0073e9SAndroid Build Coastguard Worker            output = torch.einsum(equation, tensors)
4001*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device))
4002*da0073e9SAndroid Build Coastguard Worker
4003*da0073e9SAndroid Build Coastguard Worker        # Test equation variantions
4004*da0073e9SAndroid Build Coastguard Worker        check(' ', 1, expected_output=1)
4005*da0073e9SAndroid Build Coastguard Worker        check(' -> ', 1, expected_output=1)
4006*da0073e9SAndroid Build Coastguard Worker        check(' , ', 2, 2, expected_output=4)
4007*da0073e9SAndroid Build Coastguard Worker        check(' , , ', 2, 2, 2, expected_output=8)
4008*da0073e9SAndroid Build Coastguard Worker        check(' , -> ', 2, 2, expected_output=4)
4009*da0073e9SAndroid Build Coastguard Worker        check(' i ', [1], expected_output=[1])
4010*da0073e9SAndroid Build Coastguard Worker        check(' i -> ', [1], expected_output=1)
4011*da0073e9SAndroid Build Coastguard Worker        check(' i -> i ', [1], expected_output=[1])
4012*da0073e9SAndroid Build Coastguard Worker        check(' i , i ', [2], [2], expected_output=4)
4013*da0073e9SAndroid Build Coastguard Worker        check(' i , i -> i ', [2], [2], expected_output=[4])
4014*da0073e9SAndroid Build Coastguard Worker
4015*da0073e9SAndroid Build Coastguard Worker        # Test tensors with 0 size dimensions
4016*da0073e9SAndroid Build Coastguard Worker        check('i', [], expected_output=[])
4017*da0073e9SAndroid Build Coastguard Worker        check(' i j -> j', [[], []], expected_output=[])
4018*da0073e9SAndroid Build Coastguard Worker        check('ij->i', [[], []], expected_output=[0., 0.])
4019*da0073e9SAndroid Build Coastguard Worker        check(' i j k  ,  k  -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []])
4020*da0073e9SAndroid Build Coastguard Worker
4021*da0073e9SAndroid Build Coastguard Worker        # Test broadcasting
4022*da0073e9SAndroid Build Coastguard Worker        check('i,j', [2], [1, 2], expected_output=[[2, 4]])
4023*da0073e9SAndroid Build Coastguard Worker        check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]])
4024*da0073e9SAndroid Build Coastguard Worker
4025*da0073e9SAndroid Build Coastguard Worker        # Test ellipsis broadcasting
4026*da0073e9SAndroid Build Coastguard Worker        check('...', 1, expected_output=1)
4027*da0073e9SAndroid Build Coastguard Worker        check('...->', 1, expected_output=1)
4028*da0073e9SAndroid Build Coastguard Worker        check('...->...', 1, expected_output=1)
4029*da0073e9SAndroid Build Coastguard Worker        check('...', [1], expected_output=[1])
4030*da0073e9SAndroid Build Coastguard Worker        check('...->', [1], expected_output=1)
4031*da0073e9SAndroid Build Coastguard Worker        check('z...->z', [1], expected_output=[1])
4032*da0073e9SAndroid Build Coastguard Worker        check('Z...->...Z', [1], expected_output=[1])
4033*da0073e9SAndroid Build Coastguard Worker        check('...a->', [[2], [4]], expected_output=6)
4034*da0073e9SAndroid Build Coastguard Worker        check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]])
4035*da0073e9SAndroid Build Coastguard Worker
4036*da0073e9SAndroid Build Coastguard Worker    def test_einsum_error_cases(self, device):
4037*da0073e9SAndroid Build Coastguard Worker        def check(*args, regex, exception=RuntimeError):
4038*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex):
4039*da0073e9SAndroid Build Coastguard Worker                torch.einsum(*args)
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2,), dtype=torch.float32, device=device)
4042*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((2, 3), dtype=torch.float32, device=device)
4043*da0073e9SAndroid Build Coastguard Worker
4044*da0073e9SAndroid Build Coastguard Worker        check('', [], regex=r'at least one operand', exception=ValueError)
4045*da0073e9SAndroid Build Coastguard Worker        check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis')
4046*da0073e9SAndroid Build Coastguard Worker        check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found')
4047*da0073e9SAndroid Build Coastguard Worker        check('1', [x], regex=r'invalid subscript given at index 0')
4048*da0073e9SAndroid Build Coastguard Worker        check(',', [x], regex=r'fewer operands were provided than specified in the equation')
4049*da0073e9SAndroid Build Coastguard Worker        check('', [x, x], regex=r'more operands were provided than specified in the equation')
4050*da0073e9SAndroid Build Coastguard Worker        check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number '
4051*da0073e9SAndroid Build Coastguard Worker              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4052*da0073e9SAndroid Build Coastguard Worker        check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number '
4053*da0073e9SAndroid Build Coastguard Worker              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4054*da0073e9SAndroid Build Coastguard Worker        check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number '
4055*da0073e9SAndroid Build Coastguard Worker              r'of dimensions \(1\) for operand 0')
4056*da0073e9SAndroid Build Coastguard Worker        check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found')
4057*da0073e9SAndroid Build Coastguard Worker        check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)')
4058*da0073e9SAndroid Build Coastguard Worker        check('a->1', [x], regex=r'invalid subscript given at index 3')
4059*da0073e9SAndroid Build Coastguard Worker        check('a->aa', [x], regex=r'output subscript a appears more than once in the output')
4060*da0073e9SAndroid Build Coastguard Worker        check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand')
4061*da0073e9SAndroid Build Coastguard Worker        check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
4062*da0073e9SAndroid Build Coastguard Worker        check('...,...', [x, y], regex=r'does not broadcast')
4063*da0073e9SAndroid Build Coastguard Worker        check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast')
4064*da0073e9SAndroid Build Coastguard Worker        check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously'
4065*da0073e9SAndroid Build Coastguard Worker              r' seen size 2')
4066*da0073e9SAndroid Build Coastguard Worker
4067*da0073e9SAndroid Build Coastguard Worker        check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4068*da0073e9SAndroid Build Coastguard Worker        check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4069*da0073e9SAndroid Build Coastguard Worker
4070*da0073e9SAndroid Build Coastguard Worker    def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
4071*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
4072*da0073e9SAndroid Build Coastguard Worker        make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device)
4073*da0073e9SAndroid Build Coastguard Worker        b, n, k = shape
4074*da0073e9SAndroid Build Coastguard Worker        for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
4075*da0073e9SAndroid Build Coastguard Worker            # expand means that we generate a batch of matrices with a stride of zero in the batch dimension
4076*da0073e9SAndroid Build Coastguard Worker            if (conj_a or conj_b) and not dtype.is_complex:
4077*da0073e9SAndroid Build Coastguard Worker                continue
4078*da0073e9SAndroid Build Coastguard Worker            # We just expand on the batch size
4079*da0073e9SAndroid Build Coastguard Worker            if (expand_a or expand_b) and b == 1:
4080*da0073e9SAndroid Build Coastguard Worker                continue
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker            size_a = (b, n, n) if left else (b, k, k)
4083*da0073e9SAndroid Build Coastguard Worker            size_b = (b, n, k) if not tr_b else (b, k, n)
4084*da0073e9SAndroid Build Coastguard Worker
4085*da0073e9SAndroid Build Coastguard Worker            # If expand_a or expand_b, we'll expand them to the correct size later
4086*da0073e9SAndroid Build Coastguard Worker            if b == 1 or expand_a:
4087*da0073e9SAndroid Build Coastguard Worker                size_a = size_a[1:]
4088*da0073e9SAndroid Build Coastguard Worker            if b == 1 or expand_b:
4089*da0073e9SAndroid Build Coastguard Worker                size_b = size_b[1:]
4090*da0073e9SAndroid Build Coastguard Worker
4091*da0073e9SAndroid Build Coastguard Worker            if well_conditioned:
4092*da0073e9SAndroid Build Coastguard Worker                PLU = torch.linalg.lu(make_fullrank(*size_a))
4093*da0073e9SAndroid Build Coastguard Worker                if uni:
4094*da0073e9SAndroid Build Coastguard Worker                    # A = L from PLU
4095*da0073e9SAndroid Build Coastguard Worker                    A = PLU[1].transpose(-2, -1).contiguous()
4096*da0073e9SAndroid Build Coastguard Worker                else:
4097*da0073e9SAndroid Build Coastguard Worker                    # A = U from PLU
4098*da0073e9SAndroid Build Coastguard Worker                    A = PLU[2].contiguous()
4099*da0073e9SAndroid Build Coastguard Worker            else:
4100*da0073e9SAndroid Build Coastguard Worker                A = make_arg(size_a)
4101*da0073e9SAndroid Build Coastguard Worker                A.triu_()
4102*da0073e9SAndroid Build Coastguard Worker
4103*da0073e9SAndroid Build Coastguard Worker            diag = A.diagonal(0, -2, -1)
4104*da0073e9SAndroid Build Coastguard Worker            if uni:
4105*da0073e9SAndroid Build Coastguard Worker                diag.fill_(1.)
4106*da0073e9SAndroid Build Coastguard Worker            else:
4107*da0073e9SAndroid Build Coastguard Worker                diag[diag.abs() < 1e-6] = 1.
4108*da0073e9SAndroid Build Coastguard Worker
4109*da0073e9SAndroid Build Coastguard Worker            B = make_arg(size_b)
4110*da0073e9SAndroid Build Coastguard Worker
4111*da0073e9SAndroid Build Coastguard Worker            if tr_a:
4112*da0073e9SAndroid Build Coastguard Worker                A.transpose_(-2, -1)
4113*da0073e9SAndroid Build Coastguard Worker            if tr_b:
4114*da0073e9SAndroid Build Coastguard Worker                B.transpose_(-2, -1)
4115*da0073e9SAndroid Build Coastguard Worker            if conj_a:
4116*da0073e9SAndroid Build Coastguard Worker                A = A.conj()
4117*da0073e9SAndroid Build Coastguard Worker            if conj_b:
4118*da0073e9SAndroid Build Coastguard Worker                B = B.conj()
4119*da0073e9SAndroid Build Coastguard Worker            if expand_a:
4120*da0073e9SAndroid Build Coastguard Worker                A = A.expand(b, *size_a)
4121*da0073e9SAndroid Build Coastguard Worker            if expand_b:
4122*da0073e9SAndroid Build Coastguard Worker                B = B.expand(b, n, k)
4123*da0073e9SAndroid Build Coastguard Worker            yield A, B, left, not tr_a, uni
4124*da0073e9SAndroid Build Coastguard Worker
4125*da0073e9SAndroid Build Coastguard Worker    def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
4126*da0073e9SAndroid Build Coastguard Worker        X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4127*da0073e9SAndroid Build Coastguard Worker        if left:
4128*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A @ X, B)
4129*da0073e9SAndroid Build Coastguard Worker        else:
4130*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(X @ A, B)
4131*da0073e9SAndroid Build Coastguard Worker        out = B
4132*da0073e9SAndroid Build Coastguard Worker        # B may be expanded
4133*da0073e9SAndroid Build Coastguard Worker        if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
4134*da0073e9SAndroid Build Coastguard Worker            out = B.clone()
4135*da0073e9SAndroid Build Coastguard Worker        torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
4136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(X, out)
4137*da0073e9SAndroid Build Coastguard Worker
4138*da0073e9SAndroid Build Coastguard Worker    # Tolerances dictated by widest acceptable range on CPU before failure
4139*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4140*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1,
4141*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8,
4142*da0073e9SAndroid Build Coastguard Worker                        torch.complex64: 1e-1,
4143*da0073e9SAndroid Build Coastguard Worker                        torch.complex128: 1e-8})
4144*da0073e9SAndroid Build Coastguard Worker    def test_linalg_solve_triangular(self, device, dtype):
4145*da0073e9SAndroid Build Coastguard Worker        # This exercises the API + BLAS CPU + batched cuBLAS
4146*da0073e9SAndroid Build Coastguard Worker        ks = (3, 1, 0)
4147*da0073e9SAndroid Build Coastguard Worker        ns = (5, 0)
4148*da0073e9SAndroid Build Coastguard Worker        bs = (1, 2, 0)
4149*da0073e9SAndroid Build Coastguard Worker
4150*da0073e9SAndroid Build Coastguard Worker        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4151*da0073e9SAndroid Build Coastguard Worker        for b, n, k in product(bs, ns, ks):
4152*da0073e9SAndroid Build Coastguard Worker            for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True):
4153*da0073e9SAndroid Build Coastguard Worker                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4154*da0073e9SAndroid Build Coastguard Worker
4155*da0073e9SAndroid Build Coastguard Worker    @slowTest
4156*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
4157*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4158*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma  # Magma needed for the PLU decomposition
4159*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4160*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4161*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
4162*da0073e9SAndroid Build Coastguard Worker    def test_linalg_solve_triangular_large(self, device, dtype):
4163*da0073e9SAndroid Build Coastguard Worker        # Exercises magma and cublas
4164*da0073e9SAndroid Build Coastguard Worker        magma = (9, 513, 1)
4165*da0073e9SAndroid Build Coastguard Worker        iterative_cublas = (2, 64, 1)
4166*da0073e9SAndroid Build Coastguard Worker
4167*da0073e9SAndroid Build Coastguard Worker        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4168*da0073e9SAndroid Build Coastguard Worker        for shape in (magma, iterative_cublas):
4169*da0073e9SAndroid Build Coastguard Worker            for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
4170*da0073e9SAndroid Build Coastguard Worker                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4171*da0073e9SAndroid Build Coastguard Worker
4172*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4173*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4174*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
4175*da0073e9SAndroid Build Coastguard Worker    def test_linalg_solve_triangular_broadcasting(self, device, dtype):
4176*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
4177*da0073e9SAndroid Build Coastguard Worker
4178*da0073e9SAndroid Build Coastguard Worker        sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
4179*da0073e9SAndroid Build Coastguard Worker                 ((2, 1, 3, 4, 4), (4, 6)),
4180*da0073e9SAndroid Build Coastguard Worker                 ((4, 4), (2, 1, 3, 4, 2)),
4181*da0073e9SAndroid Build Coastguard Worker                 ((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
4182*da0073e9SAndroid Build Coastguard Worker        for size_A, size_B in sizes:
4183*da0073e9SAndroid Build Coastguard Worker            for left, upper, uni in itertools.product([True, False], repeat=3):
4184*da0073e9SAndroid Build Coastguard Worker                A = make_arg(size_A)
4185*da0073e9SAndroid Build Coastguard Worker                if upper:
4186*da0073e9SAndroid Build Coastguard Worker                    A.triu_()
4187*da0073e9SAndroid Build Coastguard Worker                else:
4188*da0073e9SAndroid Build Coastguard Worker                    A.tril_()
4189*da0073e9SAndroid Build Coastguard Worker                diag = A.diagonal(0, -2, -1)
4190*da0073e9SAndroid Build Coastguard Worker                if uni:
4191*da0073e9SAndroid Build Coastguard Worker                    diag.fill_(1.)
4192*da0073e9SAndroid Build Coastguard Worker                else:
4193*da0073e9SAndroid Build Coastguard Worker                    diag[diag.abs() < 1e-6] = 1.
4194*da0073e9SAndroid Build Coastguard Worker                B = make_arg(size_B)
4195*da0073e9SAndroid Build Coastguard Worker                if not left:
4196*da0073e9SAndroid Build Coastguard Worker                    B.transpose_(-2, -1)
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker                X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4199*da0073e9SAndroid Build Coastguard Worker                if left:
4200*da0073e9SAndroid Build Coastguard Worker                    B_other = A @ X
4201*da0073e9SAndroid Build Coastguard Worker                else:
4202*da0073e9SAndroid Build Coastguard Worker                    B_other = X @ A
4203*da0073e9SAndroid Build Coastguard Worker
4204*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(*torch.broadcast_tensors(B, B_other))
4205*da0073e9SAndroid Build Coastguard Worker
4206*da0073e9SAndroid Build Coastguard Worker    def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
4207*da0073e9SAndroid Build Coastguard Worker                                     device, dtype):
4208*da0073e9SAndroid Build Coastguard Worker        triangle_function = torch.triu if upper else torch.tril
4209*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(*b_dims, dtype=dtype, device=device)
4210*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(*A_dims, dtype=dtype, device=device)
4211*da0073e9SAndroid Build Coastguard Worker        # create positive definite matrix
4212*da0073e9SAndroid Build Coastguard Worker        A = torch.matmul(A, A.mT)
4213*da0073e9SAndroid Build Coastguard Worker        A_triangular = triangle_function(A)
4214*da0073e9SAndroid Build Coastguard Worker        if unitriangular:
4215*da0073e9SAndroid Build Coastguard Worker            A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.)
4216*da0073e9SAndroid Build Coastguard Worker        return b, A_triangular
4217*da0073e9SAndroid Build Coastguard Worker
4218*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
4219*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
4220*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("flaky, needs investigation")
4221*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4222*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4223*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
4224*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve(self, device, dtype):
4225*da0073e9SAndroid Build Coastguard Worker        ks = [0, 1, 3]
4226*da0073e9SAndroid Build Coastguard Worker        ns = [0, 5]
4227*da0073e9SAndroid Build Coastguard Worker        for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns,
4228*da0073e9SAndroid Build Coastguard Worker                                                                         itertools.product([True, False], repeat=3)):
4229*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper((n, n), (n, k), upper,
4230*da0073e9SAndroid Build Coastguard Worker                                                     unitriangular, device, dtype)
4231*da0073e9SAndroid Build Coastguard Worker            x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4232*da0073e9SAndroid Build Coastguard Worker            if transpose:
4233*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu()))
4234*da0073e9SAndroid Build Coastguard Worker            else:
4235*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
4236*da0073e9SAndroid Build Coastguard Worker
4237*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
4238*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
4239*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4240*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4241*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
4242*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve_batched(self, device, dtype):
4243*da0073e9SAndroid Build Coastguard Worker        def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4244*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4245*da0073e9SAndroid Build Coastguard Worker                                                     unitriangular, device, dtype)
4246*da0073e9SAndroid Build Coastguard Worker            x_exp_list = []
4247*da0073e9SAndroid Build Coastguard Worker            for i in range(b_dims[0]):
4248*da0073e9SAndroid Build Coastguard Worker                x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper,
4249*da0073e9SAndroid Build Coastguard Worker                                                         unitriangular=unitriangular,
4250*da0073e9SAndroid Build Coastguard Worker                                                         transpose=transpose)[0])
4251*da0073e9SAndroid Build Coastguard Worker            x_exp = torch.stack(x_exp_list)  # Stacked output
4252*da0073e9SAndroid Build Coastguard Worker            x_act = torch.triangular_solve(b, A, upper=upper,
4253*da0073e9SAndroid Build Coastguard Worker                                           unitriangular=unitriangular,
4254*da0073e9SAndroid Build Coastguard Worker                                           transpose=transpose)[0]  # Actual output
4255*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_act, x_exp)  # Equality check
4256*da0073e9SAndroid Build Coastguard Worker            if transpose:
4257*da0073e9SAndroid Build Coastguard Worker                A = A.mT
4258*da0073e9SAndroid Build Coastguard Worker
4259*da0073e9SAndroid Build Coastguard Worker            Ax = np.matmul(A.cpu(), x_act.cpu())
4260*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b, Ax)
4261*da0073e9SAndroid Build Coastguard Worker
4262*da0073e9SAndroid Build Coastguard Worker        def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4263*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4264*da0073e9SAndroid Build Coastguard Worker                                                     unitriangular, device, dtype)
4265*da0073e9SAndroid Build Coastguard Worker            x = torch.triangular_solve(b, A, upper=upper,
4266*da0073e9SAndroid Build Coastguard Worker                                       unitriangular=unitriangular,
4267*da0073e9SAndroid Build Coastguard Worker                                       transpose=transpose)[0]
4268*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.shape == b.shape)
4269*da0073e9SAndroid Build Coastguard Worker
4270*da0073e9SAndroid Build Coastguard Worker        for upper, unitriangular, transpose in itertools.product([True, False], repeat=3):
4271*da0073e9SAndroid Build Coastguard Worker            batchsize = 3
4272*da0073e9SAndroid Build Coastguard Worker            triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4273*da0073e9SAndroid Build Coastguard Worker                                          upper, unitriangular, transpose)
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker            # test empty input
4276*da0073e9SAndroid Build Coastguard Worker            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10),
4277*da0073e9SAndroid Build Coastguard Worker                                          upper, unitriangular, transpose)
4278*da0073e9SAndroid Build Coastguard Worker            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0),
4279*da0073e9SAndroid Build Coastguard Worker                                          upper, unitriangular, transpose)
4280*da0073e9SAndroid Build Coastguard Worker
4281*da0073e9SAndroid Build Coastguard Worker            # test zero batch case
4282*da0073e9SAndroid Build Coastguard Worker            batchsize = 0
4283*da0073e9SAndroid Build Coastguard Worker            triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4284*da0073e9SAndroid Build Coastguard Worker                                               upper, unitriangular, transpose)
4285*da0073e9SAndroid Build Coastguard Worker
4286*da0073e9SAndroid Build Coastguard Worker
4287*da0073e9SAndroid Build Coastguard Worker    @slowTest
4288*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
4289*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
4290*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4291*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4292*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
4293*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve_batched_many_batches(self, device, dtype):
4294*da0073e9SAndroid Build Coastguard Worker        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4295*da0073e9SAndroid Build Coastguard Worker            # test batched A case
4296*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1),
4297*da0073e9SAndroid Build Coastguard Worker                                                     upper, unitriangular, device, dtype)
4298*da0073e9SAndroid Build Coastguard Worker            x, _ = torch.triangular_solve(b, A,
4299*da0073e9SAndroid Build Coastguard Worker                                          upper=upper, transpose=transpose, unitriangular=unitriangular)
4300*da0073e9SAndroid Build Coastguard Worker            if transpose:
4301*da0073e9SAndroid Build Coastguard Worker                A = A.mT
4302*da0073e9SAndroid Build Coastguard Worker
4303*da0073e9SAndroid Build Coastguard Worker            Ax = torch.matmul(A, x)
4304*da0073e9SAndroid Build Coastguard Worker
4305*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision
4306*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol)
4307*da0073e9SAndroid Build Coastguard Worker
4308*da0073e9SAndroid Build Coastguard Worker            # test batched b case
4309*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1),
4310*da0073e9SAndroid Build Coastguard Worker                                                     upper, unitriangular, device, dtype)
4311*da0073e9SAndroid Build Coastguard Worker            x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose,
4312*da0073e9SAndroid Build Coastguard Worker                                          unitriangular=unitriangular)
4313*da0073e9SAndroid Build Coastguard Worker            if transpose:
4314*da0073e9SAndroid Build Coastguard Worker                A = A.mT
4315*da0073e9SAndroid Build Coastguard Worker
4316*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.matmul(A, x), b)
4317*da0073e9SAndroid Build Coastguard Worker
4318*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
4319*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
4320*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
4321*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("flaky, needs investigation")
4322*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4323*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve_batched_broadcasting(self, device, dtype):
4324*da0073e9SAndroid Build Coastguard Worker        from scipy.linalg import solve_triangular as tri_solve
4325*da0073e9SAndroid Build Coastguard Worker
4326*da0073e9SAndroid Build Coastguard Worker        def scipy_tri_solve_batched(A, B, upper, trans, diag):
4327*da0073e9SAndroid Build Coastguard Worker            batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4328*da0073e9SAndroid Build Coastguard Worker            single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4329*da0073e9SAndroid Build Coastguard Worker            expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4330*da0073e9SAndroid Build Coastguard Worker                                                     torch.Size(batch_dims_B)))
4331*da0073e9SAndroid Build Coastguard Worker            expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4332*da0073e9SAndroid Build Coastguard Worker            expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4333*da0073e9SAndroid Build Coastguard Worker            flat_A = expand_A.reshape((-1,) + single_dim_A)
4334*da0073e9SAndroid Build Coastguard Worker            flat_B = expand_B.reshape((-1,) + single_dim_B)
4335*da0073e9SAndroid Build Coastguard Worker            flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag)
4336*da0073e9SAndroid Build Coastguard Worker                                for a, b in zip(flat_A, flat_B)])
4337*da0073e9SAndroid Build Coastguard Worker            return flat_X.reshape(expand_B.shape)
4338*da0073e9SAndroid Build Coastguard Worker
4339*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, b_dims, device, upper, transpose, unitriangular):
4340*da0073e9SAndroid Build Coastguard Worker            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4341*da0073e9SAndroid Build Coastguard Worker                                                     unitriangular, device, dtype)
4342*da0073e9SAndroid Build Coastguard Worker            x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(),
4343*da0073e9SAndroid Build Coastguard Worker                                                            upper, transpose, unitriangular))
4344*da0073e9SAndroid Build Coastguard Worker            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_exp.to(device))
4347*da0073e9SAndroid Build Coastguard Worker
4348*da0073e9SAndroid Build Coastguard Worker        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4349*da0073e9SAndroid Build Coastguard Worker            # test against scipy.linalg.solve_triangular
4350*da0073e9SAndroid Build Coastguard Worker            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular)  # no broadcasting
4351*da0073e9SAndroid Build Coastguard Worker            run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular)  # broadcasting b
4352*da0073e9SAndroid Build Coastguard Worker            run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular)  # broadcasting A
4353*da0073e9SAndroid Build Coastguard Worker            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular)  # broadcasting A & b
4354*da0073e9SAndroid Build Coastguard Worker
4355*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4356*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4357*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve_large(self, device, dtype):
4358*da0073e9SAndroid Build Coastguard Worker        # Repro for https://github.com/pytorch/pytorch/issues/79191
4359*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_()
4360*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(1, 2, 524281, device=device, dtype=dtype)
4361*da0073e9SAndroid Build Coastguard Worker        X = torch.linalg.solve_triangular(A, B, upper=False)
4362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A @ X, B)
4363*da0073e9SAndroid Build Coastguard Worker
4364*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
4365*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
4366*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
4367*da0073e9SAndroid Build Coastguard Worker    def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
4368*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
4369*da0073e9SAndroid Build Coastguard Worker        a = torch.eye(2, dtype=dtype, device=device)
4370*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 1, dtype=dtype, device=device)
4371*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(b).to(torch.int)
4372*da0073e9SAndroid Build Coastguard Worker        clone_a = torch.empty_like(a)
4373*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4374*da0073e9SAndroid Build Coastguard Worker            torch.triangular_solve(b, a, out=(out, clone_a))
4375*da0073e9SAndroid Build Coastguard Worker
4376*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(b)
4377*da0073e9SAndroid Build Coastguard Worker        clone_a = clone_a.to(torch.int)
4378*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4379*da0073e9SAndroid Build Coastguard Worker            torch.triangular_solve(b, a, out=(out, clone_a))
4380*da0073e9SAndroid Build Coastguard Worker
4381*da0073e9SAndroid Build Coastguard Worker        # device should match
4382*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
4383*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
4384*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=wrong_device)
4385*da0073e9SAndroid Build Coastguard Worker            clone_a = torch.empty_like(a)
4386*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4387*da0073e9SAndroid Build Coastguard Worker                torch.triangular_solve(b, a, out=(out, clone_a))
4388*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, dtype=dtype, device=device)
4389*da0073e9SAndroid Build Coastguard Worker            clone_a = torch.empty_like(a).to(wrong_device)
4390*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4391*da0073e9SAndroid Build Coastguard Worker                torch.triangular_solve(b, a, out=(out, clone_a))
4392*da0073e9SAndroid Build Coastguard Worker
4393*da0073e9SAndroid Build Coastguard Worker        # Trigger the WARN_ONCE deprecation error
4394*da0073e9SAndroid Build Coastguard Worker        torch.triangular_solve(b, a)
4395*da0073e9SAndroid Build Coastguard Worker
4396*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
4397*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
4398*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(1, dtype=dtype, device=device)
4399*da0073e9SAndroid Build Coastguard Worker            clone_a = torch.empty(1, dtype=dtype, device=device)
4400*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
4401*da0073e9SAndroid Build Coastguard Worker            torch.triangular_solve(b, a, out=(out, clone_a))
4402*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
4403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
4404*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
4405*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
4406*da0073e9SAndroid Build Coastguard Worker
4407*da0073e9SAndroid Build Coastguard Worker
4408*da0073e9SAndroid Build Coastguard Worker    def check_single_matmul(self, x, y):
4409*da0073e9SAndroid Build Coastguard Worker
4410*da0073e9SAndroid Build Coastguard Worker        def assertEqual(answer, expected):
4411*da0073e9SAndroid Build Coastguard Worker            if x.dtype.is_floating_point or x.dtype.is_complex:
4412*da0073e9SAndroid Build Coastguard Worker                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
4413*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(answer, expected,
4414*da0073e9SAndroid Build Coastguard Worker                                 msg=f"{x.shape} x {y.shape} = {answer.shape}",
4415*da0073e9SAndroid Build Coastguard Worker                                 atol=k * 5e-5,
4416*da0073e9SAndroid Build Coastguard Worker                                 rtol=1e-4)
4417*da0073e9SAndroid Build Coastguard Worker            else:
4418*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}")
4419*da0073e9SAndroid Build Coastguard Worker
4420*da0073e9SAndroid Build Coastguard Worker        # test x @ y
4421*da0073e9SAndroid Build Coastguard Worker        expected = np.matmul(x.cpu(), y.cpu())
4422*da0073e9SAndroid Build Coastguard Worker        ans = torch.matmul(x, y)
4423*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ans.is_contiguous())
4424*da0073e9SAndroid Build Coastguard Worker        assertEqual(ans, expected)
4425*da0073e9SAndroid Build Coastguard Worker
4426*da0073e9SAndroid Build Coastguard Worker        # test out
4427*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(ans)
4428*da0073e9SAndroid Build Coastguard Worker        ans = torch.matmul(x, y, out=out)
4429*da0073e9SAndroid Build Coastguard Worker        self.assertIs(ans, out)
4430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ans.is_contiguous())
4431*da0073e9SAndroid Build Coastguard Worker        assertEqual(ans, expected)
4432*da0073e9SAndroid Build Coastguard Worker
4433*da0073e9SAndroid Build Coastguard Worker    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
4434*da0073e9SAndroid Build Coastguard Worker        """
4435*da0073e9SAndroid Build Coastguard Worker        Generates sequences of tuples (x, y) of with size(x) = x_dim and
4436*da0073e9SAndroid Build Coastguard Worker        size(y) <= y_dim that are compatible wrt. matmul
4437*da0073e9SAndroid Build Coastguard Worker        """
4438*da0073e9SAndroid Build Coastguard Worker        assert x_dim >= 1
4439*da0073e9SAndroid Build Coastguard Worker        assert y_dim >= 2
4440*da0073e9SAndroid Build Coastguard Worker        x = x_dim
4441*da0073e9SAndroid Build Coastguard Worker        for y in range(1, y_dim + 1):
4442*da0073e9SAndroid Build Coastguard Worker            for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
4443*da0073e9SAndroid Build Coastguard Worker                                     product(range(matrix_size), repeat=min(y, 2))):
4444*da0073e9SAndroid Build Coastguard Worker                if x == 1:
4445*da0073e9SAndroid Build Coastguard Worker                    size_x = mn[:1]
4446*da0073e9SAndroid Build Coastguard Worker                    size_y = batch + mn
4447*da0073e9SAndroid Build Coastguard Worker                    yield size_x, size_y
4448*da0073e9SAndroid Build Coastguard Worker                else:
4449*da0073e9SAndroid Build Coastguard Worker                    for k in range(matrix_size):
4450*da0073e9SAndroid Build Coastguard Worker                        size_x = (k,) + mn[:1]
4451*da0073e9SAndroid Build Coastguard Worker                        if x > 2:
4452*da0073e9SAndroid Build Coastguard Worker                            size_x = batch[-(x - 2):] + size_x
4453*da0073e9SAndroid Build Coastguard Worker                        size_y = mn
4454*da0073e9SAndroid Build Coastguard Worker                        if y > 2:
4455*da0073e9SAndroid Build Coastguard Worker                            size_y = batch[-(y - 2):] + size_y
4456*da0073e9SAndroid Build Coastguard Worker                        yield size_x, size_y
4457*da0073e9SAndroid Build Coastguard Worker
4458*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4459*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex64)
4460*da0073e9SAndroid Build Coastguard Worker    @setBlasBackendsToDefaultFinally
4461*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
4462*da0073e9SAndroid Build Coastguard Worker        for backend in ["cublas", "cublaslt"]:
4463*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
4464*da0073e9SAndroid Build Coastguard Worker                torch.backends.cuda.preferred_blas_library(backend)
4465*da0073e9SAndroid Build Coastguard Worker
4466*da0073e9SAndroid Build Coastguard Worker            make_arg = partial(make_tensor, device=device, dtype=dtype)
4467*da0073e9SAndroid Build Coastguard Worker
4468*da0073e9SAndroid Build Coastguard Worker            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4469*da0073e9SAndroid Build Coastguard Worker                x = make_arg(size_x, noncontiguous=nctg_x)
4470*da0073e9SAndroid Build Coastguard Worker                y = make_arg(size_y, noncontiguous=nctg_y)
4471*da0073e9SAndroid Build Coastguard Worker                self.check_single_matmul(x, y)
4472*da0073e9SAndroid Build Coastguard Worker
4473*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4474*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex64)
4475*da0073e9SAndroid Build Coastguard Worker    @setBlasBackendsToDefaultFinally
4476*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
4477*da0073e9SAndroid Build Coastguard Worker        for backend in ["cublas", "cublaslt"]:
4478*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
4479*da0073e9SAndroid Build Coastguard Worker                torch.backends.cuda.preferred_blas_library(backend)
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker            make_arg = partial(make_tensor, device=device, dtype=dtype)
4482*da0073e9SAndroid Build Coastguard Worker
4483*da0073e9SAndroid Build Coastguard Worker            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
4484*da0073e9SAndroid Build Coastguard Worker                x = make_arg(size_x, noncontiguous=nctg_x)
4485*da0073e9SAndroid Build Coastguard Worker                y = make_arg(size_y, noncontiguous=nctg_y)
4486*da0073e9SAndroid Build Coastguard Worker                self.check_single_matmul(x, y)
4487*da0073e9SAndroid Build Coastguard Worker
4488*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4489*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex64)
4490*da0073e9SAndroid Build Coastguard Worker    @setBlasBackendsToDefaultFinally
4491*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
4492*da0073e9SAndroid Build Coastguard Worker        for backend in ["cublas", "cublaslt"]:
4493*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
4494*da0073e9SAndroid Build Coastguard Worker                torch.backends.cuda.preferred_blas_library(backend)
4495*da0073e9SAndroid Build Coastguard Worker
4496*da0073e9SAndroid Build Coastguard Worker            make_arg = partial(make_tensor, device=device, dtype=dtype)
4497*da0073e9SAndroid Build Coastguard Worker
4498*da0073e9SAndroid Build Coastguard Worker            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)):
4499*da0073e9SAndroid Build Coastguard Worker                x = make_arg(size_x, noncontiguous=nctg_x)
4500*da0073e9SAndroid Build Coastguard Worker                y = make_arg(size_y, noncontiguous=nctg_y)
4501*da0073e9SAndroid Build Coastguard Worker                self.check_single_matmul(x, y)
4502*da0073e9SAndroid Build Coastguard Worker
4503*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4504*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
4505*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_tunableop(self, device, dtype):
4506*da0073e9SAndroid Build Coastguard Worker        # disable tunableop buffer rotation for all tests everywhere, it can be slow
4507*da0073e9SAndroid Build Coastguard Worker        import os
4508*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
4509*da0073e9SAndroid Build Coastguard Worker        set_tunableop_defaults()
4510*da0073e9SAndroid Build Coastguard Worker
4511*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable()
4512*da0073e9SAndroid Build Coastguard Worker        # set these to single iterations to keep it short but still exercise the code
4513*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_duration(1)
4514*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(1)
4515*da0073e9SAndroid Build Coastguard Worker
4516*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
4517*da0073e9SAndroid Build Coastguard Worker
4518*da0073e9SAndroid Build Coastguard Worker        for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4519*da0073e9SAndroid Build Coastguard Worker            x = make_arg(size_x, noncontiguous=nctg_x)
4520*da0073e9SAndroid Build Coastguard Worker            y = make_arg(size_y, noncontiguous=nctg_y)
4521*da0073e9SAndroid Build Coastguard Worker            self.check_single_matmul(x, y)
4522*da0073e9SAndroid Build Coastguard Worker
4523*da0073e9SAndroid Build Coastguard Worker        filename1 = torch.cuda.tunable.get_filename()
4524*da0073e9SAndroid Build Coastguard Worker        filename2 = "tunableop_results_tmp1.csv"
4525*da0073e9SAndroid Build Coastguard Worker        filename3 = "tunableop_results_tmp2.csv"
4526*da0073e9SAndroid Build Coastguard Worker        ordinal = torch.cuda.current_device()
4527*da0073e9SAndroid Build Coastguard Worker        assert filename1 == f"tunableop_results{ordinal}.csv"
4528*da0073e9SAndroid Build Coastguard Worker        assert len(torch.cuda.tunable.get_validators()) > 0
4529*da0073e9SAndroid Build Coastguard Worker        validators = {}
4530*da0073e9SAndroid Build Coastguard Worker        for key, value in torch.cuda.tunable.get_validators():
4531*da0073e9SAndroid Build Coastguard Worker            validators[key] = value
4532*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip:
4533*da0073e9SAndroid Build Coastguard Worker            assert "HIPBLASLT_VERSION" in validators
4534*da0073e9SAndroid Build Coastguard Worker            assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"])
4535*da0073e9SAndroid Build Coastguard Worker        assert len(torch.cuda.tunable.get_results()) > 0
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker        assert torch.cuda.tunable.write_file()  # use default filename
4538*da0073e9SAndroid Build Coastguard Worker        assert torch.cuda.tunable.write_file(filename2)  # use custom, one-time filename
4539*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_filename(filename3)
4540*da0073e9SAndroid Build Coastguard Worker        assert torch.cuda.tunable.write_file()  # use previously set filename
4541*da0073e9SAndroid Build Coastguard Worker        assert torch.cuda.tunable.read_file()  # use previously set filename, will ignore duplicates and return True
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker        with open(filename1) as file1:
4544*da0073e9SAndroid Build Coastguard Worker            file1_contents = file1.read()
4545*da0073e9SAndroid Build Coastguard Worker        with open(filename2) as file2:
4546*da0073e9SAndroid Build Coastguard Worker            file2_contents = file2.read()
4547*da0073e9SAndroid Build Coastguard Worker        with open(filename3) as file3:
4548*da0073e9SAndroid Build Coastguard Worker            file3_contents = file3.read()
4549*da0073e9SAndroid Build Coastguard Worker        assert file1_contents == file2_contents
4550*da0073e9SAndroid Build Coastguard Worker        assert file1_contents == file3_contents
4551*da0073e9SAndroid Build Coastguard Worker
4552*da0073e9SAndroid Build Coastguard Worker        # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors
4553*da0073e9SAndroid Build Coastguard Worker        for filename in [filename1, filename2, filename3]:
4554*da0073e9SAndroid Build Coastguard Worker            try:
4555*da0073e9SAndroid Build Coastguard Worker                import os
4556*da0073e9SAndroid Build Coastguard Worker                os.remove(filename)
4557*da0073e9SAndroid Build Coastguard Worker            except FileNotFoundError:
4558*da0073e9SAndroid Build Coastguard Worker                pass
4559*da0073e9SAndroid Build Coastguard Worker
4560*da0073e9SAndroid Build Coastguard Worker        # disables TunableOp
4561*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(False)
4562*da0073e9SAndroid Build Coastguard Worker
4563*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4564*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
4565*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4566*da0073e9SAndroid Build Coastguard Worker    def test_bmm_tunableop_rocm(self, device, dtype):
4567*da0073e9SAndroid Build Coastguard Worker        # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
4568*da0073e9SAndroid Build Coastguard Worker        set_tunableop_defaults()
4569*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(True)
4570*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(10)
4571*da0073e9SAndroid Build Coastguard Worker        # the following 3 cases cover all previous failure cases and are here to catch regressions
4572*da0073e9SAndroid Build Coastguard Worker        B = 16
4573*da0073e9SAndroid Build Coastguard Worker        N = M = K = 256
4574*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
4575*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
4576*da0073e9SAndroid Build Coastguard Worker        # case 1
4577*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4578*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4579*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4580*da0073e9SAndroid Build Coastguard Worker        # case 2
4581*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4582*da0073e9SAndroid Build Coastguard Worker        i1 = torch.permute(i1, (1, 2, 0))
4583*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4584*da0073e9SAndroid Build Coastguard Worker        i2 = torch.permute(i2, (1, 0, 2))
4585*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4586*da0073e9SAndroid Build Coastguard Worker        # case 3
4587*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4588*da0073e9SAndroid Build Coastguard Worker        i1 = torch.permute(i1, (1, 0, 2))
4589*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4590*da0073e9SAndroid Build Coastguard Worker        i2 = torch.permute(i2, (1, 2, 0))
4591*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4592*da0073e9SAndroid Build Coastguard Worker        # case 4
4593*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype)
4594*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.as_strided(
4595*da0073e9SAndroid Build Coastguard Worker            input_tensor, size=(1920, 1, 100), stride=(100, 100, 1)
4596*da0073e9SAndroid Build Coastguard Worker        )
4597*da0073e9SAndroid Build Coastguard Worker        batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype)
4598*da0073e9SAndroid Build Coastguard Worker        batch1_tensor = torch.as_strided(
4599*da0073e9SAndroid Build Coastguard Worker            batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1)
4600*da0073e9SAndroid Build Coastguard Worker        )
4601*da0073e9SAndroid Build Coastguard Worker        batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype)
4602*da0073e9SAndroid Build Coastguard Worker        batch2_tensor = torch.as_strided(
4603*da0073e9SAndroid Build Coastguard Worker            batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1)
4604*da0073e9SAndroid Build Coastguard Worker        )
4605*da0073e9SAndroid Build Coastguard Worker        out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor)
4606*da0073e9SAndroid Build Coastguard Worker        # clean up, remove any file that was generated
4607*da0073e9SAndroid Build Coastguard Worker        try:
4608*da0073e9SAndroid Build Coastguard Worker            import os
4609*da0073e9SAndroid Build Coastguard Worker            filename = torch.cuda.tunable.get_filename()
4610*da0073e9SAndroid Build Coastguard Worker            os.remove(filename)
4611*da0073e9SAndroid Build Coastguard Worker        except FileNotFoundError:
4612*da0073e9SAndroid Build Coastguard Worker            pass
4613*da0073e9SAndroid Build Coastguard Worker
4614*da0073e9SAndroid Build Coastguard Worker        # disable TunableOp
4615*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(False)
4616*da0073e9SAndroid Build Coastguard Worker
4617*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4618*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
4619*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4620*da0073e9SAndroid Build Coastguard Worker    def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4621*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4622*da0073e9SAndroid Build Coastguard Worker        import os
4623*da0073e9SAndroid Build Coastguard Worker        # run operator first without tuning to ensure all rocm libs are loaded,
4624*da0073e9SAndroid Build Coastguard Worker        # otherwise false positive mem leak
4625*da0073e9SAndroid Build Coastguard Worker        B = 16
4626*da0073e9SAndroid Build Coastguard Worker        N = M = K = 256
4627*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
4628*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
4629*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4630*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4631*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4632*da0073e9SAndroid Build Coastguard Worker        # enable tunableop numeric check via env variable.
4633*da0073e9SAndroid Build Coastguard Worker        PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4634*da0073e9SAndroid Build Coastguard Worker        prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4635*da0073e9SAndroid Build Coastguard Worker        try:
4636*da0073e9SAndroid Build Coastguard Worker            os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4637*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable(True)
4638*da0073e9SAndroid Build Coastguard Worker            ordinal = torch.cuda.current_device()
4639*da0073e9SAndroid Build Coastguard Worker            filename = f"tunableop_results{ordinal}.csv"
4640*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_filename(filename)
4641*da0073e9SAndroid Build Coastguard Worker            iterations = torch.cuda.tunable.get_max_tuning_iterations()
4642*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_max_tuning_iterations(10)
4643*da0073e9SAndroid Build Coastguard Worker            with CudaMemoryLeakCheck(self):
4644*da0073e9SAndroid Build Coastguard Worker                out = torch.bmm(i1, i2)
4645*da0073e9SAndroid Build Coastguard Worker                torch.cuda.tunable.set_max_tuning_iterations(iterations)
4646*da0073e9SAndroid Build Coastguard Worker                torch.cuda.tunable.enable(False)
4647*da0073e9SAndroid Build Coastguard Worker                # clean up, remove any file that was generated
4648*da0073e9SAndroid Build Coastguard Worker                try:
4649*da0073e9SAndroid Build Coastguard Worker                    os.remove(filename)
4650*da0073e9SAndroid Build Coastguard Worker                except FileNotFoundError:
4651*da0073e9SAndroid Build Coastguard Worker                    pass
4652*da0073e9SAndroid Build Coastguard Worker        finally:
4653*da0073e9SAndroid Build Coastguard Worker            if prev_val is None:
4654*da0073e9SAndroid Build Coastguard Worker                del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4655*da0073e9SAndroid Build Coastguard Worker            else:
4656*da0073e9SAndroid Build Coastguard Worker                os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4657*da0073e9SAndroid Build Coastguard Worker
4658*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4659*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
4660*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4661*da0073e9SAndroid Build Coastguard Worker    def test_validator_tunableop_rocm(self, device, dtype):
4662*da0073e9SAndroid Build Coastguard Worker        # Test that the validator on ROCM has exactly 5 lines
4663*da0073e9SAndroid Build Coastguard Worker        # Format of the Validator is as follows:
4664*da0073e9SAndroid Build Coastguard Worker        # Validator,PT_VERSION,X.Y.Z.
4665*da0073e9SAndroid Build Coastguard Worker        # Validator,ROCBLAS_VERSION,X.Y,Z
4666*da0073e9SAndroid Build Coastguard Worker        # Validator,HIPBLASLT_VERSION,X,Y.Z
4667*da0073e9SAndroid Build Coastguard Worker        # Validator,ROCM_Version,X,Y.Z
4668*da0073e9SAndroid Build Coastguard Worker        # Validator,GCN_ARCH_NAME,<architecutre name>
4669*da0073e9SAndroid Build Coastguard Worker        validator_num_lines = 5
4670*da0073e9SAndroid Build Coastguard Worker
4671*da0073e9SAndroid Build Coastguard Worker        # Test in try-finally block to avoid leaking state
4672*da0073e9SAndroid Build Coastguard Worker        # if test is interrupted.
4673*da0073e9SAndroid Build Coastguard Worker        try:
4674*da0073e9SAndroid Build Coastguard Worker            set_tunableop_defaults()
4675*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable()
4676*da0073e9SAndroid Build Coastguard Worker            # set these to single iterations to keep it short but still exercise the code
4677*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_max_tuning_iterations(1)
4678*da0073e9SAndroid Build Coastguard Worker
4679*da0073e9SAndroid Build Coastguard Worker            N = M = K = 4
4680*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(N, K, device=device, dtype=dtype)
4681*da0073e9SAndroid Build Coastguard Worker            B = torch.randn(K, M, device=device, dtype=dtype)
4682*da0073e9SAndroid Build Coastguard Worker            C = torch.matmul(A, B)
4683*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
4684*da0073e9SAndroid Build Coastguard Worker        finally:
4685*da0073e9SAndroid Build Coastguard Worker            # disable TunableOp
4686*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable(False)
4687*da0073e9SAndroid Build Coastguard Worker
4688*da0073e9SAndroid Build Coastguard Worker            # clean up, remove any file that was generated
4689*da0073e9SAndroid Build Coastguard Worker            try:
4690*da0073e9SAndroid Build Coastguard Worker                import os
4691*da0073e9SAndroid Build Coastguard Worker                filename = torch.cuda.tunable.get_filename()
4692*da0073e9SAndroid Build Coastguard Worker                os.remove(filename)
4693*da0073e9SAndroid Build Coastguard Worker            except FileNotFoundError:
4694*da0073e9SAndroid Build Coastguard Worker                pass
4695*da0073e9SAndroid Build Coastguard Worker
4696*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4697*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
4698*da0073e9SAndroid Build Coastguard Worker    def test_minimum_tuning_iteration_tunableop(self, device, dtype):
4699*da0073e9SAndroid Build Coastguard Worker        # Make sure that there is at least one tuning iteration under various scenarios
4700*da0073e9SAndroid Build Coastguard Worker
4701*da0073e9SAndroid Build Coastguard Worker        # Test in try-finally block to avoid leaking state
4702*da0073e9SAndroid Build Coastguard Worker        # if test is interrupted.
4703*da0073e9SAndroid Build Coastguard Worker        try:
4704*da0073e9SAndroid Build Coastguard Worker            set_tunableop_defaults()
4705*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable()
4706*da0073e9SAndroid Build Coastguard Worker            # set these to single iterations to keep it short but still exercise the code
4707*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_max_tuning_iterations(1)
4708*da0073e9SAndroid Build Coastguard Worker
4709*da0073e9SAndroid Build Coastguard Worker            # Set tuning duration to zero milliseconds
4710*da0073e9SAndroid Build Coastguard Worker            # Tune a single GEMM and verify that we get a new tuning result
4711*da0073e9SAndroid Build Coastguard Worker            import os
4712*da0073e9SAndroid Build Coastguard Worker            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0"
4713*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4714*da0073e9SAndroid Build Coastguard Worker            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"  # reset to default
4715*da0073e9SAndroid Build Coastguard Worker
4716*da0073e9SAndroid Build Coastguard Worker            # Reference number of results
4717*da0073e9SAndroid Build Coastguard Worker            ref_num_results = len(torch.cuda.tunable.get_results())
4718*da0073e9SAndroid Build Coastguard Worker
4719*da0073e9SAndroid Build Coastguard Worker            N = M = K = 8
4720*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(N, K, device=device, dtype=dtype)
4721*da0073e9SAndroid Build Coastguard Worker            B = torch.randn(K, M, device=device, dtype=dtype)
4722*da0073e9SAndroid Build Coastguard Worker            C = torch.matmul(A, B)
4723*da0073e9SAndroid Build Coastguard Worker
4724*da0073e9SAndroid Build Coastguard Worker            # This stores total number of cummulative results
4725*da0073e9SAndroid Build Coastguard Worker            total_num_results = len(torch.cuda.tunable.get_results())
4726*da0073e9SAndroid Build Coastguard Worker
4727*da0073e9SAndroid Build Coastguard Worker            # There must be a new tuning result
4728*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((total_num_results - ref_num_results), 1)
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Worker            # Set tuning iterations to zero
4731*da0073e9SAndroid Build Coastguard Worker            # Tune a single GEMM and verify that we get a new tuning result
4732*da0073e9SAndroid Build Coastguard Worker            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0"
4733*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4734*da0073e9SAndroid Build Coastguard Worker            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100"  # reset to default
4735*da0073e9SAndroid Build Coastguard Worker
4736*da0073e9SAndroid Build Coastguard Worker            # Reference number of results
4737*da0073e9SAndroid Build Coastguard Worker            ref_num_results = total_num_results
4738*da0073e9SAndroid Build Coastguard Worker
4739*da0073e9SAndroid Build Coastguard Worker            N = M = K = 16
4740*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(N, K, device=device, dtype=dtype)
4741*da0073e9SAndroid Build Coastguard Worker            B = torch.randn(K, M, device=device, dtype=dtype)
4742*da0073e9SAndroid Build Coastguard Worker            C = torch.matmul(A, B)
4743*da0073e9SAndroid Build Coastguard Worker
4744*da0073e9SAndroid Build Coastguard Worker            # This stores total number of cummulative results
4745*da0073e9SAndroid Build Coastguard Worker            total_num_results = len(torch.cuda.tunable.get_results())
4746*da0073e9SAndroid Build Coastguard Worker
4747*da0073e9SAndroid Build Coastguard Worker            # There must be a new tuning result
4748*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((total_num_results - ref_num_results), 1)
4749*da0073e9SAndroid Build Coastguard Worker
4750*da0073e9SAndroid Build Coastguard Worker        finally:
4751*da0073e9SAndroid Build Coastguard Worker            # disable TunableOp
4752*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable(False)
4753*da0073e9SAndroid Build Coastguard Worker
4754*da0073e9SAndroid Build Coastguard Worker            # clean up, remove any file that was generated
4755*da0073e9SAndroid Build Coastguard Worker            try:
4756*da0073e9SAndroid Build Coastguard Worker                import os
4757*da0073e9SAndroid Build Coastguard Worker                filename = torch.cuda.tunable.get_filename()
4758*da0073e9SAndroid Build Coastguard Worker                os.remove(filename)
4759*da0073e9SAndroid Build Coastguard Worker            except FileNotFoundError:
4760*da0073e9SAndroid Build Coastguard Worker                pass
4761*da0073e9SAndroid Build Coastguard Worker
4762*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4763*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
4764*da0073e9SAndroid Build Coastguard Worker    def test_matmul_check_entries_tunableop(self, device, dtype):
4765*da0073e9SAndroid Build Coastguard Worker        # Tune a couple of matrix multiplies
4766*da0073e9SAndroid Build Coastguard Worker        # Verify we get the correct number of results
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker        try:
4769*da0073e9SAndroid Build Coastguard Worker            set_tunableop_defaults()
4770*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable()
4771*da0073e9SAndroid Build Coastguard Worker            # set these to single iterations to keep it short but still exercise the code
4772*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_max_tuning_iterations(1)
4773*da0073e9SAndroid Build Coastguard Worker
4774*da0073e9SAndroid Build Coastguard Worker            # Reference number of results
4775*da0073e9SAndroid Build Coastguard Worker            ref_num_results = len(torch.cuda.tunable.get_results())
4776*da0073e9SAndroid Build Coastguard Worker
4777*da0073e9SAndroid Build Coastguard Worker            # Execute matrix multiplies. We intentionally throw in M list the same index
4778*da0073e9SAndroid Build Coastguard Worker            # twice. The CSV file should only get unique GEMMs
4779*da0073e9SAndroid Build Coastguard Worker            count_matmul = 4
4780*da0073e9SAndroid Build Coastguard Worker            K = 64
4781*da0073e9SAndroid Build Coastguard Worker            for M in [32, 64, 32]:
4782*da0073e9SAndroid Build Coastguard Worker                for N in [32, 64]:
4783*da0073e9SAndroid Build Coastguard Worker                    A = torch.randn(N, K, device=device, dtype=dtype)
4784*da0073e9SAndroid Build Coastguard Worker                    B = torch.randn(K, M, device=device, dtype=dtype)
4785*da0073e9SAndroid Build Coastguard Worker                    C = torch.matmul(A, B)
4786*da0073e9SAndroid Build Coastguard Worker
4787*da0073e9SAndroid Build Coastguard Worker            # This stores total number of cummulative results
4788*da0073e9SAndroid Build Coastguard Worker            total_num_results = len(torch.cuda.tunable.get_results())
4789*da0073e9SAndroid Build Coastguard Worker
4790*da0073e9SAndroid Build Coastguard Worker            # Take the difference to calculate the number of results from
4791*da0073e9SAndroid Build Coastguard Worker            # the this test and verify that it agrees with the number of
4792*da0073e9SAndroid Build Coastguard Worker            # GEMMs.
4793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((total_num_results - ref_num_results), count_matmul)
4794*da0073e9SAndroid Build Coastguard Worker
4795*da0073e9SAndroid Build Coastguard Worker        finally:
4796*da0073e9SAndroid Build Coastguard Worker            # disable TunableOp
4797*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable(False)
4798*da0073e9SAndroid Build Coastguard Worker
4799*da0073e9SAndroid Build Coastguard Worker            # clean up, remove any file that was generated
4800*da0073e9SAndroid Build Coastguard Worker            try:
4801*da0073e9SAndroid Build Coastguard Worker                import os
4802*da0073e9SAndroid Build Coastguard Worker                filename = torch.cuda.tunable.get_filename()
4803*da0073e9SAndroid Build Coastguard Worker                os.remove(filename)
4804*da0073e9SAndroid Build Coastguard Worker            except FileNotFoundError:
4805*da0073e9SAndroid Build Coastguard Worker                pass
4806*da0073e9SAndroid Build Coastguard Worker
4807*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4808*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
4809*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4810*da0073e9SAndroid Build Coastguard Worker    def test_bmm_tunableop_rocm(self, device, dtype):
4811*da0073e9SAndroid Build Coastguard Worker        # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
4812*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(True)
4813*da0073e9SAndroid Build Coastguard Worker        ordinal = torch.cuda.current_device()
4814*da0073e9SAndroid Build Coastguard Worker        filename = f"tunableop_results{ordinal}.csv"
4815*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_filename(filename)
4816*da0073e9SAndroid Build Coastguard Worker        iterations = torch.cuda.tunable.get_max_tuning_iterations()
4817*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(10)
4818*da0073e9SAndroid Build Coastguard Worker        # the following 3 cases cover all previous failure cases and are here to catch regressions
4819*da0073e9SAndroid Build Coastguard Worker        B = 16
4820*da0073e9SAndroid Build Coastguard Worker        N = M = K = 256
4821*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
4822*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
4823*da0073e9SAndroid Build Coastguard Worker        # case 1
4824*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4825*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4826*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4827*da0073e9SAndroid Build Coastguard Worker        # case 2
4828*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4829*da0073e9SAndroid Build Coastguard Worker        i1 = torch.permute(i1, (1, 2, 0))
4830*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4831*da0073e9SAndroid Build Coastguard Worker        i2 = torch.permute(i2, (1, 0, 2))
4832*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4833*da0073e9SAndroid Build Coastguard Worker        # case 3
4834*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4835*da0073e9SAndroid Build Coastguard Worker        i1 = torch.permute(i1, (1, 0, 2))
4836*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4837*da0073e9SAndroid Build Coastguard Worker        i2 = torch.permute(i2, (1, 2, 0))
4838*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4839*da0073e9SAndroid Build Coastguard Worker        # clean up, remove any file that was generated
4840*da0073e9SAndroid Build Coastguard Worker        try:
4841*da0073e9SAndroid Build Coastguard Worker            import os
4842*da0073e9SAndroid Build Coastguard Worker            os.remove(filename)
4843*da0073e9SAndroid Build Coastguard Worker        except FileNotFoundError:
4844*da0073e9SAndroid Build Coastguard Worker            pass
4845*da0073e9SAndroid Build Coastguard Worker        # reset back to prior settings
4846*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(iterations)
4847*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(False)
4848*da0073e9SAndroid Build Coastguard Worker
4849*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4850*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
4851*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
4852*da0073e9SAndroid Build Coastguard Worker    def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4853*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4854*da0073e9SAndroid Build Coastguard Worker        import os
4855*da0073e9SAndroid Build Coastguard Worker        # run operator first without tuning to ensure all rocm libs are loaded,
4856*da0073e9SAndroid Build Coastguard Worker        # otherwise false positive mem leak
4857*da0073e9SAndroid Build Coastguard Worker        B = 16
4858*da0073e9SAndroid Build Coastguard Worker        N = M = K = 256
4859*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
4860*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
4861*da0073e9SAndroid Build Coastguard Worker        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4862*da0073e9SAndroid Build Coastguard Worker        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4863*da0073e9SAndroid Build Coastguard Worker        out = torch.bmm(i1, i2)
4864*da0073e9SAndroid Build Coastguard Worker        # enable tunableop numeric check via env variable.
4865*da0073e9SAndroid Build Coastguard Worker        PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4866*da0073e9SAndroid Build Coastguard Worker        prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4867*da0073e9SAndroid Build Coastguard Worker        try:
4868*da0073e9SAndroid Build Coastguard Worker            os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4869*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.enable(True)
4870*da0073e9SAndroid Build Coastguard Worker            ordinal = torch.cuda.current_device()
4871*da0073e9SAndroid Build Coastguard Worker            filename = f"tunableop_results{ordinal}.csv"
4872*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_filename(filename)
4873*da0073e9SAndroid Build Coastguard Worker            iterations = torch.cuda.tunable.get_max_tuning_iterations()
4874*da0073e9SAndroid Build Coastguard Worker            torch.cuda.tunable.set_max_tuning_iterations(10)
4875*da0073e9SAndroid Build Coastguard Worker            with CudaMemoryLeakCheck(self):
4876*da0073e9SAndroid Build Coastguard Worker                out = torch.bmm(i1, i2)
4877*da0073e9SAndroid Build Coastguard Worker                torch.cuda.tunable.set_max_tuning_iterations(iterations)
4878*da0073e9SAndroid Build Coastguard Worker                torch.cuda.tunable.enable(False)
4879*da0073e9SAndroid Build Coastguard Worker                # clean up, remove any file that was generated
4880*da0073e9SAndroid Build Coastguard Worker                try:
4881*da0073e9SAndroid Build Coastguard Worker                    os.remove(filename)
4882*da0073e9SAndroid Build Coastguard Worker                except FileNotFoundError:
4883*da0073e9SAndroid Build Coastguard Worker                    pass
4884*da0073e9SAndroid Build Coastguard Worker        finally:
4885*da0073e9SAndroid Build Coastguard Worker            if prev_val is None:
4886*da0073e9SAndroid Build Coastguard Worker                del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4887*da0073e9SAndroid Build Coastguard Worker            else:
4888*da0073e9SAndroid Build Coastguard Worker                os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4889*da0073e9SAndroid Build Coastguard Worker
4890*da0073e9SAndroid Build Coastguard Worker
4891*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.complex64)
4892*da0073e9SAndroid Build Coastguard Worker    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
4893*da0073e9SAndroid Build Coastguard Worker        a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
4894*da0073e9SAndroid Build Coastguard Worker        b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2)
4895*da0073e9SAndroid Build Coastguard Worker        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
4896*da0073e9SAndroid Build Coastguard Worker
4897*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a.detach(), b.detach(), out=c)
4898*da0073e9SAndroid Build Coastguard Worker
4899*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"):
4900*da0073e9SAndroid Build Coastguard Worker            torch.matmul(a, b, out=c)
4901*da0073e9SAndroid Build Coastguard Worker
4902*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4903*da0073e9SAndroid Build Coastguard Worker            torch.matmul(a, b, out=c)
4904*da0073e9SAndroid Build Coastguard Worker
4905*da0073e9SAndroid Build Coastguard Worker    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4906*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('16GB', device='cuda')
4907*da0073e9SAndroid Build Coastguard Worker    def test_large_bmm_mm_backward(self, device):
4908*da0073e9SAndroid Build Coastguard Worker        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4909*da0073e9SAndroid Build Coastguard Worker        B = torch.randn([1024, 65536], device="cuda", requires_grad=True)
4910*da0073e9SAndroid Build Coastguard Worker        G = torch.randn([1024, 2, 65536], device="cuda")
4911*da0073e9SAndroid Build Coastguard Worker
4912*da0073e9SAndroid Build Coastguard Worker        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4913*da0073e9SAndroid Build Coastguard Worker        (A @ B).backward(G)
4914*da0073e9SAndroid Build Coastguard Worker
4915*da0073e9SAndroid Build Coastguard Worker    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4916*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('16GB', device='cuda')
4917*da0073e9SAndroid Build Coastguard Worker    def test_large_bmm_backward(self, device):
4918*da0073e9SAndroid Build Coastguard Worker        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4919*da0073e9SAndroid Build Coastguard Worker        B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True)
4920*da0073e9SAndroid Build Coastguard Worker        G = torch.randn([1024, 2, 65536], device="cuda")
4921*da0073e9SAndroid Build Coastguard Worker
4922*da0073e9SAndroid Build Coastguard Worker        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4923*da0073e9SAndroid Build Coastguard Worker        (A @ B).backward(G)
4924*da0073e9SAndroid Build Coastguard Worker
4925*da0073e9SAndroid Build Coastguard Worker    def test_linear_algebra_scalar_raises(self, device) -> None:
4926*da0073e9SAndroid Build Coastguard Worker        m = torch.randn(5, 5, device=device)
4927*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, device=device)
4928*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(7, device=device)
4929*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.mv(m, s))
4930*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s))
4931*da0073e9SAndroid Build Coastguard Worker
4932*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
4933*da0073e9SAndroid Build Coastguard Worker    def test_cross(self, device, dtype):
4934*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4935*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4936*da0073e9SAndroid Build Coastguard Worker        res1 = torch.cross(x, y)
4937*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
4938*da0073e9SAndroid Build Coastguard Worker        torch.cross(x, y, out=res2)
4939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
4940*da0073e9SAndroid Build Coastguard Worker
4941*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
4942*da0073e9SAndroid Build Coastguard Worker    def test_linalg_cross(self, device, dtype):
4943*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4944*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4945*da0073e9SAndroid Build Coastguard Worker        res1 = torch.linalg.cross(x, y, dim=1)
4946*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
4947*da0073e9SAndroid Build Coastguard Worker        torch.linalg.cross(x, y, dim=1, out=res2)
4948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
4949*da0073e9SAndroid Build Coastguard Worker
4950*da0073e9SAndroid Build Coastguard Worker        # test for broadcastable inputs
4951*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 3, 2, dtype=dtype, device=device)
4952*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(4, 3, 1, dtype=dtype, device=device)
4953*da0073e9SAndroid Build Coastguard Worker        res1 = torch.linalg.cross(x, y, dim=1)
4954*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
4955*da0073e9SAndroid Build Coastguard Worker        torch.linalg.cross(x, y, dim=1, out=res2)
4956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
4957*da0073e9SAndroid Build Coastguard Worker
4958*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
4959*da0073e9SAndroid Build Coastguard Worker    def test_cross_with_and_without_dim(self, device, dtype):
4960*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 3, dtype=dtype, device=device)
4961*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(100, 3, dtype=dtype, device=device)
4962*da0073e9SAndroid Build Coastguard Worker        res1 = torch.cross(x, y, dim=1)
4963*da0073e9SAndroid Build Coastguard Worker        res2 = torch.cross(x, y, dim=-1)
4964*da0073e9SAndroid Build Coastguard Worker        res3 = torch.cross(x, y)
4965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
4966*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res3)
4967*da0073e9SAndroid Build Coastguard Worker
4968*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
4969*da0073e9SAndroid Build Coastguard Worker    def test_linalg_cross_with_and_without_dim(self, device, dtype):
4970*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 3, dtype=dtype, device=device)
4971*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(100, 3, dtype=dtype, device=device)
4972*da0073e9SAndroid Build Coastguard Worker        res1 = torch.linalg.cross(x, y, dim=1)
4973*da0073e9SAndroid Build Coastguard Worker        res2 = torch.linalg.cross(x, y, dim=-1)
4974*da0073e9SAndroid Build Coastguard Worker        res3 = torch.linalg.cross(x, y)
4975*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
4976*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res3)
4977*da0073e9SAndroid Build Coastguard Worker
4978*da0073e9SAndroid Build Coastguard Worker    def test_renorm(self, device):
4979*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(20, 20, device=device)  # big enough to exercise vectorized path
4980*da0073e9SAndroid Build Coastguard Worker        res1 = torch.tensor((), device=device)
4981*da0073e9SAndroid Build Coastguard Worker
4982*da0073e9SAndroid Build Coastguard Worker        def renorm(matrix, value, dim, max_norm):
4983*da0073e9SAndroid Build Coastguard Worker            m1 = matrix.transpose(dim, 0).contiguous()
4984*da0073e9SAndroid Build Coastguard Worker            # collapse non-dim dimensions.
4985*da0073e9SAndroid Build Coastguard Worker            m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
4986*da0073e9SAndroid Build Coastguard Worker            norms = m2.norm(value, 1, True)
4987*da0073e9SAndroid Build Coastguard Worker            # clip
4988*da0073e9SAndroid Build Coastguard Worker            new_norms = norms.clone()
4989*da0073e9SAndroid Build Coastguard Worker            new_norms[torch.gt(norms, max_norm)] = max_norm
4990*da0073e9SAndroid Build Coastguard Worker            new_norms.div_(norms.add_(1e-7))
4991*da0073e9SAndroid Build Coastguard Worker            # renormalize
4992*da0073e9SAndroid Build Coastguard Worker            m1.mul_(new_norms.expand_as(m1))
4993*da0073e9SAndroid Build Coastguard Worker            return m1.transpose(dim, 0)
4994*da0073e9SAndroid Build Coastguard Worker
4995*da0073e9SAndroid Build Coastguard Worker        # note that the axis fed to torch.renorm is different (2~=1)
4996*da0073e9SAndroid Build Coastguard Worker        maxnorm = m1.norm(2, 1).mean()
4997*da0073e9SAndroid Build Coastguard Worker        m2 = renorm(m1, 2, 1, maxnorm)
4998*da0073e9SAndroid Build Coastguard Worker        m1.renorm_(2, 1, maxnorm)
4999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1, m2, atol=1e-5, rtol=0)
5000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0)
5001*da0073e9SAndroid Build Coastguard Worker
5002*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(3, 4, 5, device=device)
5003*da0073e9SAndroid Build Coastguard Worker        m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
5004*da0073e9SAndroid Build Coastguard Worker        maxnorm = m2.norm(2, 0).mean()
5005*da0073e9SAndroid Build Coastguard Worker        m2 = renorm(m2, 2, 1, maxnorm)
5006*da0073e9SAndroid Build Coastguard Worker        m1.renorm_(2, 1, maxnorm)
5007*da0073e9SAndroid Build Coastguard Worker        m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
5008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m3, m2)
5009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
5010*da0073e9SAndroid Build Coastguard Worker
5011*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5012*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
5013*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5014*da0073e9SAndroid Build Coastguard Worker    def test_ormqr(self, device, dtype):
5015*da0073e9SAndroid Build Coastguard Worker
5016*da0073e9SAndroid Build Coastguard Worker        def run_test(batch, m, n, fortran_contiguous):
5017*da0073e9SAndroid Build Coastguard Worker            A = make_tensor((*batch, m, n), dtype=dtype, device=device)
5018*da0073e9SAndroid Build Coastguard Worker            reflectors, tau = torch.geqrf(A)
5019*da0073e9SAndroid Build Coastguard Worker            if not fortran_contiguous:
5020*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(reflectors.mT.is_contiguous())
5021*da0073e9SAndroid Build Coastguard Worker                reflectors = reflectors.contiguous()
5022*da0073e9SAndroid Build Coastguard Worker
5023*da0073e9SAndroid Build Coastguard Worker            # Q is of size m x m
5024*da0073e9SAndroid Build Coastguard Worker            Q, _ = torch.linalg.qr(A, mode='complete')
5025*da0073e9SAndroid Build Coastguard Worker            C_right = make_tensor((*batch, m, n), dtype=dtype, device=device)
5026*da0073e9SAndroid Build Coastguard Worker            C_left = make_tensor((*batch, n, m), dtype=dtype, device=device)
5027*da0073e9SAndroid Build Coastguard Worker
5028*da0073e9SAndroid Build Coastguard Worker            expected = Q @ C_right
5029*da0073e9SAndroid Build Coastguard Worker            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False)
5030*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
5031*da0073e9SAndroid Build Coastguard Worker
5032*da0073e9SAndroid Build Coastguard Worker            expected = C_left @ Q
5033*da0073e9SAndroid Build Coastguard Worker            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False)
5034*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
5035*da0073e9SAndroid Build Coastguard Worker
5036*da0073e9SAndroid Build Coastguard Worker            expected = Q.mH @ C_right
5037*da0073e9SAndroid Build Coastguard Worker            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True)
5038*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
5039*da0073e9SAndroid Build Coastguard Worker
5040*da0073e9SAndroid Build Coastguard Worker            expected = C_left @ Q.mH
5041*da0073e9SAndroid Build Coastguard Worker            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True)
5042*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
5043*da0073e9SAndroid Build Coastguard Worker
5044*da0073e9SAndroid Build Coastguard Worker            # if tau is all zeros then the implicit matrix Q is the identity matrix
5045*da0073e9SAndroid Build Coastguard Worker            # so the actual result should be C_right in this case
5046*da0073e9SAndroid Build Coastguard Worker            zero_tau = torch.zeros_like(tau)
5047*da0073e9SAndroid Build Coastguard Worker            actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False)
5048*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(C_right, actual)
5049*da0073e9SAndroid Build Coastguard Worker
5050*da0073e9SAndroid Build Coastguard Worker        batches = [(), (0, ), (2, ), (2, 1)]
5051*da0073e9SAndroid Build Coastguard Worker        ns = [5, 2, 0]
5052*da0073e9SAndroid Build Coastguard Worker        for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]):
5053*da0073e9SAndroid Build Coastguard Worker            run_test(batch, m, n, fortran_contiguous)
5054*da0073e9SAndroid Build Coastguard Worker
5055*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5056*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
5057*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5058*da0073e9SAndroid Build Coastguard Worker    def test_ormqr_errors_and_warnings(self, device, dtype):
5059*da0073e9SAndroid Build Coastguard Worker        test_cases = [
5060*da0073e9SAndroid Build Coastguard Worker            # input1 size, input2 size, input3 size, error regex
5061*da0073e9SAndroid Build Coastguard Worker            ((10,), (2,), (2,), r"input must have at least 2 dimensions"),
5062*da0073e9SAndroid Build Coastguard Worker            ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"),
5063*da0073e9SAndroid Build Coastguard Worker            ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"),
5064*da0073e9SAndroid Build Coastguard Worker            ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"),
5065*da0073e9SAndroid Build Coastguard Worker            ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"),
5066*da0073e9SAndroid Build Coastguard Worker            ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"),
5067*da0073e9SAndroid Build Coastguard Worker        ]
5068*da0073e9SAndroid Build Coastguard Worker        for a_size, tau_size, c_size, error_regex in test_cases:
5069*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(a_size, dtype=dtype, device=device)
5070*da0073e9SAndroid Build Coastguard Worker            tau = make_tensor(tau_size, dtype=dtype, device=device)
5071*da0073e9SAndroid Build Coastguard Worker            c = make_tensor(c_size, dtype=dtype, device=device)
5072*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_regex):
5073*da0073e9SAndroid Build Coastguard Worker                torch.ormqr(a, tau, c)
5074*da0073e9SAndroid Build Coastguard Worker
5075*da0073e9SAndroid Build Coastguard Worker    def test_blas_empty(self, device):
5076*da0073e9SAndroid Build Coastguard Worker        def fn(torchfn, *args, test_out=False, **kwargs):
5077*da0073e9SAndroid Build Coastguard Worker            def call_torch_fn(*args, **kwargs):
5078*da0073e9SAndroid Build Coastguard Worker                return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
5079*da0073e9SAndroid Build Coastguard Worker                                      for shape in args), **kwargs)
5080*da0073e9SAndroid Build Coastguard Worker            result = call_torch_fn(*args, **kwargs)
5081*da0073e9SAndroid Build Coastguard Worker            if not test_out:
5082*da0073e9SAndroid Build Coastguard Worker                return result
5083*da0073e9SAndroid Build Coastguard Worker            else:
5084*da0073e9SAndroid Build Coastguard Worker                out = torch.full_like(result, math.nan)
5085*da0073e9SAndroid Build Coastguard Worker                out1 = call_torch_fn(*args, **kwargs, out=out)
5086*da0073e9SAndroid Build Coastguard Worker                return out
5087*da0073e9SAndroid Build Coastguard Worker
5088*da0073e9SAndroid Build Coastguard Worker        # mm, addmm
5089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
5090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
5091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
5092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
5093*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
5094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True))
5095*da0073e9SAndroid Build Coastguard Worker
5096*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
5097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape)
5098*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((5, 6), device=device)
5099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
5100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
5101*da0073e9SAndroid Build Coastguard Worker
5102*da0073e9SAndroid Build Coastguard Worker        # mv, addmv
5103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
5104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
5105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
5106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True))
5107*da0073e9SAndroid Build Coastguard Worker
5108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
5109*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((3,), device=device)
5110*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
5111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
5112*da0073e9SAndroid Build Coastguard Worker
5113*da0073e9SAndroid Build Coastguard Worker        # bmm, baddbmm
5114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
5115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
5116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
5117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)))
5118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True))
5119*da0073e9SAndroid Build Coastguard Worker
5120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape)
5121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape)
5122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape)
5123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape)
5124*da0073e9SAndroid Build Coastguard Worker        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
5125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2))  # Issue #33467
5126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True))  # Issue #33467
5127*da0073e9SAndroid Build Coastguard Worker
5128*da0073e9SAndroid Build Coastguard Worker        # addbmm
5129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
5130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
5131*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((5, 6), device=device)
5132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
5133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
5134*da0073e9SAndroid Build Coastguard Worker
5135*da0073e9SAndroid Build Coastguard Worker        # matmul
5136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,)))
5137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True))
5138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
5139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
5140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
5141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)))
5142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True))
5143*da0073e9SAndroid Build Coastguard Worker
5144*da0073e9SAndroid Build Coastguard Worker        # dot
5145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
5146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))
5147*da0073e9SAndroid Build Coastguard Worker
5148*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
5149*da0073e9SAndroid Build Coastguard Worker                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5150*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types_and(
5151*da0073e9SAndroid Build Coastguard Worker                  torch.half,
5152*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16] if SM53OrLater else []
5153*da0073e9SAndroid Build Coastguard Worker                  ))
5154*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bfloat16))
5155*da0073e9SAndroid Build Coastguard Worker    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
5156*da0073e9SAndroid Build Coastguard Worker        # common case
5157*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(128, device=device).to(dtype)
5158*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(2048, 2400, device=device).to(dtype)
5159*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(128, 2400, device=device).to(dtype)
5160*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.linear(m1, m2, M)
5161*da0073e9SAndroid Build Coastguard Worker        # Ntrans_B has ld >> rows
5162*da0073e9SAndroid Build Coastguard Worker        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
5163*da0073e9SAndroid Build Coastguard Worker        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
5164*da0073e9SAndroid Build Coastguard Worker        M = torch.rand([128]).to(dtype).to(device)
5165*da0073e9SAndroid Build Coastguard Worker        torch.addmm(M, m2.t(), m1)
5166*da0073e9SAndroid Build Coastguard Worker        # trans_A has ld >> rows
5167*da0073e9SAndroid Build Coastguard Worker        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
5168*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(2048, 2400, device=device).to(dtype)
5169*da0073e9SAndroid Build Coastguard Worker        M = torch.rand([128]).to(dtype).to(device)
5170*da0073e9SAndroid Build Coastguard Worker        torch.addmm(M, m2, m1)
5171*da0073e9SAndroid Build Coastguard Worker        # large tensor dim > 65535
5172*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(16, device=device).to(dtype)
5173*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(32, 131071 , device=device).to(dtype)
5174*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(16, 131071, device=device).to(dtype)
5175*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.linear(m1, m2, M)
5176*da0073e9SAndroid Build Coastguard Worker
5177*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5178*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
5179*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.bfloat16, torch.half))
5180*da0073e9SAndroid Build Coastguard Worker    def test_hipblaslt_corner_cases_rocm(self, device, dtype):
5181*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.double:
5182*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("hipblasLt doesn't support doubles yet")
5183*da0073e9SAndroid Build Coastguard Worker
5184*da0073e9SAndroid Build Coastguard Worker        # enable hipblaslt path via env variable.
5185*da0073e9SAndroid Build Coastguard Worker        import os
5186*da0073e9SAndroid Build Coastguard Worker        DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT"
5187*da0073e9SAndroid Build Coastguard Worker        prev_val = os.getenv(DISABLE_ADDMM_HIP_LT)
5188*da0073e9SAndroid Build Coastguard Worker        try:
5189*da0073e9SAndroid Build Coastguard Worker            os.environ[DISABLE_ADDMM_HIP_LT] = "0"
5190*da0073e9SAndroid Build Coastguard Worker            # common case
5191*da0073e9SAndroid Build Coastguard Worker            M = torch.randn(128, device=device, dtype=dtype)
5192*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5193*da0073e9SAndroid Build Coastguard Worker            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5194*da0073e9SAndroid Build Coastguard Worker            out1 = torch.nn.functional.linear(m1, m2, M)
5195*da0073e9SAndroid Build Coastguard Worker            M_cpu = M.to('cpu')
5196*da0073e9SAndroid Build Coastguard Worker            m1_cpu = m1.to('cpu')
5197*da0073e9SAndroid Build Coastguard Worker            m2_cpu = m2.to('cpu')
5198*da0073e9SAndroid Build Coastguard Worker            out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu)
5199*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2))
5200*da0073e9SAndroid Build Coastguard Worker
5201*da0073e9SAndroid Build Coastguard Worker            # common case without bias
5202*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5203*da0073e9SAndroid Build Coastguard Worker            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5204*da0073e9SAndroid Build Coastguard Worker            out2 = torch.nn.functional.linear(m1, m2, bias=None)
5205*da0073e9SAndroid Build Coastguard Worker            m1_cpu = m1.to('cpu')
5206*da0073e9SAndroid Build Coastguard Worker            m2_cpu = m2.to('cpu')
5207*da0073e9SAndroid Build Coastguard Worker            out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None)
5208*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2))
5209*da0073e9SAndroid Build Coastguard Worker        finally:
5210*da0073e9SAndroid Build Coastguard Worker            if prev_val is None:
5211*da0073e9SAndroid Build Coastguard Worker                del os.environ[DISABLE_ADDMM_HIP_LT]
5212*da0073e9SAndroid Build Coastguard Worker            else:
5213*da0073e9SAndroid Build Coastguard Worker                os.environ[DISABLE_ADDMM_HIP_LT] = prev_val
5214*da0073e9SAndroid Build Coastguard Worker
5215*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types_and(
5216*da0073e9SAndroid Build Coastguard Worker                  torch.half,
5217*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16] if SM53OrLater else []
5218*da0073e9SAndroid Build Coastguard Worker                  ))
5219*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half))
5220*da0073e9SAndroid Build Coastguard Worker    def test_blas_alpha_beta_empty(self, device, dtype):
5221*da0073e9SAndroid Build Coastguard Worker        # This test is disabled on CUDA 9 due to:
5222*da0073e9SAndroid Build Coastguard Worker        # See: https://github.com/pytorch/pytorch/issues/31006
5223*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bfloat16 and self.device_type == 'xla':
5224*da0073e9SAndroid Build Coastguard Worker            # TODO (@zasdfgbnm): this causes the following error on test
5225*da0073e9SAndroid Build Coastguard Worker            # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16:
5226*da0073e9SAndroid Build Coastguard Worker            #
5227*da0073e9SAndroid Build Coastguard Worker            #   RuntimeError: _th_equal not supported on CPUType for BFloat16
5228*da0073e9SAndroid Build Coastguard Worker            return
5229*da0073e9SAndroid Build Coastguard Worker        # ensure beta is respected
5230*da0073e9SAndroid Build Coastguard Worker        value = 11
5231*da0073e9SAndroid Build Coastguard Worker        input = torch.full((2,), value, dtype=dtype, device=device)
5232*da0073e9SAndroid Build Coastguard Worker        mat = torch.ones((2, 0), dtype=dtype, device=device)
5233*da0073e9SAndroid Build Coastguard Worker        vec = torch.ones((0,), dtype=dtype, device=device)
5234*da0073e9SAndroid Build Coastguard Worker        out = torch.empty((2,), dtype=dtype, device=device)
5235*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
5236*da0073e9SAndroid Build Coastguard Worker            alpha = 6 + 7j
5237*da0073e9SAndroid Build Coastguard Worker            beta = 3 + 4j
5238*da0073e9SAndroid Build Coastguard Worker        else:
5239*da0073e9SAndroid Build Coastguard Worker            alpha = 6
5240*da0073e9SAndroid Build Coastguard Worker            beta = 3
5241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5242*da0073e9SAndroid Build Coastguard Worker                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta))
5243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5244*da0073e9SAndroid Build Coastguard Worker                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out))
5245*da0073e9SAndroid Build Coastguard Worker
5246*da0073e9SAndroid Build Coastguard Worker        # torch.addmm
5247*da0073e9SAndroid Build Coastguard Worker        input = torch.full((2, 3), value, dtype=dtype, device=device)
5248*da0073e9SAndroid Build Coastguard Worker        mat2 = torch.ones((0, 3), dtype=dtype, device=device)
5249*da0073e9SAndroid Build Coastguard Worker        out = torch.empty((2, 3), dtype=dtype, device=device)
5250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5251*da0073e9SAndroid Build Coastguard Worker                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta))
5252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5253*da0073e9SAndroid Build Coastguard Worker                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
5254*da0073e9SAndroid Build Coastguard Worker
5255*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
5256*da0073e9SAndroid Build Coastguard Worker    def test_blas_nan_out(self, device, dtype):
5257*da0073e9SAndroid Build Coastguard Worker        # These functions should work correctly with NaN filled outputs,
5258*da0073e9SAndroid Build Coastguard Worker        # but need special handling, see [NOTE: cpu_zero]
5259*da0073e9SAndroid Build Coastguard Worker        b = 3
5260*da0073e9SAndroid Build Coastguard Worker        n = 5
5261*da0073e9SAndroid Build Coastguard Worker        m = 7
5262*da0073e9SAndroid Build Coastguard Worker        p = 11
5263*da0073e9SAndroid Build Coastguard Worker
5264*da0073e9SAndroid Build Coastguard Worker        # torch.mv
5265*da0073e9SAndroid Build Coastguard Worker        nm = torch.randn((m, n), device=device).t()
5266*da0073e9SAndroid Build Coastguard Worker        _m = torch.randn((), device=device).expand(m)
5267*da0073e9SAndroid Build Coastguard Worker        _m_out = torch.full((m,), float('nan'), device=device)
5268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum())
5270*da0073e9SAndroid Build Coastguard Worker
5271*da0073e9SAndroid Build Coastguard Worker        # torch.mm
5272*da0073e9SAndroid Build Coastguard Worker        mp = torch.randn((p, m), device=device).t()
5273*da0073e9SAndroid Build Coastguard Worker        np_out = torch.full((n, p), float('nan'), device=device)
5274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out))
5275*da0073e9SAndroid Build Coastguard Worker
5276*da0073e9SAndroid Build Coastguard Worker        # torch.bmm
5277*da0073e9SAndroid Build Coastguard Worker        bnm = torch.randn((b, m, n), device=device).transpose(1, 2)
5278*da0073e9SAndroid Build Coastguard Worker        bmp = torch.randn((b, p, m), device=device).transpose(1, 2)
5279*da0073e9SAndroid Build Coastguard Worker        bnp_out = torch.full((b, n, p), float('nan'), device=device)
5280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out))
5281*da0073e9SAndroid Build Coastguard Worker
5282*da0073e9SAndroid Build Coastguard Worker    @onlyCPU  # not supported by CUBLAS
5283*da0073e9SAndroid Build Coastguard Worker    def test_blas_mv_large_input(self, device):
5284*da0073e9SAndroid Build Coastguard Worker        # This would previously fail if the allocated output had NaNs, see:
5285*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero]
5286*da0073e9SAndroid Build Coastguard Worker        n = 3000
5287*da0073e9SAndroid Build Coastguard Worker        m = 200
5288*da0073e9SAndroid Build Coastguard Worker
5289*da0073e9SAndroid Build Coastguard Worker        nm = torch.randn((m, n), device=device).t()
5290*da0073e9SAndroid Build Coastguard Worker        _m = torch.randn((), device=device).expand(m)
5291*da0073e9SAndroid Build Coastguard Worker        _m_out = torch.full((m,), 0., device=device)
5292*da0073e9SAndroid Build Coastguard Worker
5293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5294*da0073e9SAndroid Build Coastguard Worker
5295*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5296*da0073e9SAndroid Build Coastguard Worker    def test_renorm_ps(self, device):
5297*da0073e9SAndroid Build Coastguard Worker        # full reduction
5298*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
5299*da0073e9SAndroid Build Coastguard Worker        xn = x.numpy()
5300*da0073e9SAndroid Build Coastguard Worker        for p in [1, 2, 3, 4, inf]:
5301*da0073e9SAndroid Build Coastguard Worker            res = x.renorm(p, 1, 1)
5302*da0073e9SAndroid Build Coastguard Worker            expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
5303*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm")
5304*da0073e9SAndroid Build Coastguard Worker
5305*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5306*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
5307*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5308*da0073e9SAndroid Build Coastguard Worker    def test_householder_product(self, device, dtype):
5309*da0073e9SAndroid Build Coastguard Worker        def generate_reflectors_and_tau(A):
5310*da0073e9SAndroid Build Coastguard Worker            """
5311*da0073e9SAndroid Build Coastguard Worker            This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf.
5312*da0073e9SAndroid Build Coastguard Worker            There is torch.geqrf function but it doesn't work with complex-valued input.
5313*da0073e9SAndroid Build Coastguard Worker            """
5314*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0:
5315*da0073e9SAndroid Build Coastguard Worker                A_cpu = A.cpu()
5316*da0073e9SAndroid Build Coastguard Worker                flattened_batch_shape = [-1, *A_cpu.shape[-2:]]
5317*da0073e9SAndroid Build Coastguard Worker                reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape)
5318*da0073e9SAndroid Build Coastguard Worker                tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]]
5319*da0073e9SAndroid Build Coastguard Worker                tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1])
5320*da0073e9SAndroid Build Coastguard Worker                for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau):
5321*da0073e9SAndroid Build Coastguard Worker                    reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw'))
5322*da0073e9SAndroid Build Coastguard Worker                    reflectors_i[:] = reflectors_tmp.T
5323*da0073e9SAndroid Build Coastguard Worker                reflectors = reflectors.view(*A_cpu.shape)
5324*da0073e9SAndroid Build Coastguard Worker                tau = tau.view(tau_shape)
5325*da0073e9SAndroid Build Coastguard Worker                return reflectors.to(A.device), tau.to(A.device)
5326*da0073e9SAndroid Build Coastguard Worker
5327*da0073e9SAndroid Build Coastguard Worker            reflectors = torch.empty_like(A)
5328*da0073e9SAndroid Build Coastguard Worker            tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device)
5329*da0073e9SAndroid Build Coastguard Worker            return reflectors, tau
5330*da0073e9SAndroid Build Coastguard Worker
5331*da0073e9SAndroid Build Coastguard Worker        def run_test(shape):
5332*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(*shape, dtype=dtype, device=device)
5333*da0073e9SAndroid Build Coastguard Worker            reflectors, tau = generate_reflectors_and_tau(A)
5334*da0073e9SAndroid Build Coastguard Worker            expected, _ = torch.linalg.qr(A)
5335*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.householder_product(reflectors, tau)
5336*da0073e9SAndroid Build Coastguard Worker            # torch.linalg.qr does not work correctly for zero batch dimension tensors
5337*da0073e9SAndroid Build Coastguard Worker            # see https://github.com/pytorch/pytorch/issues/50576
5338*da0073e9SAndroid Build Coastguard Worker            if (A.numel() > 0):
5339*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
5340*da0073e9SAndroid Build Coastguard Worker            else:
5341*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(actual.shape == shape)
5342*da0073e9SAndroid Build Coastguard Worker
5343*da0073e9SAndroid Build Coastguard Worker            # if tau is empty and A is not the result should be a matrix with ones on the diagonal
5344*da0073e9SAndroid Build Coastguard Worker            if (A.numel() > 0):
5345*da0073e9SAndroid Build Coastguard Worker                tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device)
5346*da0073e9SAndroid Build Coastguard Worker                identity_mat = torch.zeros_like(reflectors)
5347*da0073e9SAndroid Build Coastguard Worker                identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1
5348*da0073e9SAndroid Build Coastguard Worker                actual = torch.linalg.householder_product(reflectors, tau_empty)
5349*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, identity_mat)
5350*da0073e9SAndroid Build Coastguard Worker
5351*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(A)
5352*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.householder_product(reflectors, tau, out=out)
5353*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
5354*da0073e9SAndroid Build Coastguard Worker            if (A.numel() > 0):
5355*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
5356*da0073e9SAndroid Build Coastguard Worker
5357*da0073e9SAndroid Build Coastguard Worker        shapes = [(0, 0), (5, 0),  # Empty matrix
5358*da0073e9SAndroid Build Coastguard Worker                  (5, 5), (5, 3),  # Single matrix
5359*da0073e9SAndroid Build Coastguard Worker                  (0, 0, 0), (0, 5, 5), (0, 5, 3),  # Zero batch dimension tensors
5360*da0073e9SAndroid Build Coastguard Worker                  (2, 5, 5), (2, 5, 3),  # 3-dim tensors
5361*da0073e9SAndroid Build Coastguard Worker                  (2, 1, 5, 5), (2, 1, 5, 3)]  # 4-dim tensors
5362*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
5363*da0073e9SAndroid Build Coastguard Worker            run_test(shape)
5364*da0073e9SAndroid Build Coastguard Worker
5365*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5366*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
5367*da0073e9SAndroid Build Coastguard Worker    def test_householder_product_errors_and_warnings(self, device):
5368*da0073e9SAndroid Build Coastguard Worker        test_cases = [
5369*da0073e9SAndroid Build Coastguard Worker            # input1 size, input2 size, error regex
5370*da0073e9SAndroid Build Coastguard Worker            ((10,), (2,), r"input must have at least 2 dimensions"),
5371*da0073e9SAndroid Build Coastguard Worker            ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"),
5372*da0073e9SAndroid Build Coastguard Worker            ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"),
5373*da0073e9SAndroid Build Coastguard Worker        ]
5374*da0073e9SAndroid Build Coastguard Worker        for a_size, tau_size, error_regex in test_cases:
5375*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(*a_size, device=device)
5376*da0073e9SAndroid Build Coastguard Worker            tau = torch.rand(*tau_size, device=device)
5377*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_regex):
5378*da0073e9SAndroid Build Coastguard Worker                torch.linalg.householder_product(a, tau)
5379*da0073e9SAndroid Build Coastguard Worker
5380*da0073e9SAndroid Build Coastguard Worker        # if out tensor with wrong shape is passed a warning is given
5381*da0073e9SAndroid Build Coastguard Worker        reflectors = torch.randn(3, 3, device=device)
5382*da0073e9SAndroid Build Coastguard Worker        tau = torch.randn(3, device=device)
5383*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(2, 3, device=device)
5384*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5385*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
5386*da0073e9SAndroid Build Coastguard Worker            torch.linalg.householder_product(reflectors, tau, out=out)
5387*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
5388*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
5389*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
5390*da0073e9SAndroid Build Coastguard Worker
5391*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
5392*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(reflectors).to(torch.int)
5393*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
5394*da0073e9SAndroid Build Coastguard Worker            torch.linalg.householder_product(reflectors, tau, out=out)
5395*da0073e9SAndroid Build Coastguard Worker
5396*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"):
5397*da0073e9SAndroid Build Coastguard Worker            torch.linalg.householder_product(reflectors, tau.to(torch.int))
5398*da0073e9SAndroid Build Coastguard Worker
5399*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
5400*da0073e9SAndroid Build Coastguard Worker            # device of out and input should match
5401*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5402*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(reflectors).to(wrong_device)
5403*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5404*da0073e9SAndroid Build Coastguard Worker                torch.linalg.householder_product(reflectors, tau, out=out)
5405*da0073e9SAndroid Build Coastguard Worker
5406*da0073e9SAndroid Build Coastguard Worker            # device of tau and input should match
5407*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5408*da0073e9SAndroid Build Coastguard Worker            tau = tau.to(wrong_device)
5409*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5410*da0073e9SAndroid Build Coastguard Worker                torch.linalg.householder_product(reflectors, tau)
5411*da0073e9SAndroid Build Coastguard Worker
5412*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5413*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
5414*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor")
5415*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5416*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5417*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lu_family(self, device, dtype):
5418*da0073e9SAndroid Build Coastguard Worker        # Tests torch.lu
5419*da0073e9SAndroid Build Coastguard Worker        #       torch.linalg.lu_factor
5420*da0073e9SAndroid Build Coastguard Worker        #       torch.linalg.lu_factor_ex
5421*da0073e9SAndroid Build Coastguard Worker        #       torch.lu_unpack
5422*da0073e9SAndroid Build Coastguard Worker        #       torch.linalg.lu_solve
5423*da0073e9SAndroid Build Coastguard Worker        #       torch.linalg.solve
5424*da0073e9SAndroid Build Coastguard Worker        make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
5425*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
5426*da0073e9SAndroid Build Coastguard Worker
5427*da0073e9SAndroid Build Coastguard Worker        def run_test(A, pivot, singular, fn):
5428*da0073e9SAndroid Build Coastguard Worker            k = min(A.shape[-2:])
5429*da0073e9SAndroid Build Coastguard Worker            batch = A.shape[:-2]
5430*da0073e9SAndroid Build Coastguard Worker            check_errors = (fn == torch.linalg.lu_factor)
5431*da0073e9SAndroid Build Coastguard Worker            if singular and check_errors:
5432*da0073e9SAndroid Build Coastguard Worker                # It may or may not throw as the LU decomposition without pivoting
5433*da0073e9SAndroid Build Coastguard Worker                # may still succeed for singular matrices
5434*da0073e9SAndroid Build Coastguard Worker                try:
5435*da0073e9SAndroid Build Coastguard Worker                    LU, pivots = fn(A, pivot=pivot)
5436*da0073e9SAndroid Build Coastguard Worker                except RuntimeError:
5437*da0073e9SAndroid Build Coastguard Worker                    return
5438*da0073e9SAndroid Build Coastguard Worker            else:
5439*da0073e9SAndroid Build Coastguard Worker                LU, pivots = fn(A, pivot=pivot)[:2]
5440*da0073e9SAndroid Build Coastguard Worker
5441*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(LU.size(), A.shape)
5442*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pivots.size(), batch + (k,))
5443*da0073e9SAndroid Build Coastguard Worker
5444*da0073e9SAndroid Build Coastguard Worker            if not pivot:
5445*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, )))
5446*da0073e9SAndroid Build Coastguard Worker
5447*da0073e9SAndroid Build Coastguard Worker            P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot)
5448*da0073e9SAndroid Build Coastguard Worker
5449*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(P @ L @ U if pivot else L @ U, A)
5450*da0073e9SAndroid Build Coastguard Worker
5451*da0073e9SAndroid Build Coastguard Worker            PLU = torch.linalg.lu(A, pivot=pivot)
5452*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(P, PLU.P)
5453*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(L, PLU.L)
5454*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(U, PLU.U)
5455*da0073e9SAndroid Build Coastguard Worker
5456*da0073e9SAndroid Build Coastguard Worker            if not singular and A.size(-2) == A.size(-1):
5457*da0073e9SAndroid Build Coastguard Worker                nrhs = ((), (1,), (3,))
5458*da0073e9SAndroid Build Coastguard Worker                for left, rhs in product((True, False), nrhs):
5459*da0073e9SAndroid Build Coastguard Worker                    # Vector case when left = False is not allowed
5460*da0073e9SAndroid Build Coastguard Worker                    if not left and rhs == ():
5461*da0073e9SAndroid Build Coastguard Worker                        continue
5462*da0073e9SAndroid Build Coastguard Worker                    if left:
5463*da0073e9SAndroid Build Coastguard Worker                        shape_B = A.shape[:-1] + rhs
5464*da0073e9SAndroid Build Coastguard Worker                    else:
5465*da0073e9SAndroid Build Coastguard Worker                        shape_B = A.shape[:-2] + rhs + A.shape[-1:]
5466*da0073e9SAndroid Build Coastguard Worker                    B = make_arg(shape_B)
5467*da0073e9SAndroid Build Coastguard Worker
5468*da0073e9SAndroid Build Coastguard Worker                    # Test linalg.lu_solve. It does not support vectors as rhs
5469*da0073e9SAndroid Build Coastguard Worker                    # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913
5470*da0073e9SAndroid Build Coastguard Worker                    if rhs != ():
5471*da0073e9SAndroid Build Coastguard Worker                        for adjoint in (True, False):
5472*da0073e9SAndroid Build Coastguard Worker                            X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
5473*da0073e9SAndroid Build Coastguard Worker                            A_adj = A.mH if adjoint else A
5474*da0073e9SAndroid Build Coastguard Worker                            if left:
5475*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(B, A_adj @ X)
5476*da0073e9SAndroid Build Coastguard Worker                            else:
5477*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(B, X @ A_adj)
5478*da0073e9SAndroid Build Coastguard Worker
5479*da0073e9SAndroid Build Coastguard Worker                    # Test linalg.solve
5480*da0073e9SAndroid Build Coastguard Worker                    X = torch.linalg.solve(A, B, left=left)
5481*da0073e9SAndroid Build Coastguard Worker                    X_ = X.unsqueeze(-1) if rhs == () else X
5482*da0073e9SAndroid Build Coastguard Worker                    B_ = B.unsqueeze(-1) if rhs == () else B
5483*da0073e9SAndroid Build Coastguard Worker                    if left:
5484*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(B_, A @ X_)
5485*da0073e9SAndroid Build Coastguard Worker                    else:
5486*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(B_, X_ @ A)
5487*da0073e9SAndroid Build Coastguard Worker
5488*da0073e9SAndroid Build Coastguard Worker
5489*da0073e9SAndroid Build Coastguard Worker        sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
5490*da0073e9SAndroid Build Coastguard Worker        batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
5491*da0073e9SAndroid Build Coastguard Worker        # Non pivoting just implemented for CUDA
5492*da0073e9SAndroid Build Coastguard Worker        pivots = (True, False) if self.device_type == "cuda" else (True,)
5493*da0073e9SAndroid Build Coastguard Worker        fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
5494*da0073e9SAndroid Build Coastguard Worker        for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
5495*da0073e9SAndroid Build Coastguard Worker            shape = batch + ms
5496*da0073e9SAndroid Build Coastguard Worker            A = make_arg(shape) if singular else make_arg_full(*shape)
5497*da0073e9SAndroid Build Coastguard Worker            # Just do one of them on singular matrices
5498*da0073e9SAndroid Build Coastguard Worker            if A.numel() == 0 and not singular:
5499*da0073e9SAndroid Build Coastguard Worker                continue
5500*da0073e9SAndroid Build Coastguard Worker            run_test(A, pivot, singular, fn)
5501*da0073e9SAndroid Build Coastguard Worker
5502*da0073e9SAndroid Build Coastguard Worker            # Reproducer of a magma bug,
5503*da0073e9SAndroid Build Coastguard Worker            # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on
5504*da0073e9SAndroid Build Coastguard Worker            # This is also a bug in cuSOLVER < 11.3
5505*da0073e9SAndroid Build Coastguard Worker            if (dtype == torch.double
5506*da0073e9SAndroid Build Coastguard Worker               and singular):
5507*da0073e9SAndroid Build Coastguard Worker                A = torch.ones(batch + ms, dtype=dtype, device=device)
5508*da0073e9SAndroid Build Coastguard Worker                run_test(A, pivot, singular, fn)
5509*da0073e9SAndroid Build Coastguard Worker
5510*da0073e9SAndroid Build Coastguard Worker        # Info should be positive for rank deficient matrices
5511*da0073e9SAndroid Build Coastguard Worker        A = torch.ones(5, 3, 3, device=device)
5512*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all())
5513*da0073e9SAndroid Build Coastguard Worker
5514*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
5515*da0073e9SAndroid Build Coastguard Worker            # Error checking, no pivoting variant on CPU
5516*da0073e9SAndroid Build Coastguard Worker            fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu]
5517*da0073e9SAndroid Build Coastguard Worker            for f in fns:
5518*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
5519*da0073e9SAndroid Build Coastguard Worker                    f(torch.empty(1, 2, 2), pivot=False)
5520*da0073e9SAndroid Build Coastguard Worker
5521*da0073e9SAndroid Build Coastguard Worker
5522*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5523*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
5524*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5525*da0073e9SAndroid Build Coastguard Worker    @setLinalgBackendsToDefaultFinally
5526*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5527*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lu_solve(self, device, dtype):
5528*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
5529*da0073e9SAndroid Build Coastguard Worker
5530*da0073e9SAndroid Build Coastguard Worker        backends = ["default"]
5531*da0073e9SAndroid Build Coastguard Worker
5532*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
5533*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.has_magma:
5534*da0073e9SAndroid Build Coastguard Worker                backends.append("magma")
5535*da0073e9SAndroid Build Coastguard Worker            if has_cusolver():
5536*da0073e9SAndroid Build Coastguard Worker                backends.append("cusolver")
5537*da0073e9SAndroid Build Coastguard Worker
5538*da0073e9SAndroid Build Coastguard Worker        def gen_matrices():
5539*da0073e9SAndroid Build Coastguard Worker            rhs = 3
5540*da0073e9SAndroid Build Coastguard Worker            ns = (5, 2, 0)
5541*da0073e9SAndroid Build Coastguard Worker            batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
5542*da0073e9SAndroid Build Coastguard Worker            for batch, n in product(batches, ns):
5543*da0073e9SAndroid Build Coastguard Worker                yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs))
5544*da0073e9SAndroid Build Coastguard Worker            # Shapes to exercise all the paths
5545*da0073e9SAndroid Build Coastguard Worker            shapes = ((1, 64), (2, 128), (1025, 2))
5546*da0073e9SAndroid Build Coastguard Worker            for b, n in shapes:
5547*da0073e9SAndroid Build Coastguard Worker                yield make_arg((b, n, n)), make_arg((b, n, rhs))
5548*da0073e9SAndroid Build Coastguard Worker
5549*da0073e9SAndroid Build Coastguard Worker
5550*da0073e9SAndroid Build Coastguard Worker        for A, B in gen_matrices():
5551*da0073e9SAndroid Build Coastguard Worker            LU, pivots = torch.linalg.lu_factor(A)
5552*da0073e9SAndroid Build Coastguard Worker            for backend in backends:
5553*da0073e9SAndroid Build Coastguard Worker                torch.backends.cuda.preferred_linalg_library(backend)
5554*da0073e9SAndroid Build Coastguard Worker
5555*da0073e9SAndroid Build Coastguard Worker                for left, adjoint in product((True, False), repeat=2):
5556*da0073e9SAndroid Build Coastguard Worker                    B_left = B if left else B.mT
5557*da0073e9SAndroid Build Coastguard Worker                    X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint)
5558*da0073e9SAndroid Build Coastguard Worker                    A_adj = A.mH if adjoint else A
5559*da0073e9SAndroid Build Coastguard Worker                    if left:
5560*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(B_left, A_adj @ X)
5561*da0073e9SAndroid Build Coastguard Worker                    else:
5562*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(B_left, X @ A_adj)
5563*da0073e9SAndroid Build Coastguard Worker
5564*da0073e9SAndroid Build Coastguard Worker
5565*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5566*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
5567*da0073e9SAndroid Build Coastguard Worker    def test_linalg_lu_cpu_errors(self, device, dtype):
5568*da0073e9SAndroid Build Coastguard Worker        # Square tests
5569*da0073e9SAndroid Build Coastguard Worker        sample = torch.randn(3, 2, 2, device=device, dtype=dtype)
5570*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(3, 2, 2, device=device, dtype=dtype)
5571*da0073e9SAndroid Build Coastguard Worker        LU, pivots = torch.linalg.lu_factor(sample)
5572*da0073e9SAndroid Build Coastguard Worker
5573*da0073e9SAndroid Build Coastguard Worker        # This should run without issues
5574*da0073e9SAndroid Build Coastguard Worker        torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5575*da0073e9SAndroid Build Coastguard Worker        torch.lu_unpack(LU, pivots)
5576*da0073e9SAndroid Build Coastguard Worker
5577*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 0
5578*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"):
5579*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5580*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5581*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5582*da0073e9SAndroid Build Coastguard Worker
5583*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 3
5584*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"):
5585*da0073e9SAndroid Build Coastguard Worker            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5586*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5587*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5588*da0073e9SAndroid Build Coastguard Worker
5589*da0073e9SAndroid Build Coastguard Worker        # Rectangular tests
5590*da0073e9SAndroid Build Coastguard Worker        sample = torch.randn(3, 4, 2, device=device, dtype=dtype)
5591*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(3, 4, 2, device=device, dtype=dtype)
5592*da0073e9SAndroid Build Coastguard Worker        LU, pivots = torch.linalg.lu_factor(sample)
5593*da0073e9SAndroid Build Coastguard Worker
5594*da0073e9SAndroid Build Coastguard Worker        # This should run without issues
5595*da0073e9SAndroid Build Coastguard Worker        torch.lu_unpack(LU, pivots)
5596*da0073e9SAndroid Build Coastguard Worker
5597*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 0
5598*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5599*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5600*da0073e9SAndroid Build Coastguard Worker
5601*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 5
5602*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5603*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5604*da0073e9SAndroid Build Coastguard Worker
5605*da0073e9SAndroid Build Coastguard Worker
5606*da0073e9SAndroid Build Coastguard Worker        # Rectangular tests
5607*da0073e9SAndroid Build Coastguard Worker        sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
5608*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(2, 3, 5, device=device, dtype=dtype)
5609*da0073e9SAndroid Build Coastguard Worker        LU, pivots = torch.linalg.lu_factor(sample)
5610*da0073e9SAndroid Build Coastguard Worker
5611*da0073e9SAndroid Build Coastguard Worker        # This should run without issues
5612*da0073e9SAndroid Build Coastguard Worker        torch.lu_unpack(LU, pivots)
5613*da0073e9SAndroid Build Coastguard Worker
5614*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 0
5615*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5616*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5617*da0073e9SAndroid Build Coastguard Worker
5618*da0073e9SAndroid Build Coastguard Worker        pivots[0] = 4
5619*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5620*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(LU, pivots)
5621*da0073e9SAndroid Build Coastguard Worker
5622*da0073e9SAndroid Build Coastguard Worker
5623*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5624*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
5625*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5626*da0073e9SAndroid Build Coastguard Worker    def test_lu_unpack_check_input(self, device, dtype):
5627*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 5, 5, device=device, dtype=dtype)
5628*da0073e9SAndroid Build Coastguard Worker        lu_data, lu_pivots = torch.linalg.lu_factor(x)
5629*da0073e9SAndroid Build Coastguard Worker
5630*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
5631*da0073e9SAndroid Build Coastguard Worker            torch.lu_unpack(lu_data, lu_pivots.long())
5632*da0073e9SAndroid Build Coastguard Worker
5633*da0073e9SAndroid Build Coastguard Worker        # check that onces flags are unset, Nones are returned
5634*da0073e9SAndroid Build Coastguard Worker        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
5635*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(l.numel() == 0 and u.numel() == 0)
5636*da0073e9SAndroid Build Coastguard Worker        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
5637*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p.numel() == 0)
5638*da0073e9SAndroid Build Coastguard Worker        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
5639*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0)
5640*da0073e9SAndroid Build Coastguard Worker
5641*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
5642*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5643*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5644*da0073e9SAndroid Build Coastguard Worker    def test_lobpcg_basic(self, device, dtype):
5645*da0073e9SAndroid Build Coastguard Worker        self._test_lobpcg_method(device, dtype, 'basic')
5646*da0073e9SAndroid Build Coastguard Worker
5647*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
5648*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5649*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5650*da0073e9SAndroid Build Coastguard Worker    def test_lobpcg_ortho(self, device, dtype):
5651*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip:
5652*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.preferred_linalg_library('magma')
5653*da0073e9SAndroid Build Coastguard Worker        self._test_lobpcg_method(device, dtype, 'ortho')
5654*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip:
5655*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.preferred_linalg_library('default')
5656*da0073e9SAndroid Build Coastguard Worker
5657*da0073e9SAndroid Build Coastguard Worker    def _test_lobpcg_method(self, device, dtype, method):
5658*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix
5659*da0073e9SAndroid Build Coastguard Worker        from torch._linalg_utils import matmul, qform
5660*da0073e9SAndroid Build Coastguard Worker        from torch._lobpcg import lobpcg
5661*da0073e9SAndroid Build Coastguard Worker
5662*da0073e9SAndroid Build Coastguard Worker        def test_tracker(worker):
5663*da0073e9SAndroid Build Coastguard Worker            k = worker.iparams['k']
5664*da0073e9SAndroid Build Coastguard Worker            nc = worker.ivars['converged_count']
5665*da0073e9SAndroid Build Coastguard Worker            if k <= nc:
5666*da0073e9SAndroid Build Coastguard Worker                tol = worker.fparams['tol']
5667*da0073e9SAndroid Build Coastguard Worker                rerr = worker.tvars['rerr']
5668*da0073e9SAndroid Build Coastguard Worker                X = worker.X
5669*da0073e9SAndroid Build Coastguard Worker                E = worker.E
5670*da0073e9SAndroid Build Coastguard Worker                B = worker.B
5671*da0073e9SAndroid Build Coastguard Worker                A = worker.A
5672*da0073e9SAndroid Build Coastguard Worker                dtype = X.dtype
5673*da0073e9SAndroid Build Coastguard Worker                device = X.device
5674*da0073e9SAndroid Build Coastguard Worker
5675*da0073e9SAndroid Build Coastguard Worker                # Check convergence
5676*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(rerr[:k].max(), tol)
5677*da0073e9SAndroid Build Coastguard Worker
5678*da0073e9SAndroid Build Coastguard Worker                # Check B-orthogonality
5679*da0073e9SAndroid Build Coastguard Worker                I = torch.eye(k, k, dtype=dtype, device=device)
5680*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(qform(B, X[:, :k]), I)
5681*da0073e9SAndroid Build Coastguard Worker
5682*da0073e9SAndroid Build Coastguard Worker                # Check block equation
5683*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0)
5684*da0073e9SAndroid Build Coastguard Worker
5685*da0073e9SAndroid Build Coastguard Worker        orig_lobpcg = lobpcg
5686*da0073e9SAndroid Build Coastguard Worker
5687*da0073e9SAndroid Build Coastguard Worker        def lobpcg(*args, **kwargs):
5688*da0073e9SAndroid Build Coastguard Worker            kwargs['tracker'] = test_tracker
5689*da0073e9SAndroid Build Coastguard Worker            kwargs['niter'] = 1000
5690*da0073e9SAndroid Build Coastguard Worker            kwargs['method'] = method
5691*da0073e9SAndroid Build Coastguard Worker            kwargs['tol'] = 1e-8
5692*da0073e9SAndroid Build Coastguard Worker            return orig_lobpcg(*args, **kwargs)
5693*da0073e9SAndroid Build Coastguard Worker        prec = 5e-4
5694*da0073e9SAndroid Build Coastguard Worker
5695*da0073e9SAndroid Build Coastguard Worker        # check dense input
5696*da0073e9SAndroid Build Coastguard Worker        mm = torch.matmul
5697*da0073e9SAndroid Build Coastguard Worker        for batches in [(), (2,), (2, 3)]:
5698*da0073e9SAndroid Build Coastguard Worker            for m, n, k in [
5699*da0073e9SAndroid Build Coastguard Worker                    (9, 3, 1),
5700*da0073e9SAndroid Build Coastguard Worker                    (9, 3, 2),
5701*da0073e9SAndroid Build Coastguard Worker                    (9, 2, 2),
5702*da0073e9SAndroid Build Coastguard Worker                    (100, 15, 5),
5703*da0073e9SAndroid Build Coastguard Worker            ]:
5704*da0073e9SAndroid Build Coastguard Worker                # skip tests that are known to fail with the basic
5705*da0073e9SAndroid Build Coastguard Worker                # LOBPCG method due to calling cholesky on singular
5706*da0073e9SAndroid Build Coastguard Worker                # input
5707*da0073e9SAndroid Build Coastguard Worker                if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]:
5708*da0073e9SAndroid Build Coastguard Worker                    continue
5709*da0073e9SAndroid Build Coastguard Worker                A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5710*da0073e9SAndroid Build Coastguard Worker                B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5711*da0073e9SAndroid Build Coastguard Worker
5712*da0073e9SAndroid Build Coastguard Worker                # classical eigenvalue problem, smallest eigenvalues
5713*da0073e9SAndroid Build Coastguard Worker                E, V = lobpcg(A, k=k, n=n, largest=False)
5714*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(E.shape, batches + (k,))
5715*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(V.shape, batches + (m, k))
5716*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5717*da0073e9SAndroid Build Coastguard Worker                e = torch.linalg.eigvalsh(A)
5718*da0073e9SAndroid Build Coastguard Worker                e_smallest = e[..., :k]
5719*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(E, e_smallest)
5720*da0073e9SAndroid Build Coastguard Worker
5721*da0073e9SAndroid Build Coastguard Worker                # classical eigenvalue problem, largest eigenvalues
5722*da0073e9SAndroid Build Coastguard Worker                E, V = lobpcg(A, k=k, n=n, largest=True)
5723*da0073e9SAndroid Build Coastguard Worker                e_largest, _ = torch.sort(e[..., -k:], descending=True)
5724*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(E, e_largest, atol=prec, rtol=0)
5725*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5726*da0073e9SAndroid Build Coastguard Worker
5727*da0073e9SAndroid Build Coastguard Worker                # generalized eigenvalue problem, smallest eigenvalues
5728*da0073e9SAndroid Build Coastguard Worker                E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5729*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0)
5730*da0073e9SAndroid Build Coastguard Worker
5731*da0073e9SAndroid Build Coastguard Worker                # generalized eigenvalue problem, largest eigenvalues
5732*da0073e9SAndroid Build Coastguard Worker                E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5733*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5734*da0073e9SAndroid Build Coastguard Worker                                 atol=prec, rtol=0)
5735*da0073e9SAndroid Build Coastguard Worker
5736*da0073e9SAndroid Build Coastguard Worker        # check sparse input
5737*da0073e9SAndroid Build Coastguard Worker        for m, n, k, density in [
5738*da0073e9SAndroid Build Coastguard Worker                (5, 1, 1, 0.8),
5739*da0073e9SAndroid Build Coastguard Worker                (9, 3, 2, 0.5),
5740*da0073e9SAndroid Build Coastguard Worker                (100, 1, 1, 0.1),
5741*da0073e9SAndroid Build Coastguard Worker                (1000, 7, 3, 0.01),
5742*da0073e9SAndroid Build Coastguard Worker        ]:
5743*da0073e9SAndroid Build Coastguard Worker            # skip tests that are known to fail with the basic LOBCG
5744*da0073e9SAndroid Build Coastguard Worker            # method due to insufficient accuracy
5745*da0073e9SAndroid Build Coastguard Worker            if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]:
5746*da0073e9SAndroid Build Coastguard Worker                continue
5747*da0073e9SAndroid Build Coastguard Worker            A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5748*da0073e9SAndroid Build Coastguard Worker            B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5749*da0073e9SAndroid Build Coastguard Worker            A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m
5750*da0073e9SAndroid Build Coastguard Worker            e_smallest = A_eigenvalues[..., :k]
5751*da0073e9SAndroid Build Coastguard Worker            e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True)
5752*da0073e9SAndroid Build Coastguard Worker
5753*da0073e9SAndroid Build Coastguard Worker            # classical eigenvalue problem, smallest eigenvalues
5754*da0073e9SAndroid Build Coastguard Worker            E, V = lobpcg(A, k=k, n=n, largest=False)
5755*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(E, e_smallest)
5756*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5757*da0073e9SAndroid Build Coastguard Worker
5758*da0073e9SAndroid Build Coastguard Worker            # classical eigenvalue problem, largest eigenvalues
5759*da0073e9SAndroid Build Coastguard Worker            E, V = lobpcg(A, k=k, n=n, largest=True)
5760*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(E, e_largest)
5762*da0073e9SAndroid Build Coastguard Worker
5763*da0073e9SAndroid Build Coastguard Worker            # generalized eigenvalue problem, smallest eigenvalues
5764*da0073e9SAndroid Build Coastguard Worker            E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5765*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0)
5766*da0073e9SAndroid Build Coastguard Worker
5767*da0073e9SAndroid Build Coastguard Worker            # generalized eigenvalue problem, largest eigenvalues
5768*da0073e9SAndroid Build Coastguard Worker            E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5769*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5770*da0073e9SAndroid Build Coastguard Worker                             atol=prec, rtol=0)
5771*da0073e9SAndroid Build Coastguard Worker
5772*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5773*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5774*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5775*da0073e9SAndroid Build Coastguard Worker    def test_lobpcg_torchscript(self, device, dtype):
5776*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5777*da0073e9SAndroid Build Coastguard Worker        from torch._linalg_utils import matmul as mm
5778*da0073e9SAndroid Build Coastguard Worker
5779*da0073e9SAndroid Build Coastguard Worker        lobpcg = torch.jit.script(torch.lobpcg)
5780*da0073e9SAndroid Build Coastguard Worker
5781*da0073e9SAndroid Build Coastguard Worker        m = 500
5782*da0073e9SAndroid Build Coastguard Worker        k = 5
5783*da0073e9SAndroid Build Coastguard Worker        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5784*da0073e9SAndroid Build Coastguard Worker        X1 = torch.randn((m, k), dtype=dtype, device=device)
5785*da0073e9SAndroid Build Coastguard Worker        E1, V1 = lobpcg(A1, X=X1)
5786*da0073e9SAndroid Build Coastguard Worker        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5787*da0073e9SAndroid Build Coastguard Worker        self.assertLess(eq_err, 1e-6)
5788*da0073e9SAndroid Build Coastguard Worker
5789*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1")
5790*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
5791*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg")
5792*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5793*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5794*da0073e9SAndroid Build Coastguard Worker    def test_lobpcg_scipy(self, device, dtype):
5795*da0073e9SAndroid Build Coastguard Worker        """Compare torch and scipy.sparse.linalg implementations of lobpcg
5796*da0073e9SAndroid Build Coastguard Worker        """
5797*da0073e9SAndroid Build Coastguard Worker        import time
5798*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5799*da0073e9SAndroid Build Coastguard Worker        from torch._linalg_utils import matmul as mm
5800*da0073e9SAndroid Build Coastguard Worker        from scipy.sparse.linalg import lobpcg as scipy_lobpcg
5801*da0073e9SAndroid Build Coastguard Worker        import scipy.sparse
5802*da0073e9SAndroid Build Coastguard Worker
5803*da0073e9SAndroid Build Coastguard Worker        def toscipy(A):
5804*da0073e9SAndroid Build Coastguard Worker            if A.layout == torch.sparse_coo:
5805*da0073e9SAndroid Build Coastguard Worker                values = A.coalesce().values().cpu().numpy().copy()
5806*da0073e9SAndroid Build Coastguard Worker                indices = A.coalesce().indices().cpu().numpy().copy()
5807*da0073e9SAndroid Build Coastguard Worker                return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape)
5808*da0073e9SAndroid Build Coastguard Worker            return A.cpu().numpy().copy()
5809*da0073e9SAndroid Build Coastguard Worker
5810*da0073e9SAndroid Build Coastguard Worker        niter = 1000
5811*da0073e9SAndroid Build Coastguard Worker        repeat = 10
5812*da0073e9SAndroid Build Coastguard Worker        m = 500   # size of the square matrix
5813*da0073e9SAndroid Build Coastguard Worker        k = 7     # the number of requested eigenpairs
5814*da0073e9SAndroid Build Coastguard Worker        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5815*da0073e9SAndroid Build Coastguard Worker        B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5816*da0073e9SAndroid Build Coastguard Worker        X1 = torch.randn((m, k), dtype=dtype, device=device)
5817*da0073e9SAndroid Build Coastguard Worker
5818*da0073e9SAndroid Build Coastguard Worker        A2 = toscipy(A1)
5819*da0073e9SAndroid Build Coastguard Worker        B2 = toscipy(B1)
5820*da0073e9SAndroid Build Coastguard Worker        X2 = toscipy(X1)
5821*da0073e9SAndroid Build Coastguard Worker
5822*da0073e9SAndroid Build Coastguard Worker        lambdas1 = []
5823*da0073e9SAndroid Build Coastguard Worker
5824*da0073e9SAndroid Build Coastguard Worker        def tracker(worker):
5825*da0073e9SAndroid Build Coastguard Worker            lambdas1.append(worker.E[:])
5826*da0073e9SAndroid Build Coastguard Worker
5827*da0073e9SAndroid Build Coastguard Worker        tol = 1e-8
5828*da0073e9SAndroid Build Coastguard Worker        # tol for scipy lobpcg will be choosed so that the number of
5829*da0073e9SAndroid Build Coastguard Worker        # iterations will be equal or very close to pytorch lobpcg
5830*da0073e9SAndroid Build Coastguard Worker        # (that is around 170-180)
5831*da0073e9SAndroid Build Coastguard Worker
5832*da0073e9SAndroid Build Coastguard Worker        # Standard eigenvalue problem
5833*da0073e9SAndroid Build Coastguard Worker        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5834*da0073e9SAndroid Build Coastguard Worker        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol)
5835*da0073e9SAndroid Build Coastguard Worker        iters1 = len(lambdas1)
5836*da0073e9SAndroid Build Coastguard Worker        iters2 = len(lambdas2)
5837*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5838*da0073e9SAndroid Build Coastguard Worker
5839*da0073e9SAndroid Build Coastguard Worker        E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False)
5840*da0073e9SAndroid Build Coastguard Worker
5841*da0073e9SAndroid Build Coastguard Worker        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5842*da0073e9SAndroid Build Coastguard Worker        eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5843*da0073e9SAndroid Build Coastguard Worker        self.assertLess(eq_err, 1e-6)        # std
5844*da0073e9SAndroid Build Coastguard Worker        self.assertLess(eq_err_scipy, 1e-6)  # std
5845*da0073e9SAndroid Build Coastguard Worker
5846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5847*da0073e9SAndroid Build Coastguard Worker
5848*da0073e9SAndroid Build Coastguard Worker        # Generalized eigenvalue problem
5849*da0073e9SAndroid Build Coastguard Worker        lambdas1 = []
5850*da0073e9SAndroid Build Coastguard Worker
5851*da0073e9SAndroid Build Coastguard Worker        def tracker(worker):
5852*da0073e9SAndroid Build Coastguard Worker            lambdas1.append(worker.E[:])
5853*da0073e9SAndroid Build Coastguard Worker
5854*da0073e9SAndroid Build Coastguard Worker        E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5855*da0073e9SAndroid Build Coastguard Worker        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol)
5856*da0073e9SAndroid Build Coastguard Worker        E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False)
5857*da0073e9SAndroid Build Coastguard Worker        iters1 = len(lambdas1)
5858*da0073e9SAndroid Build Coastguard Worker        iters2 = len(lambdas2)
5859*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5860*da0073e9SAndroid Build Coastguard Worker
5861*da0073e9SAndroid Build Coastguard Worker        eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5862*da0073e9SAndroid Build Coastguard Worker        eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5863*da0073e9SAndroid Build Coastguard Worker        self.assertLess(eq_err, 1e-6)        # general
5864*da0073e9SAndroid Build Coastguard Worker        self.assertLess(eq_err_scipy, 1e-6)  # general
5865*da0073e9SAndroid Build Coastguard Worker
5866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5867*da0073e9SAndroid Build Coastguard Worker
5868*da0073e9SAndroid Build Coastguard Worker        # Timings
5869*da0073e9SAndroid Build Coastguard Worker        elapsed_ortho = 0
5870*da0073e9SAndroid Build Coastguard Worker        elapsed_ortho_general = 0
5871*da0073e9SAndroid Build Coastguard Worker        elapsed_scipy = 0
5872*da0073e9SAndroid Build Coastguard Worker        elapsed_general_scipy = 0
5873*da0073e9SAndroid Build Coastguard Worker        for i in range(repeat):
5874*da0073e9SAndroid Build Coastguard Worker            start = time.time()
5875*da0073e9SAndroid Build Coastguard Worker            torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol)
5876*da0073e9SAndroid Build Coastguard Worker            end = time.time()
5877*da0073e9SAndroid Build Coastguard Worker            elapsed_ortho += end - start
5878*da0073e9SAndroid Build Coastguard Worker
5879*da0073e9SAndroid Build Coastguard Worker            start = time.time()
5880*da0073e9SAndroid Build Coastguard Worker            torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol)
5881*da0073e9SAndroid Build Coastguard Worker            end = time.time()
5882*da0073e9SAndroid Build Coastguard Worker            elapsed_ortho_general += end - start
5883*da0073e9SAndroid Build Coastguard Worker
5884*da0073e9SAndroid Build Coastguard Worker            start = time.time()
5885*da0073e9SAndroid Build Coastguard Worker            scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol)
5886*da0073e9SAndroid Build Coastguard Worker            end = time.time()
5887*da0073e9SAndroid Build Coastguard Worker            elapsed_scipy += end - start
5888*da0073e9SAndroid Build Coastguard Worker
5889*da0073e9SAndroid Build Coastguard Worker            start = time.time()
5890*da0073e9SAndroid Build Coastguard Worker            scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol)
5891*da0073e9SAndroid Build Coastguard Worker            end = time.time()
5892*da0073e9SAndroid Build Coastguard Worker            elapsed_general_scipy += end - start
5893*da0073e9SAndroid Build Coastguard Worker
5894*da0073e9SAndroid Build Coastguard Worker        elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat
5895*da0073e9SAndroid Build Coastguard Worker        elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat
5896*da0073e9SAndroid Build Coastguard Worker        elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat
5897*da0073e9SAndroid Build Coastguard Worker        elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat
5898*da0073e9SAndroid Build Coastguard Worker
5899*da0073e9SAndroid Build Coastguard Worker        print(f'''
5900*da0073e9SAndroid Build Coastguard WorkerCPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5901*da0073e9SAndroid Build Coastguard Worker-------------------------------------------------------
5902*da0073e9SAndroid Build Coastguard Worker              | standard    | generalized | method
5903*da0073e9SAndroid Build Coastguard Workertorch.lobpcg  | {elapsed_ortho_ms:10.2f}  | {elapsed_ortho_general_ms:10.2f}  | ortho
5904*da0073e9SAndroid Build Coastguard Workerscipy_lobpcg  | {elapsed_scipy_ms:10.2f}  | {elapsed_general_scipy_ms:10.2f}  | N/A
5905*da0073e9SAndroid Build Coastguard Worker-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)-
5906*da0073e9SAndroid Build Coastguard Worker        ''')
5907*da0073e9SAndroid Build Coastguard Worker
5908*da0073e9SAndroid Build Coastguard Worker        # Handling of very small tolerence
5909*da0073e9SAndroid Build Coastguard Worker        tol = 1e-100
5910*da0073e9SAndroid Build Coastguard Worker
5911*da0073e9SAndroid Build Coastguard Worker        lambdas1 = []
5912*da0073e9SAndroid Build Coastguard Worker
5913*da0073e9SAndroid Build Coastguard Worker        def tracker(worker):
5914*da0073e9SAndroid Build Coastguard Worker            lambdas1.append(worker.E[:])
5915*da0073e9SAndroid Build Coastguard Worker
5916*da0073e9SAndroid Build Coastguard Worker        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5917*da0073e9SAndroid Build Coastguard Worker        iters1 = len(lambdas1)
5918*da0073e9SAndroid Build Coastguard Worker        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5919*da0073e9SAndroid Build Coastguard Worker
5920*da0073e9SAndroid Build Coastguard Worker        try:
5921*da0073e9SAndroid Build Coastguard Worker            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5922*da0073e9SAndroid Build Coastguard Worker            iters2 = len(lambdas2)
5923*da0073e9SAndroid Build Coastguard Worker            eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5924*da0073e9SAndroid Build Coastguard Worker        except Exception as msg:
5925*da0073e9SAndroid Build Coastguard Worker            print('Calling scipy_lobpcg failed [standard]:', msg)
5926*da0073e9SAndroid Build Coastguard Worker            iters2 = -1
5927*da0073e9SAndroid Build Coastguard Worker            eq_err_scipy = -1
5928*da0073e9SAndroid Build Coastguard Worker
5929*da0073e9SAndroid Build Coastguard Worker        lambdas1 = []
5930*da0073e9SAndroid Build Coastguard Worker
5931*da0073e9SAndroid Build Coastguard Worker        def tracker(worker):
5932*da0073e9SAndroid Build Coastguard Worker            lambdas1.append(worker.E[:])
5933*da0073e9SAndroid Build Coastguard Worker
5934*da0073e9SAndroid Build Coastguard Worker        E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol)
5935*da0073e9SAndroid Build Coastguard Worker        iters1_general = len(lambdas1)
5936*da0073e9SAndroid Build Coastguard Worker        eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5937*da0073e9SAndroid Build Coastguard Worker
5938*da0073e9SAndroid Build Coastguard Worker        try:
5939*da0073e9SAndroid Build Coastguard Worker            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5940*da0073e9SAndroid Build Coastguard Worker            iters2_general = len(lambdas2)
5941*da0073e9SAndroid Build Coastguard Worker            eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5942*da0073e9SAndroid Build Coastguard Worker        except Exception as msg:
5943*da0073e9SAndroid Build Coastguard Worker            print('Calling scipy_lobpcg failed [generalized]:', msg)
5944*da0073e9SAndroid Build Coastguard Worker            iters2_general = -1
5945*da0073e9SAndroid Build Coastguard Worker            eq_err_general_scipy = -1
5946*da0073e9SAndroid Build Coastguard Worker
5947*da0073e9SAndroid Build Coastguard Worker        print(f'''\
5948*da0073e9SAndroid Build Coastguard WorkerHandling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5949*da0073e9SAndroid Build Coastguard Worker----------------------------------------------------------------------------
5950*da0073e9SAndroid Build Coastguard Worker              | standard    | generalized |  niter | method
5951*da0073e9SAndroid Build Coastguard Workertorch.lobpcg  | {eq_err:10.2e}  | {eq_err_general:10.2e}  | {iters1:6} | ortho
5952*da0073e9SAndroid Build Coastguard Workerscipy_lobpcg  | {eq_err_scipy:10.2e}  | {eq_err_general_scipy:10.2e}  | {iters2:6} | N/A
5953*da0073e9SAndroid Build Coastguard Worker---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})---
5954*da0073e9SAndroid Build Coastguard Worker''')
5955*da0073e9SAndroid Build Coastguard Worker
5956*da0073e9SAndroid Build Coastguard Worker    def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None):
5957*da0073e9SAndroid Build Coastguard Worker        dtype = t.dtype
5958*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype
5959*da0073e9SAndroid Build Coastguard Worker        if dtype in {torch.bfloat16, torch.half}:
5960*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = torch.float
5961*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
5962*da0073e9SAndroid Build Coastguard Worker            alpha = 0.9 + 0.3j if alpha is None else alpha
5963*da0073e9SAndroid Build Coastguard Worker            beta = 0.5 + 0.6j if beta is None else beta
5964*da0073e9SAndroid Build Coastguard Worker        else:
5965*da0073e9SAndroid Build Coastguard Worker            alpha = 1.2 if alpha is None else alpha
5966*da0073e9SAndroid Build Coastguard Worker            beta = 0.8 if beta is None else beta
5967*da0073e9SAndroid Build Coastguard Worker        if activation == "gelu":
5968*da0073e9SAndroid Build Coastguard Worker            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
5969*da0073e9SAndroid Build Coastguard Worker        else:
5970*da0073e9SAndroid Build Coastguard Worker            res1 = f(t, m, v, alpha=alpha, beta=beta)
5971*da0073e9SAndroid Build Coastguard Worker        res2 = torch.full_like(res1, math.nan)
5972*da0073e9SAndroid Build Coastguard Worker        if transpose_out:
5973*da0073e9SAndroid Build Coastguard Worker            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
5974*da0073e9SAndroid Build Coastguard Worker        if activation == "gelu":
5975*da0073e9SAndroid Build Coastguard Worker            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
5976*da0073e9SAndroid Build Coastguard Worker        else:
5977*da0073e9SAndroid Build Coastguard Worker            f(t, m, v, alpha=alpha, beta=beta, out=res2)
5978*da0073e9SAndroid Build Coastguard Worker        res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
5979*da0073e9SAndroid Build Coastguard Worker        if beta != 0:
5980*da0073e9SAndroid Build Coastguard Worker            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
5981*da0073e9SAndroid Build Coastguard Worker        if activation == "relu":
5982*da0073e9SAndroid Build Coastguard Worker            res3 = res3 * (res3 > 0)
5983*da0073e9SAndroid Build Coastguard Worker        elif activation == "gelu":
5984*da0073e9SAndroid Build Coastguard Worker            res3_t = torch.from_numpy(res3).to(dtype)
5985*da0073e9SAndroid Build Coastguard Worker            approximate = "tanh" if t.is_cuda else "none"
5986*da0073e9SAndroid Build Coastguard Worker            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
5987*da0073e9SAndroid Build Coastguard Worker            res3 = res3_t.to(numpy_dtype).cpu().numpy()
5988*da0073e9SAndroid Build Coastguard Worker        else:
5989*da0073e9SAndroid Build Coastguard Worker            assert activation is None, f"unsupported activation {activation}"
5990*da0073e9SAndroid Build Coastguard Worker        res3 = torch.from_numpy(res3).to(dtype)
5991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
5992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res3)
5993*da0073e9SAndroid Build Coastguard Worker
5994*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8,
5995*da0073e9SAndroid Build Coastguard Worker                        torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5996*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types_and(
5997*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [],
5998*da0073e9SAndroid Build Coastguard Worker                  torch.half))
5999*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
6000*da0073e9SAndroid Build Coastguard Worker    def test_addmv(self, device, dtype):
6001*da0073e9SAndroid Build Coastguard Worker        if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
6002*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
6003*da0073e9SAndroid Build Coastguard Worker        # have to use torch.randn(...).to(bfloat16) instead of
6004*da0073e9SAndroid Build Coastguard Worker        # torch.randn(..., dtype=bfloat16). randn does not support
6005*da0073e9SAndroid Build Coastguard Worker        # bfloat16 yet.
6006*da0073e9SAndroid Build Coastguard Worker        # "*0.2" to reduce errors for low precision
6007*da0073e9SAndroid Build Coastguard Worker        ts = [
6008*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(50, device=device).to(dtype),
6009*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
6010*da0073e9SAndroid Build Coastguard Worker        ]
6011*da0073e9SAndroid Build Coastguard Worker        vs = [
6012*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(100, device=device).to(dtype),
6013*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.ones(1, device=device).to(dtype).expand(100),  # to reduce errors for low precision
6014*da0073e9SAndroid Build Coastguard Worker        ]
6015*da0073e9SAndroid Build Coastguard Worker        ms = [
6016*da0073e9SAndroid Build Coastguard Worker            # 0d
6017*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.ones((), device=device).to(dtype).expand(50, 100),  # to reduce errors for low precision
6018*da0073e9SAndroid Build Coastguard Worker            # 1d
6019*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
6020*da0073e9SAndroid Build Coastguard Worker            # this initialization reduces errors for low precision for broadcasted matrices
6021*da0073e9SAndroid Build Coastguard Worker            # by making sure that intermediate and result values are exactly representable
6022*da0073e9SAndroid Build Coastguard Worker            # in low precision type
6023*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100),
6024*da0073e9SAndroid Build Coastguard Worker            # 2d
6025*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((50, 100), device=device).to(dtype),
6026*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
6027*da0073e9SAndroid Build Coastguard Worker        ]
6028*da0073e9SAndroid Build Coastguard Worker        for m, v, t in itertools.product(ms, vs, ts):
6029*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, t, m, v)
6030*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, t=nan
6031*da0073e9SAndroid Build Coastguard Worker        t = torch.full((50,), math.nan, device=device).to(dtype)
6032*da0073e9SAndroid Build Coastguard Worker        for m, v in itertools.product(ms, vs):
6033*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
6034*da0073e9SAndroid Build Coastguard Worker
6035*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or
6036*da0073e9SAndroid Build Coastguard Worker                  SM53OrLater else []))
6037*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
6038*da0073e9SAndroid Build Coastguard Worker    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
6039*da0073e9SAndroid Build Coastguard Worker        # tests (o, s)*(s).  o is output size, s is summed size.
6040*da0073e9SAndroid Build Coastguard Worker        o = 5
6041*da0073e9SAndroid Build Coastguard Worker        s = 3
6042*da0073e9SAndroid Build Coastguard Worker        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
6043*da0073e9SAndroid Build Coastguard Worker        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
6044*da0073e9SAndroid Build Coastguard Worker        y_data = torch.ones(o, device=device, dtype=dtype)
6045*da0073e9SAndroid Build Coastguard Worker        control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype)
6046*da0073e9SAndroid Build Coastguard Worker
6047*da0073e9SAndroid Build Coastguard Worker        def _test(row_major, incx, incy, lda_tail):
6048*da0073e9SAndroid Build Coastguard Worker            if row_major:
6049*da0073e9SAndroid Build Coastguard Worker                a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype)
6050*da0073e9SAndroid Build Coastguard Worker            else:
6051*da0073e9SAndroid Build Coastguard Worker                a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0)
6052*da0073e9SAndroid Build Coastguard Worker            a = a_storage[:o, :s].copy_(a_data)
6053*da0073e9SAndroid Build Coastguard Worker
6054*da0073e9SAndroid Build Coastguard Worker            x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype)
6055*da0073e9SAndroid Build Coastguard Worker            x = x_storage[:, 0].copy_(x_data)
6056*da0073e9SAndroid Build Coastguard Worker
6057*da0073e9SAndroid Build Coastguard Worker            y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype)
6058*da0073e9SAndroid Build Coastguard Worker            y = y_storage[:, 0].copy_(y_data)
6059*da0073e9SAndroid Build Coastguard Worker
6060*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, y, a, x)
6061*da0073e9SAndroid Build Coastguard Worker
6062*da0073e9SAndroid Build Coastguard Worker        for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)):
6063*da0073e9SAndroid Build Coastguard Worker            _test(row_major, incx, incy, lda_tail)
6064*da0073e9SAndroid Build Coastguard Worker
6065*da0073e9SAndroid Build Coastguard Worker    def _test_addmm_impl(self, func, activation, device, dtype):
6066*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(10, 25, device=device).to(dtype)
6067*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device=device).to(dtype)
6068*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device=device).to(dtype)
6069*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
6070*da0073e9SAndroid Build Coastguard Worker
6071*da0073e9SAndroid Build Coastguard Worker        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
6072*da0073e9SAndroid Build Coastguard Worker        V = torch.randn(25, device=device).to(dtype)
6073*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
6074*da0073e9SAndroid Build Coastguard Worker
6075*da0073e9SAndroid Build Coastguard Worker        # Test 0-strided
6076*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
6077*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
6078*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device=device).to(dtype)
6079*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
6080*da0073e9SAndroid Build Coastguard Worker
6081*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, M=nan
6082*da0073e9SAndroid Build Coastguard Worker        M = torch.full((10, 25), math.nan, device=device).to(dtype)
6083*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device=device).to(dtype)
6084*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device=device).to(dtype)
6085*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
6086*da0073e9SAndroid Build Coastguard Worker
6087*da0073e9SAndroid Build Coastguard Worker        # Test transpose
6088*da0073e9SAndroid Build Coastguard Worker        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
6089*da0073e9SAndroid Build Coastguard Worker            def maybe_transpose(cond, m):
6090*da0073e9SAndroid Build Coastguard Worker                if not cond:
6091*da0073e9SAndroid Build Coastguard Worker                    return m
6092*da0073e9SAndroid Build Coastguard Worker                return m.t().clone(memory_format=torch.contiguous_format).t()
6093*da0073e9SAndroid Build Coastguard Worker
6094*da0073e9SAndroid Build Coastguard Worker            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
6095*da0073e9SAndroid Build Coastguard Worker            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
6096*da0073e9SAndroid Build Coastguard Worker            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
6097*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
6098*da0073e9SAndroid Build Coastguard Worker
6099*da0073e9SAndroid Build Coastguard Worker            if t1:
6100*da0073e9SAndroid Build Coastguard Worker                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
6101*da0073e9SAndroid Build Coastguard Worker                self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
6102*da0073e9SAndroid Build Coastguard Worker
6103*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
6104*da0073e9SAndroid Build Coastguard Worker                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6105*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float32)
6106*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types_and(
6107*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
6108*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6109*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6110*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6111*da0073e9SAndroid Build Coastguard Worker    def test_addmm(self, device, dtype):
6112*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch.addmm, None, device, dtype)
6113*da0073e9SAndroid Build Coastguard Worker
6114*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6115*da0073e9SAndroid Build Coastguard Worker                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6116*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(
6117*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6118*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.bfloat16))
6119*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6120*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6121*da0073e9SAndroid Build Coastguard Worker    def test_addmm_relu(self, device, dtype):
6122*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6123*da0073e9SAndroid Build Coastguard Worker
6124*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6125*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
6126*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6127*da0073e9SAndroid Build Coastguard Worker                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6128*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(
6129*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6130*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.bfloat16))
6131*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6132*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6133*da0073e9SAndroid Build Coastguard Worker    def test_addmm_relu_tunableop_rocm(self, device, dtype):
6134*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(True)
6135*da0073e9SAndroid Build Coastguard Worker        ordinal = torch.cuda.current_device()
6136*da0073e9SAndroid Build Coastguard Worker        filename = f"tunableop_results{ordinal}.csv"
6137*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_filename(filename)
6138*da0073e9SAndroid Build Coastguard Worker        iterations = torch.cuda.tunable.get_max_tuning_iterations()
6139*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(10)
6140*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6141*da0073e9SAndroid Build Coastguard Worker        # clean up, remove any file that was generated
6142*da0073e9SAndroid Build Coastguard Worker        try:
6143*da0073e9SAndroid Build Coastguard Worker            import os
6144*da0073e9SAndroid Build Coastguard Worker            os.remove(filename)
6145*da0073e9SAndroid Build Coastguard Worker        except FileNotFoundError:
6146*da0073e9SAndroid Build Coastguard Worker            pass
6147*da0073e9SAndroid Build Coastguard Worker        # reset back to prior settings
6148*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.set_max_tuning_iterations(iterations)
6149*da0073e9SAndroid Build Coastguard Worker        torch.cuda.tunable.enable(False)
6150*da0073e9SAndroid Build Coastguard Worker
6151*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6152*da0073e9SAndroid Build Coastguard Worker                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6153*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(
6154*da0073e9SAndroid Build Coastguard Worker                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6155*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.bfloat16))
6156*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6157*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6158*da0073e9SAndroid Build Coastguard Worker    def test_addmm_gelu(self, device, dtype):
6159*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
6160*da0073e9SAndroid Build Coastguard Worker
6161*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
6162*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types())
6163*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
6164*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
6165*da0073e9SAndroid Build Coastguard Worker    def test_addmm_sizes(self, device, dtype):
6166*da0073e9SAndroid Build Coastguard Worker        for m in [0, 1, 25]:
6167*da0073e9SAndroid Build Coastguard Worker            for n in [0, 1, 10]:
6168*da0073e9SAndroid Build Coastguard Worker                for k in [0, 1, 8]:
6169*da0073e9SAndroid Build Coastguard Worker                    M = torch.randn(n, m, device=device).to(dtype)
6170*da0073e9SAndroid Build Coastguard Worker                    m1 = torch.randn(n, k, device=device).to(dtype)
6171*da0073e9SAndroid Build Coastguard Worker                    m2 = torch.randn(k, m, device=device).to(dtype)
6172*da0073e9SAndroid Build Coastguard Worker                    self._test_addmm_addmv(torch.addmm, M, m1, m2)
6173*da0073e9SAndroid Build Coastguard Worker
6174*da0073e9SAndroid Build Coastguard Worker                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
6175*da0073e9SAndroid Build Coastguard Worker                    m2 = torch.randn(k, m, device=device).to(dtype)
6176*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
6177*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
6178*da0073e9SAndroid Build Coastguard Worker
6179*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
6180*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6181*da0073e9SAndroid Build Coastguard Worker    def test_addmm_baddbmm_overflow(self, device, dtype):
6182*da0073e9SAndroid Build Coastguard Worker        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
6183*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
6184*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(128, 128, dtype=torch.half, device=device)
6185*da0073e9SAndroid Build Coastguard Worker        mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
6186*da0073e9SAndroid Build Coastguard Worker        mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
6187*da0073e9SAndroid Build Coastguard Worker        out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6188*da0073e9SAndroid Build Coastguard Worker        # just check for no overflow on ROCM
6189*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6190*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out.isinf().any())
6191*da0073e9SAndroid Build Coastguard Worker        else:
6192*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((out == 10000.).all())
6193*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
6194*da0073e9SAndroid Build Coastguard Worker        mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
6195*da0073e9SAndroid Build Coastguard Worker        mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
6196*da0073e9SAndroid Build Coastguard Worker        out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6197*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6198*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out.isinf().any())
6199*da0073e9SAndroid Build Coastguard Worker        else:
6200*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((out == 10000.).all())
6201*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
6202*da0073e9SAndroid Build Coastguard Worker
6203*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
6204*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
6205*da0073e9SAndroid Build Coastguard Worker        for shape in [[3, 2, 2], [2, 20, 20]]:
6206*da0073e9SAndroid Build Coastguard Worker            mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2))
6207*da0073e9SAndroid Build Coastguard Worker            inputs = [torch.randn(shape, dtype=dtype, device=device),
6208*da0073e9SAndroid Build Coastguard Worker                      torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6209*da0073e9SAndroid Build Coastguard Worker            outs = [None, torch.randn(shape, dtype=dtype, device=device),
6210*da0073e9SAndroid Build Coastguard Worker                    torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6211*da0073e9SAndroid Build Coastguard Worker            options = itertools.product(inputs, outs)
6212*da0073e9SAndroid Build Coastguard Worker            for input, out in options:
6213*da0073e9SAndroid Build Coastguard Worker                y_ref = torch.bmm(mat1, mat2)
6214*da0073e9SAndroid Build Coastguard Worker                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
6215*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_ref, y)
6216*da0073e9SAndroid Build Coastguard Worker
6217*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64)
6218*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
6219*da0073e9SAndroid Build Coastguard Worker        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6220*da0073e9SAndroid Build Coastguard Worker        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6221*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
6222*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.float32:
6223*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
6224*da0073e9SAndroid Build Coastguard Worker                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
6225*da0073e9SAndroid Build Coastguard Worker        else:
6226*da0073e9SAndroid Build Coastguard Worker            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
6227*da0073e9SAndroid Build Coastguard Worker            y_ref = torch.bmm(batch1, batch2)
6228*da0073e9SAndroid Build Coastguard Worker            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
6229*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, y_ref)
6230*da0073e9SAndroid Build Coastguard Worker
6231*da0073e9SAndroid Build Coastguard Worker
6232*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6233*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6234*da0073e9SAndroid Build Coastguard Worker    def test_matmul_45724(self, device):
6235*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/45724
6236*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
6237*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
6238*da0073e9SAndroid Build Coastguard Worker        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
6239*da0073e9SAndroid Build Coastguard Worker        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
6240*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a, b, out=c)
6241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cpu_result)
6242*da0073e9SAndroid Build Coastguard Worker
6243*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6244*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90")
6245*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6246*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6247*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [16, 32])
6248*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [16, 32])
6249*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_transpose_a", [True, False])
6250*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_transpose_b", [True, False])
6251*da0073e9SAndroid Build Coastguard Worker    def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
6252*da0073e9SAndroid Build Coastguard Worker        def genf_int_float(x, y, use_transpose):
6253*da0073e9SAndroid Build Coastguard Worker            if use_transpose:
6254*da0073e9SAndroid Build Coastguard Worker                x, y = y, x
6255*da0073e9SAndroid Build Coastguard Worker            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6256*da0073e9SAndroid Build Coastguard Worker            x_float = x_int8.to(torch.float32)
6257*da0073e9SAndroid Build Coastguard Worker            if use_transpose:
6258*da0073e9SAndroid Build Coastguard Worker                return x_int8.t(), x_float.t()
6259*da0073e9SAndroid Build Coastguard Worker            return x_int8, x_float
6260*da0073e9SAndroid Build Coastguard Worker
6261*da0073e9SAndroid Build Coastguard Worker        def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
6262*da0073e9SAndroid Build Coastguard Worker            a_int8, a_float = genf_int_float(m, k, transpose_a)
6263*da0073e9SAndroid Build Coastguard Worker            b_int8, b_float = genf_int_float(k, n, transpose_b)
6264*da0073e9SAndroid Build Coastguard Worker            c_int32 = torch._int_mm(a_int8, b_int8)
6265*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(c_int32.dtype is torch.int32)
6266*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(c_int32.device, torch.device(device))
6267*da0073e9SAndroid Build Coastguard Worker            if test_equal:
6268*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6269*da0073e9SAndroid Build Coastguard Worker            else:
6270*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float))
6271*da0073e9SAndroid Build Coastguard Worker            c_int32_result = c_int32.new_empty(c_int32.size())
6272*da0073e9SAndroid Build Coastguard Worker            # Checking out variant
6273*da0073e9SAndroid Build Coastguard Worker            torch._int_mm(a_int8, b_int8, out=c_int32_result)
6274*da0073e9SAndroid Build Coastguard Worker            if test_equal:
6275*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6276*da0073e9SAndroid Build Coastguard Worker            else:
6277*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6278*da0073e9SAndroid Build Coastguard Worker
6279*da0073e9SAndroid Build Coastguard Worker        # NOTE: We're just exercising terrible failures here.
6280*da0073e9SAndroid Build Coastguard Worker        version = _get_torch_cuda_version()
6281*da0073e9SAndroid Build Coastguard Worker        SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
6282*da0073e9SAndroid Build Coastguard Worker        SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0)
6283*da0073e9SAndroid Build Coastguard Worker        SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5)
6284*da0073e9SAndroid Build Coastguard Worker
6285*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6286*da0073e9SAndroid Build Coastguard Worker            _test(17, k, n, use_transpose_a, use_transpose_b, True)
6287*da0073e9SAndroid Build Coastguard Worker        elif version >= (11, 7):
6288*da0073e9SAndroid Build Coastguard Worker            if not use_transpose_a and use_transpose_b:
6289*da0073e9SAndroid Build Coastguard Worker                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6290*da0073e9SAndroid Build Coastguard Worker                    _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))
6291*da0073e9SAndroid Build Coastguard Worker                else:
6292*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError,
6293*da0073e9SAndroid Build Coastguard Worker                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6294*da0073e9SAndroid Build Coastguard Worker                        _test(17, k, n, use_transpose_a, use_transpose_b)
6295*da0073e9SAndroid Build Coastguard Worker
6296*da0073e9SAndroid Build Coastguard Worker            if use_transpose_a and not use_transpose_b:
6297*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError,
6298*da0073e9SAndroid Build Coastguard Worker                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6299*da0073e9SAndroid Build Coastguard Worker                    _test(17, k, n, use_transpose_a, use_transpose_b)
6300*da0073e9SAndroid Build Coastguard Worker
6301*da0073e9SAndroid Build Coastguard Worker            if use_transpose_a and use_transpose_b:
6302*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError,
6303*da0073e9SAndroid Build Coastguard Worker                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6304*da0073e9SAndroid Build Coastguard Worker                    _test(17, k, n, use_transpose_a, use_transpose_b)
6305*da0073e9SAndroid Build Coastguard Worker
6306*da0073e9SAndroid Build Coastguard Worker            if not use_transpose_a and not use_transpose_b:
6307*da0073e9SAndroid Build Coastguard Worker                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6308*da0073e9SAndroid Build Coastguard Worker                    _test(17, k, n, use_transpose_a, use_transpose_b)
6309*da0073e9SAndroid Build Coastguard Worker                else:
6310*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError,
6311*da0073e9SAndroid Build Coastguard Worker                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6312*da0073e9SAndroid Build Coastguard Worker                        _test(17, k, n, use_transpose_a, use_transpose_b)
6313*da0073e9SAndroid Build Coastguard Worker        else:
6314*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"):
6315*da0073e9SAndroid Build Coastguard Worker                _test(17, k, n, use_transpose_a, use_transpose_b, False)
6316*da0073e9SAndroid Build Coastguard Worker
6317*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6318*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6319*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6320*da0073e9SAndroid Build Coastguard Worker    def test__int_mm_errors(self, device):
6321*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6322*da0073e9SAndroid Build Coastguard Worker            self.skipTest("_int_mm not compiled for ROCM")
6323*da0073e9SAndroid Build Coastguard Worker
6324*da0073e9SAndroid Build Coastguard Worker        version = _get_torch_cuda_version()
6325*da0073e9SAndroid Build Coastguard Worker        if version < (11, 7):
6326*da0073e9SAndroid Build Coastguard Worker            self.skipTest("_int_mm only compiled for CUDA 11.7")
6327*da0073e9SAndroid Build Coastguard Worker
6328*da0073e9SAndroid Build Coastguard Worker        def genf_int(x, y):
6329*da0073e9SAndroid Build Coastguard Worker            return torch.empty((x, y), dtype=torch.int8, device=device)
6330*da0073e9SAndroid Build Coastguard Worker
6331*da0073e9SAndroid Build Coastguard Worker        def _gen_pair(m, k, n):
6332*da0073e9SAndroid Build Coastguard Worker            return genf_int(m, k), genf_int(k, n)
6333*da0073e9SAndroid Build Coastguard Worker
6334*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6335*da0073e9SAndroid Build Coastguard Worker                               r"self.size\(0\) needs to be greater than 16, but got 16",
6336*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(*_gen_pair(16, 8, 32)))
6337*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6338*da0073e9SAndroid Build Coastguard Worker                               r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7",
6339*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(*_gen_pair(17, 7, 32)))
6340*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6341*da0073e9SAndroid Build Coastguard Worker                               r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7",
6342*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32)))
6343*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6344*da0073e9SAndroid Build Coastguard Worker                               r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31",
6345*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(*_gen_pair(17, 8, 31)))
6346*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6347*da0073e9SAndroid Build Coastguard Worker                               r"expected scalar type Char but found Float",
6348*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32)))
6349*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6350*da0073e9SAndroid Build Coastguard Worker                               r"expected scalar type Char but found Float",
6351*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float()))
6352*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6353*da0073e9SAndroid Build Coastguard Worker                               r"Expected result dtype to be of type kInt but got float",
6354*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float()))
6355*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6356*da0073e9SAndroid Build Coastguard Worker                               r"Expected result.size\(0\) to be 17 but got 15",
6357*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int()))
6358*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
6359*da0073e9SAndroid Build Coastguard Worker                               r"Expected result.size\(0\) to be 17 but got 16",
6360*da0073e9SAndroid Build Coastguard Worker                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
6361*da0073e9SAndroid Build Coastguard Worker
6362*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
6363*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [0, 8, 17])
6364*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [0, 16, 32])
6365*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [16, 32])
6366*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_transpose_a", [True, False])
6367*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_transpose_b", [True, False])
6368*da0073e9SAndroid Build Coastguard Worker    @parametrize("non_contig_type", [0, 1, 2])
6369*da0073e9SAndroid Build Coastguard Worker    def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
6370*da0073e9SAndroid Build Coastguard Worker        # non_contig_type:
6371*da0073e9SAndroid Build Coastguard Worker        # 0: the whole data buffer is contiguous (can be transposed)
6372*da0073e9SAndroid Build Coastguard Worker        # 1: stride of one dimension is 1, but the whole buffer is not contiguous
6373*da0073e9SAndroid Build Coastguard Worker        # 2: Neither stride is 1
6374*da0073e9SAndroid Build Coastguard Worker
6375*da0073e9SAndroid Build Coastguard Worker        def genf_int_float(x, y, use_transpose, non_contig_type):
6376*da0073e9SAndroid Build Coastguard Worker            if use_transpose:
6377*da0073e9SAndroid Build Coastguard Worker                x, y = y, x
6378*da0073e9SAndroid Build Coastguard Worker            if non_contig_type != 0:
6379*da0073e9SAndroid Build Coastguard Worker                y = y * 2
6380*da0073e9SAndroid Build Coastguard Worker            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6381*da0073e9SAndroid Build Coastguard Worker            x_float = x_int8.to(torch.float32)
6382*da0073e9SAndroid Build Coastguard Worker            if non_contig_type == 1:
6383*da0073e9SAndroid Build Coastguard Worker                x_int8 = x_int8[:, : y // 2]
6384*da0073e9SAndroid Build Coastguard Worker                x_float = x_float[:, : y // 2]
6385*da0073e9SAndroid Build Coastguard Worker            elif non_contig_type == 2:
6386*da0073e9SAndroid Build Coastguard Worker                x_int8 = x_int8[:, ::2]
6387*da0073e9SAndroid Build Coastguard Worker                x_float = x_float[:, ::2]
6388*da0073e9SAndroid Build Coastguard Worker            if use_transpose:
6389*da0073e9SAndroid Build Coastguard Worker                return x_int8.t(), x_float.t()
6390*da0073e9SAndroid Build Coastguard Worker            return x_int8, x_float
6391*da0073e9SAndroid Build Coastguard Worker
6392*da0073e9SAndroid Build Coastguard Worker        if non_contig_type != 0 and (m == 0 or k == 0):
6393*da0073e9SAndroid Build Coastguard Worker            return
6394*da0073e9SAndroid Build Coastguard Worker        a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
6395*da0073e9SAndroid Build Coastguard Worker        b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
6396*da0073e9SAndroid Build Coastguard Worker        c_int32 = torch._int_mm(a_int8, b_int8)
6397*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(c_int32.dtype is torch.int32)
6398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_int32.device, torch.device(device))
6399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6400*da0073e9SAndroid Build Coastguard Worker        c_int32_result = c_int32.new_empty(c_int32.size())
6401*da0073e9SAndroid Build Coastguard Worker        # Checking out variant
6402*da0073e9SAndroid Build Coastguard Worker        torch._int_mm(a_int8, b_int8, out=c_int32_result)
6403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6404*da0073e9SAndroid Build Coastguard Worker
6405*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6406*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6407*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6408*da0073e9SAndroid Build Coastguard Worker    def test__convert_weight_to_int4pack(self, device):
6409*da0073e9SAndroid Build Coastguard Worker        # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
6410*da0073e9SAndroid Build Coastguard Worker        test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
6411*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and not SM80OrLater:
6412*da0073e9SAndroid Build Coastguard Worker            self.skipTest("requires SM80 or later")
6413*da0073e9SAndroid Build Coastguard Worker
6414*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6415*da0073e9SAndroid Build Coastguard Worker            if not CDNA2OrLater():
6416*da0073e9SAndroid Build Coastguard Worker                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6417*da0073e9SAndroid Build Coastguard Worker
6418*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6419*da0073e9SAndroid Build Coastguard Worker        for shape, innerKTiles in test_list:
6420*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(shape, dtype=torch.bfloat16, device=device)
6421*da0073e9SAndroid Build Coastguard Worker            b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
6422*da0073e9SAndroid Build Coastguard Worker            b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
6423*da0073e9SAndroid Build Coastguard Worker            b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
6424*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
6425*da0073e9SAndroid Build Coastguard Worker
6426*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6427*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6428*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6429*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [32, 64])
6430*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 64])
6431*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [48, 64])
6432*da0073e9SAndroid Build Coastguard Worker    def test__int4_mm(self, device, m, k, n):
6433*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and not SM80OrLater:
6434*da0073e9SAndroid Build Coastguard Worker            self.skipTest("requires SM80 or later")
6435*da0073e9SAndroid Build Coastguard Worker
6436*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6437*da0073e9SAndroid Build Coastguard Worker            if not CDNA2OrLater():
6438*da0073e9SAndroid Build Coastguard Worker                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6439*da0073e9SAndroid Build Coastguard Worker
6440*da0073e9SAndroid Build Coastguard Worker        q_group = 32
6441*da0073e9SAndroid Build Coastguard Worker        inner_k_tiles = 2
6442*da0073e9SAndroid Build Coastguard Worker
6443*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6444*da0073e9SAndroid Build Coastguard Worker        a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6445*da0073e9SAndroid Build Coastguard Worker        b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6446*da0073e9SAndroid Build Coastguard Worker
6447*da0073e9SAndroid Build Coastguard Worker        def convert_weight_to_int4pack(b):
6448*da0073e9SAndroid Build Coastguard Worker            b_uint8, b_scales_and_zeros = _group_quantize_tensor(
6449*da0073e9SAndroid Build Coastguard Worker                b, n_bit=4, q_group_size=q_group
6450*da0073e9SAndroid Build Coastguard Worker            )
6451*da0073e9SAndroid Build Coastguard Worker            b_int4pack = torch._convert_weight_to_int4pack(
6452*da0073e9SAndroid Build Coastguard Worker                b_uint8, inner_k_tiles
6453*da0073e9SAndroid Build Coastguard Worker            )
6454*da0073e9SAndroid Build Coastguard Worker
6455*da0073e9SAndroid Build Coastguard Worker            return b_int4pack, b_scales_and_zeros
6456*da0073e9SAndroid Build Coastguard Worker
6457*da0073e9SAndroid Build Coastguard Worker        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
6458*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int4pack_mm(
6459*da0073e9SAndroid Build Coastguard Worker                a, b_int4pack, q_group, b_scales_and_zeros
6460*da0073e9SAndroid Build Coastguard Worker            )
6461*da0073e9SAndroid Build Coastguard Worker
6462*da0073e9SAndroid Build Coastguard Worker        b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)
6463*da0073e9SAndroid Build Coastguard Worker
6464*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
6465*da0073e9SAndroid Build Coastguard Worker            a = a_bf16.to(dtype=dtype)
6466*da0073e9SAndroid Build Coastguard Worker            b = b_bf16.to(dtype=dtype)
6467*da0073e9SAndroid Build Coastguard Worker            b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
6468*da0073e9SAndroid Build Coastguard Worker            ref = torch.mm(a, b)
6469*da0073e9SAndroid Build Coastguard Worker            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
6470*da0073e9SAndroid Build Coastguard Worker
6471*da0073e9SAndroid Build Coastguard Worker            mean_err = ((res - ref).abs() / ref).mean()
6472*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mean_err < 0.05)
6473*da0073e9SAndroid Build Coastguard Worker
6474*da0073e9SAndroid Build Coastguard Worker
6475*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6476*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6477*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6478*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [32, 64])
6479*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 64])
6480*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [48, 64])
6481*da0073e9SAndroid Build Coastguard Worker    def test_compile_int4_mm(self, device, m, k, n):
6482*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and not SM80OrLater:
6483*da0073e9SAndroid Build Coastguard Worker            self.skipTest("requires SM80 or later")
6484*da0073e9SAndroid Build Coastguard Worker
6485*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
6486*da0073e9SAndroid Build Coastguard Worker            if not CDNA2OrLater():
6487*da0073e9SAndroid Build Coastguard Worker                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6488*da0073e9SAndroid Build Coastguard Worker
6489*da0073e9SAndroid Build Coastguard Worker        q_group = 32
6490*da0073e9SAndroid Build Coastguard Worker        inner_k_tiles = 2
6491*da0073e9SAndroid Build Coastguard Worker
6492*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6493*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6494*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6495*da0073e9SAndroid Build Coastguard Worker
6496*da0073e9SAndroid Build Coastguard Worker        b_int32, b_scales_and_zeros = _group_quantize_tensor(
6497*da0073e9SAndroid Build Coastguard Worker            b, n_bit=4, q_group_size=q_group
6498*da0073e9SAndroid Build Coastguard Worker        )
6499*da0073e9SAndroid Build Coastguard Worker
6500*da0073e9SAndroid Build Coastguard Worker        @torch.compile
6501*da0073e9SAndroid Build Coastguard Worker        def int4_mm(a, b_int32, b_scales_and_zeros):
6502*da0073e9SAndroid Build Coastguard Worker            b_int4pack = torch._convert_weight_to_int4pack(
6503*da0073e9SAndroid Build Coastguard Worker                b_int32, inner_k_tiles
6504*da0073e9SAndroid Build Coastguard Worker            )
6505*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int4pack_mm(
6506*da0073e9SAndroid Build Coastguard Worker                a, b_int4pack, q_group, b_scales_and_zeros
6507*da0073e9SAndroid Build Coastguard Worker            )
6508*da0073e9SAndroid Build Coastguard Worker
6509*da0073e9SAndroid Build Coastguard Worker        res = int4_mm(a, b_int32, b_scales_and_zeros)
6510*da0073e9SAndroid Build Coastguard Worker        ref = torch.mm(a, b)
6511*da0073e9SAndroid Build Coastguard Worker
6512*da0073e9SAndroid Build Coastguard Worker        mean_err = ((res - ref).abs() / ref).mean()
6513*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mean_err < 0.05)
6514*da0073e9SAndroid Build Coastguard Worker
6515*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
6516*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [32, 64])
6517*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 64])
6518*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [48, 64])
6519*da0073e9SAndroid Build Coastguard Worker    def test__int8_mm(self, device, m, k, n):
6520*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6521*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6522*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6523*da0073e9SAndroid Build Coastguard Worker
6524*da0073e9SAndroid Build Coastguard Worker        def convert_weight_to_int8pack(b):
6525*da0073e9SAndroid Build Coastguard Worker            b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6526*da0073e9SAndroid Build Coastguard Worker                b, -128, 127, torch.int8
6527*da0073e9SAndroid Build Coastguard Worker            )
6528*da0073e9SAndroid Build Coastguard Worker            return b_int8pack, b_scales
6529*da0073e9SAndroid Build Coastguard Worker
6530*da0073e9SAndroid Build Coastguard Worker        def weight_int8pack_mm(a, b_int8pack, b_scales):
6531*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int8pack_mm(
6532*da0073e9SAndroid Build Coastguard Worker                a, b_int8pack, b_scales
6533*da0073e9SAndroid Build Coastguard Worker            )
6534*da0073e9SAndroid Build Coastguard Worker
6535*da0073e9SAndroid Build Coastguard Worker        b_int8pack, b_scales = convert_weight_to_int8pack(b)
6536*da0073e9SAndroid Build Coastguard Worker        res = weight_int8pack_mm(a, b_int8pack, b_scales)
6537*da0073e9SAndroid Build Coastguard Worker        ref = torch.mm(a, b.transpose(0, 1))
6538*da0073e9SAndroid Build Coastguard Worker
6539*da0073e9SAndroid Build Coastguard Worker        mean_err = ((res - ref).abs() / ref).mean()
6540*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mean_err < 0.05)
6541*da0073e9SAndroid Build Coastguard Worker
6542*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
6543*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [32, 64])
6544*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 64])
6545*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [48, 64])
6546*da0073e9SAndroid Build Coastguard Worker    def test_compile_int8_mm(self, device, m, k, n):
6547*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6548*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6549*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6550*da0073e9SAndroid Build Coastguard Worker
6551*da0073e9SAndroid Build Coastguard Worker        b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6552*da0073e9SAndroid Build Coastguard Worker            b, -128, 127, torch.int8
6553*da0073e9SAndroid Build Coastguard Worker        )
6554*da0073e9SAndroid Build Coastguard Worker
6555*da0073e9SAndroid Build Coastguard Worker        @torch.compile
6556*da0073e9SAndroid Build Coastguard Worker        def int8_mm(a, b_int8pack, b_scales):
6557*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int8pack_mm(
6558*da0073e9SAndroid Build Coastguard Worker                a, b_int8pack, b_scales
6559*da0073e9SAndroid Build Coastguard Worker            )
6560*da0073e9SAndroid Build Coastguard Worker
6561*da0073e9SAndroid Build Coastguard Worker        res = int8_mm(a, b_int8pack, b_scales)
6562*da0073e9SAndroid Build Coastguard Worker        ref = torch.mm(a, b.transpose(0, 1))
6563*da0073e9SAndroid Build Coastguard Worker
6564*da0073e9SAndroid Build Coastguard Worker        mean_err = ((res - ref).abs() / ref).mean()
6565*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mean_err < 0.05)
6566*da0073e9SAndroid Build Coastguard Worker
6567*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
6568*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [32, 35, 36, 40, 64])
6569*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 35, 36, 40, 64])
6570*da0073e9SAndroid Build Coastguard Worker    # NOTE: This is intended to cover fp16_gemv_trans in
6571*da0073e9SAndroid Build Coastguard Worker    # BlasKernel.cpp. Currently, bounds being divisible by 32, 8-but-not-32, and 4-but-not-8
6572*da0073e9SAndroid Build Coastguard Worker    # all matter.
6573*da0073e9SAndroid Build Coastguard Worker    def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
6574*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
6575*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((m, k), dtype=torch.half, device=device)
6576*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((1, k), dtype=torch.half, device=device)
6577*da0073e9SAndroid Build Coastguard Worker
6578*da0073e9SAndroid Build Coastguard Worker        prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction()
6579*da0073e9SAndroid Build Coastguard Worker        try:
6580*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False)
6581*da0073e9SAndroid Build Coastguard Worker            ref = torch.mm(a, b.t())
6582*da0073e9SAndroid Build Coastguard Worker            try:
6583*da0073e9SAndroid Build Coastguard Worker                torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True)
6584*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
6585*da0073e9SAndroid Build Coastguard Worker                raise unittest.SkipTest from e
6586*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(a, b.t())
6587*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
6588*da0073e9SAndroid Build Coastguard Worker        finally:
6589*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev)
6590*da0073e9SAndroid Build Coastguard Worker
6591*da0073e9SAndroid Build Coastguard Worker    @slowTest
6592*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6593*da0073e9SAndroid Build Coastguard Worker    # bfloat16 doesn't have sufficient precision to pass this test
6594*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
6595*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble)
6596*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.01)
6597*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.01)
6598*da0073e9SAndroid Build Coastguard Worker    def test_mm(self, device, dtype):
6599*da0073e9SAndroid Build Coastguard Worker        def _test_mm(n, m, p, dtype, genf):
6600*da0073e9SAndroid Build Coastguard Worker            # helper function
6601*da0073e9SAndroid Build Coastguard Worker            def matrixmultiply(mat1, mat2):
6602*da0073e9SAndroid Build Coastguard Worker                n = mat1.size(0)
6603*da0073e9SAndroid Build Coastguard Worker                m = mat1.size(1)
6604*da0073e9SAndroid Build Coastguard Worker                p = mat2.size(1)
6605*da0073e9SAndroid Build Coastguard Worker                dtype_ = torch.float if dtype == torch.half else dtype
6606*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
6607*da0073e9SAndroid Build Coastguard Worker                    mat1 = mat1.float()
6608*da0073e9SAndroid Build Coastguard Worker                    mat2 = mat2.float()
6609*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros(n, p, dtype=dtype_, device=device)
6610*da0073e9SAndroid Build Coastguard Worker                for i, j in iter_indices(res):
6611*da0073e9SAndroid Build Coastguard Worker                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
6612*da0073e9SAndroid Build Coastguard Worker                return res.half() if dtype == torch.half else res
6613*da0073e9SAndroid Build Coastguard Worker
6614*da0073e9SAndroid Build Coastguard Worker            # contiguous case
6615*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
6616*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
6617*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
6618*da0073e9SAndroid Build Coastguard Worker
6619*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6621*da0073e9SAndroid Build Coastguard Worker
6622*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 1
6623*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
6624*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
6625*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
6626*da0073e9SAndroid Build Coastguard Worker
6627*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6628*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6629*da0073e9SAndroid Build Coastguard Worker
6630*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 2
6631*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
6632*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
6633*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
6634*da0073e9SAndroid Build Coastguard Worker
6635*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6636*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6637*da0073e9SAndroid Build Coastguard Worker
6638*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 3
6639*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
6640*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
6641*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
6642*da0073e9SAndroid Build Coastguard Worker
6643*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6644*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6645*da0073e9SAndroid Build Coastguard Worker
6646*da0073e9SAndroid Build Coastguard Worker            # test with zero stride
6647*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
6648*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, 1).expand(m, p)
6649*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
6650*da0073e9SAndroid Build Coastguard Worker
6651*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6653*da0073e9SAndroid Build Coastguard Worker
6654*da0073e9SAndroid Build Coastguard Worker            # explicitly exercise the _out variant in torch.mm().
6655*da0073e9SAndroid Build Coastguard Worker            # contiguous case
6656*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
6657*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
6658*da0073e9SAndroid Build Coastguard Worker            res = genf(n, p)
6659*da0073e9SAndroid Build Coastguard Worker            torch.mm(mat1, mat2, out=res)
6660*da0073e9SAndroid Build Coastguard Worker
6661*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6663*da0073e9SAndroid Build Coastguard Worker
6664*da0073e9SAndroid Build Coastguard Worker            # explicitly exercise the _out variant in torch.mm().
6665*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 3
6666*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
6667*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
6668*da0073e9SAndroid Build Coastguard Worker            res = genf(n, p)
6669*da0073e9SAndroid Build Coastguard Worker            torch.mm(mat1, mat2, out=res)
6670*da0073e9SAndroid Build Coastguard Worker
6671*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
6672*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
6673*da0073e9SAndroid Build Coastguard Worker
6674*da0073e9SAndroid Build Coastguard Worker        def genf_int(x, y):
6675*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
6676*da0073e9SAndroid Build Coastguard Worker
6677*da0073e9SAndroid Build Coastguard Worker        def genf_bfloat(x, y):
6678*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
6679*da0073e9SAndroid Build Coastguard Worker
6680*da0073e9SAndroid Build Coastguard Worker        def genf_float(x, y):
6681*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=dtype, device=device)
6682*da0073e9SAndroid Build Coastguard Worker
6683*da0073e9SAndroid Build Coastguard Worker        def genf_Half(x, y):
6684*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=dtype, device=device)
6685*da0073e9SAndroid Build Coastguard Worker
6686*da0073e9SAndroid Build Coastguard Worker        for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
6687*da0073e9SAndroid Build Coastguard Worker            if (dtype == torch.int32) or (dtype == torch.int64):
6688*da0073e9SAndroid Build Coastguard Worker                genf = genf_int
6689*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bfloat16):
6690*da0073e9SAndroid Build Coastguard Worker                genf = genf_bfloat
6691*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.half):
6692*da0073e9SAndroid Build Coastguard Worker                genf = genf_Half
6693*da0073e9SAndroid Build Coastguard Worker            else:
6694*da0073e9SAndroid Build Coastguard Worker                genf = genf_float
6695*da0073e9SAndroid Build Coastguard Worker
6696*da0073e9SAndroid Build Coastguard Worker            _test_mm(n, m, p, dtype, genf)
6697*da0073e9SAndroid Build Coastguard Worker
6698*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6699*da0073e9SAndroid Build Coastguard Worker    def test_mm_bmm_non_memory_dense(self, device):
6700*da0073e9SAndroid Build Coastguard Worker        def _slice(tensor, fn):
6701*da0073e9SAndroid Build Coastguard Worker            return fn(tensor)[..., ::2]
6702*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(3, 6, dtype=torch.cfloat, device=device)
6703*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6704*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6705*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6706*da0073e9SAndroid Build Coastguard Worker        A_conj = _slice(A, torch.conj)
6707*da0073e9SAndroid Build Coastguard Worker        A_conj_physical = _slice(A, torch.conj_physical)
6708*da0073e9SAndroid Build Coastguard Worker
6709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out))
6710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out))
6711*da0073e9SAndroid Build Coastguard Worker
6712*da0073e9SAndroid Build Coastguard Worker        Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
6713*da0073e9SAndroid Build Coastguard Worker        Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
6714*da0073e9SAndroid Build Coastguard Worker        Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
6715*da0073e9SAndroid Build Coastguard Worker        out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT
6716*da0073e9SAndroid Build Coastguard Worker
6717*da0073e9SAndroid Build Coastguard Worker        Ab_conj = _slice(Ab, torch.conj)
6718*da0073e9SAndroid Build Coastguard Worker        Ab_conj_physical = _slice(Ab, torch.conj_physical)
6719*da0073e9SAndroid Build Coastguard Worker
6720*da0073e9SAndroid Build Coastguard Worker        def t_b(tensor):
6721*da0073e9SAndroid Build Coastguard Worker            return tensor.mT
6722*da0073e9SAndroid Build Coastguard Worker
6723*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b))
6724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b))
6725*da0073e9SAndroid Build Coastguard Worker
6726*da0073e9SAndroid Build Coastguard Worker        # test broadcasting
6727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
6728*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))
6729*da0073e9SAndroid Build Coastguard Worker
6730*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6731*da0073e9SAndroid Build Coastguard Worker    def test_mm_conjtranspose(self, device):
6732*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6733*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6734*da0073e9SAndroid Build Coastguard Worker
6735*da0073e9SAndroid Build Coastguard Worker        # A conjtranspose
6736*da0073e9SAndroid Build Coastguard Worker        out1 = torch.mm(A.t().conj(), B)
6737*da0073e9SAndroid Build Coastguard Worker        out1_ref = torch.mm(A.t().conj_physical(), B)
6738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out1_ref)
6739*da0073e9SAndroid Build Coastguard Worker
6740*da0073e9SAndroid Build Coastguard Worker        # B conjtranspose
6741*da0073e9SAndroid Build Coastguard Worker        out1 = torch.mm(A, B.t().conj())
6742*da0073e9SAndroid Build Coastguard Worker        out1_ref = torch.mm(A, B.t().conj_physical())
6743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out1_ref)
6744*da0073e9SAndroid Build Coastguard Worker
6745*da0073e9SAndroid Build Coastguard Worker        # A&B conjtranspose
6746*da0073e9SAndroid Build Coastguard Worker        out1 = torch.mm(A.t().conj(), B.t().conj())
6747*da0073e9SAndroid Build Coastguard Worker        out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical())
6748*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out1_ref)
6749*da0073e9SAndroid Build Coastguard Worker
6750*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6751*da0073e9SAndroid Build Coastguard Worker    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
6752*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
6753*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 20, dtype=torch.float32, device=device)
6754*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"):
6755*da0073e9SAndroid Build Coastguard Worker            torch.mm(a, b)
6756*da0073e9SAndroid Build Coastguard Worker
6757*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6758*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
6759*da0073e9SAndroid Build Coastguard Worker    def test_strided_mm_bmm(self, device, dtype):
6760*da0073e9SAndroid Build Coastguard Worker        # Tests strided view case with stride smaller than corresponding dimension size
6761*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device)
6762*da0073e9SAndroid Build Coastguard Worker        new_shape = [2, 2, 2]
6763*da0073e9SAndroid Build Coastguard Worker        new_stride = [3, 1, 1]
6764*da0073e9SAndroid Build Coastguard Worker        sx = torch.as_strided(x, size=new_shape, stride=new_stride)
6765*da0073e9SAndroid Build Coastguard Worker
6766*da0073e9SAndroid Build Coastguard Worker        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
6767*da0073e9SAndroid Build Coastguard Worker        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
6768*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch_fn, np_fn, sx)
6769*da0073e9SAndroid Build Coastguard Worker
6770*da0073e9SAndroid Build Coastguard Worker        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
6771*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch_fn, np_fn, sx[0])
6772*da0073e9SAndroid Build Coastguard Worker
6773*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6774*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6775*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6776*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6777*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6778*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self, device, dtype):
6779*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6780*da0073e9SAndroid Build Coastguard Worker            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6781*da0073e9SAndroid Build Coastguard Worker            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6782*da0073e9SAndroid Build Coastguard Worker            # undefined bahavior
6783*da0073e9SAndroid Build Coastguard Worker            return
6784*da0073e9SAndroid Build Coastguard Worker
6785*da0073e9SAndroid Build Coastguard Worker        batch_sizes = [1, 10]
6786*da0073e9SAndroid Build Coastguard Worker        M, N, O = 23, 15, 12
6787*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6788*da0073e9SAndroid Build Coastguard Worker
6789*da0073e9SAndroid Build Coastguard Worker        is_supported = True
6790*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6791*da0073e9SAndroid Build Coastguard Worker            is_supported = TEST_WITH_ROCM or SM53OrLater
6792*da0073e9SAndroid Build Coastguard Worker
6793*da0073e9SAndroid Build Coastguard Worker        if not is_supported:
6794*da0073e9SAndroid Build Coastguard Worker            for num_batches in batch_sizes:
6795*da0073e9SAndroid Build Coastguard Worker                b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
6796*da0073e9SAndroid Build Coastguard Worker                b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
6797*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6798*da0073e9SAndroid Build Coastguard Worker                                       lambda: torch.bmm(b1, b2))
6799*da0073e9SAndroid Build Coastguard Worker            return
6800*da0073e9SAndroid Build Coastguard Worker
6801*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
6802*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
6803*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
6804*da0073e9SAndroid Build Coastguard Worker
6805*da0073e9SAndroid Build Coastguard Worker        def generate_inputs(num_batches):
6806*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
6807*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6808*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1)
6809*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1)
6810*da0073e9SAndroid Build Coastguard Worker                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6811*da0073e9SAndroid Build Coastguard Worker                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6812*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
6813*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
6814*da0073e9SAndroid Build Coastguard Worker            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
6815*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
6816*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
6817*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N)
6818*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O)
6819*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
6820*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
6821*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6822*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6823*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6824*da0073e9SAndroid Build Coastguard Worker                b1 = torch.randn(shape1, dtype=dtype, device=device)
6825*da0073e9SAndroid Build Coastguard Worker                b2 = torch.randn(shape2, dtype=dtype, device=device)
6826*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
6827*da0073e9SAndroid Build Coastguard Worker
6828*da0073e9SAndroid Build Coastguard Worker        for num_batches in batch_sizes:
6829*da0073e9SAndroid Build Coastguard Worker            for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
6830*da0073e9SAndroid Build Coastguard Worker                res1 = torch.bmm(b1, b2)
6831*da0073e9SAndroid Build Coastguard Worker                res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
6832*da0073e9SAndroid Build Coastguard Worker                    .permute(perm3).contiguous().permute(invert_perm(perm3))
6833*da0073e9SAndroid Build Coastguard Worker                torch.bmm(b1, b2, out=res2)
6834*da0073e9SAndroid Build Coastguard Worker                expect = torch.from_numpy(
6835*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6836*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expect, res1)
6837*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expect, res2)
6838*da0073e9SAndroid Build Coastguard Worker
6839*da0073e9SAndroid Build Coastguard Worker                if self.device_type == 'cuda':
6840*da0073e9SAndroid Build Coastguard Worker                    # check that mixed arguments are rejected
6841*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
6842*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
6843*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()))
6844*da0073e9SAndroid Build Coastguard Worker
6845*da0073e9SAndroid Build Coastguard Worker    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
6846*da0073e9SAndroid Build Coastguard Worker        getattr(out_tensor, func + "_")(b1, b2)
6847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref)
6848*da0073e9SAndroid Build Coastguard Worker        res3 = out_tensor.clone()
6849*da0073e9SAndroid Build Coastguard Worker
6850*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
6851*da0073e9SAndroid Build Coastguard Worker                UserWarning, f"This overload of {func}_ is deprecated"):
6852*da0073e9SAndroid Build Coastguard Worker            getattr(out_tensor, func + "_")(1, b1, b2)
6853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref * 2),
6854*da0073e9SAndroid Build Coastguard Worker        getattr(res3, func + "_")(b1, b2, beta=1)
6855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, res3)
6856*da0073e9SAndroid Build Coastguard Worker
6857*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
6858*da0073e9SAndroid Build Coastguard Worker                UserWarning, f"This overload of {func}_ is deprecated"):
6859*da0073e9SAndroid Build Coastguard Worker            getattr(out_tensor, func + "_")(1., .5, b1, b2)
6860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref * 2.5)
6861*da0073e9SAndroid Build Coastguard Worker        getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5)
6862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, res3)
6863*da0073e9SAndroid Build Coastguard Worker
6864*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
6865*da0073e9SAndroid Build Coastguard Worker                UserWarning, f"This overload of {func} is deprecated"):
6866*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
6867*da0073e9SAndroid Build Coastguard Worker
6868*da0073e9SAndroid Build Coastguard Worker        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
6869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res4, ref * 3),
6870*da0073e9SAndroid Build Coastguard Worker
6871*da0073e9SAndroid Build Coastguard Worker        nan = torch.full_like(out_tensor, math.nan)
6872*da0073e9SAndroid Build Coastguard Worker        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
6873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res5, ref)
6874*da0073e9SAndroid Build Coastguard Worker
6875*da0073e9SAndroid Build Coastguard Worker        if b1.is_complex():
6876*da0073e9SAndroid Build Coastguard Worker            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j)
6877*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res6, out_tensor * .1j + .5j * ref)
6878*da0073e9SAndroid Build Coastguard Worker        else:
6879*da0073e9SAndroid Build Coastguard Worker            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5)
6880*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res6, out_tensor * .1 + .5 * ref)
6881*da0073e9SAndroid Build Coastguard Worker
6882*da0073e9SAndroid Build Coastguard Worker        res7 = torch.full_like(out_tensor, math.nan)
6883*da0073e9SAndroid Build Coastguard Worker        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
6884*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res7, ref)
6885*da0073e9SAndroid Build Coastguard Worker
6886*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6887*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6888*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6889*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6890*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6891*da0073e9SAndroid Build Coastguard Worker    def test_addbmm(self, device, dtype):
6892*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6893*da0073e9SAndroid Build Coastguard Worker            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6894*da0073e9SAndroid Build Coastguard Worker            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6895*da0073e9SAndroid Build Coastguard Worker            # undefined bahavior
6896*da0073e9SAndroid Build Coastguard Worker            return
6897*da0073e9SAndroid Build Coastguard Worker
6898*da0073e9SAndroid Build Coastguard Worker        num_batches = 2
6899*da0073e9SAndroid Build Coastguard Worker        M, N, O = 16, 17, 18
6900*da0073e9SAndroid Build Coastguard Worker
6901*da0073e9SAndroid Build Coastguard Worker        is_supported = True
6902*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16:
6903*da0073e9SAndroid Build Coastguard Worker            if self.device_type == 'cpu':
6904*da0073e9SAndroid Build Coastguard Worker                self.precision = 1  # 43 vs 43.75
6905*da0073e9SAndroid Build Coastguard Worker            else:
6906*da0073e9SAndroid Build Coastguard Worker                is_supported = TEST_WITH_ROCM or SM53OrLater
6907*da0073e9SAndroid Build Coastguard Worker
6908*da0073e9SAndroid Build Coastguard Worker        if not is_supported:
6909*da0073e9SAndroid Build Coastguard Worker            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6910*da0073e9SAndroid Build Coastguard Worker            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6911*da0073e9SAndroid Build Coastguard Worker            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
6912*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6913*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch.addbmm(t, b1, b2))
6914*da0073e9SAndroid Build Coastguard Worker            return
6915*da0073e9SAndroid Build Coastguard Worker
6916*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
6917*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
6918*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
6919*da0073e9SAndroid Build Coastguard Worker
6920*da0073e9SAndroid Build Coastguard Worker        def generate_tensor():
6921*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6922*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
6923*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6924*da0073e9SAndroid Build Coastguard Worker                for perm3 in itertools.permutations((0, 1)):
6925*da0073e9SAndroid Build Coastguard Worker                    b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1
6926*da0073e9SAndroid Build Coastguard Worker                    b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1
6927*da0073e9SAndroid Build Coastguard Worker                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6928*da0073e9SAndroid Build Coastguard Worker                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6929*da0073e9SAndroid Build Coastguard Worker                    ref = torch.from_numpy(
6930*da0073e9SAndroid Build Coastguard Worker                        b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6931*da0073e9SAndroid Build Coastguard Worker                    ).to(device=device, dtype=dtype).sum(0)
6932*da0073e9SAndroid Build Coastguard Worker                    out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
6933*da0073e9SAndroid Build Coastguard Worker                    yield b1, b2, ref, out_tensor
6934*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
6935*da0073e9SAndroid Build Coastguard Worker            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6936*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6937*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6938*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1
6939*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1
6940*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
6941*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6942*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype).sum(0)
6943*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
6944*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
6945*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
6946*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6947*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6948*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6949*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1
6950*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1
6951*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
6952*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6953*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype).sum(0)
6954*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
6955*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
6956*da0073e9SAndroid Build Coastguard Worker
6957*da0073e9SAndroid Build Coastguard Worker        for b1, b2, ref, out_tensor in generate_tensor():
6958*da0073e9SAndroid Build Coastguard Worker            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
6959*da0073e9SAndroid Build Coastguard Worker
6960*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
6961*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6962*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6963*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.05)
6964*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.05)
6965*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm(self, device, dtype):
6966*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6967*da0073e9SAndroid Build Coastguard Worker            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6968*da0073e9SAndroid Build Coastguard Worker            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6969*da0073e9SAndroid Build Coastguard Worker            # undefined bahavior
6970*da0073e9SAndroid Build Coastguard Worker            return
6971*da0073e9SAndroid Build Coastguard Worker
6972*da0073e9SAndroid Build Coastguard Worker        num_batches = 10
6973*da0073e9SAndroid Build Coastguard Worker        M, N, O = 12, 8, 50
6974*da0073e9SAndroid Build Coastguard Worker
6975*da0073e9SAndroid Build Coastguard Worker        is_supported = True
6976*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6977*da0073e9SAndroid Build Coastguard Worker            is_supported = TEST_WITH_ROCM or SM53OrLater
6978*da0073e9SAndroid Build Coastguard Worker
6979*da0073e9SAndroid Build Coastguard Worker        if not is_supported:
6980*da0073e9SAndroid Build Coastguard Worker            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6981*da0073e9SAndroid Build Coastguard Worker            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6982*da0073e9SAndroid Build Coastguard Worker            t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1)
6983*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6984*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch.baddbmm(t, b1, b2))
6985*da0073e9SAndroid Build Coastguard Worker            return
6986*da0073e9SAndroid Build Coastguard Worker
6987*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
6988*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
6989*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
6990*da0073e9SAndroid Build Coastguard Worker
6991*da0073e9SAndroid Build Coastguard Worker        def generate_tensor():
6992*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
6993*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
6994*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3):
6995*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6996*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6997*da0073e9SAndroid Build Coastguard Worker                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6998*da0073e9SAndroid Build Coastguard Worker                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6999*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
7000*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7001*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
7002*da0073e9SAndroid Build Coastguard Worker                out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
7003*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
7004*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
7005*da0073e9SAndroid Build Coastguard Worker            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
7006*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
7007*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
7008*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
7009*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
7010*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
7011*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7012*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
7013*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
7014*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
7015*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
7016*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
7017*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
7018*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
7019*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
7020*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
7021*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
7022*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
7023*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
7024*da0073e9SAndroid Build Coastguard Worker
7025*da0073e9SAndroid Build Coastguard Worker        for b1, b2, ref, out_tensor in generate_tensor():
7026*da0073e9SAndroid Build Coastguard Worker            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
7027*da0073e9SAndroid Build Coastguard Worker
7028*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
7029*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7030*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7031*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
7032*da0073e9SAndroid Build Coastguard Worker    def test_pinverse(self, device, dtype):
7033*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7034*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_fullrank, device=device, dtype=dtype)
7035*da0073e9SAndroid Build Coastguard Worker
7036*da0073e9SAndroid Build Coastguard Worker        def run_test(M):
7037*da0073e9SAndroid Build Coastguard Worker            # Testing against definition for pseudo-inverses
7038*da0073e9SAndroid Build Coastguard Worker            MPI = torch.pinverse(M)
7039*da0073e9SAndroid Build Coastguard Worker            MPI_ = MPI.cpu().numpy()
7040*da0073e9SAndroid Build Coastguard Worker            M_ = M.cpu().numpy()
7041*da0073e9SAndroid Build Coastguard Worker            if M.numel() > 0:
7042*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_))
7043*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_))
7044*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj())
7045*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj())
7046*da0073e9SAndroid Build Coastguard Worker            else:
7047*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
7048*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5),  # square matrices
7049*da0073e9SAndroid Build Coastguard Worker                      (3, 2), (5, 3, 2), (7, 5, 3, 2),  # fat matrices
7050*da0073e9SAndroid Build Coastguard Worker                      (2, 3), (5, 2, 3), (7, 5, 2, 3),  # thin matrices
7051*da0073e9SAndroid Build Coastguard Worker                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
7052*da0073e9SAndroid Build Coastguard Worker            M = torch.randn(*sizes, dtype=dtype, device=device)
7053*da0073e9SAndroid Build Coastguard Worker            run_test(M)
7054*da0073e9SAndroid Build Coastguard Worker
7055*da0073e9SAndroid Build Coastguard Worker        # Test inverse and pseudo-inverse for invertible matrix
7056*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
7057*da0073e9SAndroid Build Coastguard Worker            matsize = sizes[-1]
7058*da0073e9SAndroid Build Coastguard Worker            batchdims = sizes[:-2]
7059*da0073e9SAndroid Build Coastguard Worker            M = make_arg(*batchdims, matsize, matsize)
7060*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
7061*da0073e9SAndroid Build Coastguard Worker                             atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
7062*da0073e9SAndroid Build Coastguard Worker
7063*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7064*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
7065*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
7066*da0073e9SAndroid Build Coastguard Worker    def test_matrix_power_non_negative(self, device, dtype):
7067*da0073e9SAndroid Build Coastguard Worker        def check(*size):
7068*da0073e9SAndroid Build Coastguard Worker            t = make_tensor(size, dtype=dtype, device=device)
7069*da0073e9SAndroid Build Coastguard Worker            for n in range(8):
7070*da0073e9SAndroid Build Coastguard Worker                res = torch.linalg.matrix_power(t, n)
7071*da0073e9SAndroid Build Coastguard Worker                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7072*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu(), torch.from_numpy(ref))
7073*da0073e9SAndroid Build Coastguard Worker
7074*da0073e9SAndroid Build Coastguard Worker        check(0, 0)
7075*da0073e9SAndroid Build Coastguard Worker        check(1, 1)
7076*da0073e9SAndroid Build Coastguard Worker        check(5, 5)
7077*da0073e9SAndroid Build Coastguard Worker        check(0, 3, 3)
7078*da0073e9SAndroid Build Coastguard Worker        check(2, 3, 3)
7079*da0073e9SAndroid Build Coastguard Worker
7080*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7081*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
7082*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
7083*da0073e9SAndroid Build Coastguard Worker    def test_matrix_power_negative(self, device, dtype):
7084*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7085*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_fullrank, device=device, dtype=dtype)
7086*da0073e9SAndroid Build Coastguard Worker
7087*da0073e9SAndroid Build Coastguard Worker        def check(*size):
7088*da0073e9SAndroid Build Coastguard Worker            t = make_arg(*size)
7089*da0073e9SAndroid Build Coastguard Worker            for n in range(-7, 0):
7090*da0073e9SAndroid Build Coastguard Worker                res = torch.linalg.matrix_power(t, n)
7091*da0073e9SAndroid Build Coastguard Worker                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7092*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu(), torch.from_numpy(ref))
7093*da0073e9SAndroid Build Coastguard Worker
7094*da0073e9SAndroid Build Coastguard Worker        check(0, 0)
7095*da0073e9SAndroid Build Coastguard Worker        check(5, 5)
7096*da0073e9SAndroid Build Coastguard Worker        check(2, 0, 0)
7097*da0073e9SAndroid Build Coastguard Worker        check(0, 3, 3)
7098*da0073e9SAndroid Build Coastguard Worker        check(2, 3, 3)
7099*da0073e9SAndroid Build Coastguard Worker        check(2, 3, 5, 5)
7100*da0073e9SAndroid Build Coastguard Worker
7101*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7102*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7103*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.complex64)
7104*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_utils(self, device, dtype):
7105*da0073e9SAndroid Build Coastguard Worker        # test linear combination
7106*da0073e9SAndroid Build Coastguard Worker        def run_test(coeff_shape, data_shape):
7107*da0073e9SAndroid Build Coastguard Worker            coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float)
7108*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype)
7109*da0073e9SAndroid Build Coastguard Worker
7110*da0073e9SAndroid Build Coastguard Worker            res1 = torch._compute_linear_combination(x, coeffs)
7111*da0073e9SAndroid Build Coastguard Worker            res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1)
7112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2, atol=1e-5, rtol=0.0)
7113*da0073e9SAndroid Build Coastguard Worker
7114*da0073e9SAndroid Build Coastguard Worker            # check `out=` version
7115*da0073e9SAndroid Build Coastguard Worker            res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7116*da0073e9SAndroid Build Coastguard Worker            torch._compute_linear_combination(x, coeffs, out=res3)
7117*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res3, atol=1e-5, rtol=0.0)
7118*da0073e9SAndroid Build Coastguard Worker
7119*da0073e9SAndroid Build Coastguard Worker            res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7120*da0073e9SAndroid Build Coastguard Worker            torch._compute_linear_combination(x, coeffs, out=res4)
7121*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0)
7122*da0073e9SAndroid Build Coastguard Worker
7123*da0073e9SAndroid Build Coastguard Worker            res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7124*da0073e9SAndroid Build Coastguard Worker            res5_clone = res5.clone()
7125*da0073e9SAndroid Build Coastguard Worker            torch._compute_linear_combination(x, coeffs, out=res5)
7126*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0)
7127*da0073e9SAndroid Build Coastguard Worker
7128*da0073e9SAndroid Build Coastguard Worker        run_test([1, 3], [2, 2])
7129*da0073e9SAndroid Build Coastguard Worker        run_test([3, 1], [2, 2])
7130*da0073e9SAndroid Build Coastguard Worker        run_test([1, 10], [10, 10])
7131*da0073e9SAndroid Build Coastguard Worker        run_test([10, 1], [10, 10])
7132*da0073e9SAndroid Build Coastguard Worker        run_test([5, 3], [2, 2])
7133*da0073e9SAndroid Build Coastguard Worker        run_test([5, 3], [100, 100])
7134*da0073e9SAndroid Build Coastguard Worker        run_test([3, 4], [3, 3, 3])
7135*da0073e9SAndroid Build Coastguard Worker        run_test([3, 4], [3, 3, 3, 3])
7136*da0073e9SAndroid Build Coastguard Worker
7137*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/94124
7138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
7139*da0073e9SAndroid Build Coastguard Worker            x = torch.rand([], device=device, dtype=dtype)
7140*da0073e9SAndroid Build Coastguard Worker            coeffs = torch.rand([2, 2], device=device, dtype=dtype)
7141*da0073e9SAndroid Build Coastguard Worker            res = torch._compute_linear_combination(x, coeffs)
7142*da0073e9SAndroid Build Coastguard Worker
7143*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
7144*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7145*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64)
7146*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_no_warnings(self, device, dtype):
7147*da0073e9SAndroid Build Coastguard Worker        # this tests https://github.com/pytorch/pytorch/issues/80948
7148*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
7149*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(42)
7150*da0073e9SAndroid Build Coastguard Worker            tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device)
7151*da0073e9SAndroid Build Coastguard Worker            tens = (0.5 * (tens.transpose(-1, -2) + tens))
7152*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
7153*da0073e9SAndroid Build Coastguard Worker                tens.imag = torch.matrix_exp(tens.imag)
7154*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(len(w))
7155*da0073e9SAndroid Build Coastguard Worker
7156*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7157*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7158*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7159*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_boundary_cases(self, device, dtype):
7160*da0073e9SAndroid Build Coastguard Worker        expm = torch.linalg.matrix_exp
7161*da0073e9SAndroid Build Coastguard Worker
7162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"):
7163*da0073e9SAndroid Build Coastguard Worker            expm(torch.randn(3, 3).type(torch.int))
7164*da0073e9SAndroid Build Coastguard Worker
7165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7166*da0073e9SAndroid Build Coastguard Worker            expm(torch.randn(3))
7167*da0073e9SAndroid Build Coastguard Worker
7168*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7169*da0073e9SAndroid Build Coastguard Worker            expm(torch.randn(3, 2, 1))
7170*da0073e9SAndroid Build Coastguard Worker
7171*da0073e9SAndroid Build Coastguard Worker        # check 1x1 matrices
7172*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, 1, 1)
7173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expm(x), x.exp())
7174*da0073e9SAndroid Build Coastguard Worker
7175*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7176*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7177*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7178*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype):
7179*da0073e9SAndroid Build Coastguard Worker        expm = torch.linalg.matrix_exp
7180*da0073e9SAndroid Build Coastguard Worker
7181*da0073e9SAndroid Build Coastguard Worker        def with_nan(x):
7182*da0073e9SAndroid Build Coastguard Worker            x[0, 0, 0] = torch.nan
7183*da0073e9SAndroid Build Coastguard Worker            return x
7184*da0073e9SAndroid Build Coastguard Worker
7185*da0073e9SAndroid Build Coastguard Worker        # Check small batches
7186*da0073e9SAndroid Build Coastguard Worker        x = with_nan(torch.randn(1, 1, 1))
7187*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isnan(expm(x)).any())
7188*da0073e9SAndroid Build Coastguard Worker        x = with_nan(torch.randn(1, 2, 2))
7189*da0073e9SAndroid Build Coastguard Worker        for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]:
7190*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.isnan(expm(x / v)).any())
7191*da0073e9SAndroid Build Coastguard Worker
7192*da0073e9SAndroid Build Coastguard Worker        # Check large batches
7193*da0073e9SAndroid Build Coastguard Worker        x = with_nan(torch.randn(2, 2, 2))
7194*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isnan(expm(x)).any())
7195*da0073e9SAndroid Build Coastguard Worker        x = with_nan(torch.randn(4096, 2, 2))
7196*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isnan(expm(x)).any())
7197*da0073e9SAndroid Build Coastguard Worker
7198*da0073e9SAndroid Build Coastguard Worker    @slowTest
7199*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7200*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7201*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7202*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_analytic(self, device, dtype):
7203*da0073e9SAndroid Build Coastguard Worker        expm = torch.linalg.matrix_exp
7204*da0073e9SAndroid Build Coastguard Worker        # check zero matrix
7205*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(20, 20, dtype=dtype, device=device)
7206*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item())
7207*da0073e9SAndroid Build Coastguard Worker
7208*da0073e9SAndroid Build Coastguard Worker        def normalize_to_1_operator_norm(sample, desired_norm):
7209*da0073e9SAndroid Build Coastguard Worker            sample_norm, _ = sample.abs().sum(-2).max(-1)
7210*da0073e9SAndroid Build Coastguard Worker            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7211*da0073e9SAndroid Build Coastguard Worker            return sample_to_1_norm * desired_norm
7212*da0073e9SAndroid Build Coastguard Worker
7213*da0073e9SAndroid Build Coastguard Worker        def gen_good_cond_number_matrices(*n):
7214*da0073e9SAndroid Build Coastguard Worker            """
7215*da0073e9SAndroid Build Coastguard Worker            Generates a diagonally-domimant matrix
7216*da0073e9SAndroid Build Coastguard Worker            with the eigenvalues centered at 1
7217*da0073e9SAndroid Build Coastguard Worker            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7218*da0073e9SAndroid Build Coastguard Worker            """
7219*da0073e9SAndroid Build Coastguard Worker            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7220*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7221*da0073e9SAndroid Build Coastguard Worker            x = (x - x * identity) + identity
7222*da0073e9SAndroid Build Coastguard Worker            return x
7223*da0073e9SAndroid Build Coastguard Worker
7224*da0073e9SAndroid Build Coastguard Worker        def run_test(*n):
7225*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.float:
7226*da0073e9SAndroid Build Coastguard Worker                thetas = [
7227*da0073e9SAndroid Build Coastguard Worker                    1.192092800768788e-07,  # deg 1
7228*da0073e9SAndroid Build Coastguard Worker                    5.978858893805233e-04,  # deg 2
7229*da0073e9SAndroid Build Coastguard Worker                    5.116619363445086e-02,  # deg 4
7230*da0073e9SAndroid Build Coastguard Worker                    5.800524627688768e-01,  # deg 8
7231*da0073e9SAndroid Build Coastguard Worker                    1.461661507209034e+00,  # deg 12
7232*da0073e9SAndroid Build Coastguard Worker                    3.010066362817634e+00   # deg 18
7233*da0073e9SAndroid Build Coastguard Worker                ]
7234*da0073e9SAndroid Build Coastguard Worker            else:  # if torch.double
7235*da0073e9SAndroid Build Coastguard Worker                thetas = [
7236*da0073e9SAndroid Build Coastguard Worker                    2.220446049250313e-16,  # deg 1
7237*da0073e9SAndroid Build Coastguard Worker                    2.580956802971767e-08,  # deg 2
7238*da0073e9SAndroid Build Coastguard Worker                    3.397168839976962e-04,  # deg 4
7239*da0073e9SAndroid Build Coastguard Worker                    4.991228871115323e-02,  # deg 8
7240*da0073e9SAndroid Build Coastguard Worker                    2.996158913811580e-01,  # deg 12
7241*da0073e9SAndroid Build Coastguard Worker                    1.090863719290036e+00   # deg 18
7242*da0073e9SAndroid Build Coastguard Worker                ]
7243*da0073e9SAndroid Build Coastguard Worker
7244*da0073e9SAndroid Build Coastguard Worker            # generate input
7245*da0073e9SAndroid Build Coastguard Worker            q = gen_good_cond_number_matrices(*n)
7246*da0073e9SAndroid Build Coastguard Worker            q_ = q.cpu().numpy()
7247*da0073e9SAndroid Build Coastguard Worker            qinv = torch.inverse(q)
7248*da0073e9SAndroid Build Coastguard Worker            qinv_ = qinv.cpu().numpy()
7249*da0073e9SAndroid Build Coastguard Worker            d = torch.randn(n[:-1], dtype=dtype, device=device)
7250*da0073e9SAndroid Build Coastguard Worker            x = torch.from_numpy(
7251*da0073e9SAndroid Build Coastguard Worker                np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device)
7252*da0073e9SAndroid Build Coastguard Worker            x_norm, _ = x.abs().sum(-2).max(-1)
7253*da0073e9SAndroid Build Coastguard Worker
7254*da0073e9SAndroid Build Coastguard Worker            # test simple analytic whatever norm generated
7255*da0073e9SAndroid Build Coastguard Worker            mexp = expm(x)
7256*da0073e9SAndroid Build Coastguard Worker            mexp_analytic = np.matmul(
7257*da0073e9SAndroid Build Coastguard Worker                q_,
7258*da0073e9SAndroid Build Coastguard Worker                np.matmul(
7259*da0073e9SAndroid Build Coastguard Worker                    torch.diag_embed(d.exp()).cpu().numpy(),
7260*da0073e9SAndroid Build Coastguard Worker                    qinv_
7261*da0073e9SAndroid Build Coastguard Worker                )
7262*da0073e9SAndroid Build Coastguard Worker            )
7263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7264*da0073e9SAndroid Build Coastguard Worker
7265*da0073e9SAndroid Build Coastguard Worker            # generate norms to test different degree expansions
7266*da0073e9SAndroid Build Coastguard Worker            sample_norms = []
7267*da0073e9SAndroid Build Coastguard Worker            for i in range(len(thetas) - 1):
7268*da0073e9SAndroid Build Coastguard Worker                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7269*da0073e9SAndroid Build Coastguard Worker            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7270*da0073e9SAndroid Build Coastguard Worker
7271*da0073e9SAndroid Build Coastguard Worker            # matrices to equal norm
7272*da0073e9SAndroid Build Coastguard Worker            for sample_norm in sample_norms:
7273*da0073e9SAndroid Build Coastguard Worker                x_normalized = normalize_to_1_operator_norm(x, sample_norm)
7274*da0073e9SAndroid Build Coastguard Worker
7275*da0073e9SAndroid Build Coastguard Worker                mexp = expm(x_normalized)
7276*da0073e9SAndroid Build Coastguard Worker                mexp_analytic = np.matmul(
7277*da0073e9SAndroid Build Coastguard Worker                    q_,
7278*da0073e9SAndroid Build Coastguard Worker                    np.matmul(
7279*da0073e9SAndroid Build Coastguard Worker                        torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(),
7280*da0073e9SAndroid Build Coastguard Worker                        qinv_
7281*da0073e9SAndroid Build Coastguard Worker                    )
7282*da0073e9SAndroid Build Coastguard Worker                )
7283*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7284*da0073e9SAndroid Build Coastguard Worker
7285*da0073e9SAndroid Build Coastguard Worker        # single matrix
7286*da0073e9SAndroid Build Coastguard Worker        run_test(2, 2)
7287*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3)
7288*da0073e9SAndroid Build Coastguard Worker        run_test(4, 4)
7289*da0073e9SAndroid Build Coastguard Worker        run_test(5, 5)
7290*da0073e9SAndroid Build Coastguard Worker        run_test(100, 100)
7291*da0073e9SAndroid Build Coastguard Worker        run_test(200, 200)
7292*da0073e9SAndroid Build Coastguard Worker
7293*da0073e9SAndroid Build Coastguard Worker        # small batch of matrices
7294*da0073e9SAndroid Build Coastguard Worker        run_test(3, 2, 2)
7295*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3)
7296*da0073e9SAndroid Build Coastguard Worker        run_test(3, 4, 4)
7297*da0073e9SAndroid Build Coastguard Worker        run_test(3, 5, 5)
7298*da0073e9SAndroid Build Coastguard Worker        run_test(3, 100, 100)
7299*da0073e9SAndroid Build Coastguard Worker        run_test(3, 200, 200)
7300*da0073e9SAndroid Build Coastguard Worker
7301*da0073e9SAndroid Build Coastguard Worker        # large batch of matrices
7302*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 2, 2)
7303*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3, 3)
7304*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 4, 4)
7305*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 5, 5)
7306*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 100, 100)
7307*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 200, 200)
7308*da0073e9SAndroid Build Coastguard Worker
7309*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7310*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7311*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
7312*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_batch(self, device, dtype):
7313*da0073e9SAndroid Build Coastguard Worker
7314*da0073e9SAndroid Build Coastguard Worker        def run_test(*n):
7315*da0073e9SAndroid Build Coastguard Worker            tensors_batch = torch.zeros(n, dtype=dtype, device=device)
7316*da0073e9SAndroid Build Coastguard Worker            tensors_batch = tensors_batch.view(-1, n[-2], n[-1])
7317*da0073e9SAndroid Build Coastguard Worker
7318*da0073e9SAndroid Build Coastguard Worker            num_matrices = tensors_batch.size(0)
7319*da0073e9SAndroid Build Coastguard Worker            tensors_list = []
7320*da0073e9SAndroid Build Coastguard Worker            for i in range(num_matrices):
7321*da0073e9SAndroid Build Coastguard Worker                tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device))
7322*da0073e9SAndroid Build Coastguard Worker
7323*da0073e9SAndroid Build Coastguard Worker            for i in range(num_matrices):
7324*da0073e9SAndroid Build Coastguard Worker                tensors_batch[i, ...] = tensors_list[i]
7325*da0073e9SAndroid Build Coastguard Worker
7326*da0073e9SAndroid Build Coastguard Worker            tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list)
7327*da0073e9SAndroid Build Coastguard Worker            tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch)
7328*da0073e9SAndroid Build Coastguard Worker
7329*da0073e9SAndroid Build Coastguard Worker            for i, tensor_exp in enumerate(tensors_exp_map):
7330*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tensors_exp_batch[i, ...], tensor_exp)
7331*da0073e9SAndroid Build Coastguard Worker
7332*da0073e9SAndroid Build Coastguard Worker        # small batch of matrices
7333*da0073e9SAndroid Build Coastguard Worker        run_test(3, 2, 2)
7334*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3)
7335*da0073e9SAndroid Build Coastguard Worker        run_test(3, 4, 4)
7336*da0073e9SAndroid Build Coastguard Worker        run_test(3, 5, 5)
7337*da0073e9SAndroid Build Coastguard Worker
7338*da0073e9SAndroid Build Coastguard Worker        # large batch of matrices
7339*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 2, 2)
7340*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3, 3)
7341*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 4, 4)
7342*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 5, 5)
7343*da0073e9SAndroid Build Coastguard Worker
7344*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7345*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7346*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7347*da0073e9SAndroid Build Coastguard Worker    def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype):
7348*da0073e9SAndroid Build Coastguard Worker
7349*da0073e9SAndroid Build Coastguard Worker        def normalize_to_1_operator_norm(sample, desired_norm):
7350*da0073e9SAndroid Build Coastguard Worker            sample_norm, _ = sample.abs().sum(-2).max(-1)
7351*da0073e9SAndroid Build Coastguard Worker            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7352*da0073e9SAndroid Build Coastguard Worker            return sample_to_1_norm * desired_norm
7353*da0073e9SAndroid Build Coastguard Worker
7354*da0073e9SAndroid Build Coastguard Worker        def gen_good_cond_number_matrices(*n):
7355*da0073e9SAndroid Build Coastguard Worker            """
7356*da0073e9SAndroid Build Coastguard Worker            Generates a diagonally-domimant matrix
7357*da0073e9SAndroid Build Coastguard Worker            with the eigenvalues centered at 1
7358*da0073e9SAndroid Build Coastguard Worker            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7359*da0073e9SAndroid Build Coastguard Worker            """
7360*da0073e9SAndroid Build Coastguard Worker            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7361*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7362*da0073e9SAndroid Build Coastguard Worker            x = (x - x * identity) + identity
7363*da0073e9SAndroid Build Coastguard Worker            return x
7364*da0073e9SAndroid Build Coastguard Worker
7365*da0073e9SAndroid Build Coastguard Worker        def get_taylor_approximation(a, deg):
7366*da0073e9SAndroid Build Coastguard Worker            a_ = a.cpu().numpy()
7367*da0073e9SAndroid Build Coastguard Worker            identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a)
7368*da0073e9SAndroid Build Coastguard Worker            res = identity.cpu().numpy()
7369*da0073e9SAndroid Build Coastguard Worker            taylor_term = identity.cpu().numpy()
7370*da0073e9SAndroid Build Coastguard Worker
7371*da0073e9SAndroid Build Coastguard Worker            for i in range(1, deg + 1):
7372*da0073e9SAndroid Build Coastguard Worker                taylor_term = np.matmul(a_, taylor_term) / i
7373*da0073e9SAndroid Build Coastguard Worker                res = res + taylor_term
7374*da0073e9SAndroid Build Coastguard Worker
7375*da0073e9SAndroid Build Coastguard Worker            return res
7376*da0073e9SAndroid Build Coastguard Worker
7377*da0073e9SAndroid Build Coastguard Worker        def scale_square(a, deg):
7378*da0073e9SAndroid Build Coastguard Worker            if a.abs().pow(2).sum().sqrt() < 1.0:
7379*da0073e9SAndroid Build Coastguard Worker                return get_taylor_approximation(a, 12)
7380*da0073e9SAndroid Build Coastguard Worker            else:
7381*da0073e9SAndroid Build Coastguard Worker                s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item())
7382*da0073e9SAndroid Build Coastguard Worker                b = a / (2 ** s)
7383*da0073e9SAndroid Build Coastguard Worker                b = get_taylor_approximation(b, 18)
7384*da0073e9SAndroid Build Coastguard Worker                for _ in range(s):
7385*da0073e9SAndroid Build Coastguard Worker                    b = np.matmul(b, b)
7386*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(b).to(a.device)
7387*da0073e9SAndroid Build Coastguard Worker
7388*da0073e9SAndroid Build Coastguard Worker        def run_test(*n):
7389*da0073e9SAndroid Build Coastguard Worker            degs = [1, 2, 4, 8, 12, 18]
7390*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.float:
7391*da0073e9SAndroid Build Coastguard Worker                thetas = [
7392*da0073e9SAndroid Build Coastguard Worker                    1.192092800768788e-07,  # deg 1
7393*da0073e9SAndroid Build Coastguard Worker                    5.978858893805233e-04,  # deg 2
7394*da0073e9SAndroid Build Coastguard Worker                    5.116619363445086e-02,  # deg 4
7395*da0073e9SAndroid Build Coastguard Worker                    5.800524627688768e-01,  # deg 8
7396*da0073e9SAndroid Build Coastguard Worker                    1.461661507209034e+00,  # deg 12
7397*da0073e9SAndroid Build Coastguard Worker                    3.010066362817634e+00   # deg 18
7398*da0073e9SAndroid Build Coastguard Worker                ]
7399*da0073e9SAndroid Build Coastguard Worker            else:  # if torch.double
7400*da0073e9SAndroid Build Coastguard Worker                thetas = [
7401*da0073e9SAndroid Build Coastguard Worker                    2.220446049250313e-16,  # deg 1
7402*da0073e9SAndroid Build Coastguard Worker                    2.580956802971767e-08,  # deg 2
7403*da0073e9SAndroid Build Coastguard Worker                    3.397168839976962e-04,  # deg 4
7404*da0073e9SAndroid Build Coastguard Worker                    4.991228871115323e-02,  # deg 8
7405*da0073e9SAndroid Build Coastguard Worker                    2.996158913811580e-01,  # deg 12
7406*da0073e9SAndroid Build Coastguard Worker                    1.090863719290036e+00   # deg 18
7407*da0073e9SAndroid Build Coastguard Worker                ]
7408*da0073e9SAndroid Build Coastguard Worker
7409*da0073e9SAndroid Build Coastguard Worker            # generate norms to test different degree expansions
7410*da0073e9SAndroid Build Coastguard Worker            sample_norms = []
7411*da0073e9SAndroid Build Coastguard Worker            for i in range(len(thetas) - 1):
7412*da0073e9SAndroid Build Coastguard Worker                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7413*da0073e9SAndroid Build Coastguard Worker            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7414*da0073e9SAndroid Build Coastguard Worker            degs = [degs[0]] + degs
7415*da0073e9SAndroid Build Coastguard Worker
7416*da0073e9SAndroid Build Coastguard Worker            for sample_norm, deg in zip(sample_norms, degs):
7417*da0073e9SAndroid Build Coastguard Worker                x = gen_good_cond_number_matrices(*n)
7418*da0073e9SAndroid Build Coastguard Worker                x = normalize_to_1_operator_norm(x, sample_norm)
7419*da0073e9SAndroid Build Coastguard Worker
7420*da0073e9SAndroid Build Coastguard Worker                mexp = torch.linalg.matrix_exp(x)
7421*da0073e9SAndroid Build Coastguard Worker                mexp_taylor = scale_square(x, deg)
7422*da0073e9SAndroid Build Coastguard Worker
7423*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0)
7424*da0073e9SAndroid Build Coastguard Worker
7425*da0073e9SAndroid Build Coastguard Worker        # single matrix
7426*da0073e9SAndroid Build Coastguard Worker        run_test(2, 2)
7427*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3)
7428*da0073e9SAndroid Build Coastguard Worker        run_test(4, 4)
7429*da0073e9SAndroid Build Coastguard Worker        run_test(5, 5)
7430*da0073e9SAndroid Build Coastguard Worker
7431*da0073e9SAndroid Build Coastguard Worker        # small batch of matrices
7432*da0073e9SAndroid Build Coastguard Worker        run_test(3, 2, 2)
7433*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3)
7434*da0073e9SAndroid Build Coastguard Worker        run_test(3, 4, 4)
7435*da0073e9SAndroid Build Coastguard Worker        run_test(3, 5, 5)
7436*da0073e9SAndroid Build Coastguard Worker
7437*da0073e9SAndroid Build Coastguard Worker        # large batch of matrices
7438*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 2, 2)
7439*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 3, 3)
7440*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 4, 4)
7441*da0073e9SAndroid Build Coastguard Worker        run_test(3, 3, 5, 5)
7442*da0073e9SAndroid Build Coastguard Worker
7443*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7444*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7445*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
7446*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7447*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
7448*da0073e9SAndroid Build Coastguard Worker    def test_slogdet(self, device, dtype):
7449*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix,
7450*da0073e9SAndroid Build Coastguard Worker                                                          random_hermitian_pd_matrix, random_square_matrix_of_rank)
7451*da0073e9SAndroid Build Coastguard Worker
7452*da0073e9SAndroid Build Coastguard Worker        # mat_chars denotes matrix characteristics
7453*da0073e9SAndroid Build Coastguard Worker        # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular
7454*da0073e9SAndroid Build Coastguard Worker        def run_test(matsize, batchdims, mat_chars):
7455*da0073e9SAndroid Build Coastguard Worker            num_matrices = np.prod(batchdims)
7456*da0073e9SAndroid Build Coastguard Worker            list_of_matrices = []
7457*da0073e9SAndroid Build Coastguard Worker            if num_matrices != 0:
7458*da0073e9SAndroid Build Coastguard Worker                for idx in range(num_matrices):
7459*da0073e9SAndroid Build Coastguard Worker                    mat_type = idx % len(mat_chars)
7460*da0073e9SAndroid Build Coastguard Worker                    if mat_chars[mat_type] == 'hermitian':
7461*da0073e9SAndroid Build Coastguard Worker                        list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device))
7462*da0073e9SAndroid Build Coastguard Worker                    elif mat_chars[mat_type] == 'hermitian_psd':
7463*da0073e9SAndroid Build Coastguard Worker                        list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device))
7464*da0073e9SAndroid Build Coastguard Worker                    elif mat_chars[mat_type] == 'hermitian_pd':
7465*da0073e9SAndroid Build Coastguard Worker                        list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device))
7466*da0073e9SAndroid Build Coastguard Worker                    elif mat_chars[mat_type] == 'singular':
7467*da0073e9SAndroid Build Coastguard Worker                        list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7468*da0073e9SAndroid Build Coastguard Worker                    elif mat_chars[mat_type] == 'non_singular':
7469*da0073e9SAndroid Build Coastguard Worker                        list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7470*da0073e9SAndroid Build Coastguard Worker                full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7471*da0073e9SAndroid Build Coastguard Worker            else:
7472*da0073e9SAndroid Build Coastguard Worker                full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device)
7473*da0073e9SAndroid Build Coastguard Worker
7474*da0073e9SAndroid Build Coastguard Worker            actual_value = torch.linalg.slogdet(full_tensor)
7475*da0073e9SAndroid Build Coastguard Worker            expected_value = np.linalg.slogdet(full_tensor.cpu().numpy())
7476*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision)
7477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision)
7478*da0073e9SAndroid Build Coastguard Worker
7479*da0073e9SAndroid Build Coastguard Worker            # test out=variant
7480*da0073e9SAndroid Build Coastguard Worker            sign_out = torch.empty_like(actual_value[0])
7481*da0073e9SAndroid Build Coastguard Worker            logabsdet_out = torch.empty_like(actual_value[1])
7482*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out))
7483*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans[0], sign_out)
7484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans[1], logabsdet_out)
7485*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sign_out, actual_value[0])
7486*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(logabsdet_out, actual_value[1])
7487*da0073e9SAndroid Build Coastguard Worker
7488*da0073e9SAndroid Build Coastguard Worker        for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]):
7489*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['hermitian_pd'])
7490*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['singular'])
7491*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['non_singular'])
7492*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd'])
7493*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['singular', 'non_singular'])
7494*da0073e9SAndroid Build Coastguard Worker
7495*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7496*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7497*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
7498*da0073e9SAndroid Build Coastguard Worker    def test_slogdet_errors_and_warnings(self, device, dtype):
7499*da0073e9SAndroid Build Coastguard Worker        # slogdet requires the input to be a square matrix or batch of square matrices
7500*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device, dtype=dtype)
7501*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
7502*da0073e9SAndroid Build Coastguard Worker            torch.linalg.slogdet(a)
7503*da0073e9SAndroid Build Coastguard Worker
7504*da0073e9SAndroid Build Coastguard Worker        # slogdet requires the input to be at least 2 dimensional tensor
7505*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, device=device, dtype=dtype)
7506*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
7507*da0073e9SAndroid Build Coastguard Worker            torch.linalg.slogdet(a)
7508*da0073e9SAndroid Build Coastguard Worker
7509*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, device=device, dtype=torch.bfloat16)
7510*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'):
7511*da0073e9SAndroid Build Coastguard Worker            torch.linalg.slogdet(a)
7512*da0073e9SAndroid Build Coastguard Worker
7513*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
7514*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 3, device=device, dtype=dtype)
7515*da0073e9SAndroid Build Coastguard Worker        sign_out = torch.empty(1, device=device, dtype=dtype)
7516*da0073e9SAndroid Build Coastguard Worker        real_dtype = a.real.dtype if dtype.is_complex else dtype
7517*da0073e9SAndroid Build Coastguard Worker        logabsdet_out = torch.empty(1, device=device, dtype=real_dtype)
7518*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
7519*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
7520*da0073e9SAndroid Build Coastguard Worker            torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7521*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
7522*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
7523*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7524*da0073e9SAndroid Build Coastguard Worker
7525*da0073e9SAndroid Build Coastguard Worker        # device should match
7526*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7527*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7528*da0073e9SAndroid Build Coastguard Worker            sign_out = torch.empty(0, device=wrong_device, dtype=dtype)
7529*da0073e9SAndroid Build Coastguard Worker            logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype)
7530*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
7531*da0073e9SAndroid Build Coastguard Worker                torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7532*da0073e9SAndroid Build Coastguard Worker
7533*da0073e9SAndroid Build Coastguard Worker    # FIXME One of the backends of lu_factor fails in windows. I haven't investigated which or why
7534*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/75225
7535*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
7536*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
7537*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7538*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
7539*da0073e9SAndroid Build Coastguard Worker    def test_det_logdet_slogdet(self, device, dtype):
7540*da0073e9SAndroid Build Coastguard Worker        def reference_slogdet(M):
7541*da0073e9SAndroid Build Coastguard Worker            sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy())
7542*da0073e9SAndroid Build Coastguard Worker            return M.new_tensor(sdet), M.new_tensor(logabsdet)
7543*da0073e9SAndroid Build Coastguard Worker
7544*da0073e9SAndroid Build Coastguard Worker        def test_single_det(M, target, desc):
7545*da0073e9SAndroid Build Coastguard Worker            target_sdet, target_logabsdet = target
7546*da0073e9SAndroid Build Coastguard Worker
7547*da0073e9SAndroid Build Coastguard Worker            det = M.det()
7548*da0073e9SAndroid Build Coastguard Worker            logdet = M.logdet()
7549*da0073e9SAndroid Build Coastguard Worker            sdet, logabsdet = M.slogdet()
7550*da0073e9SAndroid Build Coastguard Worker            linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M)
7551*da0073e9SAndroid Build Coastguard Worker
7552*da0073e9SAndroid Build Coastguard Worker            # Test det
7553*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(det, target_sdet * target_logabsdet.exp(),
7554*da0073e9SAndroid Build Coastguard Worker                             atol=1e-6, rtol=0, msg=f'{desc} (det)')
7555*da0073e9SAndroid Build Coastguard Worker
7556*da0073e9SAndroid Build Coastguard Worker            # Test slogdet
7557*da0073e9SAndroid Build Coastguard Worker            # Compare the overall value rather than individual parts because of
7558*da0073e9SAndroid Build Coastguard Worker            # precision issues when det is near zero.
7559*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7560*da0073e9SAndroid Build Coastguard Worker                             atol=1e-6, rtol=0, msg=f'{desc} (slogdet)')
7561*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7562*da0073e9SAndroid Build Coastguard Worker                             atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)')
7563*da0073e9SAndroid Build Coastguard Worker
7564*da0073e9SAndroid Build Coastguard Worker            # Test logdet
7565*da0073e9SAndroid Build Coastguard Worker            # Compare logdet against our own pytorch slogdet because they should
7566*da0073e9SAndroid Build Coastguard Worker            # be consistent, while it may behave slightly differently with other
7567*da0073e9SAndroid Build Coastguard Worker            # slogdet implementations when det is near zero due to precision
7568*da0073e9SAndroid Build Coastguard Worker            # issues.
7569*da0073e9SAndroid Build Coastguard Worker            if sdet.item() < 0:
7570*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)')
7571*da0073e9SAndroid Build Coastguard Worker            else:
7572*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(logdet.exp(), target_logabsdet.exp(),
7573*da0073e9SAndroid Build Coastguard Worker                                 atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)')
7574*da0073e9SAndroid Build Coastguard Worker
7575*da0073e9SAndroid Build Coastguard Worker        eye = torch.eye(5, dtype=dtype, device=device)
7576*da0073e9SAndroid Build Coastguard Worker        test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
7577*da0073e9SAndroid Build Coastguard Worker        # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061)
7578*da0073e9SAndroid Build Coastguard Worker        for n in range(250, 551, 100):
7579*da0073e9SAndroid Build Coastguard Worker            mat = torch.randn(n, n, dtype=dtype, device=device)
7580*da0073e9SAndroid Build Coastguard Worker            q, _ = torch.qr(mat)
7581*da0073e9SAndroid Build Coastguard Worker            ref_det, ref_logabsdet = reference_slogdet(q)
7582*da0073e9SAndroid Build Coastguard Worker            test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal')
7583*da0073e9SAndroid Build Coastguard Worker
7584*da0073e9SAndroid Build Coastguard Worker        def test(M):
7585*da0073e9SAndroid Build Coastguard Worker            assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
7586*da0073e9SAndroid Build Coastguard Worker            M = M.to(device)
7587*da0073e9SAndroid Build Coastguard Worker
7588*da0073e9SAndroid Build Coastguard Worker            ref_M_sdet, ref_M_logabsdet = reference_slogdet(M)
7589*da0073e9SAndroid Build Coastguard Worker
7590*da0073e9SAndroid Build Coastguard Worker            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic')
7591*da0073e9SAndroid Build Coastguard Worker            if ref_M_logabsdet.exp().item() >= 1e-6:  # skip singular
7592*da0073e9SAndroid Build Coastguard Worker                M_inv = M.inverse()
7593*da0073e9SAndroid Build Coastguard Worker                test_single_det(M_inv, reference_slogdet(M_inv), 'inverse')
7594*da0073e9SAndroid Build Coastguard Worker
7595*da0073e9SAndroid Build Coastguard Worker            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose')
7596*da0073e9SAndroid Build Coastguard Worker
7597*da0073e9SAndroid Build Coastguard Worker            for x in [0, 2, 4]:
7598*da0073e9SAndroid Build Coastguard Worker                for scale in [-2, -0.1, 0, 10]:
7599*da0073e9SAndroid Build Coastguard Worker                    if scale > 0:
7600*da0073e9SAndroid Build Coastguard Worker                        target = ref_M_sdet, ref_M_logabsdet + math.log(scale)
7601*da0073e9SAndroid Build Coastguard Worker                    elif scale == 0:
7602*da0073e9SAndroid Build Coastguard Worker                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7603*da0073e9SAndroid Build Coastguard Worker                    else:
7604*da0073e9SAndroid Build Coastguard Worker                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale)
7605*da0073e9SAndroid Build Coastguard Worker
7606*da0073e9SAndroid Build Coastguard Worker                    # dim 0
7607*da0073e9SAndroid Build Coastguard Worker                    M_clone = M.clone()
7608*da0073e9SAndroid Build Coastguard Worker                    M_clone[:, x] *= scale
7609*da0073e9SAndroid Build Coastguard Worker                    test_single_det(M_clone, target, 'scale a row')
7610*da0073e9SAndroid Build Coastguard Worker                    # dim 1
7611*da0073e9SAndroid Build Coastguard Worker                    M_clone = M.clone()
7612*da0073e9SAndroid Build Coastguard Worker                    M_clone[x, :] *= scale
7613*da0073e9SAndroid Build Coastguard Worker                    test_single_det(M_clone, target, 'scale a column')
7614*da0073e9SAndroid Build Coastguard Worker
7615*da0073e9SAndroid Build Coastguard Worker            for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
7616*da0073e9SAndroid Build Coastguard Worker                assert x1 != x2, 'x1 and x2 needs to be different for this test'
7617*da0073e9SAndroid Build Coastguard Worker                target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7618*da0073e9SAndroid Build Coastguard Worker                # dim 0
7619*da0073e9SAndroid Build Coastguard Worker                M_clone = M.clone()
7620*da0073e9SAndroid Build Coastguard Worker                M_clone[:, x2] = M_clone[:, x1]
7621*da0073e9SAndroid Build Coastguard Worker                test_single_det(M_clone, target, 'two rows are same')
7622*da0073e9SAndroid Build Coastguard Worker                # dim 1
7623*da0073e9SAndroid Build Coastguard Worker                M_clone = M.clone()
7624*da0073e9SAndroid Build Coastguard Worker                M_clone[x2, :] = M_clone[x1, :]
7625*da0073e9SAndroid Build Coastguard Worker                test_single_det(M_clone, target, 'two columns are same')
7626*da0073e9SAndroid Build Coastguard Worker
7627*da0073e9SAndroid Build Coastguard Worker                for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
7628*da0073e9SAndroid Build Coastguard Worker                    det_scale = scale1 * scale2 * -1
7629*da0073e9SAndroid Build Coastguard Worker                    if det_scale > 0:
7630*da0073e9SAndroid Build Coastguard Worker                        target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale)
7631*da0073e9SAndroid Build Coastguard Worker                    elif det_scale == 0:
7632*da0073e9SAndroid Build Coastguard Worker                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7633*da0073e9SAndroid Build Coastguard Worker                    else:
7634*da0073e9SAndroid Build Coastguard Worker                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale)
7635*da0073e9SAndroid Build Coastguard Worker
7636*da0073e9SAndroid Build Coastguard Worker                    # dim 0
7637*da0073e9SAndroid Build Coastguard Worker                    M_clone = M.clone()
7638*da0073e9SAndroid Build Coastguard Worker                    t = M_clone[:, x1] * scale1
7639*da0073e9SAndroid Build Coastguard Worker                    M_clone[:, x1] += M_clone[:, x2] * scale2
7640*da0073e9SAndroid Build Coastguard Worker                    M_clone[:, x2] = t
7641*da0073e9SAndroid Build Coastguard Worker                    test_single_det(M_clone, target, 'exchanging rows')
7642*da0073e9SAndroid Build Coastguard Worker                    # dim 1
7643*da0073e9SAndroid Build Coastguard Worker                    M_clone = M.clone()
7644*da0073e9SAndroid Build Coastguard Worker                    t = M_clone[x1, :] * scale1
7645*da0073e9SAndroid Build Coastguard Worker                    M_clone[x1, :] += M_clone[x2, :] * scale2
7646*da0073e9SAndroid Build Coastguard Worker                    M_clone[x2, :] = t
7647*da0073e9SAndroid Build Coastguard Worker                    test_single_det(M_clone, target, 'exchanging columns')
7648*da0073e9SAndroid Build Coastguard Worker
7649*da0073e9SAndroid Build Coastguard Worker        def get_random_mat_scale(n):
7650*da0073e9SAndroid Build Coastguard Worker            # For matrices with values i.i.d. with 0 mean, unit variance, and
7651*da0073e9SAndroid Build Coastguard Worker            # subexponential tail, we have:
7652*da0073e9SAndroid Build Coastguard Worker            #   E[log det(A^2)] \approx log((n-1)!)
7653*da0073e9SAndroid Build Coastguard Worker            #
7654*da0073e9SAndroid Build Coastguard Worker            # Notice:
7655*da0073e9SAndroid Build Coastguard Worker            #   log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)]
7656*da0073e9SAndroid Build Coastguard Worker            #
7657*da0073e9SAndroid Build Coastguard Worker            # So:
7658*da0073e9SAndroid Build Coastguard Worker            #   stddev[det(A)] >= sqrt( (n-1)! )
7659*da0073e9SAndroid Build Coastguard Worker            #
7660*da0073e9SAndroid Build Coastguard Worker            # We use this as an intuitive guideline to scale random generated
7661*da0073e9SAndroid Build Coastguard Worker            # matrices so our closeness tests can work more robustly:
7662*da0073e9SAndroid Build Coastguard Worker            #   scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n))
7663*da0073e9SAndroid Build Coastguard Worker            #
7664*da0073e9SAndroid Build Coastguard Worker            # source: https://arxiv.org/pdf/1112.0752.pdf
7665*da0073e9SAndroid Build Coastguard Worker
7666*da0073e9SAndroid Build Coastguard Worker            # TODO: technically we need subexponential distn for this to hold,
7667*da0073e9SAndroid Build Coastguard Worker            #       but we mostly use gaussian entries below. Consider switching
7668*da0073e9SAndroid Build Coastguard Worker            #       to Chi-sq if this turns out not stable enough, since Chi-sq
7669*da0073e9SAndroid Build Coastguard Worker            #       is easy enough to sample from.
7670*da0073e9SAndroid Build Coastguard Worker            return math.factorial(n - 1) ** (-1.0 / (2 * n))
7671*da0073e9SAndroid Build Coastguard Worker
7672*da0073e9SAndroid Build Coastguard Worker        for n in [5, 10, 25]:
7673*da0073e9SAndroid Build Coastguard Worker            scale = get_random_mat_scale(n)
7674*da0073e9SAndroid Build Coastguard Worker            test(torch.randn(n, n, dtype=dtype, device=device) * scale)
7675*da0073e9SAndroid Build Coastguard Worker            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7676*da0073e9SAndroid Build Coastguard Worker            # symmetric psd
7677*da0073e9SAndroid Build Coastguard Worker            test(r.mm(r.t()))
7678*da0073e9SAndroid Build Coastguard Worker            # symmetric pd
7679*da0073e9SAndroid Build Coastguard Worker            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7680*da0073e9SAndroid Build Coastguard Worker            test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6)
7681*da0073e9SAndroid Build Coastguard Worker            # symmetric
7682*da0073e9SAndroid Build Coastguard Worker            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7683*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
7684*da0073e9SAndroid Build Coastguard Worker                for j in range(i):
7685*da0073e9SAndroid Build Coastguard Worker                    r[i, j] = r[j, i]
7686*da0073e9SAndroid Build Coastguard Worker            test(r)
7687*da0073e9SAndroid Build Coastguard Worker            # non-contiguous
7688*da0073e9SAndroid Build Coastguard Worker            test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:])
7689*da0073e9SAndroid Build Coastguard Worker            # det = 0
7690*da0073e9SAndroid Build Coastguard Worker            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7691*da0073e9SAndroid Build Coastguard Worker            u, s, v = r.svd()
7692*da0073e9SAndroid Build Coastguard Worker            if reference_slogdet(u)[0] < 0:
7693*da0073e9SAndroid Build Coastguard Worker                u = -u
7694*da0073e9SAndroid Build Coastguard Worker            if reference_slogdet(v)[0] < 0:
7695*da0073e9SAndroid Build Coastguard Worker                v = -v
7696*da0073e9SAndroid Build Coastguard Worker            s[0] *= -1
7697*da0073e9SAndroid Build Coastguard Worker            s[-1] = 0
7698*da0073e9SAndroid Build Coastguard Worker            test(u.mm(s.diag()).mm(v))
7699*da0073e9SAndroid Build Coastguard Worker
7700*da0073e9SAndroid Build Coastguard Worker        # Small values to test numerical stability. Note that we don't scale
7701*da0073e9SAndroid Build Coastguard Worker        # this matrix.
7702*da0073e9SAndroid Build Coastguard Worker        r = torch.randn(512, 512, dtype=dtype, device=device)
7703*da0073e9SAndroid Build Coastguard Worker        u, s, v = r.svd()
7704*da0073e9SAndroid Build Coastguard Worker        s.fill_(1. / (100 * s.numel()))
7705*da0073e9SAndroid Build Coastguard Worker        test(u.mm(s.diag()).mm(v))
7706*da0073e9SAndroid Build Coastguard Worker
7707*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7708*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7709*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
7710*da0073e9SAndroid Build Coastguard Worker    def test_det_logdet_slogdet_batched(self, device, dtype):
7711*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
7712*da0073e9SAndroid Build Coastguard Worker                                                          random_symmetric_pd_matrix, random_square_matrix_of_rank)
7713*da0073e9SAndroid Build Coastguard Worker
7714*da0073e9SAndroid Build Coastguard Worker        # mat_chars denotes matrix characteristics
7715*da0073e9SAndroid Build Coastguard Worker        # possible values are: sym, sym_psd, sym_pd, sing, non_sym
7716*da0073e9SAndroid Build Coastguard Worker        def run_test(matsize, batchdims, mat_chars):
7717*da0073e9SAndroid Build Coastguard Worker            num_matrices = reduce(operator.mul, batchdims, 1)
7718*da0073e9SAndroid Build Coastguard Worker            list_of_matrices = []
7719*da0073e9SAndroid Build Coastguard Worker
7720*da0073e9SAndroid Build Coastguard Worker            for idx in range(num_matrices):
7721*da0073e9SAndroid Build Coastguard Worker                mat_type = idx % len(mat_chars)
7722*da0073e9SAndroid Build Coastguard Worker                if mat_chars[mat_type] == 'sym':
7723*da0073e9SAndroid Build Coastguard Worker                    list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device))
7724*da0073e9SAndroid Build Coastguard Worker                elif mat_chars[mat_type] == 'sym_psd':
7725*da0073e9SAndroid Build Coastguard Worker                    list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device))
7726*da0073e9SAndroid Build Coastguard Worker                elif mat_chars[mat_type] == 'sym_pd':
7727*da0073e9SAndroid Build Coastguard Worker                    list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device))
7728*da0073e9SAndroid Build Coastguard Worker                elif mat_chars[mat_type] == 'sing':
7729*da0073e9SAndroid Build Coastguard Worker                    list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7730*da0073e9SAndroid Build Coastguard Worker                elif mat_chars[mat_type] == 'non_sing':
7731*da0073e9SAndroid Build Coastguard Worker                    list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7732*da0073e9SAndroid Build Coastguard Worker            full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7733*da0073e9SAndroid Build Coastguard Worker            # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet
7734*da0073e9SAndroid Build Coastguard Worker            full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize)))
7735*da0073e9SAndroid Build Coastguard Worker
7736*da0073e9SAndroid Build Coastguard Worker            for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
7737*da0073e9SAndroid Build Coastguard Worker                expected_value = []
7738*da0073e9SAndroid Build Coastguard Worker                actual_value = fn(full_tensor)
7739*da0073e9SAndroid Build Coastguard Worker                for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
7740*da0073e9SAndroid Build Coastguard Worker                    expected_value.append(fn(full_tensor[full_idx]))
7741*da0073e9SAndroid Build Coastguard Worker
7742*da0073e9SAndroid Build Coastguard Worker                if fn == torch.slogdet or fn == torch.linalg.slogdet:
7743*da0073e9SAndroid Build Coastguard Worker                    sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims)
7744*da0073e9SAndroid Build Coastguard Worker                    expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims)
7745*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(sign_value, actual_value[0])
7746*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_value, actual_value[1])
7747*da0073e9SAndroid Build Coastguard Worker                else:
7748*da0073e9SAndroid Build Coastguard Worker                    expected_value = torch.stack(expected_value, dim=0).reshape(batchdims)
7749*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual_value, expected_value)
7750*da0073e9SAndroid Build Coastguard Worker
7751*da0073e9SAndroid Build Coastguard Worker        for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]):
7752*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['sym_pd'])
7753*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['sing'])
7754*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['non_sing'])
7755*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd'])
7756*da0073e9SAndroid Build Coastguard Worker            run_test(matsize, batchdims, mat_chars=['sing', 'non_sing'])
7757*da0073e9SAndroid Build Coastguard Worker
7758*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7759*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7760*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
7761*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_inverse(self, device, dtype):
7762*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
7763*da0073e9SAndroid Build Coastguard Worker
7764*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, upper, contiguous):
7765*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
7766*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0 and not contiguous:
7767*da0073e9SAndroid Build Coastguard Worker                A = A.mT
7768*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(A.is_contiguous())
7769*da0073e9SAndroid Build Coastguard Worker            L = torch.linalg.cholesky(A)
7770*da0073e9SAndroid Build Coastguard Worker            expected_inverse = torch.inverse(A)
7771*da0073e9SAndroid Build Coastguard Worker            L = L.mH if upper else L
7772*da0073e9SAndroid Build Coastguard Worker            actual_inverse = torch.cholesky_inverse(L, upper)
7773*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_inverse, expected_inverse)
7774*da0073e9SAndroid Build Coastguard Worker
7775*da0073e9SAndroid Build Coastguard Worker        shapes = (0, 3, 5)
7776*da0073e9SAndroid Build Coastguard Worker        batches = ((), (0,), (3, ), (2, 2))
7777*da0073e9SAndroid Build Coastguard Worker        for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))):
7778*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, upper, contiguous)
7779*da0073e9SAndroid Build Coastguard Worker
7780*da0073e9SAndroid Build Coastguard Worker        # check the out= variant
7781*da0073e9SAndroid Build Coastguard Worker        A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device)
7782*da0073e9SAndroid Build Coastguard Worker        L = torch.linalg.cholesky(A)
7783*da0073e9SAndroid Build Coastguard Worker
7784*da0073e9SAndroid Build Coastguard Worker        # There are two code paths currently for the out= variant
7785*da0073e9SAndroid Build Coastguard Worker        # 1. When 'out' tensor is in Fortran (column-major) memory format
7786*da0073e9SAndroid Build Coastguard Worker        # then the fast route is taken and the storage is reused directly in the computations
7787*da0073e9SAndroid Build Coastguard Worker        # 2. When 'out' tensor is not in Fortran format then a temporary tensor is allocated internally
7788*da0073e9SAndroid Build Coastguard Worker        # and the result is copied from the temporary tensor to 'out' tensor
7789*da0073e9SAndroid Build Coastguard Worker
7790*da0073e9SAndroid Build Coastguard Worker        # This test checks the first code path
7791*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(A)
7792*da0073e9SAndroid Build Coastguard Worker        out_t = out.mT.clone(memory_format=torch.contiguous_format)
7793*da0073e9SAndroid Build Coastguard Worker        out = out_t.mT
7794*da0073e9SAndroid Build Coastguard Worker        ans = torch.cholesky_inverse(L, out=out)
7795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ans, out)
7796*da0073e9SAndroid Build Coastguard Worker        expected = torch.inverse(A)
7797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, out)
7798*da0073e9SAndroid Build Coastguard Worker
7799*da0073e9SAndroid Build Coastguard Worker        # This test checks the second code path
7800*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(A)
7801*da0073e9SAndroid Build Coastguard Worker        ans = torch.cholesky_inverse(L, out=out)
7802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ans, out)
7803*da0073e9SAndroid Build Coastguard Worker        expected = torch.inverse(A)
7804*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, out)
7805*da0073e9SAndroid Build Coastguard Worker
7806*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
7807*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
7808*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
7809*da0073e9SAndroid Build Coastguard Worker    def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
7810*da0073e9SAndroid Build Coastguard Worker        # cholesky_inverse requires the input to be at least 2 dimensional tensor
7811*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, device=device, dtype=dtype)
7812*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7813*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_inverse(a)
7814*da0073e9SAndroid Build Coastguard Worker
7815*da0073e9SAndroid Build Coastguard Worker        # cholesky_inverse requires a square matrix
7816*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device, dtype=dtype)
7817*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7818*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_inverse(a)
7819*da0073e9SAndroid Build Coastguard Worker
7820*da0073e9SAndroid Build Coastguard Worker        # if non-empty out tensor with wrong shape is passed a warning is given
7821*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3, device=device, dtype=dtype)
7822*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(2, 3, device=device, dtype=dtype)
7823*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
7824*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
7825*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_inverse(a, out=out)
7826*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
7827*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
7828*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7829*da0073e9SAndroid Build Coastguard Worker
7830*da0073e9SAndroid Build Coastguard Worker        # dtypes should be safely castable
7831*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(*a.shape, dtype=torch.int, device=device)
7832*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
7833*da0073e9SAndroid Build Coastguard Worker            torch.cholesky_inverse(a, out=out)
7834*da0073e9SAndroid Build Coastguard Worker
7835*da0073e9SAndroid Build Coastguard Worker        # device should match
7836*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7837*da0073e9SAndroid Build Coastguard Worker            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7838*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=wrong_device, dtype=dtype)
7839*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
7840*da0073e9SAndroid Build Coastguard Worker                torch.cholesky_inverse(a, out=out)
7841*da0073e9SAndroid Build Coastguard Worker
7842*da0073e9SAndroid Build Coastguard Worker        # cholesky_inverse raises an error for invalid inputs on CPU
7843*da0073e9SAndroid Build Coastguard Worker        # for example if at least one diagonal element is zero
7844*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3, device=device, dtype=dtype)
7845*da0073e9SAndroid Build Coastguard Worker        a[1, 1] = 0
7846*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
7847*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
7848*da0073e9SAndroid Build Coastguard Worker                torch.cholesky_inverse(a)
7849*da0073e9SAndroid Build Coastguard Worker        # cholesky_inverse on GPU does not raise an error for this case
7850*da0073e9SAndroid Build Coastguard Worker        elif self.device_type == 'cuda':
7851*da0073e9SAndroid Build Coastguard Worker            out = torch.cholesky_inverse(a)
7852*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.isinf().any() or out.isnan().any())
7853*da0073e9SAndroid Build Coastguard Worker
7854*da0073e9SAndroid Build Coastguard Worker    def _select_broadcastable_dims(self, dims_full=None):
7855*da0073e9SAndroid Build Coastguard Worker        # select full dimensionality
7856*da0073e9SAndroid Build Coastguard Worker        if dims_full is None:
7857*da0073e9SAndroid Build Coastguard Worker            dims_full = []
7858*da0073e9SAndroid Build Coastguard Worker            ndims = random.randint(1, 4)
7859*da0073e9SAndroid Build Coastguard Worker            dims_full = [random.randint(1, 8) for _ in range(ndims)]
7860*da0073e9SAndroid Build Coastguard Worker        else:
7861*da0073e9SAndroid Build Coastguard Worker            ndims = len(dims_full)
7862*da0073e9SAndroid Build Coastguard Worker
7863*da0073e9SAndroid Build Coastguard Worker        # select actual dimensions for ops:
7864*da0073e9SAndroid Build Coastguard Worker        # larger: full ndims, individual sizes may be reduced
7865*da0073e9SAndroid Build Coastguard Worker        # smaller: possibly reduced ndims, sizes may be reduced
7866*da0073e9SAndroid Build Coastguard Worker        smaller_ndims = random.randint(1, ndims)
7867*da0073e9SAndroid Build Coastguard Worker        dims_small = []
7868*da0073e9SAndroid Build Coastguard Worker        dims_large = []
7869*da0073e9SAndroid Build Coastguard Worker        for i in range(ndims - 1, -1, -1):
7870*da0073e9SAndroid Build Coastguard Worker            j = random.randint(1, 3)
7871*da0073e9SAndroid Build Coastguard Worker            if j == 1:  # no reduced singleton dimension
7872*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
7873*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
7874*da0073e9SAndroid Build Coastguard Worker            elif j == 2:  # larger may have reduced singleton dimension
7875*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
7876*da0073e9SAndroid Build Coastguard Worker                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
7877*da0073e9SAndroid Build Coastguard Worker            elif j == 3:  # smaller may have reduced singleton dimension
7878*da0073e9SAndroid Build Coastguard Worker                ds = 1
7879*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
7880*da0073e9SAndroid Build Coastguard Worker            dims_large = [dl] + dims_large
7881*da0073e9SAndroid Build Coastguard Worker            if len(dims_small) < smaller_ndims:
7882*da0073e9SAndroid Build Coastguard Worker                dims_small = [ds] + dims_small
7883*da0073e9SAndroid Build Coastguard Worker        return (dims_small, dims_large, dims_full)
7884*da0073e9SAndroid Build Coastguard Worker
7885*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_fused_matmul(self, device):
7886*da0073e9SAndroid Build Coastguard Worker        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
7887*da0073e9SAndroid Build Coastguard Worker
7888*da0073e9SAndroid Build Coastguard Worker        for fn in fns:
7889*da0073e9SAndroid Build Coastguard Worker            batch_dim = random.randint(1, 8)
7890*da0073e9SAndroid Build Coastguard Worker            n_dim = random.randint(1, 8)
7891*da0073e9SAndroid Build Coastguard Worker            m_dim = random.randint(1, 8)
7892*da0073e9SAndroid Build Coastguard Worker            p_dim = random.randint(1, 8)
7893*da0073e9SAndroid Build Coastguard Worker
7894*da0073e9SAndroid Build Coastguard Worker            def dims_full_for_fn():
7895*da0073e9SAndroid Build Coastguard Worker                if fn == "baddbmm":
7896*da0073e9SAndroid Build Coastguard Worker                    return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7897*da0073e9SAndroid Build Coastguard Worker                elif fn == "addbmm":
7898*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7899*da0073e9SAndroid Build Coastguard Worker                elif fn == "addmm":
7900*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
7901*da0073e9SAndroid Build Coastguard Worker                elif fn == "addmv":
7902*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim], [n_dim, m_dim], [m_dim])
7903*da0073e9SAndroid Build Coastguard Worker                elif fn == "addr":
7904*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim, m_dim], [n_dim], [m_dim])
7905*da0073e9SAndroid Build Coastguard Worker                else:
7906*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError("unknown function")
7907*da0073e9SAndroid Build Coastguard Worker
7908*da0073e9SAndroid Build Coastguard Worker            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
7909*da0073e9SAndroid Build Coastguard Worker            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
7910*da0073e9SAndroid Build Coastguard Worker
7911*da0073e9SAndroid Build Coastguard Worker            t0_small = torch.randn(*t0_dims_small, device=device).float()
7912*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(*t1_dims, device=device).float()
7913*da0073e9SAndroid Build Coastguard Worker            t2 = torch.randn(*t2_dims, device=device).float()
7914*da0073e9SAndroid Build Coastguard Worker
7915*da0073e9SAndroid Build Coastguard Worker            t0_full = t0_small.expand(*t0_dims_full).to(device)
7916*da0073e9SAndroid Build Coastguard Worker
7917*da0073e9SAndroid Build Coastguard Worker            fntorch = getattr(torch, fn)
7918*da0073e9SAndroid Build Coastguard Worker            r0 = fntorch(t0_small, t1, t2)
7919*da0073e9SAndroid Build Coastguard Worker            r1 = fntorch(t0_full, t1, t2)
7920*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r0, r1)
7921*da0073e9SAndroid Build Coastguard Worker
7922*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.001)
7923*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.001)
7924*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_batched_matmul(self, device):
7925*da0073e9SAndroid Build Coastguard Worker        n_dim = random.randint(1, 8)
7926*da0073e9SAndroid Build Coastguard Worker        m_dim = random.randint(1, 8)
7927*da0073e9SAndroid Build Coastguard Worker        p_dim = random.randint(1, 8)
7928*da0073e9SAndroid Build Coastguard Worker        full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
7929*da0073e9SAndroid Build Coastguard Worker        (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
7930*da0073e9SAndroid Build Coastguard Worker
7931*da0073e9SAndroid Build Coastguard Worker        def verify_batched_matmul(full_lhs, one_dimensional):
7932*da0073e9SAndroid Build Coastguard Worker            if not one_dimensional:
7933*da0073e9SAndroid Build Coastguard Worker                lhs_dims = [n_dim, m_dim]
7934*da0073e9SAndroid Build Coastguard Worker                rhs_dims = [m_dim, p_dim]
7935*da0073e9SAndroid Build Coastguard Worker                result_dims = [n_dim, p_dim]
7936*da0073e9SAndroid Build Coastguard Worker            else:
7937*da0073e9SAndroid Build Coastguard Worker                lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
7938*da0073e9SAndroid Build Coastguard Worker                rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
7939*da0073e9SAndroid Build Coastguard Worker                result_dims = [n_dim] if full_lhs else [p_dim]
7940*da0073e9SAndroid Build Coastguard Worker
7941*da0073e9SAndroid Build Coastguard Worker            lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
7942*da0073e9SAndroid Build Coastguard Worker            rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
7943*da0073e9SAndroid Build Coastguard Worker            full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
7944*da0073e9SAndroid Build Coastguard Worker            dim0_dims = rhs_dims if full_lhs else lhs_dims
7945*da0073e9SAndroid Build Coastguard Worker            small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)
7946*da0073e9SAndroid Build Coastguard Worker
7947*da0073e9SAndroid Build Coastguard Worker            small = torch.randn(*(small_dims), device=device).float()
7948*da0073e9SAndroid Build Coastguard Worker            dim0 = torch.randn(*(dim0_dims), device=device).float()
7949*da0073e9SAndroid Build Coastguard Worker            full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float()
7950*da0073e9SAndroid Build Coastguard Worker            if not one_dimensional:
7951*da0073e9SAndroid Build Coastguard Worker                (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
7952*da0073e9SAndroid Build Coastguard Worker            else:
7953*da0073e9SAndroid Build Coastguard Worker                (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))
7954*da0073e9SAndroid Build Coastguard Worker
7955*da0073e9SAndroid Build Coastguard Worker            def maybe_squeeze_result(l, r, result):
7956*da0073e9SAndroid Build Coastguard Worker                if len(lhs_dims) == 1 and l.dim() != 1:
7957*da0073e9SAndroid Build Coastguard Worker                    return result.squeeze(-2)
7958*da0073e9SAndroid Build Coastguard Worker                elif len(rhs_dims) == 1 and r.dim() != 1:
7959*da0073e9SAndroid Build Coastguard Worker                    return result.squeeze(-1)
7960*da0073e9SAndroid Build Coastguard Worker                else:
7961*da0073e9SAndroid Build Coastguard Worker                    return result
7962*da0073e9SAndroid Build Coastguard Worker
7963*da0073e9SAndroid Build Coastguard Worker            for lhs in lhsTensors:
7964*da0073e9SAndroid Build Coastguard Worker                lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
7965*da0073e9SAndroid Build Coastguard Worker                lhs_expanded_matmul_fn = lhs_expanded.matmul
7966*da0073e9SAndroid Build Coastguard Worker                for rhs in rhsTensors:
7967*da0073e9SAndroid Build Coastguard Worker                    rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
7968*da0073e9SAndroid Build Coastguard Worker                                    expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
7969*da0073e9SAndroid Build Coastguard Worker                    truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
7970*da0073e9SAndroid Build Coastguard Worker                    for l in (lhs, lhs_expanded):
7971*da0073e9SAndroid Build Coastguard Worker                        for r in (rhs, rhs_expanded):
7972*da0073e9SAndroid Build Coastguard Worker                            l_matmul_fn = l.matmul
7973*da0073e9SAndroid Build Coastguard Worker                            result = maybe_squeeze_result(l, r, l_matmul_fn(r))
7974*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(truth, result)
7975*da0073e9SAndroid Build Coastguard Worker                            # test torch.matmul function as well
7976*da0073e9SAndroid Build Coastguard Worker                            torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
7977*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(truth, torch_result)
7978*da0073e9SAndroid Build Coastguard Worker                            # test torch.matmul with out
7979*da0073e9SAndroid Build Coastguard Worker                            out = torch.zeros_like(torch_result)
7980*da0073e9SAndroid Build Coastguard Worker                            torch.matmul(l, r, out=out)
7981*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(truth, maybe_squeeze_result(l, r, out))
7982*da0073e9SAndroid Build Coastguard Worker
7983*da0073e9SAndroid Build Coastguard Worker                # compare to bmm
7984*da0073e9SAndroid Build Coastguard Worker                bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
7985*da0073e9SAndroid Build Coastguard Worker                                        rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
7986*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
7987*da0073e9SAndroid Build Coastguard Worker
7988*da0073e9SAndroid Build Coastguard Worker        for indices in itertools.product((True, False), repeat=2):
7989*da0073e9SAndroid Build Coastguard Worker            verify_batched_matmul(*indices)
7990*da0073e9SAndroid Build Coastguard Worker
7991*da0073e9SAndroid Build Coastguard Worker    def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
7992*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7993*da0073e9SAndroid Build Coastguard Worker        make_A = partial(make_fullrank, device=device, dtype=dtype)
7994*da0073e9SAndroid Build Coastguard Worker
7995*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(*b_dims, dtype=dtype, device=device)
7996*da0073e9SAndroid Build Coastguard Worker        A = make_A(*A_dims)
7997*da0073e9SAndroid Build Coastguard Worker        LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
7998*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(info, torch.zeros_like(info))
7999*da0073e9SAndroid Build Coastguard Worker        return b, A, LU_data, LU_pivots
8000*da0073e9SAndroid Build Coastguard Worker
8001*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8002*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
8003*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8004*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
8005*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
8006*da0073e9SAndroid Build Coastguard Worker    def test_lu_solve(self, device, dtype):
8007*da0073e9SAndroid Build Coastguard Worker        def sub_test(pivot):
8008*da0073e9SAndroid Build Coastguard Worker            for k, n in zip([2, 3, 5], [3, 5, 7]):
8009*da0073e9SAndroid Build Coastguard Worker                b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
8010*da0073e9SAndroid Build Coastguard Worker                x = torch.lu_solve(b, LU_data, LU_pivots)
8011*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
8012*da0073e9SAndroid Build Coastguard Worker
8013*da0073e9SAndroid Build Coastguard Worker        sub_test(True)
8014*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8015*da0073e9SAndroid Build Coastguard Worker            sub_test(False)
8016*da0073e9SAndroid Build Coastguard Worker
8017*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8018*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
8019*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8020*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
8021*da0073e9SAndroid Build Coastguard Worker                        torch.float64: 1e-8, torch.complex128: 1e-8})
8022*da0073e9SAndroid Build Coastguard Worker    def test_lu_solve_batched(self, device, dtype):
8023*da0073e9SAndroid Build Coastguard Worker        def sub_test(pivot):
8024*da0073e9SAndroid Build Coastguard Worker            def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
8025*da0073e9SAndroid Build Coastguard Worker                b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
8026*da0073e9SAndroid Build Coastguard Worker                x_exp_list = []
8027*da0073e9SAndroid Build Coastguard Worker                for i in range(b_dims[0]):
8028*da0073e9SAndroid Build Coastguard Worker                    x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
8029*da0073e9SAndroid Build Coastguard Worker                x_exp = torch.stack(x_exp_list)  # Stacked output
8030*da0073e9SAndroid Build Coastguard Worker                x_act = torch.lu_solve(b, LU_data, LU_pivots)  # Actual output
8031*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_exp, x_act)  # Equality check
8032*da0073e9SAndroid Build Coastguard Worker                Ax = np.matmul(A.cpu(), x_act.cpu())
8033*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, Ax)
8034*da0073e9SAndroid Build Coastguard Worker
8035*da0073e9SAndroid Build Coastguard Worker            for batchsize in [1, 3, 4]:
8036*da0073e9SAndroid Build Coastguard Worker                lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
8037*da0073e9SAndroid Build Coastguard Worker
8038*da0073e9SAndroid Build Coastguard Worker        # Tests tensors with 0 elements
8039*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 0, 3, dtype=dtype, device=device)
8040*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(3, 0, 0, dtype=dtype, device=device)
8041*da0073e9SAndroid Build Coastguard Worker        LU_data, LU_pivots = torch.linalg.lu_factor(A)
8042*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
8043*da0073e9SAndroid Build Coastguard Worker
8044*da0073e9SAndroid Build Coastguard Worker        sub_test(True)
8045*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8046*da0073e9SAndroid Build Coastguard Worker            sub_test(False)
8047*da0073e9SAndroid Build Coastguard Worker
8048*da0073e9SAndroid Build Coastguard Worker    @slowTest
8049*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8050*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
8051*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8052*da0073e9SAndroid Build Coastguard Worker    def test_lu_solve_batched_many_batches(self, device, dtype):
8053*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, b_dims):
8054*da0073e9SAndroid Build Coastguard Worker            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8055*da0073e9SAndroid Build Coastguard Worker            x = torch.lu_solve(b, LU_data, LU_pivots)
8056*da0073e9SAndroid Build Coastguard Worker            Ax = torch.matmul(A, x)
8057*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Ax, b.expand_as(Ax))
8058*da0073e9SAndroid Build Coastguard Worker
8059*da0073e9SAndroid Build Coastguard Worker        run_test((65536, 5, 5), (65536, 5, 10))
8060*da0073e9SAndroid Build Coastguard Worker        run_test((262144, 5, 5), (262144, 5, 10))
8061*da0073e9SAndroid Build Coastguard Worker
8062*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8063*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
8064*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8065*da0073e9SAndroid Build Coastguard Worker    def test_lu_solve_batched_broadcasting(self, device, dtype):
8066*da0073e9SAndroid Build Coastguard Worker        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
8067*da0073e9SAndroid Build Coastguard Worker        make_A = partial(make_fullrank, device=device, dtype=dtype)
8068*da0073e9SAndroid Build Coastguard Worker
8069*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, b_dims, pivot=True):
8070*da0073e9SAndroid Build Coastguard Worker            A_matrix_size = A_dims[-1]
8071*da0073e9SAndroid Build Coastguard Worker            A_batch_dims = A_dims[:-2]
8072*da0073e9SAndroid Build Coastguard Worker            A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
8073*da0073e9SAndroid Build Coastguard Worker            b = make_tensor(b_dims, dtype=dtype, device=device)
8074*da0073e9SAndroid Build Coastguard Worker            x_exp = np.linalg.solve(A.cpu(), b.cpu())
8075*da0073e9SAndroid Build Coastguard Worker            LU_data, LU_pivots = torch.linalg.lu_factor(A)
8076*da0073e9SAndroid Build Coastguard Worker            x = torch.lu_solve(b, LU_data, LU_pivots)
8077*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_exp)
8078*da0073e9SAndroid Build Coastguard Worker
8079*da0073e9SAndroid Build Coastguard Worker        # test against numpy.linalg.solve
8080*da0073e9SAndroid Build Coastguard Worker        run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6))  # no broadcasting
8081*da0073e9SAndroid Build Coastguard Worker        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting b
8082*da0073e9SAndroid Build Coastguard Worker        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
8083*da0073e9SAndroid Build Coastguard Worker        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & b
8084*da0073e9SAndroid Build Coastguard Worker
8085*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8086*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8087*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8088*da0073e9SAndroid Build Coastguard Worker    # this tests https://github.com/pytorch/pytorch/issues/36921
8089*da0073e9SAndroid Build Coastguard Worker    def test_lu_solve_large_matrices(self, device, dtype):
8090*da0073e9SAndroid Build Coastguard Worker        def run_test(A_dims, b_dims):
8091*da0073e9SAndroid Build Coastguard Worker            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8092*da0073e9SAndroid Build Coastguard Worker            x = torch.lu_solve(b, LU_data, LU_pivots)
8093*da0073e9SAndroid Build Coastguard Worker            Ax = torch.matmul(A, x)
8094*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Ax, b.expand_as(Ax))
8095*da0073e9SAndroid Build Coastguard Worker
8096*da0073e9SAndroid Build Coastguard Worker        run_test((1, 1), (1, 1, 1025))
8097*da0073e9SAndroid Build Coastguard Worker
8098*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
8099*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8100*da0073e9SAndroid Build Coastguard Worker    def test_pca_lowrank(self, device):
8101*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
8102*da0073e9SAndroid Build Coastguard Worker
8103*da0073e9SAndroid Build Coastguard Worker        dtype = torch.double
8104*da0073e9SAndroid Build Coastguard Worker
8105*da0073e9SAndroid Build Coastguard Worker        def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options):
8106*da0073e9SAndroid Build Coastguard Worker            density = options.pop('density', 1)
8107*da0073e9SAndroid Build Coastguard Worker            use_svd_lowrank = options.pop('use_svd_lowrank', False)
8108*da0073e9SAndroid Build Coastguard Worker            if isinstance(matrix_size, int):
8109*da0073e9SAndroid Build Coastguard Worker                rows = columns = matrix_size
8110*da0073e9SAndroid Build Coastguard Worker            else:
8111*da0073e9SAndroid Build Coastguard Worker                rows, columns = matrix_size
8112*da0073e9SAndroid Build Coastguard Worker            if density == 1:
8113*da0073e9SAndroid Build Coastguard Worker                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
8114*da0073e9SAndroid Build Coastguard Worker                a = a_input
8115*da0073e9SAndroid Build Coastguard Worker            else:
8116*da0073e9SAndroid Build Coastguard Worker                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
8117*da0073e9SAndroid Build Coastguard Worker                a = a_input.to_dense()
8118*da0073e9SAndroid Build Coastguard Worker
8119*da0073e9SAndroid Build Coastguard Worker            if use_svd_lowrank:
8120*da0073e9SAndroid Build Coastguard Worker                m = a_input.mean(dim=-2, keepdim=True)
8121*da0073e9SAndroid Build Coastguard Worker                u, s, v = pca(a_input, q=guess_rank, M=m, **options)
8122*da0073e9SAndroid Build Coastguard Worker            else:
8123*da0073e9SAndroid Build Coastguard Worker                u, s, v = pca(a_input, q=guess_rank, **options)
8124*da0073e9SAndroid Build Coastguard Worker
8125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s.shape[-1], guess_rank)
8126*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(u.shape[-2], rows)
8127*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(u.shape[-1], guess_rank)
8128*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v.shape[-1], guess_rank)
8129*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v.shape[-2], columns)
8130*da0073e9SAndroid Build Coastguard Worker
8131*da0073e9SAndroid Build Coastguard Worker            A1 = u.matmul(s.diag_embed()).matmul(v.mT)
8132*da0073e9SAndroid Build Coastguard Worker            ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device)
8133*da0073e9SAndroid Build Coastguard Worker            c = a.sum(axis=-2) / rows
8134*da0073e9SAndroid Build Coastguard Worker            c = c.reshape(batches + (1, columns))
8135*da0073e9SAndroid Build Coastguard Worker            A2 = a - ones_m1.matmul(c)
8136*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A1, A2)
8137*da0073e9SAndroid Build Coastguard Worker
8138*da0073e9SAndroid Build Coastguard Worker            if density == 1:
8139*da0073e9SAndroid Build Coastguard Worker                # actual rank is known only for dense input
8140*da0073e9SAndroid Build Coastguard Worker                detect_rank = (s.abs() > 1e-5).sum(axis=-1)
8141*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
8142*da0073e9SAndroid Build Coastguard Worker                S = torch.linalg.svdvals(A2)
8143*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
8144*da0073e9SAndroid Build Coastguard Worker
8145*da0073e9SAndroid Build Coastguard Worker        all_batches = [(), (1,), (3,), (2, 3)]
8146*da0073e9SAndroid Build Coastguard Worker        for actual_rank, size, all_batches in [  # noqa: B020
8147*da0073e9SAndroid Build Coastguard Worker                (2, (17, 4), all_batches),
8148*da0073e9SAndroid Build Coastguard Worker                (2, (100, 4), all_batches),
8149*da0073e9SAndroid Build Coastguard Worker                (6, (100, 40), all_batches),
8150*da0073e9SAndroid Build Coastguard Worker                (12, (1000, 1000), [()]),
8151*da0073e9SAndroid Build Coastguard Worker        ]:
8152*da0073e9SAndroid Build Coastguard Worker            for batches in all_batches:
8153*da0073e9SAndroid Build Coastguard Worker                for guess_rank in [
8154*da0073e9SAndroid Build Coastguard Worker                        actual_rank,
8155*da0073e9SAndroid Build Coastguard Worker                        actual_rank + 2,
8156*da0073e9SAndroid Build Coastguard Worker                        actual_rank + 6,
8157*da0073e9SAndroid Build Coastguard Worker                ]:
8158*da0073e9SAndroid Build Coastguard Worker                    if guess_rank <= min(*size):
8159*da0073e9SAndroid Build Coastguard Worker                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank)
8160*da0073e9SAndroid Build Coastguard Worker                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank)
8161*da0073e9SAndroid Build Coastguard Worker                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8162*da0073e9SAndroid Build Coastguard Worker                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8163*da0073e9SAndroid Build Coastguard Worker
8164*da0073e9SAndroid Build Coastguard Worker        # sparse input
8165*da0073e9SAndroid Build Coastguard Worker        for guess_rank, size in [
8166*da0073e9SAndroid Build Coastguard Worker                (4, (17, 4)), (4, (4, 17)), (16, (17, 17)),
8167*da0073e9SAndroid Build Coastguard Worker                (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]:
8168*da0073e9SAndroid Build Coastguard Worker            for density in [0.005, 0.1]:
8169*da0073e9SAndroid Build Coastguard Worker                run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density)
8170*da0073e9SAndroid Build Coastguard Worker
8171*da0073e9SAndroid Build Coastguard Worker        # jitting support
8172*da0073e9SAndroid Build Coastguard Worker        jitted = torch.jit.script(torch.pca_lowrank)
8173*da0073e9SAndroid Build Coastguard Worker        guess_rank, actual_rank, size, batches = 2, 2, (17, 4), ()
8174*da0073e9SAndroid Build Coastguard Worker        run_subtest(guess_rank, actual_rank, size, batches, device, jitted)
8175*da0073e9SAndroid Build Coastguard Worker
8176*da0073e9SAndroid Build Coastguard Worker    # Ensure that nuclear_norm's out variant gives the same result as the non-out
8177*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8178*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8179*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8180*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
8181*da0073e9SAndroid Build Coastguard Worker    def test_nuclear_norm_out(self, device, dtype):
8182*da0073e9SAndroid Build Coastguard Worker        test_cases = [
8183*da0073e9SAndroid Build Coastguard Worker            # input size, dim
8184*da0073e9SAndroid Build Coastguard Worker            ((25, 25), None),
8185*da0073e9SAndroid Build Coastguard Worker            ((25, 25), (0, 1)),
8186*da0073e9SAndroid Build Coastguard Worker            ((25, 25), (1, 0)),
8187*da0073e9SAndroid Build Coastguard Worker            ((25, 25, 25), (2, 0)),
8188*da0073e9SAndroid Build Coastguard Worker            ((25, 25, 25), (0, 1)),
8189*da0073e9SAndroid Build Coastguard Worker        ]
8190*da0073e9SAndroid Build Coastguard Worker        for keepdim in [False, True]:
8191*da0073e9SAndroid Build Coastguard Worker            for input_size, dim in test_cases:
8192*da0073e9SAndroid Build Coastguard Worker                msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}'
8193*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*input_size, device=device, dtype=dtype)
8194*da0073e9SAndroid Build Coastguard Worker                result_out = torch.empty(0, device=device, dtype=dtype)
8195*da0073e9SAndroid Build Coastguard Worker                if dim is None:
8196*da0073e9SAndroid Build Coastguard Worker                    result = torch.nuclear_norm(x, keepdim=keepdim)
8197*da0073e9SAndroid Build Coastguard Worker                    torch.nuclear_norm(x, keepdim=keepdim, out=result_out)
8198*da0073e9SAndroid Build Coastguard Worker                else:
8199*da0073e9SAndroid Build Coastguard Worker                    result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim)
8200*da0073e9SAndroid Build Coastguard Worker                    torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out)
8201*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_out, msg=msg)
8202*da0073e9SAndroid Build Coastguard Worker
8203*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagmaAndNoCusolver
8204*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8205*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8206*da0073e9SAndroid Build Coastguard Worker    def test_geqrf(self, device, dtype):
8207*da0073e9SAndroid Build Coastguard Worker
8208*da0073e9SAndroid Build Coastguard Worker        def run_test(shape):
8209*da0073e9SAndroid Build Coastguard Worker            # numpy.linalg.qr with mode = 'raw' computes the same operation as torch.geqrf
8210*da0073e9SAndroid Build Coastguard Worker            # so this test compares against that function
8211*da0073e9SAndroid Build Coastguard Worker            A = make_tensor(shape, dtype=dtype, device=device)
8212*da0073e9SAndroid Build Coastguard Worker
8213*da0073e9SAndroid Build Coastguard Worker            # numpy.linalg.qr doesn't work with batched input
8214*da0073e9SAndroid Build Coastguard Worker            m, n = A.shape[-2:]
8215*da0073e9SAndroid Build Coastguard Worker            tau_size = "n" if m > n else "m"
8216*da0073e9SAndroid Build Coastguard Worker            np_dtype = A.cpu().numpy().dtype
8217*da0073e9SAndroid Build Coastguard Worker            ot = [np_dtype, np_dtype]
8218*da0073e9SAndroid Build Coastguard Worker            numpy_geqrf_batched = np.vectorize(
8219*da0073e9SAndroid Build Coastguard Worker                lambda x: np.linalg.qr(x, mode='raw'),
8220*da0073e9SAndroid Build Coastguard Worker                otypes=ot,
8221*da0073e9SAndroid Build Coastguard Worker                signature=f'(m,n)->(n,m),({tau_size})')
8222*da0073e9SAndroid Build Coastguard Worker
8223*da0073e9SAndroid Build Coastguard Worker            expected = numpy_geqrf_batched(A.cpu())
8224*da0073e9SAndroid Build Coastguard Worker            actual = torch.geqrf(A)
8225*da0073e9SAndroid Build Coastguard Worker
8226*da0073e9SAndroid Build Coastguard Worker            # numpy.linalg.qr returns transposed result
8227*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected[0].swapaxes(-2, -1), actual[0])
8228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected[1], actual[1])
8229*da0073e9SAndroid Build Coastguard Worker
8230*da0073e9SAndroid Build Coastguard Worker        batches = [(), (0, ), (2, ), (2, 1)]
8231*da0073e9SAndroid Build Coastguard Worker        ns = [5, 2, 0]
8232*da0073e9SAndroid Build Coastguard Worker        for batch, (m, n) in product(batches, product(ns, ns)):
8233*da0073e9SAndroid Build Coastguard Worker            run_test((*batch, m, n))
8234*da0073e9SAndroid Build Coastguard Worker
8235*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8236*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8237*da0073e9SAndroid Build Coastguard Worker    def test_lapack_empty(self, device):
8238*da0073e9SAndroid Build Coastguard Worker        # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.
8239*da0073e9SAndroid Build Coastguard Worker        # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although
8240*da0073e9SAndroid Build Coastguard Worker        # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing"
8241*da0073e9SAndroid Build Coastguard Worker        # (e.g. lu).  We often name our functions identically to the lapack function, so it will take work
8242*da0073e9SAndroid Build Coastguard Worker        # to name / migrate-to better wrappers.
8243*da0073e9SAndroid Build Coastguard Worker        def fn(torchfn, *args):
8244*da0073e9SAndroid Build Coastguard Worker            return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
8245*da0073e9SAndroid Build Coastguard Worker                                  for shape in args))
8246*da0073e9SAndroid Build Coastguard Worker
8247*da0073e9SAndroid Build Coastguard Worker        # inverse, pinverse
8248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape)
8249*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape)
8250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape)
8251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape)
8252*da0073e9SAndroid Build Coastguard Worker
8253*da0073e9SAndroid Build Coastguard Worker        # det, logdet, slogdet
8254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0)))
8255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0)))
8256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)),
8257*da0073e9SAndroid Build Coastguard Worker                         fn(torch.slogdet, (0, 0)))
8258*da0073e9SAndroid Build Coastguard Worker
8259*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
8260*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
8261*da0073e9SAndroid Build Coastguard Worker    def test_tensordot(self, device):
8262*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(60., device=device).reshape(3, 4, 5)
8263*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(24., device=device).reshape(4, 3, 2)
8264*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
8265*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8266*da0073e9SAndroid Build Coastguard Worker                                           axes=([1, 0], [0, 1])))
8267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
8268*da0073e9SAndroid Build Coastguard Worker
8269*da0073e9SAndroid Build Coastguard Worker        cout = torch.zeros((5, 2), device=device)
8270*da0073e9SAndroid Build Coastguard Worker        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
8271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cout)
8272*da0073e9SAndroid Build Coastguard Worker
8273*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, 5, device=device)
8274*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 5, 6, 7, device=device)
8275*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b, dims=2).cpu()
8276*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8277*da0073e9SAndroid Build Coastguard Worker                                           axes=2))
8278*da0073e9SAndroid Build Coastguard Worker
8279*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
8280*da0073e9SAndroid Build Coastguard Worker            torch.tensordot(a, b, dims=-1)
8281*da0073e9SAndroid Build Coastguard Worker
8282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
8283*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b).cpu()
8284*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
8285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
8286*da0073e9SAndroid Build Coastguard Worker
8287*da0073e9SAndroid Build Coastguard Worker        a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
8288*da0073e9SAndroid Build Coastguard Worker        an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
8289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, an)
8290*da0073e9SAndroid Build Coastguard Worker
8291*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
8292*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8293*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8294*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("flaky, needs investigation")
8295*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8296*da0073e9SAndroid Build Coastguard Worker    def test_ldl_factor(self, device, dtype):
8297*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8298*da0073e9SAndroid Build Coastguard Worker
8299*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, hermitian):
8300*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8301*da0073e9SAndroid Build Coastguard Worker            actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8302*da0073e9SAndroid Build Coastguard Worker            actual_L = torch.tril(actual_factors, diagonal=-1)
8303*da0073e9SAndroid Build Coastguard Worker            actual_L.diagonal(0, -2, -1).fill_(1.0)
8304*da0073e9SAndroid Build Coastguard Worker
8305*da0073e9SAndroid Build Coastguard Worker            # This test is designed only for inputs with 1x1 block diagonal matrix D.
8306*da0073e9SAndroid Build Coastguard Worker            # That is for positive definite input matrices, the pivots tensor is always > 0.
8307*da0073e9SAndroid Build Coastguard Worker            # If negative pivots are encountered, it means that the input matrix is not positive definite.
8308*da0073e9SAndroid Build Coastguard Worker            # And matrix D is a 2x2 block diagonal matrix.
8309*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((actual_pivots > 0).all())
8310*da0073e9SAndroid Build Coastguard Worker
8311*da0073e9SAndroid Build Coastguard Worker            # Construct a 1x1 block diagonal matrix D from factors.
8312*da0073e9SAndroid Build Coastguard Worker            actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1))
8313*da0073e9SAndroid Build Coastguard Worker
8314*da0073e9SAndroid Build Coastguard Worker            def T(x):
8315*da0073e9SAndroid Build Coastguard Worker                return x.mH if hermitian else x.mT
8316*da0073e9SAndroid Build Coastguard Worker            A_reconstructed = actual_L @ actual_D @ T(actual_L)
8317*da0073e9SAndroid Build Coastguard Worker
8318*da0073e9SAndroid Build Coastguard Worker            def symmetric(A):
8319*da0073e9SAndroid Build Coastguard Worker                return A.tril() + A.tril(-1).mT
8320*da0073e9SAndroid Build Coastguard Worker
8321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed)
8322*da0073e9SAndroid Build Coastguard Worker
8323*da0073e9SAndroid Build Coastguard Worker            # Now test against SciPy implementation
8324*da0073e9SAndroid Build Coastguard Worker            if TEST_SCIPY:
8325*da0073e9SAndroid Build Coastguard Worker                from scipy.linalg import ldl as scipy_ldl
8326*da0073e9SAndroid Build Coastguard Worker                A_np = A.cpu().numpy()
8327*da0073e9SAndroid Build Coastguard Worker                np_dtype = A_np.dtype
8328*da0073e9SAndroid Build Coastguard Worker                scipy_ldl_batched = np.vectorize(
8329*da0073e9SAndroid Build Coastguard Worker                    lambda x: scipy_ldl(x, hermitian=hermitian, lower=True),
8330*da0073e9SAndroid Build Coastguard Worker                    otypes=[np_dtype, np_dtype, np.dtype('int64')],
8331*da0073e9SAndroid Build Coastguard Worker                    signature='(m,m)->(m,m),(m,m),(m)')
8332*da0073e9SAndroid Build Coastguard Worker
8333*da0073e9SAndroid Build Coastguard Worker                expected = scipy_ldl_batched(A_np)
8334*da0073e9SAndroid Build Coastguard Worker                expected_L, expected_D, expected_pivots = expected
8335*da0073e9SAndroid Build Coastguard Worker
8336*da0073e9SAndroid Build Coastguard Worker                if expected_pivots.ndim > 1:
8337*da0073e9SAndroid Build Coastguard Worker                    permuted_expected_L = np.stack(
8338*da0073e9SAndroid Build Coastguard Worker                        [expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])]
8339*da0073e9SAndroid Build Coastguard Worker                    )
8340*da0073e9SAndroid Build Coastguard Worker                else:
8341*da0073e9SAndroid Build Coastguard Worker                    permuted_expected_L = expected_L[expected_pivots, :]
8342*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_L, permuted_expected_L)
8343*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_D, expected_D)
8344*da0073e9SAndroid Build Coastguard Worker            else:
8345*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_factors.shape, A.shape)
8346*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_pivots.shape, A.shape[:-1])
8347*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(info.shape, A.shape[:-2])
8348*da0073e9SAndroid Build Coastguard Worker
8349*da0073e9SAndroid Build Coastguard Worker        # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
8350*da0073e9SAndroid Build Coastguard Worker        magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4)
8351*da0073e9SAndroid Build Coastguard Worker        hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,)
8352*da0073e9SAndroid Build Coastguard Worker
8353*da0073e9SAndroid Build Coastguard Worker        shapes = (5,)
8354*da0073e9SAndroid Build Coastguard Worker        batches = ((), (4,),)
8355*da0073e9SAndroid Build Coastguard Worker        for shape, batch, hermitian in itertools.product(shapes, batches, hermitians):
8356*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, hermitian)
8357*da0073e9SAndroid Build Coastguard Worker
8358*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
8359*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8360*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoLapack
8361*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
8362*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1")
8363*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
8364*da0073e9SAndroid Build Coastguard Worker    def test_ldl_solve(self, device, dtype):
8365*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8366*da0073e9SAndroid Build Coastguard Worker
8367*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, batch, nrhs, hermitian):
8368*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8369*da0073e9SAndroid Build Coastguard Worker            B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device)
8370*da0073e9SAndroid Build Coastguard Worker            factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8371*da0073e9SAndroid Build Coastguard Worker            X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian)
8372*da0073e9SAndroid Build Coastguard Worker
8373*da0073e9SAndroid Build Coastguard Worker            def symmetric(A):
8374*da0073e9SAndroid Build Coastguard Worker                return A.tril() + A.tril(-1).mT
8375*da0073e9SAndroid Build Coastguard Worker
8376*da0073e9SAndroid Build Coastguard Worker            # verify A @ X == B
8377*da0073e9SAndroid Build Coastguard Worker            expected_B = symmetric(A) @ X if not hermitian else A @ X
8378*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(B, expected_B)
8379*da0073e9SAndroid Build Coastguard Worker
8380*da0073e9SAndroid Build Coastguard Worker        # hermitian=True is not supported on CUDA yet
8381*da0073e9SAndroid Build Coastguard Worker        hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,)
8382*da0073e9SAndroid Build Coastguard Worker
8383*da0073e9SAndroid Build Coastguard Worker        shapes = (5,)
8384*da0073e9SAndroid Build Coastguard Worker        batches = ((), (4,), (2, 2))
8385*da0073e9SAndroid Build Coastguard Worker        nrhss = (1, 7)
8386*da0073e9SAndroid Build Coastguard Worker        for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians):
8387*da0073e9SAndroid Build Coastguard Worker            run_test(shape, batch, nrhs, hermitian)
8388*da0073e9SAndroid Build Coastguard Worker
8389*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8390*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoMagma
8391*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNoCusolver
8392*da0073e9SAndroid Build Coastguard Worker    @setLinalgBackendsToDefaultFinally
8393*da0073e9SAndroid Build Coastguard Worker    def test_preferred_linalg_library(self):
8394*da0073e9SAndroid Build Coastguard Worker        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8395*da0073e9SAndroid Build Coastguard Worker        x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double)
8396*da0073e9SAndroid Build Coastguard Worker
8397*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.preferred_linalg_library('cusolver')
8398*da0073e9SAndroid Build Coastguard Worker        out1 = torch.linalg.inv(x)
8399*da0073e9SAndroid Build Coastguard Worker
8400*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.preferred_linalg_library('magma')
8401*da0073e9SAndroid Build Coastguard Worker        out2 = torch.linalg.inv(x)
8402*da0073e9SAndroid Build Coastguard Worker
8403*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.preferred_linalg_library('default')
8404*da0073e9SAndroid Build Coastguard Worker        # Although linalg preferred flags doesn't affect CPU currently,
8405*da0073e9SAndroid Build Coastguard Worker        # we set this to make sure the flag can switch back to default normally.
8406*da0073e9SAndroid Build Coastguard Worker        out_ref = torch.linalg.inv(x.cpu())
8407*da0073e9SAndroid Build Coastguard Worker
8408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out1.cpu())
8409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
8410*da0073e9SAndroid Build Coastguard Worker
8411*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8412*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
8413*da0073e9SAndroid Build Coastguard Worker    @setBlasBackendsToDefaultFinally
8414*da0073e9SAndroid Build Coastguard Worker    def test_preferred_blas_library(self):
8415*da0073e9SAndroid Build Coastguard Worker        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8416*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float)
8417*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float)
8418*da0073e9SAndroid Build Coastguard Worker
8419*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.preferred_blas_library('cublaslt')
8420*da0073e9SAndroid Build Coastguard Worker        out1 = torch.nn.functional.linear(m1, m2)
8421*da0073e9SAndroid Build Coastguard Worker
8422*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.preferred_blas_library('cublas')
8423*da0073e9SAndroid Build Coastguard Worker        out2 = torch.nn.functional.linear(m1, m2)
8424*da0073e9SAndroid Build Coastguard Worker
8425*da0073e9SAndroid Build Coastguard Worker        # Although blas preferred flags doesn't affect CPU currently,
8426*da0073e9SAndroid Build Coastguard Worker        # we set this to make sure the flag can switch back to default normally.
8427*da0073e9SAndroid Build Coastguard Worker        out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu())
8428*da0073e9SAndroid Build Coastguard Worker
8429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
8430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out2.cpu())
8431*da0073e9SAndroid Build Coastguard Worker
8432*da0073e9SAndroid Build Coastguard Worker    def test_permute_matmul(self):
8433*da0073e9SAndroid Build Coastguard Worker        a = torch.ones([2, 5, 24, 24])
8434*da0073e9SAndroid Build Coastguard Worker        b = torch.ones([3, 2, 5, 24, 24])
8435*da0073e9SAndroid Build Coastguard Worker        c = a.permute(0, 1, 3, 2).matmul(b)
8436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
8437*da0073e9SAndroid Build Coastguard Worker
8438*da0073e9SAndroid Build Coastguard Worker    def test_lower_precision_accumulation_with_ref_path(self):
8439*da0073e9SAndroid Build Coastguard Worker        # fix https://github.com/pytorch/pytorch/issues/95125
8440*da0073e9SAndroid Build Coastguard Worker        # and https://github.com/pytorch/pytorch/issues/83863
8441*da0073e9SAndroid Build Coastguard Worker        # for bf16 accumulation in gemm ref path
8442*da0073e9SAndroid Build Coastguard Worker        def check_correctness(fn, dtype, *args):
8443*da0073e9SAndroid Build Coastguard Worker            expected = fn(*args).to(dtype=dtype)
8444*da0073e9SAndroid Build Coastguard Worker            with torch.backends.mkldnn.flags(enabled=False):
8445*da0073e9SAndroid Build Coastguard Worker                def test():
8446*da0073e9SAndroid Build Coastguard Worker                    lower_args = (arg.to(dtype=dtype) for arg in args)
8447*da0073e9SAndroid Build Coastguard Worker                    tmp_result = fn(*lower_args)
8448*da0073e9SAndroid Build Coastguard Worker                    return tmp_result
8449*da0073e9SAndroid Build Coastguard Worker                c = test()
8450*da0073e9SAndroid Build Coastguard Worker                assert (torch.all(c == expected)), "Incorrect result with\n" \
8451*da0073e9SAndroid Build Coastguard Worker                                                   f"expected: {expected}\n" \
8452*da0073e9SAndroid Build Coastguard Worker                                                   f"got: {c}\n"
8453*da0073e9SAndroid Build Coastguard Worker        # test matmul
8454*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.half]:
8455*da0073e9SAndroid Build Coastguard Worker            for transa in [True, False]:
8456*da0073e9SAndroid Build Coastguard Worker                for transb in [True, False]:
8457*da0073e9SAndroid Build Coastguard Worker                    a = torch.ones(300, 300)
8458*da0073e9SAndroid Build Coastguard Worker                    b = torch.ones(300, 300)
8459*da0073e9SAndroid Build Coastguard Worker                    if transa:
8460*da0073e9SAndroid Build Coastguard Worker                        a = a.transpose(0, 1).contiguous().transpose(0, 1)
8461*da0073e9SAndroid Build Coastguard Worker                    if transb:
8462*da0073e9SAndroid Build Coastguard Worker                        b = b.transpose(0, 1).contiguous().transpose(0, 1)
8463*da0073e9SAndroid Build Coastguard Worker                    check_correctness(torch.matmul, dtype, a, b)
8464*da0073e9SAndroid Build Coastguard Worker        # test bmm
8465*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 300)
8466*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1, 300, 1)
8467*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.bmm, torch.bfloat16, a, b)
8468*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.bmm, torch.half, a, b)
8469*da0073e9SAndroid Build Coastguard Worker        # test baddbmm
8470*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 300)
8471*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1, 300, 1)
8472*da0073e9SAndroid Build Coastguard Worker        c = torch.ones(1, 1, 1)
8473*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.baddbmm, torch.bfloat16, c, a, b)
8474*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.baddbmm, torch.half, c, a, b)
8475*da0073e9SAndroid Build Coastguard Worker        # test mv/addmv
8476*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.half]:
8477*da0073e9SAndroid Build Coastguard Worker            for trans in [True, False]:
8478*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(300) * -300
8479*da0073e9SAndroid Build Coastguard Worker                a = torch.ones(300, 300)
8480*da0073e9SAndroid Build Coastguard Worker                if trans:
8481*da0073e9SAndroid Build Coastguard Worker                    a = a.transpose(0, 1).contiguous().transpose(0, 1)
8482*da0073e9SAndroid Build Coastguard Worker                b = torch.ones(300)
8483*da0073e9SAndroid Build Coastguard Worker                check_correctness(torch.mv, dtype, a, b)
8484*da0073e9SAndroid Build Coastguard Worker                check_correctness(torch.addmv, dtype, c, a, b)
8485*da0073e9SAndroid Build Coastguard Worker        # test dot
8486*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(300)
8487*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(300)
8488*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.dot, torch.bfloat16, a, b)
8489*da0073e9SAndroid Build Coastguard Worker        check_correctness(torch.dot, torch.half, a, b)
8490*da0073e9SAndroid Build Coastguard Worker
8491*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.half, torch.bfloat16)
8492*da0073e9SAndroid Build Coastguard Worker    @parametrize("transpose_a", [True, False])
8493*da0073e9SAndroid Build Coastguard Worker    @parametrize("transpose_b", [True, False])
8494*da0073e9SAndroid Build Coastguard Worker    @parametrize("alpha", [0.0, 0.2, 1.0])
8495*da0073e9SAndroid Build Coastguard Worker    @parametrize("beta", [0.0, 0.5, 1.0])
8496*da0073e9SAndroid Build Coastguard Worker    def test_addmm_mv(self, device, dtype, transpose_a, transpose_b, alpha, beta):
8497*da0073e9SAndroid Build Coastguard Worker        def gen_mat(w, h, use_transpose: bool = False):
8498*da0073e9SAndroid Build Coastguard Worker            if not use_transpose:
8499*da0073e9SAndroid Build Coastguard Worker                return torch.rand(w, h, dtype=dtype, device=device)
8500*da0073e9SAndroid Build Coastguard Worker            return torch.rand(h, w, dtype=dtype, device=device).t()
8501*da0073e9SAndroid Build Coastguard Worker        # Regression tests for https://github.com/pytorch/pytorch/issues/136299
8502*da0073e9SAndroid Build Coastguard Worker        # Should only expose problems on aarch64, but let's be thorough
8503*da0073e9SAndroid Build Coastguard Worker        m, n , k = 1, 8, 32
8504*da0073e9SAndroid Build Coastguard Worker        A = gen_mat(m, k, transpose_a)
8505*da0073e9SAndroid Build Coastguard Worker        B = gen_mat(k, n, transpose_b)
8506*da0073e9SAndroid Build Coastguard Worker        C = torch.ones(m, n, dtype=dtype, device=device)
8507*da0073e9SAndroid Build Coastguard Worker        rc = torch.addmm(C, A, B, alpha=alpha, beta=beta)
8508*da0073e9SAndroid Build Coastguard Worker        ref = alpha * A @ B + beta * C
8509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rc, ref)
8510*da0073e9SAndroid Build Coastguard Worker
8511*da0073e9SAndroid Build Coastguard Worker
8512*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
8513*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4})
8514*da0073e9SAndroid Build Coastguard Worker    def test_1_sized_with_0_strided(self, device, dtype):
8515*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
8516*da0073e9SAndroid Build Coastguard Worker        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
8517*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
8518*da0073e9SAndroid Build Coastguard Worker        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
8519*da0073e9SAndroid Build Coastguard Worker        res = torch.bmm(a_strided, b_strided)
8520*da0073e9SAndroid Build Coastguard Worker        expect = torch.from_numpy(
8521*da0073e9SAndroid Build Coastguard Worker            a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype)
8522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, res)
8523*da0073e9SAndroid Build Coastguard Worker
8524*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestLinalg, globals())
8525*da0073e9SAndroid Build Coastguard Worker
8526*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
8527*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
8528*da0073e9SAndroid Build Coastguard Worker    run_tests()
8529